1use std::path::{Path, PathBuf};
2use std::str::FromStr;
3
4use crate::context::Context;
5#[cfg(feature = "debugger")]
6use crate::debugger::enter_debugger;
7use crate::error::IxaError;
8use crate::global_properties::ContextGlobalPropertiesExt;
9use crate::log::level_to_string_list;
10#[cfg(feature = "progress_bar")]
11use crate::progress::init_timeline_progress_bar;
12use crate::random::ContextRandomExt;
13use crate::report::ContextReportExt;
14#[cfg(feature = "web_api")]
15use crate::web_api::ContextWebApiExt;
16use crate::{set_log_level, set_module_filters, warn, LevelFilter};
17use clap::{ArgAction, Args, Command, FromArgMatches as _};
18#[cfg(feature = "write_cli_usage")]
19use clap_markdown::{help_markdown_command_custom, MarkdownOptions};
20
21fn parse_log_levels(s: &str) -> Result<Vec<(String, LevelFilter)>, String> {
23 s.split(',')
24 .map(|pair| {
25 let mut iter = pair.split('=');
26 let key = iter
27 .next()
28 .ok_or_else(|| format!("Invalid key in pair: {pair}"))?;
29 let value = iter
30 .next()
31 .ok_or_else(|| format!("Invalid value in pair: {pair}"))?;
32 let level =
33 LevelFilter::from_str(value).map_err(|_| format!("Invalid log level: {value}"))?;
34 Ok((key.to_string(), level))
35 })
36 .collect()
37}
38
39#[derive(Args, Debug)]
41pub struct BaseArgs {
42 #[cfg(feature = "write_cli_usage")]
43 #[arg(long, hide = true)]
47 markdown_help: bool,
48
49 #[arg(short, long, default_value = "0")]
51 pub random_seed: u64,
52
53 #[arg(short, long)]
55 pub config: Option<PathBuf>,
56
57 #[arg(short, long = "output")]
59 pub output_dir: Option<PathBuf>,
60
61 #[arg(long = "prefix")]
63 pub file_prefix: Option<String>,
64
65 #[arg(short, long)]
67 pub force_overwrite: bool,
68
69 #[arg(short, long)]
71 pub log_level: Option<String>,
72
73 #[arg(
74 short,
75 long,
76 action = ArgAction::Count,
77 long_help = r#"Increase logging verbosity (-v, -vv, -vvv, etc.)
78
79| Level | ERROR | WARN | INFO | DEBUG | TRACE |
80|---------|-------|------|------|-------|-------|
81| Default | ✓ | | | | |
82| -v | ✓ | ✓ | ✓ | | |
83| -vv | ✓ | ✓ | ✓ | ✓ | |
84| -vvv | ✓ | ✓ | ✓ | ✓ | ✓ |
85"#)]
86 pub verbose: u8,
87
88 #[arg(long)]
90 pub warn: bool,
91
92 #[arg(long)]
94 pub debug: bool,
95
96 #[arg(long)]
98 pub trace: bool,
99
100 #[arg(short, long)]
102 pub debugger: Option<Option<f64>>,
103
104 #[arg(short, long)]
106 pub web: Option<Option<u16>>,
107
108 #[arg(short, long)]
110 pub timeline_progress_max: Option<f64>,
111
112 #[arg(long)]
114 pub no_stats: bool,
115}
116
117impl BaseArgs {
118 fn new() -> Self {
119 BaseArgs {
120 #[cfg(feature = "write_cli_usage")]
121 markdown_help: false,
122 random_seed: 0,
123 config: None,
124 output_dir: None,
125 file_prefix: None,
126 force_overwrite: false,
127 log_level: None,
128 verbose: 0,
129 warn: false,
130 debug: false,
131 trace: false,
132 debugger: None,
133 web: None,
134 timeline_progress_max: None,
135 no_stats: false,
136 }
137 }
138}
139
140impl Default for BaseArgs {
141 fn default() -> Self {
142 BaseArgs::new()
143 }
144}
145
146#[derive(Args)]
147pub struct PlaceholderCustom {}
148
149fn create_ixa_cli() -> Command {
150 let cli = Command::new("ixa");
151 BaseArgs::augment_args(cli)
152}
153
154#[allow(clippy::missing_errors_doc)]
165pub fn run_with_custom_args<A, F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
166where
167 A: Args,
168 F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
169{
170 let mut cli = create_ixa_cli();
171 cli = A::augment_args(cli);
172 let matches = cli.get_matches();
173
174 let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
175 let custom_matches = A::from_arg_matches(&matches)?;
176 run_with_args_internal(base_args_matches, Some(custom_matches), setup_fn)
177}
178
179#[allow(clippy::missing_errors_doc)]
189pub fn run_with_args<F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
190where
191 F: Fn(&mut Context, BaseArgs, Option<PlaceholderCustom>) -> Result<(), IxaError>,
192{
193 let cli = create_ixa_cli();
194 let matches = cli.get_matches();
195
196 let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
197 run_with_args_internal(base_args_matches, None, setup_fn)
198}
199
200fn run_with_args_internal<A, F>(
201 args: BaseArgs,
202 custom_args: Option<A>,
203 setup_fn: F,
204) -> Result<Context, Box<dyn std::error::Error>>
205where
206 F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
207{
208 #[cfg(feature = "write_cli_usage")]
209 if args.markdown_help {
211 let cli = create_ixa_cli();
212 let md_options = MarkdownOptions::new()
213 .show_footer(false)
214 .show_aliases(true)
215 .show_table_of_contents(false)
216 .title("Command Line Usage".to_string());
217 let markdown = help_markdown_command_custom(&cli, &md_options);
218 let path =
219 PathBuf::from(option_env!("CARGO_WORKSPACE_DIR").unwrap_or(env!("CARGO_MANIFEST_DIR")))
220 .join("docs")
221 .join("cli-usage.md");
222 std::fs::write(&path, markdown).unwrap_or_else(|e| {
223 panic!(
224 "Failed to write CLI help Markdown to file {}: {}",
225 path.display(),
226 e
227 );
228 });
229 }
230
231 let mut context = Context::new();
233
234 if args.config.is_some() {
236 let config_path = args.config.clone().unwrap();
237 println!("Loading global properties from: {config_path:?}");
238 context.load_global_properties(&config_path)?;
239 }
240
241 let report_config = context.report_options();
243 if args.output_dir.is_some() {
244 report_config.directory(args.output_dir.clone().unwrap());
245 }
246 if args.file_prefix.is_some() {
247 report_config.file_prefix(args.file_prefix.clone().unwrap());
248 }
249 if args.force_overwrite {
250 report_config.overwrite(true);
251 }
252
253 let mut current_log_level = crate::log::DEFAULT_LOG_LEVEL;
257
258 if let Some(log_level) = args.log_level.as_ref() {
260 if let Ok(level) = LevelFilter::from_str(log_level) {
261 current_log_level = level;
262 } else if let Ok(log_levels) = parse_log_levels(log_level) {
263 let log_levels_slice: Vec<(&String, LevelFilter)> =
264 log_levels.iter().map(|(k, v)| (k, *v)).collect();
265 set_module_filters(log_levels_slice.as_slice());
266 for (key, value) in log_levels {
267 println!("Logging enabled for {key} at level {value}");
268 }
270 } else {
271 return Err(format!("Invalid log level format: {log_level}").into());
272 }
273 }
274
275 if args.verbose > 0 {
277 let new_level = match args.verbose {
278 1 => LevelFilter::Info,
279 2 => LevelFilter::Debug,
280 _ => LevelFilter::Trace,
281 };
282 current_log_level = current_log_level.max(new_level);
283 }
284
285 if args.warn {
287 current_log_level = current_log_level.max(LevelFilter::Warn);
288 }
289 if args.debug {
290 current_log_level = current_log_level.max(LevelFilter::Debug);
291 }
292 if args.trace {
293 current_log_level = LevelFilter::Trace;
294 }
295
296 let binary_name = std::env::args().next();
298 let binary_name = binary_name
299 .as_deref()
300 .map(Path::new)
301 .and_then(Path::file_name)
302 .and_then(|s| s.to_str())
303 .unwrap_or("[model]");
304 println!(
305 "Current log levels enabled: {}",
306 level_to_string_list(current_log_level)
307 );
308 println!("Run {binary_name} --help -v to see more options");
309
310 if current_log_level != crate::log::DEFAULT_LOG_LEVEL {
312 set_log_level(current_log_level);
313 }
314
315 context.init_random(args.random_seed);
316
317 #[cfg(feature = "debugger")]
319 if let Some(t) = args.debugger {
320 assert!(
321 args.web.is_none(),
322 "Cannot run with both the debugger and the Web API"
323 );
324 match t {
325 None => {
326 context.request_debugger();
327 }
328 Some(time) => {
329 context.schedule_debugger(time, None, Box::new(enter_debugger));
330 }
331 }
332 }
333 #[cfg(not(feature = "debugger"))]
334 if args.debugger.is_some() {
335 warn!("Ixa was not compiled with the debugger feature, but a debugger option was provided");
336 }
337
338 #[cfg(feature = "web_api")]
340 if let Some(t) = args.web {
341 let port = t.unwrap_or(33334);
342 let url = context.setup_web_api(port).unwrap();
343 println!("Web API active on {url}");
344 context.schedule_web_api(0.0);
345 }
346 #[cfg(not(feature = "web_api"))]
347 if args.web.is_some() {
348 warn!("Ixa was not compiled with the web_api feature, but a web_api option was provided");
349 }
350
351 if let Some(max_time) = args.timeline_progress_max {
352 if cfg!(not(feature = "progress_bar")) && max_time > 0.0 {
354 warn!("Ixa was not compiled with the progress_bar feature, but a progress_bar option was provided");
355 } else if max_time < 0.0 {
356 warn!("timeline progress maximum must be nonnegative");
357 }
358 #[cfg(feature = "progress_bar")]
359 if max_time > 0.0 {
360 println!("ProgressBar max set to {}", max_time);
361 init_timeline_progress_bar(max_time);
362 }
363 }
364
365 if args.no_stats {
366 context.print_execution_statistics = false;
367 } else {
368 if cfg!(target_family = "wasm") {
369 warn!("the print-stats option is enabled; some statistics are not supported for the wasm target family");
370 }
371 context.print_execution_statistics = true;
372 }
373
374 setup_fn(&mut context, args, custom_args)?;
376
377 context.execute();
379 Ok(context)
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use crate::{define_global_property, define_rng};
386 use serde::{Deserialize, Serialize};
387
388 #[derive(Args, Debug)]
389 struct CustomArgs {
390 #[arg(short, long, default_value = "0")]
391 a: u32,
392 }
393
394 #[test]
395 fn test_run_with_custom_args() {
396 let result = run_with_custom_args(|_, _, _: Option<CustomArgs>| Ok(()));
397 assert!(result.is_ok());
398 }
399
400 #[test]
401 fn test_run_with_args() {
402 let result = run_with_args(|_, _, _| Ok(()));
403 assert!(result.is_ok());
404 }
405
406 #[test]
407 fn test_run_with_random_seed() {
408 let test_args = BaseArgs {
409 random_seed: 42,
410 ..Default::default()
411 };
412
413 let mut compare_ctx = Context::new();
415 compare_ctx.init_random(42);
416 define_rng!(TestRng);
417 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
418 assert_eq!(
419 ctx.sample_range(TestRng, 0..100),
420 compare_ctx.sample_range(TestRng, 0..100)
421 );
422 Ok(())
423 });
424 assert!(result.is_ok());
425 }
426
427 #[derive(Serialize, Deserialize)]
428 pub struct RunnerPropertyType {
429 field_int: u32,
430 }
431 define_global_property!(RunnerProperty, RunnerPropertyType);
432
433 #[test]
434 fn test_run_with_config_path() {
435 let test_args = BaseArgs {
436 config: Some(PathBuf::from("tests/data/global_properties_runner.json")),
437 ..Default::default()
438 };
439 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
440 let p3 = ctx.get_global_property_value(RunnerProperty).unwrap();
441 assert_eq!(p3.field_int, 0);
442 Ok(())
443 });
444 assert!(result.is_ok());
445 }
446
447 #[test]
448 fn test_run_with_report_options() {
449 let test_args = BaseArgs {
450 output_dir: Some(PathBuf::from("data")),
451 file_prefix: Some("test".to_string()),
452 force_overwrite: true,
453 ..Default::default()
454 };
455 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
456 let opts = &ctx.report_options();
457 assert_eq!(opts.output_dir, PathBuf::from("data"));
458 assert_eq!(opts.file_prefix, "test".to_string());
459 assert!(opts.overwrite);
460 Ok(())
461 });
462 assert!(result.is_ok());
463 }
464
465 #[test]
466 fn test_run_with_custom() {
467 let test_args = BaseArgs::new();
468 let custom = CustomArgs { a: 42 };
469 let result = run_with_args_internal(test_args, Some(custom), |_, _, c| {
470 assert_eq!(c.unwrap().a, 42);
471 Ok(())
472 });
473 assert!(result.is_ok());
474 }
475
476 #[test]
477 fn test_run_with_logging_enabled() {
478 let mut test_args = BaseArgs::new();
479 test_args.log_level = Some(LevelFilter::Info.to_string());
480 let result = run_with_args_internal(test_args, None, |_, _, _: Option<()>| Ok(()));
481 assert!(result.is_ok());
482 }
483}