1use std::path::PathBuf;
2use std::str::FromStr;
3
4use crate::context::Context;
5#[cfg(feature = "debugger")]
6use crate::debugger::enter_debugger;
7use crate::error::IxaError;
8use crate::global_properties::ContextGlobalPropertiesExt;
9#[cfg(feature = "progress_bar")]
10use crate::progress::init_timeline_progress_bar;
11use crate::random::ContextRandomExt;
12use crate::report::ContextReportExt;
13#[cfg(feature = "web_api")]
14use crate::web_api::ContextWebApiExt;
15use crate::{info, set_log_level, set_module_filters, warn, LevelFilter};
16use clap::{Args, Command, FromArgMatches as _};
17
18fn parse_log_levels(s: &str) -> Result<Vec<(String, LevelFilter)>, String> {
20 s.split(',')
21 .map(|pair| {
22 let mut iter = pair.split('=');
23 let key = iter
24 .next()
25 .ok_or_else(|| format!("Invalid key in pair: {pair}"))?;
26 let value = iter
27 .next()
28 .ok_or_else(|| format!("Invalid value in pair: {pair}"))?;
29 let level =
30 LevelFilter::from_str(value).map_err(|_| format!("Invalid log level: {value}"))?;
31 Ok((key.to_string(), level))
32 })
33 .collect()
34}
35
36#[derive(Args, Debug)]
38pub struct BaseArgs {
39 #[arg(short, long, default_value = "0")]
41 pub random_seed: u64,
42
43 #[arg(short, long)]
45 pub config: Option<PathBuf>,
46
47 #[arg(short, long = "output")]
49 pub output_dir: Option<PathBuf>,
50
51 #[arg(long = "prefix")]
53 pub file_prefix: Option<String>,
54
55 #[arg(short, long)]
57 pub force_overwrite: bool,
58
59 #[arg(short, long)]
61 pub log_level: Option<String>,
62
63 #[arg(short, long)]
65 pub debugger: Option<Option<f64>>,
66
67 #[arg(short, long)]
69 pub web: Option<Option<u16>>,
70
71 #[arg(short, long)]
73 pub timeline_progress_max: Option<f64>,
74
75 #[arg(long)]
77 pub no_stats: bool,
78}
79
80impl BaseArgs {
81 fn new() -> Self {
82 BaseArgs {
83 random_seed: 0,
84 config: None,
85 output_dir: None,
86 file_prefix: None,
87 force_overwrite: false,
88 log_level: None,
89 debugger: None,
90 web: None,
91 timeline_progress_max: None,
92 no_stats: false,
93 }
94 }
95}
96
97impl Default for BaseArgs {
98 fn default() -> Self {
99 BaseArgs::new()
100 }
101}
102
103#[derive(Args)]
104pub struct PlaceholderCustom {}
105
106fn create_ixa_cli() -> Command {
107 let cli = Command::new("ixa");
108 BaseArgs::augment_args(cli)
109}
110
111#[allow(clippy::missing_errors_doc)]
122pub fn run_with_custom_args<A, F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
123where
124 A: Args,
125 F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
126{
127 let mut cli = create_ixa_cli();
128 cli = A::augment_args(cli);
129 let matches = cli.get_matches();
130
131 let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
132 let custom_matches = A::from_arg_matches(&matches)?;
133 run_with_args_internal(base_args_matches, Some(custom_matches), setup_fn)
134}
135
136#[allow(clippy::missing_errors_doc)]
146pub fn run_with_args<F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
147where
148 F: Fn(&mut Context, BaseArgs, Option<PlaceholderCustom>) -> Result<(), IxaError>,
149{
150 let cli = create_ixa_cli();
151 let matches = cli.get_matches();
152
153 let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
154 run_with_args_internal(base_args_matches, None, setup_fn)
155}
156
157fn run_with_args_internal<A, F>(
158 args: BaseArgs,
159 custom_args: Option<A>,
160 setup_fn: F,
161) -> Result<Context, Box<dyn std::error::Error>>
162where
163 F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
164{
165 let mut context = Context::new();
167
168 if args.config.is_some() {
170 let config_path = args.config.clone().unwrap();
171 println!("Loading global properties from: {config_path:?}");
172 context.load_global_properties(&config_path)?;
173 }
174
175 let report_config = context.report_options();
177 if args.output_dir.is_some() {
178 report_config.directory(args.output_dir.clone().unwrap());
179 }
180 if args.file_prefix.is_some() {
181 report_config.file_prefix(args.file_prefix.clone().unwrap());
182 }
183 if args.force_overwrite {
184 report_config.overwrite(true);
185 }
186 if let Some(log_level) = args.log_level.as_ref() {
187 if let Ok(level) = LevelFilter::from_str(log_level) {
188 set_log_level(level);
189 info!("Logging enabled at level {level}");
190 } else if let Ok(log_levels) = parse_log_levels(log_level) {
191 let log_levels_slice: Vec<(&String, LevelFilter)> =
192 log_levels.iter().map(|(k, v)| (k, *v)).collect();
193 set_module_filters(log_levels_slice.as_slice());
194 for (key, value) in log_levels {
195 println!("Logging enabled for {key} at level {value}");
196 }
198 } else {
199 return Err(format!("Invalid log level format: {log_level}").into());
200 }
201 } else {
202 info!("Logging disabled.");
203 }
204
205 context.init_random(args.random_seed);
206
207 #[cfg(feature = "debugger")]
209 if let Some(t) = args.debugger {
210 assert!(
211 args.web.is_none(),
212 "Cannot run with both the debugger and the Web API"
213 );
214 match t {
215 None => {
216 context.request_debugger();
217 }
218 Some(time) => {
219 context.schedule_debugger(time, None, Box::new(enter_debugger));
220 }
221 }
222 }
223 #[cfg(not(feature = "debugger"))]
224 if args.debugger.is_some() {
225 warn!("Ixa was not compiled with the debugger feature, but a debugger option was provided");
226 }
227
228 #[cfg(feature = "web_api")]
230 if let Some(t) = args.web {
231 let port = t.unwrap_or(33334);
232 let url = context.setup_web_api(port).unwrap();
233 println!("Web API active on {url}");
234 context.schedule_web_api(0.0);
235 }
236 #[cfg(not(feature = "web_api"))]
237 if args.web.is_some() {
238 warn!("Ixa was not compiled with the web_api feature, but a web_api option was provided");
239 }
240
241 if let Some(max_time) = args.timeline_progress_max {
242 if cfg!(not(feature = "progress_bar")) && max_time > 0.0 {
244 warn!("Ixa was not compiled with the progress_bar feature, but a progress_bar option was provided");
245 } else if max_time < 0.0 {
246 warn!("timeline progress maximum must be nonnegative");
247 }
248 #[cfg(feature = "progress_bar")]
249 if max_time > 0.0 {
250 println!("ProgressBar max set to {}", max_time);
251 init_timeline_progress_bar(max_time);
252 }
253 }
254
255 if args.no_stats {
256 context.print_execution_statistics = false;
257 } else {
258 if cfg!(target_family = "wasm") {
259 warn!("the print-stats option is enabled; some statistics are not supported for the wasm target family");
260 }
261 context.print_execution_statistics = true;
262 }
263
264 setup_fn(&mut context, args, custom_args)?;
266
267 context.execute();
269 Ok(context)
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use crate::{define_global_property, define_rng};
276 use serde::{Deserialize, Serialize};
277
278 #[derive(Args, Debug)]
279 struct CustomArgs {
280 #[arg(short, long, default_value = "0")]
281 a: u32,
282 }
283
284 #[test]
285 fn test_run_with_custom_args() {
286 let result = run_with_custom_args(|_, _, _: Option<CustomArgs>| Ok(()));
287 assert!(result.is_ok());
288 }
289
290 #[test]
291 fn test_run_with_args() {
292 let result = run_with_args(|_, _, _| Ok(()));
293 assert!(result.is_ok());
294 }
295
296 #[test]
297 fn test_run_with_random_seed() {
298 let test_args = BaseArgs {
299 random_seed: 42,
300 ..Default::default()
301 };
302
303 let mut compare_ctx = Context::new();
305 compare_ctx.init_random(42);
306 define_rng!(TestRng);
307 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
308 assert_eq!(
309 ctx.sample_range(TestRng, 0..100),
310 compare_ctx.sample_range(TestRng, 0..100)
311 );
312 Ok(())
313 });
314 assert!(result.is_ok());
315 }
316
317 #[derive(Serialize, Deserialize)]
318 pub struct RunnerPropertyType {
319 field_int: u32,
320 }
321 define_global_property!(RunnerProperty, RunnerPropertyType);
322
323 #[test]
324 fn test_run_with_config_path() {
325 let test_args = BaseArgs {
326 config: Some(PathBuf::from("tests/data/global_properties_runner.json")),
327 ..Default::default()
328 };
329 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
330 let p3 = ctx.get_global_property_value(RunnerProperty).unwrap();
331 assert_eq!(p3.field_int, 0);
332 Ok(())
333 });
334 assert!(result.is_ok());
335 }
336
337 #[test]
338 fn test_run_with_report_options() {
339 let test_args = BaseArgs {
340 output_dir: Some(PathBuf::from("data")),
341 file_prefix: Some("test".to_string()),
342 force_overwrite: true,
343 ..Default::default()
344 };
345 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
346 let opts = &ctx.report_options();
347 assert_eq!(opts.output_dir, PathBuf::from("data"));
348 assert_eq!(opts.file_prefix, "test".to_string());
349 assert!(opts.overwrite);
350 Ok(())
351 });
352 assert!(result.is_ok());
353 }
354
355 #[test]
356 fn test_run_with_custom() {
357 let test_args = BaseArgs::new();
358 let custom = CustomArgs { a: 42 };
359 let result = run_with_args_internal(test_args, Some(custom), |_, _, c| {
360 assert_eq!(c.unwrap().a, 42);
361 Ok(())
362 });
363 assert!(result.is_ok());
364 }
365
366 #[test]
367 fn test_run_with_logging_enabled() {
368 let mut test_args = BaseArgs::new();
369 test_args.log_level = Some(LevelFilter::Info.to_string());
370 let result = run_with_args_internal(test_args, None, |_, _, _: Option<()>| Ok(()));
371 assert!(result.is_ok());
372 }
373}