ixa/
data_plugin.rs

1use std::any::{Any, TypeId};
2use std::cell::RefCell;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{LazyLock, Mutex};
5
6use crate::{HashSet, PluginContext};
7
8/// A collection of [`TypeId`]s of all [`DataPlugin`] types linked into the code.
9static DATA_PLUGINS: LazyLock<Mutex<RefCell<HashSet<TypeId>>>> =
10    LazyLock::new(|| Mutex::new(RefCell::new(HashSet::default())));
11
12pub fn add_data_plugin_to_registry<T: DataPlugin>() {
13    DATA_PLUGINS
14        .lock()
15        .unwrap()
16        .borrow_mut()
17        .insert(TypeId::of::<T>());
18}
19
20pub fn get_data_plugin_ids() -> Vec<TypeId> {
21    DATA_PLUGINS
22        .lock()
23        .unwrap()
24        .borrow()
25        .iter()
26        .copied()
27        .collect()
28}
29
30pub fn get_data_plugin_count() -> usize {
31    DATA_PLUGINS.lock().unwrap().borrow().len()
32}
33
34/// Global data plugin index counter, keeps track of the index that will be assigned to the next
35/// data plugin that requests an index.
36///
37/// Instead of storing data plugins in a [`HashMap`] in [`Context`], we store them in a vector. To fetch
38/// the data plugin, we ask the data plugin type for the index into [`Context::data_plugins`] at
39/// which an instance of the data plugin type should be stored. Accessing a data plugin, then, is
40/// just an index into an array.
41static NEXT_DATA_PLUGIN_INDEX: Mutex<usize> = Mutex::new(0);
42
43/// Acquires a global lock on the next available plugin index, but only increments it if we
44/// successfully initialize the provided index. (Must be `pub`, as it's called from within a macro.)
45pub fn initialize_data_plugin_index(plugin_index: &AtomicUsize) -> usize {
46    // Acquire a global lock.
47    let mut guard = NEXT_DATA_PLUGIN_INDEX.lock().unwrap();
48    let candidate = *guard;
49
50    // Try to claim the candidate index. Here we guard against the potential race condition that
51    // another instance of this plugin in another thread just initialized the index prior to us
52    // obtaining the lock. If the index has been initialized beneath us, we do not update
53    // `NEXT_DATA_PLUGIN_INDEX`, we just return the value `plugin_index` was initialized to.
54    // For a justification of the data ordering, see:
55    //     https://github.com/CDCgov/ixa/pull/477#discussion_r2244302872
56    match plugin_index.compare_exchange(usize::MAX, candidate, Ordering::AcqRel, Ordering::Acquire)
57    {
58        Ok(_) => {
59            // We won the race — increment the global next plugin index and return the new index
60            *guard += 1;
61            candidate
62        }
63        Err(existing) => {
64            // Another thread beat us — don’t increment the global next plugin index,
65            // just return existing
66            existing
67        }
68    }
69}
70
71/// A trait for objects that can provide data containers to be held by [`Context`](crate::Context)
72pub trait DataPlugin: Any {
73    type DataContainer;
74
75    fn init<C: PluginContext>(context: &C) -> Self::DataContainer;
76
77    /// Returns the index into `Context::data_plugins`, the vector of data plugins, where
78    /// the instance of this data plugin can be found.
79    fn index_within_context() -> usize;
80}
81
82#[cfg(test)]
83mod tests {
84    use std::sync::{Arc, Barrier};
85    use std::thread;
86
87    use super::*;
88    use crate::{define_data_plugin, Context};
89
90    // We attempt an out-of-bounds index with a plugin
91    #[test]
92    #[should_panic(
93        expected = "No data plugin found with index = 1000. You must use the `define_data_plugin!` macro to create a data plugin."
94    )]
95    fn test_wrong_data_plugin_impl_index_oob() {
96        // Suppose a user doesn't use the `define_data_plugin` macro and tries to implement it
97        // themselves. What error modes are possible? First lets try an obviously out-of-bounds
98        // index.
99        struct MyDataPlugin;
100        impl DataPlugin for MyDataPlugin {
101            type DataContainer = Vec<u32>;
102
103            fn init<C: PluginContext>(_context: &C) -> Self::DataContainer {
104                vec![]
105            }
106
107            fn index_within_context() -> usize {
108                1000 // arbitrarily out of bounds
109            }
110        }
111
112        let context = Context::new();
113        let container = context.get_data(MyDataPlugin);
114        println!("{}", container.len());
115    }
116
117    // We attempt a collision with a plugin
118    define_data_plugin!(LegitDataPlugin, Vec<u32>, vec![]);
119    #[should_panic(
120        expected = "TypeID does not match data plugin type. You must use the `define_data_plugin!` macro to create a data plugin."
121    )]
122    #[test]
123    fn test_wrong_data_plugin_impl_wrong_type() {
124        // Suppose a user doesn't use the `define_data_plugin` macro and tries
125        // to implement it themselves. What error modes are possible? Here we
126        // test for an index collision.
127        struct MyOtherDataPlugin;
128        impl DataPlugin for MyOtherDataPlugin {
129            type DataContainer = Vec<u8>;
130
131            fn init<C: PluginContext>(_context: &C) -> Self::DataContainer {
132                vec![]
133            }
134
135            fn index_within_context() -> usize {
136                // Several plugins are registered in a test context, so an index of 1 should
137                // collide with another plugin of a different type.
138                LegitDataPlugin::index_within_context()
139            }
140        }
141
142        let context = Context::new();
143        // Make sure the legit plugin is initialized first
144        let _ = context.get_data(LegitDataPlugin);
145
146        // Panics here:
147        let container = context.get_data(MyOtherDataPlugin);
148        // Some arbitrary code involving `container`
149        println!("{}", container.len());
150    }
151
152    // Test thread safety of `initialize_data_plugin_index`.
153    #[test]
154    fn test_multithreaded_plugin_init() {
155        struct DataPluginContainerA;
156        define_data_plugin!(DataPluginA, DataPluginContainerA, DataPluginContainerA);
157        struct DataPluginContainerB;
158        define_data_plugin!(DataPluginB, DataPluginContainerB, DataPluginContainerB);
159        struct DataPluginContainerC;
160        define_data_plugin!(DataPluginC, DataPluginContainerC, DataPluginContainerC);
161        struct DataPluginContainerD;
162        define_data_plugin!(DataPluginD, DataPluginContainerD, DataPluginContainerD);
163
164        // Plugin accessors
165        let accessors: Vec<&(dyn Fn(&Context) + Send + Sync)> = vec![
166            &|ctx: &Context| {
167                let _ = ctx.get_data(DataPluginA);
168            },
169            &|ctx: &Context| {
170                let _ = ctx.get_data(DataPluginB);
171            },
172            &|ctx: &Context| {
173                let _ = ctx.get_data(DataPluginC);
174            },
175            &|ctx: &Context| {
176                let _ = ctx.get_data(DataPluginD);
177            },
178        ];
179
180        let num_threads = 20;
181        let barrier = Arc::new(Barrier::new(num_threads));
182        let mut handles = Vec::with_capacity(num_threads);
183
184        for i in 0..num_threads {
185            let barrier = Arc::clone(&barrier);
186            let accessor = accessors[i % accessors.len()];
187
188            let handle = thread::spawn(move || {
189                let context = Context::new();
190                barrier.wait();
191                accessor(&context);
192            });
193
194            handles.push(handle);
195        }
196
197        for handle in handles {
198            handle.join().expect("Thread panicked");
199        }
200    }
201}