ixa/
data_plugin.rs

1use crate::{HashSet, PluginContext};
2use std::any::{Any, TypeId};
3use std::cell::RefCell;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::{LazyLock, Mutex};
6
7/// A collection of `TypeId`s of all `DataPlugin` types linked into the code.
8static DATA_PLUGINS: LazyLock<Mutex<RefCell<HashSet<TypeId>>>> =
9    LazyLock::new(|| Mutex::new(RefCell::new(HashSet::default())));
10
11pub fn add_data_plugin_to_registry<T: DataPlugin>() {
12    DATA_PLUGINS
13        .lock()
14        .unwrap()
15        .borrow_mut()
16        .insert(TypeId::of::<T>());
17}
18
19pub fn get_data_plugin_ids() -> Vec<TypeId> {
20    DATA_PLUGINS
21        .lock()
22        .unwrap()
23        .borrow()
24        .iter()
25        .copied()
26        .collect()
27}
28
29pub fn get_data_plugin_count() -> usize {
30    DATA_PLUGINS.lock().unwrap().borrow().len()
31}
32
33/// Global data plugin index counter, keeps track of the index that will be assigned to the next
34/// data plugin that requests an index.
35///
36/// Instead of storing data plugins in a `HashMap` in `Context`, we store them in a vector. To fetch
37/// the data plugin, we ask the data plugin type for the index into `Context::data_plugins` at
38/// which an instance of the data plugin type should be stored. Accessing a data plugin, then, is
39/// just an index into an array.
40static NEXT_DATA_PLUGIN_INDEX: Mutex<usize> = Mutex::new(0);
41
42/// Acquires a global lock on the next available plugin index, but only increments it if we
43/// successfully initialize the provided index. (Must be `pub`, as it's called from within a macro.)
44pub fn initialize_data_plugin_index(plugin_index: &AtomicUsize) -> usize {
45    // Acquire a global lock.
46    let mut guard = NEXT_DATA_PLUGIN_INDEX.lock().unwrap();
47    let candidate = *guard;
48
49    // Try to claim the candidate index. Here we guard against the potential race condition that
50    // another instance of this plugin in another thread just initialized the index prior to us
51    // obtaining the lock. If the index has been initialized beneath us, we do not update
52    // `NEXT_DATA_PLUGIN_INDEX`, we just return the value `plugin_index` was initialized to.
53    // For a justification of the data ordering, see:
54    //     https://github.com/CDCgov/ixa/pull/477#discussion_r2244302872
55    match plugin_index.compare_exchange(usize::MAX, candidate, Ordering::AcqRel, Ordering::Acquire)
56    {
57        Ok(_) => {
58            // We won the race — increment the global next plugin index and return the new index
59            *guard += 1;
60            candidate
61        }
62        Err(existing) => {
63            // Another thread beat us — don’t increment the global next plugin index,
64            // just return existing
65            existing
66        }
67    }
68}
69
70/// A trait for objects that can provide data containers to be held by `Context`
71pub trait DataPlugin: Any {
72    type DataContainer;
73
74    fn init(context: &impl PluginContext) -> Self::DataContainer;
75
76    /// Returns the index into `Context::data_plugins`, the vector of data plugins, where
77    /// the instance of this data plugin can be found.
78    fn index_within_context() -> usize;
79}
80
81/// Helper for `define_data_plugin`
82#[macro_export]
83macro_rules! __define_data_plugin {
84    ($data_plugin:ident, $data_container:ty, |$ctx:ident| $body:expr) => {
85        struct $data_plugin;
86
87        impl $crate::DataPlugin for $data_plugin {
88            type DataContainer = $data_container;
89
90            fn init($ctx: &impl $crate::PluginContext) -> Self::DataContainer {
91                $body
92            }
93
94            fn index_within_context() -> usize {
95                // This static must be initialized with a compile-time constant expression.
96                // We use `usize::MAX` as a sentinel to mean "uninitialized". This
97                // static variable is shared among all instances of this data plugin type.
98                static INDEX: std::sync::atomic::AtomicUsize =
99                    std::sync::atomic::AtomicUsize::new(usize::MAX);
100
101                // Fast path: already initialized.
102                let index = INDEX.load(std::sync::atomic::Ordering::Relaxed);
103                if index != usize::MAX {
104                    return index;
105                }
106
107                // Slow path: initialize it.
108                $crate::initialize_data_plugin_index(&INDEX)
109            }
110        }
111
112        $crate::paste::paste! {
113            $crate::ctor::declarative::ctor!{
114                #[ctor]
115                fn [<_register_plugin_$data_plugin:snake>]() {
116                    $crate::add_data_plugin_to_registry::<$data_plugin>()
117                }
118            }
119        }
120    };
121}
122
123/// Defines a new type for storing data in Context.
124#[macro_export]
125macro_rules! define_data_plugin {
126    ($data_plugin:ident, $data_container:ty, |$ctx:ident| $body:expr) => {
127        $crate::__define_data_plugin!($data_plugin, $data_container, |$ctx| $body);
128    };
129
130    ($data_plugin:ident, $data_container:ty, $default: expr) => {
131        $crate::__define_data_plugin!($data_plugin, $data_container, |_context| $default);
132    };
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use crate::Context;
139    use std::sync::{Arc, Barrier};
140    use std::thread;
141
142    // We attempt an out-of-bounds index with a plugin
143    #[test]
144    #[should_panic(
145        expected = "No data plugin found with index = 1000. You must use the `define_data_plugin!` macro to create a data plugin."
146    )]
147    fn test_wrong_data_plugin_impl_index_oob() {
148        // Suppose a user doesn't use the `define_data_plugin` macro and tries to implement it
149        // themselves. What error modes are possible? First lets try an obviously out-of-bounds
150        // index.
151        struct MyDataPlugin;
152        impl DataPlugin for MyDataPlugin {
153            type DataContainer = Vec<u32>;
154
155            fn init(_context: &impl PluginContext) -> Self::DataContainer {
156                vec![]
157            }
158
159            fn index_within_context() -> usize {
160                1000 // arbitrarily out of bounds
161            }
162        }
163
164        let context = Context::new();
165        let container = context.get_data(MyDataPlugin);
166        println!("{}", container.len());
167    }
168
169    // We attempt a collision with a plugin
170    define_data_plugin!(LegitDataPlugin, Vec<u32>, vec![]);
171    #[should_panic(
172        expected = "TypeID does not match data plugin type. You must use the `define_data_plugin!` macro to create a data plugin."
173    )]
174    #[test]
175    fn test_wrong_data_plugin_impl_wrong_type() {
176        // Suppose a user doesn't use the `define_data_plugin` macro and tries
177        // to implement it themselves. What error modes are possible? Here we
178        // test for an index collision.
179        struct MyOtherDataPlugin;
180        impl DataPlugin for MyOtherDataPlugin {
181            type DataContainer = Vec<u8>;
182
183            fn init(_context: &impl PluginContext) -> Self::DataContainer {
184                vec![]
185            }
186
187            fn index_within_context() -> usize {
188                // Several plugins are registered in a test context, so an index of 1 should
189                // collide with another plugin of a different type.
190                LegitDataPlugin::index_within_context()
191            }
192        }
193
194        let context = Context::new();
195        // Make sure the legit plugin is initialized first
196        let _ = context.get_data(LegitDataPlugin);
197
198        // Panics here:
199        let container = context.get_data(MyOtherDataPlugin);
200        // Some arbitrary code involving `container`
201        println!("{}", container.len());
202    }
203
204    // Test thread safety of `initialize_data_plugin_index`.
205    #[test]
206    fn test_multithreaded_plugin_init() {
207        struct DataPluginContainerA;
208        define_data_plugin!(DataPluginA, DataPluginContainerA, DataPluginContainerA);
209        struct DataPluginContainerB;
210        define_data_plugin!(DataPluginB, DataPluginContainerB, DataPluginContainerB);
211        struct DataPluginContainerC;
212        define_data_plugin!(DataPluginC, DataPluginContainerC, DataPluginContainerC);
213        struct DataPluginContainerD;
214        define_data_plugin!(DataPluginD, DataPluginContainerD, DataPluginContainerD);
215
216        // Plugin accessors
217        let accessors: Vec<&(dyn Fn(&Context) + Send + Sync)> = vec![
218            &|ctx: &Context| {
219                let _ = ctx.get_data(DataPluginA);
220            },
221            &|ctx: &Context| {
222                let _ = ctx.get_data(DataPluginB);
223            },
224            &|ctx: &Context| {
225                let _ = ctx.get_data(DataPluginC);
226            },
227            &|ctx: &Context| {
228                let _ = ctx.get_data(DataPluginD);
229            },
230        ];
231
232        let num_threads = 20;
233        let barrier = Arc::new(Barrier::new(num_threads));
234        let mut handles = Vec::with_capacity(num_threads);
235
236        for i in 0..num_threads {
237            let barrier = Arc::clone(&barrier);
238            let accessor = accessors[i % accessors.len()];
239
240            let handle = thread::spawn(move || {
241                let context = Context::new();
242                barrier.wait();
243                accessor(&context);
244            });
245
246            handles.push(handle);
247        }
248
249        for handle in handles {
250            handle.join().expect("Thread panicked");
251        }
252    }
253}