1use std::any::{Any, TypeId};
2use std::cell::RefCell;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{LazyLock, Mutex};
5
6use crate::{HashSet, PluginContext};
7
8static DATA_PLUGINS: LazyLock<Mutex<RefCell<HashSet<TypeId>>>> =
10 LazyLock::new(|| Mutex::new(RefCell::new(HashSet::default())));
11
12pub fn add_data_plugin_to_registry<T: DataPlugin>() {
13 DATA_PLUGINS
14 .lock()
15 .unwrap()
16 .borrow_mut()
17 .insert(TypeId::of::<T>());
18}
19
20pub fn get_data_plugin_ids() -> Vec<TypeId> {
21 DATA_PLUGINS
22 .lock()
23 .unwrap()
24 .borrow()
25 .iter()
26 .copied()
27 .collect()
28}
29
30pub fn get_data_plugin_count() -> usize {
31 DATA_PLUGINS.lock().unwrap().borrow().len()
32}
33
34static NEXT_DATA_PLUGIN_INDEX: Mutex<usize> = Mutex::new(0);
42
43pub fn initialize_data_plugin_index(plugin_index: &AtomicUsize) -> usize {
46 let mut guard = NEXT_DATA_PLUGIN_INDEX.lock().unwrap();
48 let candidate = *guard;
49
50 match plugin_index.compare_exchange(usize::MAX, candidate, Ordering::AcqRel, Ordering::Acquire)
57 {
58 Ok(_) => {
59 *guard += 1;
61 candidate
62 }
63 Err(existing) => {
64 existing
67 }
68 }
69}
70
71pub trait DataPlugin: Any {
73 type DataContainer;
74
75 fn init<C: PluginContext>(context: &C) -> Self::DataContainer;
76
77 fn index_within_context() -> usize;
80}
81
82#[cfg(test)]
83mod tests {
84 use std::sync::{Arc, Barrier};
85 use std::thread;
86
87 use super::*;
88 use crate::{define_data_plugin, Context};
89
90 #[test]
92 #[should_panic(
93 expected = "No data plugin found with index = 1000. You must use the `define_data_plugin!` macro to create a data plugin."
94 )]
95 fn test_wrong_data_plugin_impl_index_oob() {
96 struct MyDataPlugin;
100 impl DataPlugin for MyDataPlugin {
101 type DataContainer = Vec<u32>;
102
103 fn init<C: PluginContext>(_context: &C) -> Self::DataContainer {
104 vec![]
105 }
106
107 fn index_within_context() -> usize {
108 1000 }
110 }
111
112 let context = Context::new();
113 let container = context.get_data(MyDataPlugin);
114 println!("{}", container.len());
115 }
116
117 define_data_plugin!(LegitDataPlugin, Vec<u32>, vec![]);
119 #[should_panic(
120 expected = "TypeID does not match data plugin type. You must use the `define_data_plugin!` macro to create a data plugin."
121 )]
122 #[test]
123 fn test_wrong_data_plugin_impl_wrong_type() {
124 struct MyOtherDataPlugin;
128 impl DataPlugin for MyOtherDataPlugin {
129 type DataContainer = Vec<u8>;
130
131 fn init<C: PluginContext>(_context: &C) -> Self::DataContainer {
132 vec![]
133 }
134
135 fn index_within_context() -> usize {
136 LegitDataPlugin::index_within_context()
139 }
140 }
141
142 let context = Context::new();
143 let _ = context.get_data(LegitDataPlugin);
145
146 let container = context.get_data(MyOtherDataPlugin);
148 println!("{}", container.len());
150 }
151
152 #[test]
154 fn test_multithreaded_plugin_init() {
155 struct DataPluginContainerA;
156 define_data_plugin!(DataPluginA, DataPluginContainerA, DataPluginContainerA);
157 struct DataPluginContainerB;
158 define_data_plugin!(DataPluginB, DataPluginContainerB, DataPluginContainerB);
159 struct DataPluginContainerC;
160 define_data_plugin!(DataPluginC, DataPluginContainerC, DataPluginContainerC);
161 struct DataPluginContainerD;
162 define_data_plugin!(DataPluginD, DataPluginContainerD, DataPluginContainerD);
163
164 let accessors: Vec<&(dyn Fn(&Context) + Send + Sync)> = vec![
166 &|ctx: &Context| {
167 let _ = ctx.get_data(DataPluginA);
168 },
169 &|ctx: &Context| {
170 let _ = ctx.get_data(DataPluginB);
171 },
172 &|ctx: &Context| {
173 let _ = ctx.get_data(DataPluginC);
174 },
175 &|ctx: &Context| {
176 let _ = ctx.get_data(DataPluginD);
177 },
178 ];
179
180 let num_threads = 20;
181 let barrier = Arc::new(Barrier::new(num_threads));
182 let mut handles = Vec::with_capacity(num_threads);
183
184 for i in 0..num_threads {
185 let barrier = Arc::clone(&barrier);
186 let accessor = accessors[i % accessors.len()];
187
188 let handle = thread::spawn(move || {
189 let context = Context::new();
190 barrier.wait();
191 accessor(&context);
192 });
193
194 handles.push(handle);
195 }
196
197 for handle in handles {
198 handle.join().expect("Thread panicked");
199 }
200 }
201}