1use std::path::{Path, PathBuf};
2use std::str::FromStr;
3
4use clap::{ArgAction, Args, Command, FromArgMatches as _};
5#[cfg(feature = "write_cli_usage")]
6use clap_markdown::{help_markdown_command_custom, MarkdownOptions};
7
8use crate::context::Context;
9use crate::error::IxaError;
10use crate::global_properties::ContextGlobalPropertiesExt;
11use crate::log::level_to_string_list;
12use crate::random::ContextRandomExt;
13use crate::report::ContextReportExt;
14use crate::{info, set_log_level, set_module_filters, LevelFilter};
15
16fn parse_log_levels(s: &str) -> Result<Vec<(String, LevelFilter)>, IxaError> {
18 s.split(',')
19 .map(|pair| {
20 let mut iter = pair.split('=');
21 let key = iter.next().ok_or_else(|| IxaError::InvalidLogLevelKey {
22 pair: pair.to_string(),
23 })?;
24 let value = iter.next().ok_or_else(|| IxaError::InvalidLogLevelValue {
25 pair: pair.to_string(),
26 })?;
27 let level = LevelFilter::from_str(value).map_err(|_| IxaError::InvalidLogLevel {
28 level: value.to_string(),
29 })?;
30 Ok((key.to_string(), level))
31 })
32 .collect()
33}
34
35#[derive(Args, Debug)]
37pub struct BaseArgs {
38 #[cfg(feature = "write_cli_usage")]
39 #[arg(long, hide = true)]
43 markdown_help: bool,
44
45 #[arg(short, long, default_value = "0")]
47 pub random_seed: u64,
48
49 #[arg(short, long)]
51 pub config: Option<PathBuf>,
52
53 #[arg(short, long = "output")]
55 pub output_dir: Option<PathBuf>,
56
57 #[arg(long = "prefix")]
59 pub file_prefix: Option<String>,
60
61 #[arg(short, long)]
63 pub force_overwrite: bool,
64
65 #[arg(short, long)]
67 pub log_level: Option<String>,
68
69 #[arg(
70 short,
71 long,
72 action = ArgAction::Count,
73 long_help = r#"Increase logging verbosity (-v, -vv, -vvv, etc.)
74
75| Level | ERROR | WARN | INFO | DEBUG | TRACE |
76|---------|-------|------|------|-------|-------|
77| Default | ✓ | | | | |
78| -v | ✓ | ✓ | ✓ | | |
79| -vv | ✓ | ✓ | ✓ | ✓ | |
80| -vvv | ✓ | ✓ | ✓ | ✓ | ✓ |
81"#)]
82 pub verbose: u8,
83
84 #[arg(long)]
86 pub warn: bool,
87
88 #[arg(long)]
90 pub debug: bool,
91
92 #[arg(long)]
94 pub trace: bool,
95
96 #[arg(long)]
98 pub no_stats: bool,
99}
100
101impl BaseArgs {
102 fn new() -> Self {
103 BaseArgs {
104 #[cfg(feature = "write_cli_usage")]
105 markdown_help: false,
106 random_seed: 0,
107 config: None,
108 output_dir: None,
109 file_prefix: None,
110 force_overwrite: false,
111 log_level: None,
112 verbose: 0,
113 warn: false,
114 debug: false,
115 trace: false,
116 no_stats: false,
117 }
118 }
119}
120
121impl Default for BaseArgs {
122 fn default() -> Self {
123 BaseArgs::new()
124 }
125}
126
127#[derive(Args)]
128pub struct PlaceholderCustom {}
129
130fn create_ixa_cli() -> Command {
131 let cli = Command::new("ixa");
132 BaseArgs::augment_args(cli)
133}
134
135#[allow(clippy::missing_errors_doc)]
146pub fn run_with_custom_args<A, F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
147where
148 A: Args,
149 F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
150{
151 let mut cli = create_ixa_cli();
152 cli = A::augment_args(cli);
153 let matches = cli.get_matches();
154
155 let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
156 let custom_matches = A::from_arg_matches(&matches)?;
157 run_with_args_internal(base_args_matches, Some(custom_matches), setup_fn)
158}
159
160#[allow(clippy::missing_errors_doc)]
170pub fn run_with_args<F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
171where
172 F: Fn(&mut Context, BaseArgs, Option<PlaceholderCustom>) -> Result<(), IxaError>,
173{
174 let cli = create_ixa_cli();
175 let matches = cli.get_matches();
176
177 let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
178 run_with_args_internal(base_args_matches, None, setup_fn)
179}
180
181fn run_with_args_internal<A, F>(
182 args: BaseArgs,
183 custom_args: Option<A>,
184 setup_fn: F,
185) -> Result<Context, Box<dyn std::error::Error>>
186where
187 F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
188{
189 #[cfg(feature = "write_cli_usage")]
190 if args.markdown_help {
192 let cli = create_ixa_cli();
193 let md_options = MarkdownOptions::new()
194 .show_footer(false)
195 .show_aliases(true)
196 .show_table_of_contents(false)
197 .title("Command Line Usage".to_string());
198 let markdown = help_markdown_command_custom(&cli, &md_options);
199 let path =
200 PathBuf::from(option_env!("CARGO_WORKSPACE_DIR").unwrap_or(env!("CARGO_MANIFEST_DIR")))
201 .join("docs")
202 .join("book")
203 .join("src")
204 .join("cli-usage.md");
205 std::fs::write(&path, markdown).unwrap_or_else(|e| {
206 panic!(
207 "Failed to write CLI help Markdown to file {}: {}",
208 path.display(),
209 e
210 );
211 });
212 }
213
214 let mut context = Context::new();
216
217 if args.config.is_some() {
219 let config_path = args.config.clone().unwrap();
220 println!("Loading global properties from: {config_path:?}");
221 context.load_global_properties(&config_path)?;
222 }
223
224 let report_config = context.report_options();
226 if args.output_dir.is_some() {
227 report_config.directory(args.output_dir.clone().unwrap());
228 }
229 if args.file_prefix.is_some() {
230 report_config.file_prefix(args.file_prefix.clone().unwrap());
231 }
232 if args.force_overwrite {
233 report_config.overwrite(true);
234 }
235
236 let mut current_log_level = crate::log::DEFAULT_LOG_LEVEL;
240
241 if let Some(log_level) = args.log_level.as_ref() {
243 if let Ok(level) = LevelFilter::from_str(log_level) {
244 current_log_level = level;
245 } else {
246 match parse_log_levels(log_level) {
247 Ok(log_levels) => {
248 let log_levels_slice: Vec<(&String, LevelFilter)> =
249 log_levels.iter().map(|(k, v)| (k, *v)).collect();
250 set_module_filters(log_levels_slice.as_slice());
251 for (key, value) in log_levels {
252 println!("Logging enabled for {key} at level {value}");
253 }
255 }
256 Err(e) => return Err(Box::new(e)),
257 }
258 }
259 }
260
261 if args.verbose > 0 {
263 let new_level = match args.verbose {
264 1 => LevelFilter::Info,
265 2 => LevelFilter::Debug,
266 _ => LevelFilter::Trace,
267 };
268 current_log_level = current_log_level.max(new_level);
269 }
270
271 if args.warn {
273 current_log_level = current_log_level.max(LevelFilter::Warn);
274 }
275 if args.debug {
276 current_log_level = current_log_level.max(LevelFilter::Debug);
277 }
278 if args.trace {
279 current_log_level = LevelFilter::Trace;
280 }
281
282 let binary_name = std::env::args().next();
284 let binary_name = binary_name
285 .as_deref()
286 .map(Path::new)
287 .and_then(Path::file_name)
288 .and_then(|s| s.to_str())
289 .unwrap_or("[model]");
290 println!(
291 "Current log levels enabled: {}",
292 level_to_string_list(current_log_level)
293 );
294 println!("Run {binary_name} --help -v to see more options");
295
296 if current_log_level != crate::log::DEFAULT_LOG_LEVEL {
298 set_log_level(current_log_level);
299 }
300
301 context.init_random(args.random_seed);
302
303 if args.no_stats {
304 context.print_execution_statistics = false;
305 } else {
306 if cfg!(target_family = "wasm") {
307 info!("the print-stats option is enabled; some statistics are not supported for the wasm target family");
308 }
309 context.print_execution_statistics = true;
310 }
311
312 setup_fn(&mut context, args, custom_args)?;
314
315 context.execute();
317 Ok(context)
318}
319
320#[cfg(test)]
321mod tests {
322 use serde::{Deserialize, Serialize};
323
324 use super::*;
325 use crate::{define_global_property, define_rng};
326
327 #[derive(Args, Debug)]
328 struct CustomArgs {
329 #[arg(short, long, default_value = "0")]
330 a: u32,
331 }
332
333 #[test]
334 fn test_run_with_custom_args() {
335 let result = run_with_custom_args(|_, _, _: Option<CustomArgs>| Ok(()));
336 assert!(result.is_ok());
337 }
338
339 #[test]
340 fn test_run_with_args() {
341 let result = run_with_args(|_, _, _| Ok(()));
342 assert!(result.is_ok());
343 }
344
345 #[test]
346 fn test_run_with_random_seed() {
347 let test_args = BaseArgs {
348 random_seed: 42,
349 ..Default::default()
350 };
351
352 let mut compare_ctx = Context::new();
354 compare_ctx.init_random(42);
355 define_rng!(TestRng);
356 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
357 assert_eq!(
358 ctx.sample_range(TestRng, 0..100),
359 compare_ctx.sample_range(TestRng, 0..100)
360 );
361 Ok(())
362 });
363 assert!(result.is_ok());
364 }
365
366 #[derive(Serialize, Deserialize)]
367 pub struct RunnerPropertyType {
368 field_int: u32,
369 }
370 define_global_property!(RunnerProperty, RunnerPropertyType);
371
372 #[test]
373 fn test_run_with_config_path() {
374 let test_args = BaseArgs {
375 config: Some(PathBuf::from("tests/data/global_properties_runner.json")),
376 ..Default::default()
377 };
378 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
379 let p3 = ctx.get_global_property_value(RunnerProperty).unwrap();
380 assert_eq!(p3.field_int, 0);
381 Ok(())
382 });
383 assert!(result.is_ok());
384 }
385
386 #[test]
387 fn test_run_with_report_options() {
388 let test_args = BaseArgs {
389 output_dir: Some(PathBuf::from("data")),
390 file_prefix: Some("test".to_string()),
391 force_overwrite: true,
392 ..Default::default()
393 };
394 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
395 let opts = &ctx.report_options();
396 assert_eq!(opts.output_dir, PathBuf::from("data"));
397 assert_eq!(opts.file_prefix, "test".to_string());
398 assert!(opts.overwrite);
399 Ok(())
400 });
401 assert!(result.is_ok());
402 }
403
404 #[test]
405 fn test_run_with_custom() {
406 let test_args = BaseArgs::new();
407 let custom = CustomArgs { a: 42 };
408 let result = run_with_args_internal(test_args, Some(custom), |_, _, c| {
409 assert_eq!(c.unwrap().a, 42);
410 Ok(())
411 });
412 assert!(result.is_ok());
413 }
414
415 #[test]
416 fn test_run_with_logging_enabled() {
417 let mut test_args = BaseArgs::new();
418 test_args.log_level = Some(LevelFilter::Info.to_string());
419 let result = run_with_args_internal(test_args, None, |_, _, _: Option<()>| Ok(()));
420 assert!(result.is_ok());
421 }
422}