1use std::path::PathBuf;
2use std::str::FromStr;
3
4use crate::context::Context;
5#[cfg(feature = "debugger")]
6use crate::debugger::enter_debugger;
7use crate::error::IxaError;
8use crate::global_properties::ContextGlobalPropertiesExt;
9use crate::random::ContextRandomExt;
10use crate::report::ContextReportExt;
11#[cfg(feature = "web_api")]
12use crate::web_api::ContextWebApiExt;
13use crate::{info, set_log_level, set_module_filters, LevelFilter};
14use clap::{Args, Command, FromArgMatches as _};
15#[cfg(not(feature = "web_api"))]
16use log::warn;
17
18fn parse_log_levels(s: &str) -> Result<Vec<(String, LevelFilter)>, String> {
20 s.split(',')
21 .map(|pair| {
22 let mut iter = pair.split('=');
23 let key = iter
24 .next()
25 .ok_or_else(|| format!("Invalid key in pair: {pair}"))?;
26 let value = iter
27 .next()
28 .ok_or_else(|| format!("Invalid value in pair: {pair}"))?;
29 let level =
30 LevelFilter::from_str(value).map_err(|_| format!("Invalid log level: {value}"))?;
31 Ok((key.to_string(), level))
32 })
33 .collect()
34}
35
36#[derive(Args, Debug)]
38pub struct BaseArgs {
39 #[arg(short, long, default_value = "0")]
41 pub random_seed: u64,
42
43 #[arg(short, long)]
45 pub config: Option<PathBuf>,
46
47 #[arg(short, long = "output")]
49 pub output_dir: Option<PathBuf>,
50
51 #[arg(long = "prefix")]
53 pub file_prefix: Option<String>,
54
55 #[arg(short, long)]
57 pub force_overwrite: bool,
58
59 #[arg(short, long)]
61 pub log_level: Option<String>,
62
63 #[arg(short, long)]
65 pub debugger: Option<Option<f64>>,
66
67 #[arg(short, long)]
69 pub web: Option<Option<u16>>,
70}
71
72impl BaseArgs {
73 fn new() -> Self {
74 BaseArgs {
75 random_seed: 0,
76 config: None,
77 output_dir: None,
78 file_prefix: None,
79 force_overwrite: false,
80 log_level: None,
81 debugger: None,
82 web: None,
83 }
84 }
85}
86
87impl Default for BaseArgs {
88 fn default() -> Self {
89 BaseArgs::new()
90 }
91}
92
93#[derive(Args)]
94pub struct PlaceholderCustom {}
95
96fn create_ixa_cli() -> Command {
97 let cli = Command::new("ixa");
98 BaseArgs::augment_args(cli)
99}
100
101#[allow(clippy::missing_errors_doc)]
112pub fn run_with_custom_args<A, F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
113where
114 A: Args,
115 F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
116{
117 let mut cli = create_ixa_cli();
118 cli = A::augment_args(cli);
119 let matches = cli.get_matches();
120
121 let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
122 let custom_matches = A::from_arg_matches(&matches)?;
123 run_with_args_internal(base_args_matches, Some(custom_matches), setup_fn)
124}
125
126#[allow(clippy::missing_errors_doc)]
136pub fn run_with_args<F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
137where
138 F: Fn(&mut Context, BaseArgs, Option<PlaceholderCustom>) -> Result<(), IxaError>,
139{
140 let cli = create_ixa_cli();
141 let matches = cli.get_matches();
142
143 let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
144 run_with_args_internal(base_args_matches, None, setup_fn)
145}
146
147fn run_with_args_internal<A, F>(
148 args: BaseArgs,
149 custom_args: Option<A>,
150 setup_fn: F,
151) -> Result<Context, Box<dyn std::error::Error>>
152where
153 F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
154{
155 let mut context = Context::new();
157
158 if args.config.is_some() {
160 let config_path = args.config.clone().unwrap();
161 println!("Loading global properties from: {config_path:?}");
162 context.load_global_properties(&config_path)?;
163 }
164
165 let report_config = context.report_options();
167 if args.output_dir.is_some() {
168 report_config.directory(args.output_dir.clone().unwrap());
169 }
170 if args.file_prefix.is_some() {
171 report_config.file_prefix(args.file_prefix.clone().unwrap());
172 }
173 if args.force_overwrite {
174 report_config.overwrite(true);
175 }
176 if let Some(log_level) = args.log_level.as_ref() {
177 if let Ok(level) = LevelFilter::from_str(log_level) {
178 set_log_level(level);
179 info!("Logging enabled at level {level}");
180 } else if let Ok(log_levels) = parse_log_levels(log_level) {
181 let log_levels_slice: Vec<(&String, LevelFilter)> =
182 log_levels.iter().map(|(k, v)| (k, *v)).collect();
183 set_module_filters(log_levels_slice.as_slice());
184 for (key, value) in log_levels {
185 println!("Logging enabled for {key} at level {value}");
186 }
188 } else {
189 return Err(format!("Invalid log level format: {log_level}").into());
190 }
191 } else {
192 info!("Logging disabled.");
193 }
194
195 context.init_random(args.random_seed);
196
197 #[cfg(feature = "debugger")]
199 if let Some(t) = args.debugger {
200 assert!(
201 args.web.is_none(),
202 "Cannot run with both the debugger and the Web API"
203 );
204 match t {
205 None => {
206 context.request_debugger();
207 }
208 Some(time) => {
209 context.schedule_debugger(time, None, Box::new(enter_debugger));
210 }
211 }
212 }
213 #[cfg(not(feature = "debugger"))]
214 if args.debugger.is_some() {
215 warn!("Ixa was not compiled with the debugger feature, but a debugger option was provided");
216 }
217
218 #[cfg(feature = "web_api")]
220 if let Some(t) = args.web {
221 let port = t.unwrap_or(33334);
222 let url = context.setup_web_api(port).unwrap();
223 println!("Web API active on {url}");
224 context.schedule_web_api(0.0);
225 }
226 #[cfg(not(feature = "web_api"))]
227 if args.web.is_some() {
228 warn!("Ixa was not compiled with the web_api feature, but a web_api option was provided");
229 }
230
231 setup_fn(&mut context, args, custom_args)?;
233
234 context.execute();
236 Ok(context)
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use crate::{define_global_property, define_rng};
243 use serde::{Deserialize, Serialize};
244
245 #[derive(Args, Debug)]
246 struct CustomArgs {
247 #[arg(short, long, default_value = "0")]
248 a: u32,
249 }
250
251 #[test]
252 fn test_run_with_custom_args() {
253 let result = run_with_custom_args(|_, _, _: Option<CustomArgs>| Ok(()));
254 assert!(result.is_ok());
255 }
256
257 #[test]
258 fn test_run_with_args() {
259 let result = run_with_args(|_, _, _| Ok(()));
260 assert!(result.is_ok());
261 }
262
263 #[test]
264 fn test_run_with_random_seed() {
265 let test_args = BaseArgs {
266 random_seed: 42,
267 ..Default::default()
268 };
269
270 let mut compare_ctx = Context::new();
272 compare_ctx.init_random(42);
273 define_rng!(TestRng);
274 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
275 assert_eq!(
276 ctx.sample_range(TestRng, 0..100),
277 compare_ctx.sample_range(TestRng, 0..100)
278 );
279 Ok(())
280 });
281 assert!(result.is_ok());
282 }
283
284 #[derive(Serialize, Deserialize)]
285 pub struct RunnerPropertyType {
286 field_int: u32,
287 }
288 define_global_property!(RunnerProperty, RunnerPropertyType);
289
290 #[test]
291 fn test_run_with_config_path() {
292 let test_args = BaseArgs {
293 config: Some(PathBuf::from("tests/data/global_properties_runner.json")),
294 ..Default::default()
295 };
296 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
297 let p3 = ctx.get_global_property_value(RunnerProperty).unwrap();
298 assert_eq!(p3.field_int, 0);
299 Ok(())
300 });
301 assert!(result.is_ok());
302 }
303
304 #[test]
305 fn test_run_with_report_options() {
306 let test_args = BaseArgs {
307 output_dir: Some(PathBuf::from("data")),
308 file_prefix: Some("test".to_string()),
309 force_overwrite: true,
310 ..Default::default()
311 };
312 let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
313 let opts = &ctx.report_options();
314 assert_eq!(opts.output_dir, PathBuf::from("data"));
315 assert_eq!(opts.file_prefix, "test".to_string());
316 assert!(opts.overwrite);
317 Ok(())
318 });
319 assert!(result.is_ok());
320 }
321
322 #[test]
323 fn test_run_with_custom() {
324 let test_args = BaseArgs::new();
325 let custom = CustomArgs { a: 42 };
326 let result = run_with_args_internal(test_args, Some(custom), |_, _, c| {
327 assert_eq!(c.unwrap().a, 42);
328 Ok(())
329 });
330 assert!(result.is_ok());
331 }
332
333 #[test]
334 fn test_run_with_logging_enabled() {
335 let mut test_args = BaseArgs::new();
336 test_args.log_level = Some(LevelFilter::Info.to_string());
337 let result = run_with_args_internal(test_args, None, |_, _, _: Option<()>| Ok(()));
338 assert!(result.is_ok());
339 }
340}