1use std::path::{Path, PathBuf};
2use std::str::FromStr;
3
4use clap::{ArgAction, Args, Command, FromArgMatches as _};
5#[cfg(feature = "write_cli_usage")]
6use clap_markdown::{help_markdown_command_custom, MarkdownOptions};
7
8use crate::context::Context;
9#[cfg(feature = "debugger")]
10use crate::debugger::enter_debugger;
11use crate::error::IxaError;
12use crate::global_properties::ContextGlobalPropertiesExt;
13use crate::log::level_to_string_list;
14#[cfg(feature = "progress_bar")]
15use crate::progress::init_timeline_progress_bar;
16use crate::random::ContextRandomExt;
17use crate::report::ContextReportExt;
18#[cfg(feature = "web_api")]
19use crate::web_api::ContextWebApiExt;
20use crate::{set_log_level, set_module_filters, warn, LevelFilter};
21
22fn parse_log_levels(s: &str) -> Result<Vec<(String, LevelFilter)>, String> {
24 s.split(',')
25 .map(|pair| {
26 let mut iter = pair.split('=');
27 let key = iter
28 .next()
29 .ok_or_else(|| format!("Invalid key in pair: {pair}"))?;
30 let value = iter
31 .next()
32 .ok_or_else(|| format!("Invalid value in pair: {pair}"))?;
33 let level =
34 LevelFilter::from_str(value).map_err(|_| format!("Invalid log level: {value}"))?;
35 Ok((key.to_string(), level))
36 })
37 .collect()
38}
39
40#[derive(Args, Debug)]
42pub struct BaseArgs {
43 #[cfg(feature = "write_cli_usage")]
44 #[arg(long, hide = true)]
48 markdown_help: bool,
49
50 #[arg(short, long, default_value = "0")]
52 pub random_seed: u64,
53
54 #[arg(short, long)]
56 pub config: Option<PathBuf>,
57
58 #[arg(short, long = "output")]
60 pub output_dir: Option<PathBuf>,
61
62 #[arg(long = "prefix")]
64 pub file_prefix: Option<String>,
65
66 #[arg(short, long)]
68 pub force_overwrite: bool,
69
70 #[arg(short, long)]
72 pub log_level: Option<String>,
73
74 #[arg(
75 short,
76 long,
77 action = ArgAction::Count,
78 long_help = r#"Increase logging verbosity (-v, -vv, -vvv, etc.)
79
80| Level | ERROR | WARN | INFO | DEBUG | TRACE |
81|---------|-------|------|------|-------|-------|
82| Default | ✓ | | | | |
83| -v | ✓ | ✓ | ✓ | | |
84| -vv | ✓ | ✓ | ✓ | ✓ | |
85| -vvv | ✓ | ✓ | ✓ | ✓ | ✓ |
86"#)]
87 pub verbose: u8,
88
89 #[arg(long)]
91 pub warn: bool,
92
93 #[arg(long)]
95 pub debug: bool,
96
97 #[arg(long)]
99 pub trace: bool,
100
101 #[arg(short, long)]
103 pub debugger: Option<Option<f64>>,
104
105 #[arg(short, long)]
107 pub web: Option<Option<u16>>,
108
109 #[arg(short, long)]
111 pub timeline_progress_max: Option<f64>,
112
113 #[arg(long)]
115 pub no_stats: bool,
116}
117
118impl BaseArgs {
119 fn new() -> Self {
120 BaseArgs {
121 #[cfg(feature = "write_cli_usage")]
122 markdown_help: false,
123 random_seed: 0,
124 config: None,
125 output_dir: None,
126 file_prefix: None,
127 force_overwrite: false,
128 log_level: None,
129 verbose: 0,
130 warn: false,
131 debug: false,
132 trace: false,
133 debugger: None,
134 web: None,
135 timeline_progress_max: None,
136 no_stats: false,
137 }
138 }
139}
140
141impl Default for BaseArgs {
142 fn default() -> Self {
143 BaseArgs::new()
144 }
145}
146
147#[derive(Args)]
148pub struct PlaceholderCustom {}
149
150fn create_ixa_cli() -> Command {
151 let cli = Command::new("ixa");
152 BaseArgs::augment_args(cli)
153}
154
155#[allow(clippy::missing_errors_doc)]
166pub fn run_with_custom_args<A, F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
167where
168 A: Args,
169 F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
170{
171 let mut cli = create_ixa_cli();
172 cli = A::augment_args(cli);
173 let matches = cli.get_matches();
174
175 let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
176 let custom_matches = A::from_arg_matches(&matches)?;
177 run_with_args_internal(base_args_matches, Some(custom_matches), setup_fn)
178}
179
180#[allow(clippy::missing_errors_doc)]
190pub fn run_with_args<F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
191where
192 F: Fn(&mut Context, BaseArgs, Option<PlaceholderCustom>) -> Result<(), IxaError>,
193{
194 let cli = create_ixa_cli();
195 let matches = cli.get_matches();
196
197 let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
198 run_with_args_internal(base_args_matches, None, setup_fn)
199}
200
201fn run_with_args_internal<A, F>(
202 args: BaseArgs,
203 custom_args: Option<A>,
204 setup_fn: F,
205) -> Result<Context, Box<dyn std::error::Error>>
206where
207 F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
208{
209 #[cfg(feature = "write_cli_usage")]
210 if args.markdown_help {
212 let cli = create_ixa_cli();
213 let md_options = MarkdownOptions::new()
214 .show_footer(false)
215 .show_aliases(true)
216 .show_table_of_contents(false)
217 .title("Command Line Usage".to_string());
218 let markdown = help_markdown_command_custom(&cli, &md_options);
219 let path =
220 PathBuf::from(option_env!("CARGO_WORKSPACE_DIR").unwrap_or(env!("CARGO_MANIFEST_DIR")))
221 .join("docs")
222 .join("book")
223 .join("src")
224 .join("cli-usage.md");
225 std::fs::write(&path, markdown).unwrap_or_else(|e| {
226 panic!(
227 "Failed to write CLI help Markdown to file {}: {}",
228 path.display(),
229 e
230 );
231 });
232 }
233
234 let mut context = Context::new();
236
237 if args.config.is_some() {
239 let config_path = args.config.clone().unwrap();
240 println!("Loading global properties from: {config_path:?}");
241 context.load_global_properties(&config_path)?;
242 }
243
244 let report_config = context.report_options();
246 if args.output_dir.is_some() {
247 report_config.directory(args.output_dir.clone().unwrap());
248 }
249 if args.file_prefix.is_some() {
250 report_config.file_prefix(args.file_prefix.clone().unwrap());
251 }
252 if args.force_overwrite {
253 report_config.overwrite(true);
254 }
255
256 let mut current_log_level = crate::log::DEFAULT_LOG_LEVEL;
260
261 if let Some(log_level) = args.log_level.as_ref() {
263 if let Ok(level) = LevelFilter::from_str(log_level) {
264 current_log_level = level;
265 } else if let Ok(log_levels) = parse_log_levels(log_level) {
266 let log_levels_slice: Vec<(&String, LevelFilter)> =
267 log_levels.iter().map(|(k, v)| (k, *v)).collect();
268 set_module_filters(log_levels_slice.as_slice());
269 for (key, value) in log_levels {
270 println!("Logging enabled for {key} at level {value}");
271 }
273 } else {
274 return Err(format!("Invalid log level format: {log_level}").into());
275 }
276 }
277
278 if args.verbose > 0 {
280 let new_level = match args.verbose {
281 1 => LevelFilter::Info,
282 2 => LevelFilter::Debug,
283 _ => LevelFilter::Trace,
284 };
285 current_log_level = current_log_level.max(new_level);
286 }
287
288 if args.warn {
290 current_log_level = current_log_level.max(LevelFilter::Warn);
291 }
292 if args.debug {
293 current_log_level = current_log_level.max(LevelFilter::Debug);
294 }
295 if args.trace {
296 current_log_level = LevelFilter::Trace;
297 }
298
299 let binary_name = std::env::args().next();
301 let binary_name = binary_name
302 .as_deref()
303 .map(Path::new)
304 .and_then(Path::file_name)
305 .and_then(|s| s.to_str())
306 .unwrap_or("[model]");
307 println!(
308 "Current log levels enabled: {}",
309 level_to_string_list(current_log_level)
310 );
311 println!("Run {binary_name} --help -v to see more options");
312
313 if current_log_level != crate::log::DEFAULT_LOG_LEVEL {
315 set_log_level(current_log_level);
316 }
317
318 context.init_random(args.random_seed);
319
320 #[cfg(feature = "debugger")]
322 if let Some(t) = args.debugger {
323 assert!(
324 args.web.is_none(),
325 "Cannot run with both the debugger and the Web API"
326 );
327 match t {
328 None => {
329 context.request_debugger();
330 }
331 Some(time) => {
332 context.schedule_debugger(time, None, Box::new(enter_debugger));
333 }
334 }
335 }
336 #[cfg(not(feature = "debugger"))]
337 if args.debugger.is_some() {
338 warn!("Ixa was not compiled with the debugger feature, but a debugger option was provided");
339 }
340
341 #[cfg(feature = "web_api")]
343 if let Some(t) = args.web {
344 let port = t.unwrap_or(33334);
345 let url = context.setup_web_api(port).unwrap();
346 println!("Web API active on {url}");
347 context.schedule_web_api(0.0);
348 }
349 #[cfg(not(feature = "web_api"))]
350 if args.web.is_some() {
351 warn!("Ixa was not compiled with the web_api feature, but a web_api option was provided");
352 }
353
354 if let Some(max_time) = args.timeline_progress_max {
355 if cfg!(not(feature = "progress_bar")) && max_time > 0.0 {
357 warn!("Ixa was not compiled with the progress_bar feature, but a progress_bar option was provided");
358 } else if max_time < 0.0 {
359 warn!("timeline progress maximum must be nonnegative");
360 }
361 #[cfg(feature = "progress_bar")]
362 if max_time > 0.0 {
363 println!("ProgressBar max set to {}", max_time);
364 init_timeline_progress_bar(max_time);
365 }
366 }
367
368 if args.no_stats {
369 context.print_execution_statistics = false;
370 } else {
371 if cfg!(target_family = "wasm") {
372 warn!("the print-stats option is enabled; some statistics are not supported for the wasm target family");
373 }
374 context.print_execution_statistics = true;
375 }
376
377 setup_fn(&mut context, args, custom_args)?;
379
380 context.execute();
382 Ok(context)
383}
384
385#[cfg(test)]
386mod tests {
387 use serde::{Deserialize, Serialize};
388
389 use super::*;
390 use crate::{define_global_property, define_rng};
391
392 #[derive(Args, Debug)]
393 struct CustomArgs {
394 #[arg(short, long, default_value = "0")]
395 a: u32,
396 }
397
398 #[test]
399 fn test_run_with_custom_args() {
400 let result = run_with_custom_args(|_, _, _: Option<CustomArgs>| Ok(()));
401 assert!(result.is_ok());
402 }
403
404 #[test]
405 fn test_run_with_args() {
406 let result = run_with_args(|_, _, _| Ok(()));
407 assert!(result.is_ok());
408 }
409
410 #[test]
411 fn test_run_with_random_seed() {
412 let test_args = BaseArgs {
413 random_seed: 42,
414 ..Default::default()
415 };
416
417 let mut compare_ctx = Context::new();
419 compare_ctx.init_random(42);
420 define_rng!(TestRng);
421 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
422 assert_eq!(
423 ctx.sample_range(TestRng, 0..100),
424 compare_ctx.sample_range(TestRng, 0..100)
425 );
426 Ok(())
427 });
428 assert!(result.is_ok());
429 }
430
431 #[derive(Serialize, Deserialize)]
432 pub struct RunnerPropertyType {
433 field_int: u32,
434 }
435 define_global_property!(RunnerProperty, RunnerPropertyType);
436
437 #[test]
438 fn test_run_with_config_path() {
439 let test_args = BaseArgs {
440 config: Some(PathBuf::from("tests/data/global_properties_runner.json")),
441 ..Default::default()
442 };
443 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
444 let p3 = ctx.get_global_property_value(RunnerProperty).unwrap();
445 assert_eq!(p3.field_int, 0);
446 Ok(())
447 });
448 assert!(result.is_ok());
449 }
450
451 #[test]
452 fn test_run_with_report_options() {
453 let test_args = BaseArgs {
454 output_dir: Some(PathBuf::from("data")),
455 file_prefix: Some("test".to_string()),
456 force_overwrite: true,
457 ..Default::default()
458 };
459 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
460 let opts = &ctx.report_options();
461 assert_eq!(opts.output_dir, PathBuf::from("data"));
462 assert_eq!(opts.file_prefix, "test".to_string());
463 assert!(opts.overwrite);
464 Ok(())
465 });
466 assert!(result.is_ok());
467 }
468
469 #[test]
470 fn test_run_with_custom() {
471 let test_args = BaseArgs::new();
472 let custom = CustomArgs { a: 42 };
473 let result = run_with_args_internal(test_args, Some(custom), |_, _, c| {
474 assert_eq!(c.unwrap().a, 42);
475 Ok(())
476 });
477 assert!(result.is_ok());
478 }
479
480 #[test]
481 fn test_run_with_logging_enabled() {
482 let mut test_args = BaseArgs::new();
483 test_args.log_level = Some(LevelFilter::Info.to_string());
484 let result = run_with_args_internal(test_args, None, |_, _, _: Option<()>| Ok(()));
485 assert!(result.is_ok());
486 }
487}