1use std::any::TypeId;
2use std::cell::{RefCell, RefMut};
3use std::env;
4use std::fs::File;
5use std::path::PathBuf;
6
7use csv::Writer;
8use serde::Serializer;
9
10use crate::context::Context;
11use crate::error::IxaError;
12use crate::people::ContextPeopleExt;
13use crate::{define_data_plugin, error, trace, ContextBase, HashMap, HashMapExt, Tabulator};
14
15pub struct ConfigReportOptions {
21 pub file_prefix: String,
22 pub output_dir: PathBuf,
23 pub overwrite: bool,
24}
25
26impl ConfigReportOptions {
27 #[must_use]
28 #[allow(clippy::missing_panics_doc)]
29 pub fn new() -> Self {
30 trace!("new ConfigReportOptions");
31 ConfigReportOptions {
33 file_prefix: String::new(),
34 output_dir: env::current_dir().unwrap(),
35 overwrite: false,
36 }
37 }
38 pub fn file_prefix(&mut self, file_prefix: impl Into<String>) -> &mut ConfigReportOptions {
40 let file_prefix = file_prefix.into();
41 trace!("setting report prefix to {file_prefix}");
42 self.file_prefix = file_prefix;
43 self
44 }
45 pub fn directory(&mut self, directory: impl Into<PathBuf>) -> &mut ConfigReportOptions {
47 let directory = directory.into();
48 trace!("setting report directory to {directory:?}");
49 self.output_dir = directory;
50 self
51 }
52 pub fn overwrite(&mut self, overwrite: bool) -> &mut ConfigReportOptions {
54 trace!("setting report overwrite {overwrite}");
55 self.overwrite = overwrite;
56 self
57 }
58}
59
60impl Default for ConfigReportOptions {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66pub trait Report: 'static {
67 fn type_id(&self) -> TypeId;
69 fn serialize(&self, writer: &mut Writer<File>);
71}
72
73#[allow(clippy::trivially_copy_pass_by_ref)]
76#[allow(dead_code)]
77pub fn serialize_f64<S, const N: usize>(value: &f64, serializer: S) -> Result<S::Ok, S::Error>
78where
79 S: Serializer,
80{
81 let formatted = format!("{value:.N$}");
82 serializer.serialize_str(&formatted)
83}
84
85#[allow(clippy::trivially_copy_pass_by_ref)]
88#[allow(dead_code)]
89pub fn serialize_f32<S, const N: usize>(value: &f32, serializer: S) -> Result<S::Ok, S::Error>
90where
91 S: Serializer,
92{
93 let formatted = format!("{value:.N$}");
94 serializer.serialize_str(&formatted)
95}
96
97struct ReportData {
98 file_writers: RefCell<HashMap<TypeId, Writer<File>>>,
99 config: ConfigReportOptions,
100}
101
102define_data_plugin!(
106 ReportPlugin,
107 ReportData,
108 ReportData {
109 file_writers: RefCell::new(HashMap::new()),
110 config: ConfigReportOptions::new(),
111 }
112);
113
114pub trait ContextReportExt: ContextBase {
115 fn generate_filename(&mut self, short_name: &str) -> PathBuf {
119 let data_container = self.get_data_mut(ReportPlugin);
120 let prefix = &data_container.config.file_prefix;
121 let directory = &data_container.config.output_dir;
122 let short_name = short_name.to_string();
123 let basename = format!("{prefix}{short_name}");
124 directory.join(basename).with_extension("csv")
125 }
126
127 fn add_report_by_type_id(&mut self, type_id: TypeId, short_name: &str) -> Result<(), IxaError> {
134 trace!("adding report {short_name} by type_id {type_id:?}");
135 let path = self.generate_filename(short_name);
136
137 let data_container = self.get_data_mut(ReportPlugin);
138
139 let file_creation_result = File::create_new(&path);
140 let created_file = match file_creation_result {
141 Ok(file) => file,
142 Err(e) => match e.kind() {
143 std::io::ErrorKind::AlreadyExists => {
144 if data_container.config.overwrite {
145 File::create(&path)?
146 } else {
147 error!("File already exists: {}. Please set `overwrite` to true in the file configuration and rerun.", path.display());
148 return Err(IxaError::IoError(e));
149 }
150 }
151 _ => {
152 return Err(IxaError::IoError(e));
153 }
154 },
155 };
156 let writer = Writer::from_writer(created_file);
157 let mut file_writer = data_container.file_writers.borrow_mut();
158 file_writer.insert(type_id, writer);
159 Ok(())
160 }
161
162 fn add_report<T: Report + 'static>(&mut self, short_name: &str) -> Result<(), IxaError> {
169 trace!("Adding report {short_name}");
170 self.add_report_by_type_id(TypeId::of::<T>(), short_name)
171 }
172
173 fn add_periodic_report<T: Tabulator + Clone + 'static>(
179 &mut self,
180 short_name: &str,
181 period: f64,
182 tabulator: T,
183 ) -> Result<(), IxaError> {
184 trace!("Adding periodic report {short_name}");
185
186 self.add_report_by_type_id(TypeId::of::<T>(), short_name)?;
187
188 {
189 let mut writer = self.get_writer(TypeId::of::<T>());
191 let columns = tabulator.get_columns();
192 let mut header = vec!["t".to_string()];
193 header.extend(columns);
194 header.push("count".to_string());
195 writer
196 .write_record(&header)
197 .expect("Failed to write header");
198 }
199
200 self.add_periodic_plan_with_phase(
201 period,
202 move |context: &mut Context| {
203 context.tabulate_person_properties(&tabulator, move |context, values, count| {
204 let mut writer = context.get_writer(TypeId::of::<T>());
205 let mut row = vec![context.get_current_time().to_string()];
206 row.extend(values.to_owned());
207 row.push(count.to_string());
208
209 writer.write_record(&row).expect("Failed to write row");
210 });
211 },
212 crate::context::ExecutionPhase::Last,
213 );
214
215 Ok(())
216 }
217
218 fn get_writer(&self, type_id: TypeId) -> RefMut<Writer<File>> {
219 let data_container = self.get_data(ReportPlugin);
221 let writers = data_container.file_writers.try_borrow_mut().unwrap();
222 RefMut::map(writers, |writers| {
223 writers
224 .get_mut(&type_id)
225 .expect("No writer found for the report type")
226 })
227 }
228
229 fn send_report<T: Report>(&self, report: T) {
231 let writer = &mut self.get_writer(report.type_id());
232 report.serialize(writer);
233 }
234
235 fn report_options(&mut self) -> &mut ConfigReportOptions {
237 let data_container = self.get_data_mut(ReportPlugin);
238 &mut data_container.config
239 }
240}
241impl ContextReportExt for Context {}
242
243#[cfg(test)]
244mod test {
245 use core::convert::TryInto;
246 use std::thread;
247
248 use serde_derive::{Deserialize, Serialize};
249 use tempfile::tempdir;
250
251 use super::*;
252 use crate::{define_person_property_with_default, define_report, info};
253
254 define_person_property_with_default!(IsRunner, bool, false);
255
256 #[derive(Serialize, Deserialize)]
257 struct SampleReport {
258 id: u32,
259 value: String,
260 }
261
262 define_report!(SampleReport);
263
264 #[test]
265 fn add_and_send_report() {
266 let temp_dir = tempdir().unwrap();
267 let path = PathBuf::from(&temp_dir.path());
268 {
270 let mut context = Context::new();
271 let config = context.report_options();
272 config
273 .file_prefix("prefix1_".to_string())
274 .directory(path.clone());
275 context.add_report::<SampleReport>("sample_report").unwrap();
276 let report = SampleReport {
277 id: 1,
278 value: "Test Value".to_string(),
279 };
280
281 context.send_report(report);
282 }
283
284 let file_path = path.join("prefix1_sample_report.csv");
285 assert!(file_path.exists(), "CSV file should exist");
286 assert!(file_path.metadata().unwrap().len() > 0);
287
288 let mut reader = csv::Reader::from_path(file_path).unwrap();
289 for result in reader.deserialize() {
290 let record: SampleReport = result.unwrap();
291 assert_eq!(record.id, 1);
292 assert_eq!(record.value, "Test Value");
293 }
294 }
295
296 #[test]
297 fn add_report_empty_prefix() {
298 let temp_dir = tempdir().unwrap();
299 let path = PathBuf::from(&temp_dir.path());
300 {
302 let mut context = Context::new();
303 let config = context.report_options();
304 config.directory(path.clone());
305 context.add_report::<SampleReport>("sample_report").unwrap();
306 let report = SampleReport {
307 id: 1,
308 value: "Test Value".to_string(),
309 };
310
311 context.send_report(report);
312 }
313 let file_path = path.join("sample_report.csv");
314 assert!(file_path.exists(), "CSV file should exist");
315 assert!(file_path.metadata().unwrap().len() > 0);
316
317 let mut reader = csv::Reader::from_path(file_path).unwrap();
318 for result in reader.deserialize() {
319 let record: SampleReport = result.unwrap();
320 assert_eq!(record.id, 1);
321 assert_eq!(record.value, "Test Value");
322 }
323 }
324
325 struct PathBufWithDrop {
326 file: PathBuf,
327 }
328
329 impl Drop for PathBufWithDrop {
330 fn drop(&mut self) {
331 std::fs::remove_file(&self.file).unwrap();
332 }
333 }
334
335 #[test]
336 fn add_report_no_dir() {
337 {
339 let mut context = Context::new();
340 let config = context.report_options();
341 config.file_prefix("test_prefix_".to_string());
342 context.add_report::<SampleReport>("sample_report").unwrap();
343 let report = SampleReport {
344 id: 1,
345 value: "Test Value".to_string(),
346 };
347
348 context.send_report(report);
349 }
350
351 let path = env::current_dir().unwrap();
352 let file_path = PathBufWithDrop {
353 file: path.join("test_prefix_sample_report.csv"),
354 };
355 assert!(file_path.file.exists(), "CSV file should exist");
356 assert!(file_path.file.metadata().unwrap().len() > 0);
357
358 let mut reader = csv::Reader::from_path(&file_path.file).unwrap();
359 for result in reader.deserialize() {
360 let record: SampleReport = result.unwrap();
361 assert_eq!(record.id, 1);
362 assert_eq!(record.value, "Test Value");
363 }
364 }
365
366 #[test]
367 #[should_panic(expected = "No writer found for the report type")]
368 fn send_report_without_adding_report() {
369 let context = Context::new();
370 let report = SampleReport {
371 id: 1,
372 value: "Test Value".to_string(),
373 };
374
375 context.send_report(report);
376 }
377
378 #[test]
379 fn multiple_reports_one_context() {
380 let temp_dir = tempdir().unwrap();
381 let path = PathBuf::from(&temp_dir.path());
382 {
384 let mut context = Context::new();
385 let config = context.report_options();
386 config
387 .file_prefix("mult_report_".to_string())
388 .directory(path.clone());
389 context.add_report::<SampleReport>("sample_report").unwrap();
390 let report1 = SampleReport {
391 id: 1,
392 value: "Value,1".to_string(),
393 };
394 let report2 = SampleReport {
395 id: 2,
396 value: "Value\n2".to_string(),
397 };
398
399 context.send_report(report1);
400 context.send_report(report2);
401 }
402
403 let file_path = path.join("mult_report_sample_report.csv");
404 assert!(file_path.exists(), "CSV file should exist");
405
406 let mut reader = csv::Reader::from_path(file_path).expect("Failed to open CSV file");
407 let mut records = reader.deserialize::<SampleReport>();
408
409 let item1: SampleReport = records
410 .next()
411 .expect("No record found")
412 .expect("Failed to deserialize record");
413 assert_eq!(item1.id, 1);
414 assert_eq!(item1.value, "Value,1");
415
416 let item2: SampleReport = records
417 .next()
418 .expect("No second record found")
419 .expect("Failed to deserialize record");
420 assert_eq!(item2.id, 2);
421 assert_eq!(item2.value, "Value\n2");
422 }
423
424 #[test]
425 fn multithreaded_report_generation_thread_local() {
426 let num_threads = 10;
427 let num_reports_per_thread = 5;
428
429 let mut handles = vec![];
430 let temp_dir = tempdir().unwrap();
431 let base_path = temp_dir.path().to_path_buf();
432
433 for i in 0..num_threads {
434 let path = base_path.clone();
435 let handle = thread::spawn(move || {
436 let mut context = Context::new();
437 let config = context.report_options();
438 config.file_prefix(i.to_string()).directory(path);
439 context.add_report::<SampleReport>("sample_report").unwrap();
440
441 for j in 0..num_reports_per_thread {
442 let report = SampleReport {
443 id: u32::try_from(i * num_reports_per_thread + j).unwrap(),
444 value: format!("Thread {i} Report {j}"),
445 };
446 context.send_report(report);
447 }
448 });
449
450 handles.push(handle);
451 }
452
453 for handle in handles {
454 handle.join().expect("Thread failed");
455 }
456
457 for i in 0..num_threads {
458 let file_name = format!("{i}sample_report.csv");
459 let file_path = base_path.join(file_name);
460 assert!(file_path.exists(), "CSV file should exist");
461
462 let mut reader = csv::Reader::from_path(file_path).expect("Failed to open CSV file");
463 let records = reader.deserialize::<SampleReport>();
464
465 for (j, record) in records.enumerate() {
466 let record: SampleReport = record.expect("Failed to deserialize record");
467 let id_expected = TryInto::<u32>::try_into(i * num_reports_per_thread + j).unwrap();
468 assert_eq!(record.id, id_expected);
469 }
470 }
471 }
472
473 #[test]
474 fn dont_overwrite_report() {
475 let mut context1 = Context::new();
476 let temp_dir = tempdir().unwrap();
477 let path = PathBuf::from(&temp_dir.path());
478 let config = context1.report_options();
479 config
480 .file_prefix("prefix1_".to_string())
481 .directory(path.clone());
482 context1
483 .add_report::<SampleReport>("sample_report")
484 .unwrap();
485 let report = SampleReport {
486 id: 1,
487 value: "Test Value".to_string(),
488 };
489
490 context1.send_report(report);
491
492 let file_path = path.join("prefix1_sample_report.csv");
493 assert!(file_path.exists(), "CSV file should exist");
494
495 let mut context2 = Context::new();
496 let config = context2.report_options();
497 config.file_prefix("prefix1_".to_string()).directory(path);
498 info!("The next 'file already exists' error is intended for a passing test.");
499 let result = context2.add_report::<SampleReport>("sample_report");
500 assert!(result.is_err());
501 let error = result.err().unwrap();
502 match error {
503 IxaError::IoError(e) => {
504 assert_eq!(e.kind(), std::io::ErrorKind::AlreadyExists);
505 }
506 _ => {
507 panic!("Unexpected error type");
508 }
509 }
510 }
511
512 #[test]
513 fn overwrite_report() {
514 let mut context1 = Context::new();
515 let temp_dir = tempdir().unwrap();
516 let path = PathBuf::from(&temp_dir.path());
517 let config = context1.report_options();
518 config
519 .file_prefix("prefix1_".to_string())
520 .directory(path.clone());
521 context1
522 .add_report::<SampleReport>("sample_report")
523 .unwrap();
524 let report = SampleReport {
525 id: 1,
526 value: "Test Value".to_string(),
527 };
528
529 context1.send_report(report);
530
531 let file_path = path.join("prefix1_sample_report.csv");
532 assert!(file_path.exists(), "CSV file should exist");
533
534 let mut context2 = Context::new();
535 let config = context2.report_options();
536 config
537 .file_prefix("prefix1_".to_string())
538 .directory(path)
539 .overwrite(true);
540 let result = context2.add_report::<SampleReport>("sample_report");
541 assert!(result.is_ok());
542 let file = File::open(file_path).unwrap();
543 let reader = csv::Reader::from_reader(file);
544 let records = reader.into_records();
545 assert_eq!(records.count(), 0);
546 }
547
548 #[derive(PartialEq, Copy, Clone, Debug, Serialize, Deserialize)]
549 pub enum SymptomValue {
550 Presymptomatic,
551 Category1,
552 Category2,
553 Category3,
554 Category4,
555 }
556
557 define_person_property_with_default!(Symptoms, Option<SymptomValue>, None);
558
559 #[test]
560 fn add_periodic_report() {
561 let temp_dir = tempdir().unwrap();
562 let path = PathBuf::from(&temp_dir.path());
563
564 {
566 let mut context = Context::new();
567 let config = context.report_options();
568 config
569 .file_prefix("test_".to_string())
570 .directory(path.clone());
571 let _ = context.add_periodic_report("periodic", 1.2, (IsRunner, Symptoms));
572 let person = context.add_person(()).unwrap();
573 context.add_person(()).unwrap();
574
575 context.add_plan(1.2, move |context: &mut Context| {
576 context.set_person_property(person, IsRunner, true);
577 context.set_person_property(person, Symptoms, Some(SymptomValue::Category1));
578 });
579 context.execute();
580 }
581 let file_path = path.join("test_periodic.csv");
582 assert!(file_path.exists(), "CSV file should exist");
583
584 let mut reader = csv::Reader::from_path(file_path).unwrap();
585
586 assert_eq!(
587 reader.headers().unwrap(),
588 vec!["t", "IsRunner", "Symptoms", "count"]
589 );
590
591 let mut actual: Vec<Vec<String>> = reader
592 .records()
593 .map(|result| result.unwrap().iter().map(String::from).collect())
594 .collect();
595 let mut expected = vec![
596 vec!["0", "false", "None", "2"],
597 vec!["1.2", "false", "Category1", "0"],
598 vec!["1.2", "false", "None", "1"],
599 vec!["1.2", "true", "Category1", "1"],
600 vec!["1.2", "true", "None", "0"],
601 ];
602
603 actual.sort();
604 expected.sort();
605
606 assert_eq!(actual, expected, "CSV file should contain the correct data");
607 }
608}