1use std::path::{Path, PathBuf};
2use std::str::FromStr;
3
4use clap::{ArgAction, Args, Command, FromArgMatches as _};
5#[cfg(feature = "write_cli_usage")]
6use clap_markdown::{help_markdown_command_custom, MarkdownOptions};
7
8use crate::context::Context;
9use crate::error::IxaError;
10use crate::global_properties::ContextGlobalPropertiesExt;
11use crate::log::level_to_string_list;
12#[cfg(feature = "progress_bar")]
13use crate::progress::init_timeline_progress_bar;
14use crate::random::ContextRandomExt;
15use crate::report::ContextReportExt;
16use crate::{set_log_level, set_module_filters, warn, LevelFilter};
17
18fn parse_log_levels(s: &str) -> Result<Vec<(String, LevelFilter)>, IxaError> {
20 s.split(',')
21 .map(|pair| {
22 let mut iter = pair.split('=');
23 let key = iter.next().ok_or_else(|| IxaError::InvalidLogLevelKey {
24 pair: pair.to_string(),
25 })?;
26 let value = iter.next().ok_or_else(|| IxaError::InvalidLogLevelValue {
27 pair: pair.to_string(),
28 })?;
29 let level = LevelFilter::from_str(value).map_err(|_| IxaError::InvalidLogLevel {
30 level: value.to_string(),
31 })?;
32 Ok((key.to_string(), level))
33 })
34 .collect()
35}
36
37#[derive(Args, Debug)]
39pub struct BaseArgs {
40 #[cfg(feature = "write_cli_usage")]
41 #[arg(long, hide = true)]
45 markdown_help: bool,
46
47 #[arg(short, long, default_value = "0")]
49 pub random_seed: u64,
50
51 #[arg(short, long)]
53 pub config: Option<PathBuf>,
54
55 #[arg(short, long = "output")]
57 pub output_dir: Option<PathBuf>,
58
59 #[arg(long = "prefix")]
61 pub file_prefix: Option<String>,
62
63 #[arg(short, long)]
65 pub force_overwrite: bool,
66
67 #[arg(short, long)]
69 pub log_level: Option<String>,
70
71 #[arg(
72 short,
73 long,
74 action = ArgAction::Count,
75 long_help = r#"Increase logging verbosity (-v, -vv, -vvv, etc.)
76
77| Level | ERROR | WARN | INFO | DEBUG | TRACE |
78|---------|-------|------|------|-------|-------|
79| Default | ✓ | | | | |
80| -v | ✓ | ✓ | ✓ | | |
81| -vv | ✓ | ✓ | ✓ | ✓ | |
82| -vvv | ✓ | ✓ | ✓ | ✓ | ✓ |
83"#)]
84 pub verbose: u8,
85
86 #[arg(long)]
88 pub warn: bool,
89
90 #[arg(long)]
92 pub debug: bool,
93
94 #[arg(long)]
96 pub trace: bool,
97
98 #[arg(short, long)]
100 pub timeline_progress_max: Option<f64>,
101
102 #[arg(long)]
104 pub no_stats: bool,
105}
106
107impl BaseArgs {
108 fn new() -> Self {
109 BaseArgs {
110 #[cfg(feature = "write_cli_usage")]
111 markdown_help: false,
112 random_seed: 0,
113 config: None,
114 output_dir: None,
115 file_prefix: None,
116 force_overwrite: false,
117 log_level: None,
118 verbose: 0,
119 warn: false,
120 debug: false,
121 trace: false,
122 timeline_progress_max: None,
123 no_stats: false,
124 }
125 }
126}
127
128impl Default for BaseArgs {
129 fn default() -> Self {
130 BaseArgs::new()
131 }
132}
133
134#[derive(Args)]
135pub struct PlaceholderCustom {}
136
137fn create_ixa_cli() -> Command {
138 let cli = Command::new("ixa");
139 BaseArgs::augment_args(cli)
140}
141
142#[allow(clippy::missing_errors_doc)]
153pub fn run_with_custom_args<A, F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
154where
155 A: Args,
156 F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
157{
158 let mut cli = create_ixa_cli();
159 cli = A::augment_args(cli);
160 let matches = cli.get_matches();
161
162 let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
163 let custom_matches = A::from_arg_matches(&matches)?;
164 run_with_args_internal(base_args_matches, Some(custom_matches), setup_fn)
165}
166
167#[allow(clippy::missing_errors_doc)]
177pub fn run_with_args<F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
178where
179 F: Fn(&mut Context, BaseArgs, Option<PlaceholderCustom>) -> Result<(), IxaError>,
180{
181 let cli = create_ixa_cli();
182 let matches = cli.get_matches();
183
184 let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
185 run_with_args_internal(base_args_matches, None, setup_fn)
186}
187
188fn run_with_args_internal<A, F>(
189 args: BaseArgs,
190 custom_args: Option<A>,
191 setup_fn: F,
192) -> Result<Context, Box<dyn std::error::Error>>
193where
194 F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
195{
196 #[cfg(feature = "write_cli_usage")]
197 if args.markdown_help {
199 let cli = create_ixa_cli();
200 let md_options = MarkdownOptions::new()
201 .show_footer(false)
202 .show_aliases(true)
203 .show_table_of_contents(false)
204 .title("Command Line Usage".to_string());
205 let markdown = help_markdown_command_custom(&cli, &md_options);
206 let path =
207 PathBuf::from(option_env!("CARGO_WORKSPACE_DIR").unwrap_or(env!("CARGO_MANIFEST_DIR")))
208 .join("docs")
209 .join("book")
210 .join("src")
211 .join("cli-usage.md");
212 std::fs::write(&path, markdown).unwrap_or_else(|e| {
213 panic!(
214 "Failed to write CLI help Markdown to file {}: {}",
215 path.display(),
216 e
217 );
218 });
219 }
220
221 let mut context = Context::new();
223
224 if args.config.is_some() {
226 let config_path = args.config.clone().unwrap();
227 println!("Loading global properties from: {config_path:?}");
228 context.load_global_properties(&config_path)?;
229 }
230
231 let report_config = context.report_options();
233 if args.output_dir.is_some() {
234 report_config.directory(args.output_dir.clone().unwrap());
235 }
236 if args.file_prefix.is_some() {
237 report_config.file_prefix(args.file_prefix.clone().unwrap());
238 }
239 if args.force_overwrite {
240 report_config.overwrite(true);
241 }
242
243 let mut current_log_level = crate::log::DEFAULT_LOG_LEVEL;
247
248 if let Some(log_level) = args.log_level.as_ref() {
250 if let Ok(level) = LevelFilter::from_str(log_level) {
251 current_log_level = level;
252 } else {
253 match parse_log_levels(log_level) {
254 Ok(log_levels) => {
255 let log_levels_slice: Vec<(&String, LevelFilter)> =
256 log_levels.iter().map(|(k, v)| (k, *v)).collect();
257 set_module_filters(log_levels_slice.as_slice());
258 for (key, value) in log_levels {
259 println!("Logging enabled for {key} at level {value}");
260 }
262 }
263 Err(e) => return Err(Box::new(e)),
264 }
265 }
266 }
267
268 if args.verbose > 0 {
270 let new_level = match args.verbose {
271 1 => LevelFilter::Info,
272 2 => LevelFilter::Debug,
273 _ => LevelFilter::Trace,
274 };
275 current_log_level = current_log_level.max(new_level);
276 }
277
278 if args.warn {
280 current_log_level = current_log_level.max(LevelFilter::Warn);
281 }
282 if args.debug {
283 current_log_level = current_log_level.max(LevelFilter::Debug);
284 }
285 if args.trace {
286 current_log_level = LevelFilter::Trace;
287 }
288
289 let binary_name = std::env::args().next();
291 let binary_name = binary_name
292 .as_deref()
293 .map(Path::new)
294 .and_then(Path::file_name)
295 .and_then(|s| s.to_str())
296 .unwrap_or("[model]");
297 println!(
298 "Current log levels enabled: {}",
299 level_to_string_list(current_log_level)
300 );
301 println!("Run {binary_name} --help -v to see more options");
302
303 if current_log_level != crate::log::DEFAULT_LOG_LEVEL {
305 set_log_level(current_log_level);
306 }
307
308 context.init_random(args.random_seed);
309
310 if let Some(max_time) = args.timeline_progress_max {
311 if cfg!(not(feature = "progress_bar")) && max_time > 0.0 {
313 warn!("Ixa was not compiled with the progress_bar feature, but a progress_bar option was provided");
314 } else if max_time < 0.0 {
315 warn!("timeline progress maximum must be nonnegative");
316 }
317 #[cfg(feature = "progress_bar")]
318 if max_time > 0.0 {
319 println!("ProgressBar max set to {}", max_time);
320 init_timeline_progress_bar(max_time);
321 }
322 }
323
324 if args.no_stats {
325 context.print_execution_statistics = false;
326 } else {
327 if cfg!(target_family = "wasm") {
328 warn!("the print-stats option is enabled; some statistics are not supported for the wasm target family");
329 }
330 context.print_execution_statistics = true;
331 }
332
333 setup_fn(&mut context, args, custom_args)?;
335
336 context.execute();
338 Ok(context)
339}
340
341#[cfg(test)]
342mod tests {
343 use serde::{Deserialize, Serialize};
344
345 use super::*;
346 use crate::{define_global_property, define_rng};
347
348 #[derive(Args, Debug)]
349 struct CustomArgs {
350 #[arg(short, long, default_value = "0")]
351 a: u32,
352 }
353
354 #[test]
355 fn test_run_with_custom_args() {
356 let result = run_with_custom_args(|_, _, _: Option<CustomArgs>| Ok(()));
357 assert!(result.is_ok());
358 }
359
360 #[test]
361 fn test_run_with_args() {
362 let result = run_with_args(|_, _, _| Ok(()));
363 assert!(result.is_ok());
364 }
365
366 #[test]
367 fn test_run_with_random_seed() {
368 let test_args = BaseArgs {
369 random_seed: 42,
370 ..Default::default()
371 };
372
373 let mut compare_ctx = Context::new();
375 compare_ctx.init_random(42);
376 define_rng!(TestRng);
377 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
378 assert_eq!(
379 ctx.sample_range(TestRng, 0..100),
380 compare_ctx.sample_range(TestRng, 0..100)
381 );
382 Ok(())
383 });
384 assert!(result.is_ok());
385 }
386
387 #[derive(Serialize, Deserialize)]
388 pub struct RunnerPropertyType {
389 field_int: u32,
390 }
391 define_global_property!(RunnerProperty, RunnerPropertyType);
392
393 #[test]
394 fn test_run_with_config_path() {
395 let test_args = BaseArgs {
396 config: Some(PathBuf::from("tests/data/global_properties_runner.json")),
397 ..Default::default()
398 };
399 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
400 let p3 = ctx.get_global_property_value(RunnerProperty).unwrap();
401 assert_eq!(p3.field_int, 0);
402 Ok(())
403 });
404 assert!(result.is_ok());
405 }
406
407 #[test]
408 fn test_run_with_report_options() {
409 let test_args = BaseArgs {
410 output_dir: Some(PathBuf::from("data")),
411 file_prefix: Some("test".to_string()),
412 force_overwrite: true,
413 ..Default::default()
414 };
415 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
416 let opts = &ctx.report_options();
417 assert_eq!(opts.output_dir, PathBuf::from("data"));
418 assert_eq!(opts.file_prefix, "test".to_string());
419 assert!(opts.overwrite);
420 Ok(())
421 });
422 assert!(result.is_ok());
423 }
424
425 #[test]
426 fn test_run_with_custom() {
427 let test_args = BaseArgs::new();
428 let custom = CustomArgs { a: 42 };
429 let result = run_with_args_internal(test_args, Some(custom), |_, _, c| {
430 assert_eq!(c.unwrap().a, 42);
431 Ok(())
432 });
433 assert!(result.is_ok());
434 }
435
436 #[test]
437 fn test_run_with_logging_enabled() {
438 let mut test_args = BaseArgs::new();
439 test_args.log_level = Some(LevelFilter::Info.to_string());
440 let result = run_with_args_internal(test_args, None, |_, _, _: Option<()>| Ok(()));
441 assert!(result.is_ok());
442 }
443}