use std::path::PathBuf;
use std::str::FromStr;
use crate::error::IxaError;
use crate::global_properties::ContextGlobalPropertiesExt;
use crate::random::ContextRandomExt;
use crate::report::ContextReportExt;
use crate::{context::Context, web_api::ContextWebApiExt};
use crate::{info, set_log_level, set_module_filters, LevelFilter};
use crate::debugger::enter_debugger;
use clap::{Args, Command, FromArgMatches as _};
fn parse_log_levels(s: &str) -> Result<Vec<(String, LevelFilter)>, String> {
s.split(',')
.map(|pair| {
let mut iter = pair.split('=');
let key = iter
.next()
.ok_or_else(|| format!("Invalid key in pair: {pair}"))?;
let value = iter
.next()
.ok_or_else(|| format!("Invalid value in pair: {pair}"))?;
let level =
LevelFilter::from_str(value).map_err(|_| format!("Invalid log level: {value}"))?;
Ok((key.to_string(), level))
})
.collect()
}
#[derive(Args, Debug)]
pub struct BaseArgs {
#[arg(short, long, default_value = "0")]
pub random_seed: u64,
#[arg(short, long)]
pub config: Option<PathBuf>,
#[arg(short, long = "output")]
pub output_dir: Option<PathBuf>,
#[arg(long = "prefix")]
pub file_prefix: Option<String>,
#[arg(short, long)]
pub force_overwrite: bool,
#[arg(short, long)]
pub log_level: Option<String>,
#[arg(short, long)]
pub debugger: Option<Option<f64>>,
#[arg(short, long)]
pub web: Option<Option<u16>>,
}
impl BaseArgs {
fn new() -> Self {
BaseArgs {
random_seed: 0,
config: None,
output_dir: None,
file_prefix: None,
force_overwrite: false,
log_level: None,
debugger: None,
web: None,
}
}
}
impl Default for BaseArgs {
fn default() -> Self {
BaseArgs::new()
}
}
#[derive(Args)]
pub struct PlaceholderCustom {}
fn create_ixa_cli() -> Command {
let cli = Command::new("ixa");
BaseArgs::augment_args(cli)
}
#[allow(clippy::missing_errors_doc)]
pub fn run_with_custom_args<A, F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
where
A: Args,
F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
{
let mut cli = create_ixa_cli();
cli = A::augment_args(cli);
let matches = cli.get_matches();
let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
let custom_matches = A::from_arg_matches(&matches)?;
run_with_args_internal(base_args_matches, Some(custom_matches), setup_fn)
}
#[allow(clippy::missing_errors_doc)]
pub fn run_with_args<F>(setup_fn: F) -> Result<Context, Box<dyn std::error::Error>>
where
F: Fn(&mut Context, BaseArgs, Option<PlaceholderCustom>) -> Result<(), IxaError>,
{
let cli = create_ixa_cli();
let matches = cli.get_matches();
let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
run_with_args_internal(base_args_matches, None, setup_fn)
}
fn run_with_args_internal<A, F>(
args: BaseArgs,
custom_args: Option<A>,
setup_fn: F,
) -> Result<Context, Box<dyn std::error::Error>>
where
F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
{
let mut context = Context::new();
if args.config.is_some() {
let config_path = args.config.clone().unwrap();
println!("Loading global properties from: {config_path:?}");
context.load_global_properties(&config_path)?;
}
let report_config = context.report_options();
if args.output_dir.is_some() {
report_config.directory(args.output_dir.clone().unwrap());
}
if args.file_prefix.is_some() {
report_config.file_prefix(args.file_prefix.clone().unwrap());
}
if args.force_overwrite {
report_config.overwrite(true);
}
if let Some(log_level) = args.log_level.as_ref() {
if let Ok(level) = LevelFilter::from_str(log_level) {
set_log_level(level);
info!("Logging enabled at level {level}");
} else if let Ok(log_levels) = parse_log_levels(log_level) {
let log_levels_slice: Vec<(&String, LevelFilter)> =
log_levels.iter().map(|(k, v)| (k, *v)).collect();
set_module_filters(log_levels_slice.as_slice());
for (key, value) in log_levels {
println!("Logging enabled for {key} at level {value}");
}
} else {
return Err(format!("Invalid log level format: {log_level}").into());
}
} else {
info!("Logging disabled.");
}
context.init_random(args.random_seed);
if let Some(t) = args.debugger {
assert!(
args.web.is_none(),
"Cannot run with both the debugger and the Web API"
);
match t {
None => {
context.request_debugger();
}
Some(time) => {
context.schedule_debugger(time, None, Box::new(enter_debugger));
}
}
}
if let Some(t) = args.web {
let port = t.unwrap_or(33334);
let url = context.setup_web_api(port).unwrap();
println!("Web API active on {url}");
context.schedule_web_api(0.0);
}
setup_fn(&mut context, args, custom_args)?;
context.execute();
Ok(context)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::run_external_runner;
use crate::{define_global_property, define_rng};
use serde::{Deserialize, Serialize};
#[derive(Args, Debug)]
struct CustomArgs {
#[arg(short, long, default_value = "0")]
a: u32,
}
#[test]
fn test_run_with_custom_args() {
let result = run_with_custom_args(|_, _, _: Option<CustomArgs>| Ok(()));
assert!(result.is_ok());
}
#[test]
fn test_cli_invocation_with_custom_args() {
run_external_runner("runner_test_custom_args")
.unwrap()
.args(["-a", "42"])
.assert()
.success()
.stdout("42\n");
}
#[test]
fn test_run_with_args() {
let result = run_with_args(|_, _, _| Ok(()));
assert!(result.is_ok());
}
#[test]
fn test_run_with_random_seed() {
let test_args = BaseArgs {
random_seed: 42,
..Default::default()
};
let mut compare_ctx = Context::new();
compare_ctx.init_random(42);
define_rng!(TestRng);
let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
assert_eq!(
ctx.sample_range(TestRng, 0..100),
compare_ctx.sample_range(TestRng, 0..100)
);
Ok(())
});
assert!(result.is_ok());
}
#[derive(Serialize, Deserialize)]
pub struct RunnerPropertyType {
field_int: u32,
}
define_global_property!(RunnerProperty, RunnerPropertyType);
#[test]
fn test_run_with_config_path() {
let test_args = BaseArgs {
config: Some(PathBuf::from("tests/data/global_properties_runner.json")),
..Default::default()
};
let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
let p3 = ctx.get_global_property_value(RunnerProperty).unwrap();
assert_eq!(p3.field_int, 0);
Ok(())
});
assert!(result.is_ok());
}
#[test]
fn test_run_with_report_options() {
let test_args = BaseArgs {
output_dir: Some(PathBuf::from("data")),
file_prefix: Some("test".to_string()),
force_overwrite: true,
..Default::default()
};
let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
let opts = &ctx.report_options();
assert_eq!(opts.output_dir, PathBuf::from("data"));
assert_eq!(opts.file_prefix, "test".to_string());
assert!(opts.overwrite);
Ok(())
});
assert!(result.is_ok());
}
#[test]
fn test_run_with_custom() {
let test_args = BaseArgs::new();
let custom = CustomArgs { a: 42 };
let result = run_with_args_internal(test_args, Some(custom), |_, _, c| {
assert_eq!(c.unwrap().a, 42);
Ok(())
});
assert!(result.is_ok());
}
#[test]
fn test_run_with_logging_enabled() {
let mut test_args = BaseArgs::new();
test_args.log_level = Some(LevelFilter::Info.to_string());
let result = run_with_args_internal(test_args, None, |_, _, _: Option<()>| Ok(()));
assert!(result.is_ok());
}
#[test]
fn test_run_with_logging_modules() {
assert_cmd::Command::new("cargo")
.args(["build", "--bin", "runner_test_debug"])
.ok()
.expect("Failed to build runner_test_debug");
let output = assert_cmd::Command::cargo_bin("runner_test_debug")
.unwrap()
.args([
"--debugger",
"1.0",
"--log-level",
"rustyline=Debug,ixa=Trace",
])
.write_stdin("population\n")
.output();
match String::from_utf8(output.unwrap().stdout) {
Ok(s) => {
assert!(s.contains("Logging enabled for rustyline at level DEBUG"));
assert!(s.contains("Logging enabled for ixa at level TRACE"));
assert!(s.contains("TRACE ixa::plan - adding plan at 1"));
}
Err(e) => {
println!("Failed to convert: {e}");
panic!();
}
}
}
}