1use std::any::{Any, TypeId};
19use std::cell::RefCell;
20use std::collections::hash_map::Entry;
21use std::error::Error;
22use std::fmt::Debug;
23use std::fs;
24use std::io::BufReader;
25use std::path::Path;
26use std::sync::{Arc, LazyLock, Mutex};
27
28use serde::de::DeserializeOwned;
29
30use crate::context::Context;
31use crate::error::IxaError;
32use crate::{define_data_plugin, trace, ContextBase, HashMap, HashMapExt};
33
34type PropertySetterFn =
35 dyn Fn(&mut Context, &str, serde_json::Value) -> Result<(), IxaError> + Send + Sync;
36
37type PropertyGetterFn = dyn Fn(&Context) -> Result<Option<String>, IxaError> + Send + Sync;
38
39pub struct PropertyAccessors {
40 setter: Box<PropertySetterFn>,
41 getter: Box<PropertyGetterFn>,
42}
43
44#[allow(clippy::type_complexity)]
45#[doc(hidden)]
52pub static GLOBAL_PROPERTIES: LazyLock<Mutex<RefCell<HashMap<String, Arc<PropertyAccessors>>>>> =
53 LazyLock::new(|| Mutex::new(RefCell::new(HashMap::new())));
54
55#[allow(clippy::missing_panics_doc)]
56pub fn add_global_property<T: GlobalProperty>(name: &str)
57where
58 for<'de> <T as GlobalProperty>::Value: serde::Deserialize<'de> + serde::Serialize,
59{
60 trace!("Adding global property {name}");
61 let properties = GLOBAL_PROPERTIES.lock().unwrap();
62 properties
63 .borrow_mut()
64 .insert(
65 name.to_string(),
66 Arc::new(PropertyAccessors {
67 setter: Box::new(
68 |context: &mut Context, name, value| -> Result<(), IxaError> {
69 let val: T::Value = serde_json::from_value(value)?;
70 T::validate(&val).map_err(|source| {
71 IxaError::IllegalGlobalPropertyValue {
72 name: T::name().to_string(),
73 source,
74 }
75 })?;
76 if context.get_global_property_value(T::new()).is_some() {
77 return Err(IxaError::DuplicateProperty {
78 name: name.to_string(),
79 });
80 }
81 context.set_global_property_value(T::new(), val)?;
82 Ok(())
83 },
84 ),
85 getter: Box::new(|context: &Context| -> Result<Option<String>, IxaError> {
86 let value = context.get_global_property_value(T::new());
87 match value {
88 Some(val) => Ok(Some(serde_json::to_string(val)?)),
89 None => Ok(None),
90 }
91 }),
92 }),
93 )
94 .inspect(|_| panic!("Duplicate global property {}", name));
95}
96
97fn get_global_property_accessor(name: &str) -> Option<Arc<PropertyAccessors>> {
98 let properties = GLOBAL_PROPERTIES.lock().unwrap();
99 let tmp = properties.borrow();
100 tmp.get(name).map(Arc::clone)
101}
102
103pub trait GlobalProperty: Any {
112 type Value: Any;
114
115 fn new() -> Self;
116
117 fn name() -> &'static str {
118 let full = std::any::type_name::<Self>();
119 full.rsplit("::").next().unwrap()
120 }
121
122 fn validate(value: &Self::Value) -> Result<(), Box<dyn Error + 'static>>;
126}
127
128struct GlobalPropertiesDataContainer {
129 global_property_container: HashMap<TypeId, Box<dyn Any>>,
130}
131
132define_data_plugin!(
133 GlobalPropertiesPlugin,
134 GlobalPropertiesDataContainer,
135 GlobalPropertiesDataContainer {
136 global_property_container: HashMap::default(),
137 }
138);
139
140impl GlobalPropertiesDataContainer {
141 fn set_global_property_value<T: GlobalProperty + 'static>(
142 &mut self,
143 _property: &T,
144 value: T::Value,
145 ) -> Result<(), IxaError> {
146 match self.global_property_container.entry(TypeId::of::<T>()) {
147 Entry::Vacant(entry) => {
148 entry.insert(Box::new(value));
149 Ok(())
150 }
151 Entry::Occupied(_) => Err(IxaError::EntryAlreadyExists),
155 }
156 }
157
158 #[must_use]
159 fn get_global_property_value<T: GlobalProperty + 'static>(&self) -> Option<&T::Value> {
160 let data_container = self.global_property_container.get(&TypeId::of::<T>());
161
162 match data_container {
163 Some(property) => Some(property.downcast_ref::<T::Value>().unwrap()),
164 None => None,
165 }
166 }
167}
168
169pub trait ContextGlobalPropertiesExt: ContextBase {
170 fn set_global_property_value<T: GlobalProperty + 'static>(
175 &mut self,
176 property: T,
177 value: T::Value,
178 ) -> Result<(), IxaError> {
179 T::validate(&value).map_err(|source| IxaError::IllegalGlobalPropertyValue {
180 name: T::name().to_string(),
181 source,
182 })?;
183 let data_container = self.get_data_mut(GlobalPropertiesPlugin);
184 data_container.set_global_property_value(&property, value)
185 }
186
187 #[allow(unused_variables)]
189 fn get_global_property_value<T: GlobalProperty + 'static>(
190 &self,
191 _property: T,
192 ) -> Option<&T::Value> {
193 self.get_data(GlobalPropertiesPlugin)
194 .get_global_property_value::<T>()
195 }
196
197 fn list_registered_global_properties(&self) -> Vec<String> {
198 let properties = GLOBAL_PROPERTIES.lock().unwrap();
199 let tmp = properties.borrow();
200 tmp.keys().cloned().collect()
201 }
202
203 fn get_serialized_value_by_string(&self, name: &str) -> Result<Option<String>, IxaError>;
209
210 fn load_parameters_from_json<T: 'static + Debug + DeserializeOwned>(
218 &mut self,
219 file_name: &Path,
220 ) -> Result<T, IxaError> {
221 trace!("Loading parameters from JSON: {file_name:?}");
222 let config_file = fs::File::open(file_name)?;
223 let reader = BufReader::new(config_file);
224 let config = serde_json::from_reader(reader)?;
225 Ok(config)
226 }
227
228 fn load_global_properties(&mut self, file_name: &Path) -> Result<(), IxaError>;
250}
251impl ContextGlobalPropertiesExt for Context {
252 fn get_serialized_value_by_string(&self, name: &str) -> Result<Option<String>, IxaError> {
253 let accessor = get_global_property_accessor(name);
254 match accessor {
255 Some(accessor) => (accessor.getter)(self),
256 None => Err(IxaError::NoGlobalProperty {
257 name: name.to_string(),
258 }),
259 }
260 }
261
262 fn load_global_properties(&mut self, file_name: &Path) -> Result<(), IxaError> {
263 trace!("Loading global properties from {file_name:?}");
264 let config_file = fs::File::open(file_name)?;
265 let reader = BufReader::new(config_file);
266 let val: serde_json::Map<String, serde_json::Value> = serde_json::from_reader(reader)?;
267
268 for (k, v) in val {
269 if let Some(accessor) = get_global_property_accessor(&k) {
270 (accessor.setter)(self, &k, v)?;
271 } else {
272 return Err(IxaError::NoGlobalProperty { name: k });
273 }
274 }
275
276 Ok(())
277 }
278}
279
280#[cfg(test)]
281mod test {
282 use std::error::Error;
283 use std::fmt;
284 use std::path::PathBuf;
285
286 use serde::{Deserialize, Serialize};
287 use tempfile::tempdir;
288
289 use super::*;
290 use crate::context::Context;
291 use crate::define_global_property;
292 use crate::error::IxaError;
293
294 #[derive(Debug)]
295 struct InvalidProperty3Value {
296 field_int: u32,
297 }
298
299 impl fmt::Display for InvalidProperty3Value {
300 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
301 write!(f, "field_int must be zero, got {}", self.field_int)
302 }
303 }
304
305 impl Error for InvalidProperty3Value {}
306
307 #[derive(Serialize, Deserialize, Debug, Clone)]
308 pub struct ParamType {
309 pub days: usize,
310 pub diseases: usize,
311 }
312
313 define_global_property!(DiseaseParams, ParamType);
314
315 #[test]
316 fn set_get_global_property() {
317 let params: ParamType = ParamType {
318 days: 10,
319 diseases: 2,
320 };
321 let params2: ParamType = ParamType {
322 days: 11,
323 diseases: 3,
324 };
325
326 let mut context = Context::new();
327
328 context
330 .set_global_property_value(DiseaseParams, params.clone())
331 .unwrap();
332 let global_params = context
333 .get_global_property_value(DiseaseParams)
334 .unwrap()
335 .clone();
336 assert_eq!(global_params.days, params.days);
337 assert_eq!(global_params.diseases, params.diseases);
338
339 assert!(context
341 .set_global_property_value(DiseaseParams, params2.clone())
342 .is_err());
343
344 let global_params = context
346 .get_global_property_value(DiseaseParams)
347 .unwrap()
348 .clone();
349 assert_eq!(global_params.days, params.days);
350 assert_eq!(global_params.diseases, params.diseases);
351 }
352
353 #[test]
354 fn get_global_propert_missing() {
355 let context = Context::new();
356 let global_params = context.get_global_property_value(DiseaseParams);
357 assert!(global_params.is_none());
358 }
359
360 #[test]
361 fn set_parameters() {
362 let mut context = Context::new();
363 let temp_dir = tempdir().unwrap();
364 let config_path = PathBuf::from(&temp_dir.path());
365 let file_name = "test.json";
366 let file_path = config_path.join(file_name);
367 let config = fs::File::create(config_path.join(file_name)).unwrap();
368
369 let params: ParamType = ParamType {
370 days: 10,
371 diseases: 2,
372 };
373
374 define_global_property!(Parameters, ParamType);
375
376 let _ = serde_json::to_writer(config, ¶ms);
377 let params_json = context
378 .load_parameters_from_json::<ParamType>(&file_path)
379 .unwrap();
380
381 context
382 .set_global_property_value(Parameters, params_json)
383 .unwrap();
384
385 let params_read = context
386 .get_global_property_value(Parameters)
387 .unwrap()
388 .clone();
389 assert_eq!(params_read.days, params.days);
390 assert_eq!(params_read.diseases, params.diseases);
391 }
392
393 #[derive(Serialize, Deserialize)]
394 pub struct Property1Type {
395 field_int: u32,
396 field_str: String,
397 }
398 define_global_property!(Property1, Property1Type);
399
400 #[derive(Serialize, Deserialize)]
401 pub struct Property2Type {
402 field_int: u32,
403 }
404 define_global_property!(Property2, Property2Type);
405
406 #[test]
407 fn read_global_properties() {
408 let mut context = Context::new();
409 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
410 .join("tests/data/global_properties_test1.json");
411 context.load_global_properties(&path).unwrap();
412 let p1 = context.get_global_property_value(Property1).unwrap();
413 assert_eq!(p1.field_int, 1);
414 assert_eq!(p1.field_str, "test");
415 let p2 = context.get_global_property_value(Property2).unwrap();
416 assert_eq!(p2.field_int, 2);
417 }
418
419 #[test]
420 fn read_unknown_property() {
421 let mut context = Context::new();
422 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
423 .join("tests/data/global_properties_missing.json");
424 match context.load_global_properties(&path) {
425 Err(IxaError::NoGlobalProperty { name }) => assert_eq!(name, "ixa.PropertyUnknown"),
426 _ => panic!("Unexpected error type"),
427 }
428 }
429
430 #[test]
431 fn read_malformed_property() {
432 let mut context = Context::new();
433 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
434 .join("tests/data/global_properties_malformed.json");
435 let error = context.load_global_properties(&path);
436 println!("Error {error:?}");
437 match error {
438 Err(IxaError::JsonError(_)) => {}
439 _ => panic!("Unexpected error type"),
440 }
441 }
442
443 #[test]
444 fn read_duplicate_property() {
445 let mut context = Context::new();
446 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
447 .join("tests/data/global_properties_test1.json");
448 context.load_global_properties(&path).unwrap();
449 let error = context.load_global_properties(&path);
450 match error {
451 Err(IxaError::DuplicateProperty { .. }) => {}
452 _ => panic!("Unexpected error type"),
453 }
454 }
455
456 #[derive(Serialize, Deserialize)]
457 pub struct Property3Type {
458 field_int: u32,
459 }
460 define_global_property!(Property3, Property3Type, |v: &Property3Type| {
461 match v.field_int {
462 0 => Ok(()),
463 _ => Err(Box::new(InvalidProperty3Value {
464 field_int: v.field_int,
465 }) as Box<dyn Error + 'static>),
466 }
467 });
468
469 #[test]
470 fn validate_property_set_success() {
471 let mut context = Context::new();
472 context
473 .set_global_property_value(Property3, Property3Type { field_int: 0 })
474 .unwrap();
475 }
476
477 #[test]
478 fn validate_property_set_failure() {
479 let mut context = Context::new();
480 let error = context
481 .set_global_property_value(Property3, Property3Type { field_int: 1 })
482 .unwrap_err();
483 assert_eq!(
484 error.to_string(),
485 "illegal value for global property `Property3`: field_int must be zero, got 1"
486 );
487 match error {
488 IxaError::IllegalGlobalPropertyValue { name, source } => {
489 assert_eq!(name, "Property3");
490 assert_eq!(source.to_string(), "field_int must be zero, got 1");
491 }
492 _ => panic!("Unexpected error type"),
493 }
494 }
495
496 #[test]
497 fn validate_property_load_success() {
498 let mut context = Context::new();
499 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
500 .join("tests/data/global_properties_valid.json");
501 context.load_global_properties(&path).unwrap();
502 }
503
504 #[test]
505 fn validate_property_load_failure() {
506 let mut context = Context::new();
507 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
508 .join("tests/data/global_properties_invalid.json");
509 let error = context.load_global_properties(&path).unwrap_err();
510 assert_eq!(
511 error.to_string(),
512 "illegal value for global property `Property3`: field_int must be zero, got 42"
513 );
514 match error {
515 IxaError::IllegalGlobalPropertyValue { name, source } => {
516 assert_eq!(name, "Property3");
517 assert_eq!(source.to_string(), "field_int must be zero, got 42");
518 }
519 _ => panic!("Unexpected error type"),
520 }
521 }
522
523 #[test]
524 fn list_registered_global_properties() {
525 let context = Context::new();
526 let properties = context.list_registered_global_properties();
527 assert!(properties.contains(&"ixa.DiseaseParams".to_string()));
528 }
529
530 #[test]
531 fn get_serialized_value_by_string() {
532 let mut context = Context::new();
533 context
534 .set_global_property_value(
535 DiseaseParams,
536 ParamType {
537 days: 10,
538 diseases: 2,
539 },
540 )
541 .unwrap();
542 let serialized = context
543 .get_serialized_value_by_string("ixa.DiseaseParams")
544 .unwrap();
545 assert_eq!(serialized, Some("{\"days\":10,\"diseases\":2}".to_string()));
546 }
547}