1use std::path::{Path, PathBuf};
2use std::str::FromStr;
3
4use clap::{ArgAction, Args, Command, FromArgMatches as _};
5#[cfg(feature = "write_cli_usage")]
6use clap_markdown::{help_markdown_command_custom, MarkdownOptions};
7
8use crate::context::Context;
9#[cfg(feature = "debugger")]
10use crate::debugger::enter_debugger;
11use crate::error::IxaError;
12use crate::global_properties::ContextGlobalPropertiesExt;
13use crate::log::level_to_string_list;
14#[cfg(feature = "progress_bar")]
15use crate::progress::init_timeline_progress_bar;
16use crate::random::ContextRandomExt;
17use crate::report::ContextReportExt;
18use crate::{set_log_level, set_module_filters, warn, LevelFilter};
19
20fn parse_log_levels(s: &str) -> Result<Vec<(String, LevelFilter)>, IxaError> {
22 s.split(',')
23 .map(|pair| {
24 let mut iter = pair.split('=');
25 let key = iter.next().ok_or_else(|| IxaError::InvalidLogLevelKey {
26 pair: pair.to_string(),
27 })?;
28 let value = iter.next().ok_or_else(|| IxaError::InvalidLogLevelValue {
29 pair: pair.to_string(),
30 })?;
31 let level = LevelFilter::from_str(value).map_err(|_| IxaError::InvalidLogLevel {
32 level: value.to_string(),
33 })?;
34 Ok((key.to_string(), level))
35 })
36 .collect()
37}
38
39#[derive(Args, Debug)]
41pub struct BaseArgs {
42 #[cfg(feature = "write_cli_usage")]
43 #[arg(long, hide = true)]
47 markdown_help: bool,
48
49 #[arg(short, long, default_value = "0")]
51 pub random_seed: u64,
52
53 #[arg(short, long)]
55 pub config: Option<PathBuf>,
56
57 #[arg(short, long = "output")]
59 pub output_dir: Option<PathBuf>,
60
61 #[arg(long = "prefix")]
63 pub file_prefix: Option<String>,
64
65 #[arg(short, long)]
67 pub force_overwrite: bool,
68
69 #[arg(short, long)]
71 pub log_level: Option<String>,
72
73 #[arg(
74 short,
75 long,
76 action = ArgAction::Count,
77 long_help = r#"Increase logging verbosity (-v, -vv, -vvv, etc.)
78
79| Level | ERROR | WARN | INFO | DEBUG | TRACE |
80|---------|-------|------|------|-------|-------|
81| Default | ✓ | | | | |
82| -v | ✓ | ✓ | ✓ | | |
83| -vv | ✓ | ✓ | ✓ | ✓ | |
84| -vvv | ✓ | ✓ | ✓ | ✓ | ✓ |
85"#)]
86 pub verbose: u8,
87
88 #[arg(long)]
90 pub warn: bool,
91
92 #[arg(long)]
94 pub debug: bool,
95
96 #[arg(long)]
98 pub trace: bool,
99
100 #[arg(short, long)]
102 pub debugger: Option<Option<f64>>,
103
104 #[arg(short, long)]
106 pub web: Option<Option<u16>>,
107
108 #[arg(short, long)]
110 pub timeline_progress_max: Option<f64>,
111
112 #[arg(long)]
114 pub no_stats: bool,
115}
116
117impl BaseArgs {
118 fn new() -> Self {
119 BaseArgs {
120 #[cfg(feature = "write_cli_usage")]
121 markdown_help: false,
122 random_seed: 0,
123 config: None,
124 output_dir: None,
125 file_prefix: None,
126 force_overwrite: false,
127 log_level: None,
128 verbose: 0,
129 warn: false,
130 debug: false,
131 trace: false,
132 debugger: None,
133 web: None,
134 timeline_progress_max: None,
135 no_stats: false,
136 }
137 }
138}
139
140impl Default for BaseArgs {
141 fn default() -> Self {
142 BaseArgs::new()
143 }
144}
145
146#[derive(Args)]
147pub struct PlaceholderCustom {}
148
149fn create_ixa_cli() -> Command {
150 let cli = Command::new("ixa");
151 BaseArgs::augment_args(cli)
152}
153
154#[allow(clippy::missing_errors_doc)]
165pub fn run_with_custom_args<A, F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
166where
167 A: Args,
168 F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
169{
170 let mut cli = create_ixa_cli();
171 cli = A::augment_args(cli);
172 let matches = cli.get_matches();
173
174 let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
175 let custom_matches = A::from_arg_matches(&matches)?;
176 run_with_args_internal(base_args_matches, Some(custom_matches), setup_fn)
177}
178
179#[allow(clippy::missing_errors_doc)]
189pub fn run_with_args<F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
190where
191 F: Fn(&mut Context, BaseArgs, Option<PlaceholderCustom>) -> Result<(), IxaError>,
192{
193 let cli = create_ixa_cli();
194 let matches = cli.get_matches();
195
196 let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
197 run_with_args_internal(base_args_matches, None, setup_fn)
198}
199
200fn run_with_args_internal<A, F>(
201 args: BaseArgs,
202 custom_args: Option<A>,
203 setup_fn: F,
204) -> Result<Context, Box<dyn std::error::Error>>
205where
206 F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
207{
208 #[cfg(feature = "write_cli_usage")]
209 if args.markdown_help {
211 let cli = create_ixa_cli();
212 let md_options = MarkdownOptions::new()
213 .show_footer(false)
214 .show_aliases(true)
215 .show_table_of_contents(false)
216 .title("Command Line Usage".to_string());
217 let markdown = help_markdown_command_custom(&cli, &md_options);
218 let path =
219 PathBuf::from(option_env!("CARGO_WORKSPACE_DIR").unwrap_or(env!("CARGO_MANIFEST_DIR")))
220 .join("docs")
221 .join("book")
222 .join("src")
223 .join("cli-usage.md");
224 std::fs::write(&path, markdown).unwrap_or_else(|e| {
225 panic!(
226 "Failed to write CLI help Markdown to file {}: {}",
227 path.display(),
228 e
229 );
230 });
231 }
232
233 let mut context = Context::new();
235
236 if args.config.is_some() {
238 let config_path = args.config.clone().unwrap();
239 println!("Loading global properties from: {config_path:?}");
240 context.load_global_properties(&config_path)?;
241 }
242
243 let report_config = context.report_options();
245 if args.output_dir.is_some() {
246 report_config.directory(args.output_dir.clone().unwrap());
247 }
248 if args.file_prefix.is_some() {
249 report_config.file_prefix(args.file_prefix.clone().unwrap());
250 }
251 if args.force_overwrite {
252 report_config.overwrite(true);
253 }
254
255 let mut current_log_level = crate::log::DEFAULT_LOG_LEVEL;
259
260 if let Some(log_level) = args.log_level.as_ref() {
262 if let Ok(level) = LevelFilter::from_str(log_level) {
263 current_log_level = level;
264 } else {
265 match parse_log_levels(log_level) {
266 Ok(log_levels) => {
267 let log_levels_slice: Vec<(&String, LevelFilter)> =
268 log_levels.iter().map(|(k, v)| (k, *v)).collect();
269 set_module_filters(log_levels_slice.as_slice());
270 for (key, value) in log_levels {
271 println!("Logging enabled for {key} at level {value}");
272 }
274 }
275 Err(e) => return Err(Box::new(e)),
276 }
277 }
278 }
279
280 if args.verbose > 0 {
282 let new_level = match args.verbose {
283 1 => LevelFilter::Info,
284 2 => LevelFilter::Debug,
285 _ => LevelFilter::Trace,
286 };
287 current_log_level = current_log_level.max(new_level);
288 }
289
290 if args.warn {
292 current_log_level = current_log_level.max(LevelFilter::Warn);
293 }
294 if args.debug {
295 current_log_level = current_log_level.max(LevelFilter::Debug);
296 }
297 if args.trace {
298 current_log_level = LevelFilter::Trace;
299 }
300
301 let binary_name = std::env::args().next();
303 let binary_name = binary_name
304 .as_deref()
305 .map(Path::new)
306 .and_then(Path::file_name)
307 .and_then(|s| s.to_str())
308 .unwrap_or("[model]");
309 println!(
310 "Current log levels enabled: {}",
311 level_to_string_list(current_log_level)
312 );
313 println!("Run {binary_name} --help -v to see more options");
314
315 if current_log_level != crate::log::DEFAULT_LOG_LEVEL {
317 set_log_level(current_log_level);
318 }
319
320 context.init_random(args.random_seed);
321
322 #[cfg(feature = "debugger")]
324 if let Some(t) = args.debugger {
325 assert!(
326 args.web.is_none(),
327 "Cannot run with both the debugger and the Web API"
328 );
329 match t {
330 None => {
331 context.request_debugger();
332 }
333 Some(time) => {
334 context.schedule_debugger(time, None, Box::new(enter_debugger));
335 }
336 }
337 }
338 #[cfg(not(feature = "debugger"))]
339 if args.debugger.is_some() {
340 warn!("Ixa was not compiled with the debugger feature, but a debugger option was provided");
341 }
342
343 if let Some(max_time) = args.timeline_progress_max {
344 if cfg!(not(feature = "progress_bar")) && max_time > 0.0 {
346 warn!("Ixa was not compiled with the progress_bar feature, but a progress_bar option was provided");
347 } else if max_time < 0.0 {
348 warn!("timeline progress maximum must be nonnegative");
349 }
350 #[cfg(feature = "progress_bar")]
351 if max_time > 0.0 {
352 println!("ProgressBar max set to {}", max_time);
353 init_timeline_progress_bar(max_time);
354 }
355 }
356
357 if args.no_stats {
358 context.print_execution_statistics = false;
359 } else {
360 if cfg!(target_family = "wasm") {
361 warn!("the print-stats option is enabled; some statistics are not supported for the wasm target family");
362 }
363 context.print_execution_statistics = true;
364 }
365
366 setup_fn(&mut context, args, custom_args)?;
368
369 context.execute();
371 Ok(context)
372}
373
374#[cfg(test)]
375mod tests {
376 use serde::{Deserialize, Serialize};
377
378 use super::*;
379 use crate::{define_global_property, define_rng};
380
381 #[derive(Args, Debug)]
382 struct CustomArgs {
383 #[arg(short, long, default_value = "0")]
384 a: u32,
385 }
386
387 #[test]
388 fn test_run_with_custom_args() {
389 let result = run_with_custom_args(|_, _, _: Option<CustomArgs>| Ok(()));
390 assert!(result.is_ok());
391 }
392
393 #[test]
394 fn test_run_with_args() {
395 let result = run_with_args(|_, _, _| Ok(()));
396 assert!(result.is_ok());
397 }
398
399 #[test]
400 fn test_run_with_random_seed() {
401 let test_args = BaseArgs {
402 random_seed: 42,
403 ..Default::default()
404 };
405
406 let mut compare_ctx = Context::new();
408 compare_ctx.init_random(42);
409 define_rng!(TestRng);
410 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
411 assert_eq!(
412 ctx.sample_range(TestRng, 0..100),
413 compare_ctx.sample_range(TestRng, 0..100)
414 );
415 Ok(())
416 });
417 assert!(result.is_ok());
418 }
419
420 #[derive(Serialize, Deserialize)]
421 pub struct RunnerPropertyType {
422 field_int: u32,
423 }
424 define_global_property!(RunnerProperty, RunnerPropertyType);
425
426 #[test]
427 fn test_run_with_config_path() {
428 let test_args = BaseArgs {
429 config: Some(PathBuf::from("tests/data/global_properties_runner.json")),
430 ..Default::default()
431 };
432 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
433 let p3 = ctx.get_global_property_value(RunnerProperty).unwrap();
434 assert_eq!(p3.field_int, 0);
435 Ok(())
436 });
437 assert!(result.is_ok());
438 }
439
440 #[test]
441 fn test_run_with_report_options() {
442 let test_args = BaseArgs {
443 output_dir: Some(PathBuf::from("data")),
444 file_prefix: Some("test".to_string()),
445 force_overwrite: true,
446 ..Default::default()
447 };
448 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
449 let opts = &ctx.report_options();
450 assert_eq!(opts.output_dir, PathBuf::from("data"));
451 assert_eq!(opts.file_prefix, "test".to_string());
452 assert!(opts.overwrite);
453 Ok(())
454 });
455 assert!(result.is_ok());
456 }
457
458 #[test]
459 fn test_run_with_custom() {
460 let test_args = BaseArgs::new();
461 let custom = CustomArgs { a: 42 };
462 let result = run_with_args_internal(test_args, Some(custom), |_, _, c| {
463 assert_eq!(c.unwrap().a, 42);
464 Ok(())
465 });
466 assert!(result.is_ok());
467 }
468
469 #[test]
470 fn test_run_with_logging_enabled() {
471 let mut test_args = BaseArgs::new();
472 test_args.log_level = Some(LevelFilter::Info.to_string());
473 let result = run_with_args_internal(test_args, None, |_, _, _: Option<()>| Ok(()));
474 assert!(result.is_ok());
475 }
476}