1use crate::context::Context;
2use crate::hashing::hash_str;
3use crate::{define_data_plugin, HashMap, HashMapExt, PluginContext};
4use log::trace;
5use rand::distributions::uniform::{SampleRange, SampleUniform};
6use rand::distributions::WeightedIndex;
7use rand::prelude::Distribution;
8use rand::{Rng, SeedableRng};
9use std::any::{Any, TypeId};
10use std::cell::{RefCell, RefMut};
11
12#[macro_export]
15macro_rules! define_rng {
16 ($random_id:ident) => {
17 #[derive(Copy, Clone)]
18 struct $random_id;
19
20 impl $crate::random::RngId for $random_id {
21 type RngType = $crate::rand::rngs::SmallRng;
22
23 fn get_name() -> &'static str {
24 stringify!($random_id)
25 }
26 }
27
28 $crate::paste::paste! {
30 #[doc(hidden)]
31 #[no_mangle]
32 #[allow(non_upper_case_globals)]
33 pub static [<rng_name_duplication_guard_ $random_id>]: () = ();
34 }
35 };
36}
37pub use define_rng;
38
39pub trait RngId: Copy + Clone {
40 type RngType: SeedableRng;
41 fn get_name() -> &'static str;
42}
43
44struct RngHolder {
47 rng: Box<dyn Any>,
48}
49
50struct RngData {
51 base_seed: u64,
52 rng_holders: RefCell<HashMap<TypeId, RngHolder>>,
53}
54
55define_data_plugin!(
61 RngPlugin,
62 RngData,
63 RngData {
64 base_seed: 0,
65 rng_holders: RefCell::new(HashMap::new()),
66 }
67);
68
69fn get_rng<R: RngId + 'static>(context: &impl PluginContext) -> RefMut<R::RngType> {
73 let data_container = context.get_data(RngPlugin);
74
75 let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap();
76 RefMut::map(rng_holders, |holders| {
77 holders
78 .entry(TypeId::of::<R>())
79 .or_insert_with(|| {
81 trace!(
82 "creating new RNG (seed={}) for type id {:?}",
83 data_container.base_seed,
84 TypeId::of::<R>()
85 );
86 let base_seed = data_container.base_seed;
87 let seed_offset = hash_str(R::get_name());
88 RngHolder {
89 rng: Box::new(R::RngType::seed_from_u64(
90 base_seed.wrapping_add(seed_offset),
91 )),
92 }
93 })
94 .rng
95 .downcast_mut::<R::RngType>()
96 .unwrap()
97 })
98}
99
100pub trait ContextRandomExt: PluginContext {
103 fn init_random(&mut self, base_seed: u64) {
106 trace!("initializing random module");
107 let data_container = self.get_data_mut(RngPlugin);
108 data_container.base_seed = base_seed;
109
110 let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap();
112 rng_map.clear();
113 }
114
115 fn sample<R: RngId + 'static, T>(
120 &self,
121 _rng_type: R,
122 sampler: impl FnOnce(&mut R::RngType) -> T,
123 ) -> T {
124 let mut rng = get_rng::<R>(self);
125 sampler(&mut rng)
126 }
127
128 fn sample_distr<R: RngId + 'static, T>(
133 &self,
134 _rng_type: R,
135 distribution: impl Distribution<T>,
136 ) -> T
137 where
138 R::RngType: Rng,
139 {
140 let mut rng = get_rng::<R>(self);
141 distribution.sample::<R::RngType>(&mut rng)
142 }
143
144 fn sample_range<R: RngId + 'static, S, T>(&self, rng_id: R, range: S) -> T
148 where
149 R::RngType: Rng,
150 S: SampleRange<T>,
151 T: SampleUniform,
152 {
153 self.sample(rng_id, |rng| rng.gen_range(range))
154 }
155
156 fn sample_bool<R: RngId + 'static>(&self, rng_id: R, p: f64) -> bool
160 where
161 R::RngType: Rng,
162 {
163 self.sample(rng_id, |rng| rng.gen_bool(p))
164 }
165
166 fn sample_weighted<R: RngId + 'static, T>(&self, _rng_id: R, weights: &[T]) -> usize
171 where
172 R::RngType: Rng,
173 T: Clone + Default + SampleUniform + for<'a> std::ops::AddAssign<&'a T> + PartialOrd,
174 {
175 let index = WeightedIndex::new(weights).unwrap();
176 let mut rng = get_rng::<R>(self);
177 index.sample(&mut *rng)
178 }
179}
180impl ContextRandomExt for Context {}
181
182#[cfg(test)]
183mod test {
184 use crate::context::Context;
185 use crate::define_data_plugin;
186 use crate::random::ContextRandomExt;
187 use rand::RngCore;
188 use rand::{distributions::WeightedIndex, prelude::Distribution};
189
190 define_rng!(FooRng);
191 define_rng!(BarRng);
192
193 #[test]
194 fn get_rng_basic() {
195 let mut context = Context::new();
196 context.init_random(42);
197
198 assert_ne!(
199 context.sample(FooRng, RngCore::next_u64),
200 context.sample(FooRng, RngCore::next_u64)
201 );
202 }
203
204 #[test]
205 fn multiple_rng_types() {
206 let mut context = Context::new();
207 context.init_random(42);
208
209 assert_ne!(
210 context.sample(FooRng, RngCore::next_u64),
211 context.sample(BarRng, RngCore::next_u64)
212 );
213 }
214
215 #[test]
216 fn reset_seed() {
217 let mut context = Context::new();
218 context.init_random(42);
219
220 let run_0 = context.sample(FooRng, RngCore::next_u64);
221 let run_1 = context.sample(FooRng, RngCore::next_u64);
222
223 context.init_random(42);
225 assert_eq!(run_0, context.sample(FooRng, RngCore::next_u64));
226 assert_eq!(run_1, context.sample(FooRng, RngCore::next_u64));
227
228 context.init_random(88);
230 assert_ne!(run_0, context.sample(FooRng, RngCore::next_u64));
231 assert_ne!(run_1, context.sample(FooRng, RngCore::next_u64));
232 }
233
234 define_data_plugin!(
235 SamplerData,
236 WeightedIndex<f64>,
237 WeightedIndex::new(vec![1.0]).unwrap()
238 );
239
240 #[test]
241 fn sampler_function_closure_capture() {
242 let mut context = Context::new();
243 context.init_random(42);
244
245 *context.get_data_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap();
248
249 let parameters = context.get_data(SamplerData);
250 let n_samples = 3000;
251 let mut zero_counter = 0;
252 for _ in 0..n_samples {
253 let sample = context.sample(FooRng, |rng| parameters.sample(rng));
254 if sample == 0 {
255 zero_counter += 1;
256 }
257 }
258 assert!((zero_counter - 1000_i32).abs() < 100);
260 }
261
262 #[test]
263 fn sample_distribution() {
264 let mut context = Context::new();
265 context.init_random(42);
266
267 *context.get_data_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap();
270
271 let parameters = context.get_data(SamplerData);
272 let n_samples = 3000;
273 let mut zero_counter = 0;
274 for _ in 0..n_samples {
275 let sample = context.sample_distr(FooRng, parameters);
276 if sample == 0 {
277 zero_counter += 1;
278 }
279 }
280 assert!((zero_counter - 1000_i32).abs() < 100);
282 }
283
284 #[test]
285 fn sample_range() {
286 let mut context = Context::new();
287 context.init_random(42);
288 let result = context.sample_range(FooRng, 0..10);
289 assert!((0..10).contains(&result));
290 }
291
292 #[test]
293 fn sample_bool() {
294 let mut context = Context::new();
295 context.init_random(42);
296 let _r: bool = context.sample_bool(FooRng, 0.5);
297 }
298
299 #[test]
300 fn sample_weighted() {
301 let mut context = Context::new();
302 context.init_random(42);
303 let r: usize = context.sample_weighted(FooRng, &[0.1, 0.3, 0.4]);
304 assert!(r < 3);
305 }
306}