1use crate::context::Context;
2use crate::hashing::hash_str;
3use crate::{HashMap, HashMapExt, PluginContext};
4use log::trace;
5use rand::distributions::uniform::{SampleRange, SampleUniform};
6use rand::distributions::WeightedIndex;
7use rand::prelude::Distribution;
8use rand::{Rng, SeedableRng};
9use std::any::{Any, TypeId};
10use std::cell::{RefCell, RefMut};
11
12#[macro_export]
15macro_rules! define_rng {
16 ($random_id:ident) => {
17 #[derive(Copy, Clone)]
18 struct $random_id;
19
20 impl $crate::random::RngId for $random_id {
21 type RngType = $crate::rand::rngs::SmallRng;
22
23 fn get_name() -> &'static str {
24 stringify!($random_id)
25 }
26 }
27
28 $crate::paste::paste! {
30 #[doc(hidden)]
31 #[no_mangle]
32 #[allow(non_upper_case_globals)]
33 pub static [<rng_name_duplication_guard_ $random_id>]: () = ();
34 }
35 };
36}
37pub use define_rng;
38
39pub trait RngId: Copy + Clone {
40 type RngType: SeedableRng;
41 fn get_name() -> &'static str;
42}
43
44struct RngHolder {
47 rng: Box<dyn Any>,
48}
49
50struct RngData {
51 base_seed: u64,
52 rng_holders: RefCell<HashMap<TypeId, RngHolder>>,
53}
54
55crate::context::define_data_plugin!(
61 RngPlugin,
62 RngData,
63 RngData {
64 base_seed: 0,
65 rng_holders: RefCell::new(HashMap::new()),
66 }
67);
68
69fn get_rng<R: RngId + 'static>(context: &impl PluginContext) -> RefMut<R::RngType> {
73 let data_container = context
74 .get_data_container(RngPlugin)
75 .expect("You must initialize the random number generator with a base seed");
76
77 let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap();
78 RefMut::map(rng_holders, |holders| {
79 holders
80 .entry(TypeId::of::<R>())
81 .or_insert_with(|| {
83 trace!(
84 "creating new RNG (seed={}) for type id {:?}",
85 data_container.base_seed,
86 TypeId::of::<R>()
87 );
88 let base_seed = data_container.base_seed;
89 let seed_offset = hash_str(R::get_name());
90 RngHolder {
91 rng: Box::new(R::RngType::seed_from_u64(
92 base_seed.wrapping_add(seed_offset),
93 )),
94 }
95 })
96 .rng
97 .downcast_mut::<R::RngType>()
98 .unwrap()
99 })
100}
101
102pub trait ContextRandomExt: PluginContext {
105 fn init_random(&mut self, base_seed: u64) {
108 trace!("initializing random module");
109 let data_container = self.get_data_container_mut(RngPlugin);
110 data_container.base_seed = base_seed;
111
112 let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap();
114 rng_map.clear();
115 }
116
117 fn sample<R: RngId + 'static, T>(
122 &self,
123 _rng_type: R,
124 sampler: impl FnOnce(&mut R::RngType) -> T,
125 ) -> T {
126 let mut rng = get_rng::<R>(self);
127 sampler(&mut rng)
128 }
129
130 fn sample_distr<R: RngId + 'static, T>(
135 &self,
136 _rng_type: R,
137 distribution: impl Distribution<T>,
138 ) -> T
139 where
140 R::RngType: Rng,
141 {
142 let mut rng = get_rng::<R>(self);
143 distribution.sample::<R::RngType>(&mut rng)
144 }
145
146 fn sample_range<R: RngId + 'static, S, T>(&self, rng_id: R, range: S) -> T
150 where
151 R::RngType: Rng,
152 S: SampleRange<T>,
153 T: SampleUniform,
154 {
155 self.sample(rng_id, |rng| rng.gen_range(range))
156 }
157
158 fn sample_bool<R: RngId + 'static>(&self, rng_id: R, p: f64) -> bool
162 where
163 R::RngType: Rng,
164 {
165 self.sample(rng_id, |rng| rng.gen_bool(p))
166 }
167
168 fn sample_weighted<R: RngId + 'static, T>(&self, _rng_id: R, weights: &[T]) -> usize
173 where
174 R::RngType: Rng,
175 T: Clone + Default + SampleUniform + for<'a> std::ops::AddAssign<&'a T> + PartialOrd,
176 {
177 let index = WeightedIndex::new(weights).unwrap();
178 let mut rng = get_rng::<R>(self);
179 index.sample(&mut *rng)
180 }
181}
182impl ContextRandomExt for Context {}
183
184#[cfg(test)]
185mod test {
186 use crate::context::Context;
187 use crate::define_data_plugin;
188 use crate::random::ContextRandomExt;
189 use rand::RngCore;
190 use rand::{distributions::WeightedIndex, prelude::Distribution};
191
192 define_rng!(FooRng);
193 define_rng!(BarRng);
194
195 #[test]
196 fn get_rng_basic() {
197 let mut context = Context::new();
198 context.init_random(42);
199
200 assert_ne!(
201 context.sample(FooRng, RngCore::next_u64),
202 context.sample(FooRng, RngCore::next_u64)
203 );
204 }
205
206 #[test]
207 #[should_panic(expected = "You must initialize the random number generator with a base seed")]
208 fn panic_if_not_initialized() {
209 let context = Context::new();
210 context.sample(FooRng, RngCore::next_u64);
211 }
212
213 #[test]
214 fn multiple_rng_types() {
215 let mut context = Context::new();
216 context.init_random(42);
217
218 assert_ne!(
219 context.sample(FooRng, RngCore::next_u64),
220 context.sample(BarRng, RngCore::next_u64)
221 );
222 }
223
224 #[test]
225 fn reset_seed() {
226 let mut context = Context::new();
227 context.init_random(42);
228
229 let run_0 = context.sample(FooRng, RngCore::next_u64);
230 let run_1 = context.sample(FooRng, RngCore::next_u64);
231
232 context.init_random(42);
234 assert_eq!(run_0, context.sample(FooRng, RngCore::next_u64));
235 assert_eq!(run_1, context.sample(FooRng, RngCore::next_u64));
236
237 context.init_random(88);
239 assert_ne!(run_0, context.sample(FooRng, RngCore::next_u64));
240 assert_ne!(run_1, context.sample(FooRng, RngCore::next_u64));
241 }
242
243 define_data_plugin!(
244 SamplerData,
245 WeightedIndex<f64>,
246 WeightedIndex::new(vec![1.0]).unwrap()
247 );
248
249 #[test]
250 fn sampler_function_closure_capture() {
251 let mut context = Context::new();
252 context.init_random(42);
253
254 *context.get_data_container_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap();
257
258 let parameters = context.get_data_container(SamplerData).unwrap();
259 let n_samples = 3000;
260 let mut zero_counter = 0;
261 for _ in 0..n_samples {
262 let sample = context.sample(FooRng, |rng| parameters.sample(rng));
263 if sample == 0 {
264 zero_counter += 1;
265 }
266 }
267 assert!((zero_counter - 1000_i32).abs() < 100);
269 }
270
271 #[test]
272 fn sample_distribution() {
273 let mut context = Context::new();
274 context.init_random(42);
275
276 *context.get_data_container_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap();
279
280 let parameters = context.get_data_container(SamplerData).unwrap();
281 let n_samples = 3000;
282 let mut zero_counter = 0;
283 for _ in 0..n_samples {
284 let sample = context.sample_distr(FooRng, parameters);
285 if sample == 0 {
286 zero_counter += 1;
287 }
288 }
289 assert!((zero_counter - 1000_i32).abs() < 100);
291 }
292
293 #[test]
294 fn sample_range() {
295 let mut context = Context::new();
296 context.init_random(42);
297 let result = context.sample_range(FooRng, 0..10);
298 assert!((0..10).contains(&result));
299 }
300
301 #[test]
302 fn sample_bool() {
303 let mut context = Context::new();
304 context.init_random(42);
305 let _r: bool = context.sample_bool(FooRng, 0.5);
306 }
307
308 #[test]
309 fn sample_weighted() {
310 let mut context = Context::new();
311 context.init_random(42);
312 let r: usize = context.sample_weighted(FooRng, &[0.1, 0.3, 0.4]);
313 assert!(r < 3);
314 }
315}