ixa/
runner.rs

1use std::path::PathBuf;
2use std::str::FromStr;
3
4use crate::context::Context;
5#[cfg(feature = "debugger")]
6use crate::debugger::enter_debugger;
7use crate::error::IxaError;
8use crate::global_properties::ContextGlobalPropertiesExt;
9#[cfg(feature = "progress_bar")]
10use crate::progress::init_timeline_progress_bar;
11use crate::random::ContextRandomExt;
12use crate::report::ContextReportExt;
13#[cfg(feature = "web_api")]
14use crate::web_api::ContextWebApiExt;
15use crate::{info, set_log_level, set_module_filters, warn, LevelFilter};
16use clap::{Args, Command, FromArgMatches as _};
17
18/// Custom parser for log levels
19fn parse_log_levels(s: &str) -> Result<Vec<(String, LevelFilter)>, String> {
20    s.split(',')
21        .map(|pair| {
22            let mut iter = pair.split('=');
23            let key = iter
24                .next()
25                .ok_or_else(|| format!("Invalid key in pair: {pair}"))?;
26            let value = iter
27                .next()
28                .ok_or_else(|| format!("Invalid value in pair: {pair}"))?;
29            let level =
30                LevelFilter::from_str(value).map_err(|_| format!("Invalid log level: {value}"))?;
31            Ok((key.to_string(), level))
32        })
33        .collect()
34}
35
36/// Default cli arguments for ixa runner
37#[derive(Args, Debug)]
38pub struct BaseArgs {
39    /// Random seed
40    #[arg(short, long, default_value = "0")]
41    pub random_seed: u64,
42
43    /// Optional path for a global properties config file
44    #[arg(short, long)]
45    pub config: Option<PathBuf>,
46
47    /// Optional path for report output
48    #[arg(short, long = "output")]
49    pub output_dir: Option<PathBuf>,
50
51    /// Optional prefix for report files
52    #[arg(long = "prefix")]
53    pub file_prefix: Option<String>,
54
55    /// Overwrite existing report files?
56    #[arg(short, long)]
57    pub force_overwrite: bool,
58
59    /// Enable logging
60    #[arg(short, long)]
61    pub log_level: Option<String>,
62
63    /// Set a breakpoint at a given time and start the debugger. Defaults to t=0.0
64    #[arg(short, long)]
65    pub debugger: Option<Option<f64>>,
66
67    /// Enable the Web API at a given time. Defaults to t=0.0
68    #[arg(short, long)]
69    pub web: Option<Option<u16>>,
70
71    /// Enable the timeline progress bar with a maximum time.
72    #[arg(short, long)]
73    pub timeline_progress_max: Option<f64>,
74
75    /// Suppresses the printout of summary statistics at the end of the simulation.
76    #[arg(long)]
77    pub no_stats: bool,
78}
79
80impl BaseArgs {
81    fn new() -> Self {
82        BaseArgs {
83            random_seed: 0,
84            config: None,
85            output_dir: None,
86            file_prefix: None,
87            force_overwrite: false,
88            log_level: None,
89            debugger: None,
90            web: None,
91            timeline_progress_max: None,
92            no_stats: false,
93        }
94    }
95}
96
97impl Default for BaseArgs {
98    fn default() -> Self {
99        BaseArgs::new()
100    }
101}
102
103#[derive(Args)]
104pub struct PlaceholderCustom {}
105
106fn create_ixa_cli() -> Command {
107    let cli = Command::new("ixa");
108    BaseArgs::augment_args(cli)
109}
110
111/// Runs a simulation with custom cli arguments.
112///
113/// This function allows you to define custom arguments and a setup function
114///
115/// # Parameters
116/// - `setup_fn`: A function that takes a mutable reference to a `Context`, a `BaseArgs` struct,
117///   a `Option<A>` where `A` is the custom cli arguments struct
118///
119/// # Errors
120/// Returns an error if argument parsing or the setup function fails
121#[allow(clippy::missing_errors_doc)]
122pub fn run_with_custom_args<A, F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
123where
124    A: Args,
125    F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
126{
127    let mut cli = create_ixa_cli();
128    cli = A::augment_args(cli);
129    let matches = cli.get_matches();
130
131    let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
132    let custom_matches = A::from_arg_matches(&matches)?;
133    run_with_args_internal(base_args_matches, Some(custom_matches), setup_fn)
134}
135
136/// Runs a simulation with default cli arguments
137///
138/// This function parses command line arguments allows you to define a setup function
139///
140/// # Parameters
141/// - `setup_fn`: A function that takes a mutable reference to a `Context` and `BaseArgs` struct
142///
143/// # Errors
144/// Returns an error if argument parsing or the setup function fails
145#[allow(clippy::missing_errors_doc)]
146pub fn run_with_args<F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
147where
148    F: Fn(&mut Context, BaseArgs, Option<PlaceholderCustom>) -> Result<(), IxaError>,
149{
150    let cli = create_ixa_cli();
151    let matches = cli.get_matches();
152
153    let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
154    run_with_args_internal(base_args_matches, None, setup_fn)
155}
156
157fn run_with_args_internal<A, F>(
158    args: BaseArgs,
159    custom_args: Option<A>,
160    setup_fn: F,
161) -> Result<Context, Box<dyn std::error::Error>>
162where
163    F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
164{
165    // Instantiate a context
166    let mut context = Context::new();
167
168    // Optionally set global properties from a file
169    if args.config.is_some() {
170        let config_path = args.config.clone().unwrap();
171        println!("Loading global properties from: {config_path:?}");
172        context.load_global_properties(&config_path)?;
173    }
174
175    // Configure report options
176    let report_config = context.report_options();
177    if args.output_dir.is_some() {
178        report_config.directory(args.output_dir.clone().unwrap());
179    }
180    if args.file_prefix.is_some() {
181        report_config.file_prefix(args.file_prefix.clone().unwrap());
182    }
183    if args.force_overwrite {
184        report_config.overwrite(true);
185    }
186    if let Some(log_level) = args.log_level.as_ref() {
187        if let Ok(level) = LevelFilter::from_str(log_level) {
188            set_log_level(level);
189            info!("Logging enabled at level {level}");
190        } else if let Ok(log_levels) = parse_log_levels(log_level) {
191            let log_levels_slice: Vec<(&String, LevelFilter)> =
192                log_levels.iter().map(|(k, v)| (k, *v)).collect();
193            set_module_filters(log_levels_slice.as_slice());
194            for (key, value) in log_levels {
195                println!("Logging enabled for {key} at level {value}");
196                // Here you can set the log level for each key-value pair as needed
197            }
198        } else {
199            return Err(format!("Invalid log level format: {log_level}").into());
200        }
201    } else {
202        info!("Logging disabled.");
203    }
204
205    context.init_random(args.random_seed);
206
207    // If a breakpoint is provided, stop at that time
208    #[cfg(feature = "debugger")]
209    if let Some(t) = args.debugger {
210        assert!(
211            args.web.is_none(),
212            "Cannot run with both the debugger and the Web API"
213        );
214        match t {
215            None => {
216                context.request_debugger();
217            }
218            Some(time) => {
219                context.schedule_debugger(time, None, Box::new(enter_debugger));
220            }
221        }
222    }
223    #[cfg(not(feature = "debugger"))]
224    if args.debugger.is_some() {
225        warn!("Ixa was not compiled with the debugger feature, but a debugger option was provided");
226    }
227
228    // If the Web API is provided, stop there.
229    #[cfg(feature = "web_api")]
230    if let Some(t) = args.web {
231        let port = t.unwrap_or(33334);
232        let url = context.setup_web_api(port).unwrap();
233        println!("Web API active on {url}");
234        context.schedule_web_api(0.0);
235    }
236    #[cfg(not(feature = "web_api"))]
237    if args.web.is_some() {
238        warn!("Ixa was not compiled with the web_api feature, but a web_api option was provided");
239    }
240
241    if let Some(max_time) = args.timeline_progress_max {
242        // We allow a `max_time` of `0.0` to mean "disable timeline progress bar".
243        if cfg!(not(feature = "progress_bar")) && max_time > 0.0 {
244            warn!("Ixa was not compiled with the progress_bar feature, but a progress_bar option was provided");
245        } else if max_time < 0.0 {
246            warn!("timeline progress maximum must be nonnegative");
247        }
248        #[cfg(feature = "progress_bar")]
249        if max_time > 0.0 {
250            println!("ProgressBar max set to {}", max_time);
251            init_timeline_progress_bar(max_time);
252        }
253    }
254
255    if args.no_stats {
256        context.print_execution_statistics = false;
257    } else {
258        if cfg!(target_family = "wasm") {
259            warn!("the print-stats option is enabled; some statistics are not supported for the wasm target family");
260        }
261        context.print_execution_statistics = true;
262    }
263
264    // Run the provided Fn
265    setup_fn(&mut context, args, custom_args)?;
266
267    // Execute the context
268    context.execute();
269    Ok(context)
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use crate::{define_global_property, define_rng};
276    use serde::{Deserialize, Serialize};
277
278    #[derive(Args, Debug)]
279    struct CustomArgs {
280        #[arg(short, long, default_value = "0")]
281        a: u32,
282    }
283
284    #[test]
285    fn test_run_with_custom_args() {
286        let result = run_with_custom_args(|_, _, _: Option<CustomArgs>| Ok(()));
287        assert!(result.is_ok());
288    }
289
290    #[test]
291    fn test_run_with_args() {
292        let result = run_with_args(|_, _, _| Ok(()));
293        assert!(result.is_ok());
294    }
295
296    #[test]
297    fn test_run_with_random_seed() {
298        let test_args = BaseArgs {
299            random_seed: 42,
300            ..Default::default()
301        };
302
303        // Use a comparison context to verify the random seed was set
304        let mut compare_ctx = Context::new();
305        compare_ctx.init_random(42);
306        define_rng!(TestRng);
307        let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
308            assert_eq!(
309                ctx.sample_range(TestRng, 0..100),
310                compare_ctx.sample_range(TestRng, 0..100)
311            );
312            Ok(())
313        });
314        assert!(result.is_ok());
315    }
316
317    #[derive(Serialize, Deserialize)]
318    pub struct RunnerPropertyType {
319        field_int: u32,
320    }
321    define_global_property!(RunnerProperty, RunnerPropertyType);
322
323    #[test]
324    fn test_run_with_config_path() {
325        let test_args = BaseArgs {
326            config: Some(PathBuf::from("tests/data/global_properties_runner.json")),
327            ..Default::default()
328        };
329        let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
330            let p3 = ctx.get_global_property_value(RunnerProperty).unwrap();
331            assert_eq!(p3.field_int, 0);
332            Ok(())
333        });
334        assert!(result.is_ok());
335    }
336
337    #[test]
338    fn test_run_with_report_options() {
339        let test_args = BaseArgs {
340            output_dir: Some(PathBuf::from("data")),
341            file_prefix: Some("test".to_string()),
342            force_overwrite: true,
343            ..Default::default()
344        };
345        let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
346            let opts = &ctx.report_options();
347            assert_eq!(opts.output_dir, PathBuf::from("data"));
348            assert_eq!(opts.file_prefix, "test".to_string());
349            assert!(opts.overwrite);
350            Ok(())
351        });
352        assert!(result.is_ok());
353    }
354
355    #[test]
356    fn test_run_with_custom() {
357        let test_args = BaseArgs::new();
358        let custom = CustomArgs { a: 42 };
359        let result = run_with_args_internal(test_args, Some(custom), |_, _, c| {
360            assert_eq!(c.unwrap().a, 42);
361            Ok(())
362        });
363        assert!(result.is_ok());
364    }
365
366    #[test]
367    fn test_run_with_logging_enabled() {
368        let mut test_args = BaseArgs::new();
369        test_args.log_level = Some(LevelFilter::Info.to_string());
370        let result = run_with_args_internal(test_args, None, |_, _, _: Option<()>| Ok(()));
371        assert!(result.is_ok());
372    }
373}