ixa/
runner.rs

1use std::path::PathBuf;
2use std::str::FromStr;
3
4use crate::context::Context;
5#[cfg(feature = "debugger")]
6use crate::debugger::enter_debugger;
7use crate::error::IxaError;
8use crate::global_properties::ContextGlobalPropertiesExt;
9use crate::random::ContextRandomExt;
10use crate::report::ContextReportExt;
11#[cfg(feature = "web_api")]
12use crate::web_api::ContextWebApiExt;
13use crate::{info, set_log_level, set_module_filters, LevelFilter};
14use clap::{Args, Command, FromArgMatches as _};
15#[cfg(not(feature = "web_api"))]
16use log::warn;
17
18/// Custom parser for log levels
19fn parse_log_levels(s: &str) -> Result<Vec<(String, LevelFilter)>, String> {
20    s.split(',')
21        .map(|pair| {
22            let mut iter = pair.split('=');
23            let key = iter
24                .next()
25                .ok_or_else(|| format!("Invalid key in pair: {pair}"))?;
26            let value = iter
27                .next()
28                .ok_or_else(|| format!("Invalid value in pair: {pair}"))?;
29            let level =
30                LevelFilter::from_str(value).map_err(|_| format!("Invalid log level: {value}"))?;
31            Ok((key.to_string(), level))
32        })
33        .collect()
34}
35
36/// Default cli arguments for ixa runner
37#[derive(Args, Debug)]
38pub struct BaseArgs {
39    /// Random seed
40    #[arg(short, long, default_value = "0")]
41    pub random_seed: u64,
42
43    /// Optional path for a global properties config file
44    #[arg(short, long)]
45    pub config: Option<PathBuf>,
46
47    /// Optional path for report output
48    #[arg(short, long = "output")]
49    pub output_dir: Option<PathBuf>,
50
51    /// Optional prefix for report files
52    #[arg(long = "prefix")]
53    pub file_prefix: Option<String>,
54
55    /// Overwrite existing report files?
56    #[arg(short, long)]
57    pub force_overwrite: bool,
58
59    /// Enable logging
60    #[arg(short, long)]
61    pub log_level: Option<String>,
62
63    /// Set a breakpoint at a given time and start the debugger. Defaults to t=0.0
64    #[arg(short, long)]
65    pub debugger: Option<Option<f64>>,
66
67    /// Enable the Web API at a given time. Defaults to t=0.0
68    #[arg(short, long)]
69    pub web: Option<Option<u16>>,
70}
71
72impl BaseArgs {
73    fn new() -> Self {
74        BaseArgs {
75            random_seed: 0,
76            config: None,
77            output_dir: None,
78            file_prefix: None,
79            force_overwrite: false,
80            log_level: None,
81            debugger: None,
82            web: None,
83        }
84    }
85}
86
87impl Default for BaseArgs {
88    fn default() -> Self {
89        BaseArgs::new()
90    }
91}
92
93#[derive(Args)]
94pub struct PlaceholderCustom {}
95
96fn create_ixa_cli() -> Command {
97    let cli = Command::new("ixa");
98    BaseArgs::augment_args(cli)
99}
100
101/// Runs a simulation with custom cli arguments.
102///
103/// This function allows you to define custom arguments and a setup function
104///
105/// # Parameters
106/// - `setup_fn`: A function that takes a mutable reference to a `Context`, a `BaseArgs` struct,
107///   a Option<A> where A is the custom cli arguments struct
108///
109/// # Errors
110/// Returns an error if argument parsing or the setup function fails
111#[allow(clippy::missing_errors_doc)]
112pub fn run_with_custom_args<A, F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
113where
114    A: Args,
115    F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
116{
117    let mut cli = create_ixa_cli();
118    cli = A::augment_args(cli);
119    let matches = cli.get_matches();
120
121    let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
122    let custom_matches = A::from_arg_matches(&matches)?;
123    run_with_args_internal(base_args_matches, Some(custom_matches), setup_fn)
124}
125
126/// Runs a simulation with default cli arguments
127///
128/// This function parses command line arguments allows you to define a setup function
129///
130/// # Parameters
131/// - `setup_fn`: A function that takes a mutable reference to a `Context` and `BaseArgs` struct
132///
133/// # Errors
134/// Returns an error if argument parsing or the setup function fails
135#[allow(clippy::missing_errors_doc)]
136pub fn run_with_args<F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
137where
138    F: Fn(&mut Context, BaseArgs, Option<PlaceholderCustom>) -> Result<(), IxaError>,
139{
140    let cli = create_ixa_cli();
141    let matches = cli.get_matches();
142
143    let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
144    run_with_args_internal(base_args_matches, None, setup_fn)
145}
146
147fn run_with_args_internal<A, F>(
148    args: BaseArgs,
149    custom_args: Option<A>,
150    setup_fn: F,
151) -> Result<Context, Box<dyn std::error::Error>>
152where
153    F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
154{
155    // Instantiate a context
156    let mut context = Context::new();
157
158    // Optionally set global properties from a file
159    if args.config.is_some() {
160        let config_path = args.config.clone().unwrap();
161        println!("Loading global properties from: {config_path:?}");
162        context.load_global_properties(&config_path)?;
163    }
164
165    // Configure report options
166    let report_config = context.report_options();
167    if args.output_dir.is_some() {
168        report_config.directory(args.output_dir.clone().unwrap());
169    }
170    if args.file_prefix.is_some() {
171        report_config.file_prefix(args.file_prefix.clone().unwrap());
172    }
173    if args.force_overwrite {
174        report_config.overwrite(true);
175    }
176    if let Some(log_level) = args.log_level.as_ref() {
177        if let Ok(level) = LevelFilter::from_str(log_level) {
178            set_log_level(level);
179            info!("Logging enabled at level {level}");
180        } else if let Ok(log_levels) = parse_log_levels(log_level) {
181            let log_levels_slice: Vec<(&String, LevelFilter)> =
182                log_levels.iter().map(|(k, v)| (k, *v)).collect();
183            set_module_filters(log_levels_slice.as_slice());
184            for (key, value) in log_levels {
185                println!("Logging enabled for {key} at level {value}");
186                // Here you can set the log level for each key-value pair as needed
187            }
188        } else {
189            return Err(format!("Invalid log level format: {log_level}").into());
190        }
191    } else {
192        info!("Logging disabled.");
193    }
194
195    context.init_random(args.random_seed);
196
197    // If a breakpoint is provided, stop at that time
198    #[cfg(feature = "debugger")]
199    if let Some(t) = args.debugger {
200        assert!(
201            args.web.is_none(),
202            "Cannot run with both the debugger and the Web API"
203        );
204        match t {
205            None => {
206                context.request_debugger();
207            }
208            Some(time) => {
209                context.schedule_debugger(time, None, Box::new(enter_debugger));
210            }
211        }
212    }
213    #[cfg(not(feature = "debugger"))]
214    if args.debugger.is_some() {
215        warn!("Ixa was not compiled with the debugger feature, but a debugger option was provided");
216    }
217
218    // If the Web API is provided, stop there.
219    #[cfg(feature = "web_api")]
220    if let Some(t) = args.web {
221        let port = t.unwrap_or(33334);
222        let url = context.setup_web_api(port).unwrap();
223        println!("Web API active on {url}");
224        context.schedule_web_api(0.0);
225    }
226    #[cfg(not(feature = "web_api"))]
227    if args.web.is_some() {
228        warn!("Ixa was not compiled with the web_api feature, but a web_api option was provided");
229    }
230
231    // Run the provided Fn
232    setup_fn(&mut context, args, custom_args)?;
233
234    // Execute the context
235    context.execute();
236    Ok(context)
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use crate::{define_global_property, define_rng};
243    use serde::{Deserialize, Serialize};
244
245    #[derive(Args, Debug)]
246    struct CustomArgs {
247        #[arg(short, long, default_value = "0")]
248        a: u32,
249    }
250
251    #[test]
252    fn test_run_with_custom_args() {
253        let result = run_with_custom_args(|_, _, _: Option<CustomArgs>| Ok(()));
254        assert!(result.is_ok());
255    }
256
257    #[test]
258    fn test_run_with_args() {
259        let result = run_with_args(|_, _, _| Ok(()));
260        assert!(result.is_ok());
261    }
262
263    #[test]
264    fn test_run_with_random_seed() {
265        let test_args = BaseArgs {
266            random_seed: 42,
267            ..Default::default()
268        };
269
270        // Use a comparison context to verify the random seed was set
271        let mut compare_ctx = Context::new();
272        compare_ctx.init_random(42);
273        define_rng!(TestRng);
274        let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
275            assert_eq!(
276                ctx.sample_range(TestRng, 0..100),
277                compare_ctx.sample_range(TestRng, 0..100)
278            );
279            Ok(())
280        });
281        assert!(result.is_ok());
282    }
283
284    #[derive(Serialize, Deserialize)]
285    pub struct RunnerPropertyType {
286        field_int: u32,
287    }
288    define_global_property!(RunnerProperty, RunnerPropertyType);
289
290    #[test]
291    fn test_run_with_config_path() {
292        let test_args = BaseArgs {
293            config: Some(PathBuf::from("tests/data/global_properties_runner.json")),
294            ..Default::default()
295        };
296        let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
297            let p3 = ctx.get_global_property_value(RunnerProperty).unwrap();
298            assert_eq!(p3.field_int, 0);
299            Ok(())
300        });
301        assert!(result.is_ok());
302    }
303
304    #[test]
305    fn test_run_with_report_options() {
306        let test_args = BaseArgs {
307            output_dir: Some(PathBuf::from("data")),
308            file_prefix: Some("test".to_string()),
309            force_overwrite: true,
310            ..Default::default()
311        };
312        let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
313            let opts = &ctx.report_options();
314            assert_eq!(opts.output_dir, PathBuf::from("data"));
315            assert_eq!(opts.file_prefix, "test".to_string());
316            assert!(opts.overwrite);
317            Ok(())
318        });
319        assert!(result.is_ok());
320    }
321
322    #[test]
323    fn test_run_with_custom() {
324        let test_args = BaseArgs::new();
325        let custom = CustomArgs { a: 42 };
326        let result = run_with_args_internal(test_args, Some(custom), |_, _, c| {
327            assert_eq!(c.unwrap().a, 42);
328            Ok(())
329        });
330        assert!(result.is_ok());
331    }
332
333    #[test]
334    fn test_run_with_logging_enabled() {
335        let mut test_args = BaseArgs::new();
336        test_args.log_level = Some(LevelFilter::Info.to_string());
337        let result = run_with_args_internal(test_args, None, |_, _, _: Option<()>| Ok(()));
338        assert!(result.is_ok());
339    }
340}