1use crate::{HashSet, PluginContext};
2use std::any::{Any, TypeId};
3use std::cell::RefCell;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::{LazyLock, Mutex};
6
7static DATA_PLUGINS: LazyLock<Mutex<RefCell<HashSet<TypeId>>>> =
9 LazyLock::new(|| Mutex::new(RefCell::new(HashSet::default())));
10
11pub fn add_data_plugin_to_registry<T: DataPlugin>() {
12 DATA_PLUGINS
13 .lock()
14 .unwrap()
15 .borrow_mut()
16 .insert(TypeId::of::<T>());
17}
18
19pub fn get_data_plugin_ids() -> Vec<TypeId> {
20 DATA_PLUGINS
21 .lock()
22 .unwrap()
23 .borrow()
24 .iter()
25 .copied()
26 .collect()
27}
28
29pub fn get_data_plugin_count() -> usize {
30 DATA_PLUGINS.lock().unwrap().borrow().len()
31}
32
33static NEXT_DATA_PLUGIN_INDEX: Mutex<usize> = Mutex::new(0);
41
42pub fn initialize_data_plugin_index(plugin_index: &AtomicUsize) -> usize {
45 let mut guard = NEXT_DATA_PLUGIN_INDEX.lock().unwrap();
47 let candidate = *guard;
48
49 match plugin_index.compare_exchange(usize::MAX, candidate, Ordering::AcqRel, Ordering::Acquire)
56 {
57 Ok(_) => {
58 *guard += 1;
60 candidate
61 }
62 Err(existing) => {
63 existing
66 }
67 }
68}
69
70pub trait DataPlugin: Any {
72 type DataContainer;
73
74 fn init(context: &impl PluginContext) -> Self::DataContainer;
75
76 fn index_within_context() -> usize;
79}
80
81#[macro_export]
83macro_rules! __define_data_plugin {
84 ($data_plugin:ident, $data_container:ty, |$ctx:ident| $body:expr) => {
85 struct $data_plugin;
86
87 impl $crate::DataPlugin for $data_plugin {
88 type DataContainer = $data_container;
89
90 fn init($ctx: &impl $crate::PluginContext) -> Self::DataContainer {
91 $body
92 }
93
94 fn index_within_context() -> usize {
95 static INDEX: std::sync::atomic::AtomicUsize =
99 std::sync::atomic::AtomicUsize::new(usize::MAX);
100
101 let index = INDEX.load(std::sync::atomic::Ordering::Relaxed);
103 if index != usize::MAX {
104 return index;
105 }
106
107 $crate::initialize_data_plugin_index(&INDEX)
109 }
110 }
111
112 $crate::paste::paste! {
113 $crate::ctor::declarative::ctor!{
114 #[ctor]
115 fn [<_register_plugin_$data_plugin:snake>]() {
116 $crate::add_data_plugin_to_registry::<$data_plugin>()
117 }
118 }
119 }
120 };
121}
122
123#[macro_export]
125macro_rules! define_data_plugin {
126 ($data_plugin:ident, $data_container:ty, |$ctx:ident| $body:expr) => {
127 $crate::__define_data_plugin!($data_plugin, $data_container, |$ctx| $body);
128 };
129
130 ($data_plugin:ident, $data_container:ty, $default: expr) => {
131 $crate::__define_data_plugin!($data_plugin, $data_container, |_context| $default);
132 };
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use crate::Context;
139 use std::sync::{Arc, Barrier};
140 use std::thread;
141
142 #[test]
144 #[should_panic(
145 expected = "No data plugin found with index = 1000. You must use the `define_data_plugin!` macro to create a data plugin."
146 )]
147 fn test_wrong_data_plugin_impl_index_oob() {
148 struct MyDataPlugin;
152 impl DataPlugin for MyDataPlugin {
153 type DataContainer = Vec<u32>;
154
155 fn init(_context: &impl PluginContext) -> Self::DataContainer {
156 vec![]
157 }
158
159 fn index_within_context() -> usize {
160 1000 }
162 }
163
164 let context = Context::new();
165 let container = context.get_data(MyDataPlugin);
166 println!("{}", container.len());
167 }
168
169 define_data_plugin!(LegitDataPlugin, Vec<u32>, vec![]);
171 #[should_panic(
172 expected = "TypeID does not match data plugin type. You must use the `define_data_plugin!` macro to create a data plugin."
173 )]
174 #[test]
175 fn test_wrong_data_plugin_impl_wrong_type() {
176 struct MyOtherDataPlugin;
180 impl DataPlugin for MyOtherDataPlugin {
181 type DataContainer = Vec<u8>;
182
183 fn init(_context: &impl PluginContext) -> Self::DataContainer {
184 vec![]
185 }
186
187 fn index_within_context() -> usize {
188 LegitDataPlugin::index_within_context()
191 }
192 }
193
194 let context = Context::new();
195 let _ = context.get_data(LegitDataPlugin);
197
198 let container = context.get_data(MyOtherDataPlugin);
200 println!("{}", container.len());
202 }
203
204 #[test]
206 fn test_multithreaded_plugin_init() {
207 struct DataPluginContainerA;
208 define_data_plugin!(DataPluginA, DataPluginContainerA, DataPluginContainerA);
209 struct DataPluginContainerB;
210 define_data_plugin!(DataPluginB, DataPluginContainerB, DataPluginContainerB);
211 struct DataPluginContainerC;
212 define_data_plugin!(DataPluginC, DataPluginContainerC, DataPluginContainerC);
213 struct DataPluginContainerD;
214 define_data_plugin!(DataPluginD, DataPluginContainerD, DataPluginContainerD);
215
216 let accessors: Vec<&(dyn Fn(&Context) + Send + Sync)> = vec![
218 &|ctx: &Context| {
219 let _ = ctx.get_data(DataPluginA);
220 },
221 &|ctx: &Context| {
222 let _ = ctx.get_data(DataPluginB);
223 },
224 &|ctx: &Context| {
225 let _ = ctx.get_data(DataPluginC);
226 },
227 &|ctx: &Context| {
228 let _ = ctx.get_data(DataPluginD);
229 },
230 ];
231
232 let num_threads = 20;
233 let barrier = Arc::new(Barrier::new(num_threads));
234 let mut handles = Vec::with_capacity(num_threads);
235
236 for i in 0..num_threads {
237 let barrier = Arc::clone(&barrier);
238 let accessor = accessors[i % accessors.len()];
239
240 let handle = thread::spawn(move || {
241 let context = Context::new();
242 barrier.wait();
243 accessor(&context);
244 });
245
246 handles.push(handle);
247 }
248
249 for handle in handles {
250 handle.join().expect("Thread panicked");
251 }
252 }
253}