ixa/
random.rs

1use crate::context::Context;
2use crate::hashing::hash_str;
3use crate::{HashMap, HashMapExt, PluginContext};
4use log::trace;
5use rand::distributions::uniform::{SampleRange, SampleUniform};
6use rand::distributions::WeightedIndex;
7use rand::prelude::Distribution;
8use rand::{Rng, SeedableRng};
9use std::any::{Any, TypeId};
10use std::cell::{RefCell, RefMut};
11
12/// Use this to define a unique type which will be used as a key to retrieve
13/// an independent rng instance when calling `.get_rng`.
14#[macro_export]
15macro_rules! define_rng {
16    ($random_id:ident) => {
17        #[derive(Copy, Clone)]
18        struct $random_id;
19
20        impl $crate::random::RngId for $random_id {
21            type RngType = $crate::rand::rngs::SmallRng;
22
23            fn get_name() -> &'static str {
24                stringify!($random_id)
25            }
26        }
27
28        // This ensures that you can't define two RngIds with the same name
29        $crate::paste::paste! {
30            #[doc(hidden)]
31            #[no_mangle]
32            #[allow(non_upper_case_globals)]
33            pub static [<rng_name_duplication_guard_ $random_id>]: () = ();
34        }
35    };
36}
37pub use define_rng;
38
39pub trait RngId: Copy + Clone {
40    type RngType: SeedableRng;
41    fn get_name() -> &'static str;
42}
43
44// This is a wrapper which allows for future support for different types of
45// random number generators (anything that implements SeedableRng is valid).
46struct RngHolder {
47    rng: Box<dyn Any>,
48}
49
50struct RngData {
51    base_seed: u64,
52    rng_holders: RefCell<HashMap<TypeId, RngHolder>>,
53}
54
55// Registers a data container which stores:
56// * base_seed: A base seed for all rngs
57// * rng_holders: A map of rngs, keyed by their RngId. Note that this is
58//   stored in a RefCell to allow for mutable borrow without requiring a
59//   mutable borrow of the Context itself.
60crate::context::define_data_plugin!(
61    RngPlugin,
62    RngData,
63    RngData {
64        base_seed: 0,
65        rng_holders: RefCell::new(HashMap::new()),
66    }
67);
68
69/// Gets a mutable reference to the random number generator associated with the given
70/// `RngId`. If the Rng has not been used before, one will be created with the base seed
71/// you defined in `init`. Note that this will panic if `init` was not called yet.
72fn get_rng<R: RngId + 'static>(context: &impl PluginContext) -> RefMut<R::RngType> {
73    let data_container = context
74        .get_data_container(RngPlugin)
75        .expect("You must initialize the random number generator with a base seed");
76
77    let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap();
78    RefMut::map(rng_holders, |holders| {
79        holders
80            .entry(TypeId::of::<R>())
81            // Create a new rng holder if it doesn't exist yet
82            .or_insert_with(|| {
83                trace!(
84                    "creating new RNG (seed={}) for type id {:?}",
85                    data_container.base_seed,
86                    TypeId::of::<R>()
87                );
88                let base_seed = data_container.base_seed;
89                let seed_offset = hash_str(R::get_name());
90                RngHolder {
91                    rng: Box::new(R::RngType::seed_from_u64(
92                        base_seed.wrapping_add(seed_offset),
93                    )),
94                }
95            })
96            .rng
97            .downcast_mut::<R::RngType>()
98            .unwrap()
99    })
100}
101
102// This is a trait extension on Context for
103// random number generation functionality.
104pub trait ContextRandomExt: PluginContext {
105    /// Initializes the `RngPlugin` data container to store rngs as well as a base
106    /// seed. Note that rngs are created lazily when `get_rng` is called.
107    fn init_random(&mut self, base_seed: u64) {
108        trace!("initializing random module");
109        let data_container = self.get_data_container_mut(RngPlugin);
110        data_container.base_seed = base_seed;
111
112        // Clear any existing Rngs to ensure they get re-seeded when `get_rng` is called
113        let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap();
114        rng_map.clear();
115    }
116
117    /// Gets a random sample from the random number generator associated with the given
118    /// `RngId` by applying the specified sampler function. If the Rng has not been used
119    /// before, one will be created with the base seed you defined in `set_base_random_seed`.
120    /// Note that this will panic if `set_base_random_seed` was not called yet.
121    fn sample<R: RngId + 'static, T>(
122        &self,
123        _rng_type: R,
124        sampler: impl FnOnce(&mut R::RngType) -> T,
125    ) -> T {
126        let mut rng = get_rng::<R>(self);
127        sampler(&mut rng)
128    }
129
130    /// Gets a random sample from the specified distribution using a random number generator
131    /// associated with the given `RngId`. If the Rng has not been used before, one will be
132    /// created with the base seed you defined in `set_base_random_seed`.
133    /// Note that this will panic if `set_base_random_seed` was not called yet.
134    fn sample_distr<R: RngId + 'static, T>(
135        &self,
136        _rng_type: R,
137        distribution: impl Distribution<T>,
138    ) -> T
139    where
140        R::RngType: Rng,
141    {
142        let mut rng = get_rng::<R>(self);
143        distribution.sample::<R::RngType>(&mut rng)
144    }
145
146    /// Gets a random sample within the range provided by `range`
147    /// using the generator associated with the given `RngId`.
148    /// Note that this will panic if `set_base_random_seed` was not called yet.
149    fn sample_range<R: RngId + 'static, S, T>(&self, rng_id: R, range: S) -> T
150    where
151        R::RngType: Rng,
152        S: SampleRange<T>,
153        T: SampleUniform,
154    {
155        self.sample(rng_id, |rng| rng.gen_range(range))
156    }
157
158    /// Gets a random boolean value which is true with probability `p`
159    /// using the generator associated with the given `RngId`.
160    /// Note that this will panic if `set_base_random_seed` was not called yet.
161    fn sample_bool<R: RngId + 'static>(&self, rng_id: R, p: f64) -> bool
162    where
163        R::RngType: Rng,
164    {
165        self.sample(rng_id, |rng| rng.gen_bool(p))
166    }
167
168    /// Draws a random entry out of the list provided in `weights`
169    /// with the given weights using the generator associated with the
170    /// given `RngId`.  Note that this will panic if
171    /// `set_base_random_seed` was not called yet.
172    fn sample_weighted<R: RngId + 'static, T>(&self, _rng_id: R, weights: &[T]) -> usize
173    where
174        R::RngType: Rng,
175        T: Clone + Default + SampleUniform + for<'a> std::ops::AddAssign<&'a T> + PartialOrd,
176    {
177        let index = WeightedIndex::new(weights).unwrap();
178        let mut rng = get_rng::<R>(self);
179        index.sample(&mut *rng)
180    }
181}
182impl ContextRandomExt for Context {}
183
184#[cfg(test)]
185mod test {
186    use crate::context::Context;
187    use crate::define_data_plugin;
188    use crate::random::ContextRandomExt;
189    use rand::RngCore;
190    use rand::{distributions::WeightedIndex, prelude::Distribution};
191
192    define_rng!(FooRng);
193    define_rng!(BarRng);
194
195    #[test]
196    fn get_rng_basic() {
197        let mut context = Context::new();
198        context.init_random(42);
199
200        assert_ne!(
201            context.sample(FooRng, RngCore::next_u64),
202            context.sample(FooRng, RngCore::next_u64)
203        );
204    }
205
206    #[test]
207    #[should_panic(expected = "You must initialize the random number generator with a base seed")]
208    fn panic_if_not_initialized() {
209        let context = Context::new();
210        context.sample(FooRng, RngCore::next_u64);
211    }
212
213    #[test]
214    fn multiple_rng_types() {
215        let mut context = Context::new();
216        context.init_random(42);
217
218        assert_ne!(
219            context.sample(FooRng, RngCore::next_u64),
220            context.sample(BarRng, RngCore::next_u64)
221        );
222    }
223
224    #[test]
225    fn reset_seed() {
226        let mut context = Context::new();
227        context.init_random(42);
228
229        let run_0 = context.sample(FooRng, RngCore::next_u64);
230        let run_1 = context.sample(FooRng, RngCore::next_u64);
231
232        // Reset with same seed, ensure we get the same values
233        context.init_random(42);
234        assert_eq!(run_0, context.sample(FooRng, RngCore::next_u64));
235        assert_eq!(run_1, context.sample(FooRng, RngCore::next_u64));
236
237        // Reset with different seed, ensure we get different values
238        context.init_random(88);
239        assert_ne!(run_0, context.sample(FooRng, RngCore::next_u64));
240        assert_ne!(run_1, context.sample(FooRng, RngCore::next_u64));
241    }
242
243    define_data_plugin!(
244        SamplerData,
245        WeightedIndex<f64>,
246        WeightedIndex::new(vec![1.0]).unwrap()
247    );
248
249    #[test]
250    fn sampler_function_closure_capture() {
251        let mut context = Context::new();
252        context.init_random(42);
253
254        // Initialize weighted sampler. Zero is selected with probability 1/3, one with a
255        // probability of 2/3.
256        *context.get_data_container_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap();
257
258        let parameters = context.get_data_container(SamplerData).unwrap();
259        let n_samples = 3000;
260        let mut zero_counter = 0;
261        for _ in 0..n_samples {
262            let sample = context.sample(FooRng, |rng| parameters.sample(rng));
263            if sample == 0 {
264                zero_counter += 1;
265            }
266        }
267        // The expected value of `zero_counter` is 1000.
268        assert!((zero_counter - 1000_i32).abs() < 100);
269    }
270
271    #[test]
272    fn sample_distribution() {
273        let mut context = Context::new();
274        context.init_random(42);
275
276        // Initialize weighted sampler. Zero is selected with probability 1/3, one with a
277        // probability of 2/3.
278        *context.get_data_container_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap();
279
280        let parameters = context.get_data_container(SamplerData).unwrap();
281        let n_samples = 3000;
282        let mut zero_counter = 0;
283        for _ in 0..n_samples {
284            let sample = context.sample_distr(FooRng, parameters);
285            if sample == 0 {
286                zero_counter += 1;
287            }
288        }
289        // The expected value of `zero_counter` is 1000.
290        assert!((zero_counter - 1000_i32).abs() < 100);
291    }
292
293    #[test]
294    fn sample_range() {
295        let mut context = Context::new();
296        context.init_random(42);
297        let result = context.sample_range(FooRng, 0..10);
298        assert!((0..10).contains(&result));
299    }
300
301    #[test]
302    fn sample_bool() {
303        let mut context = Context::new();
304        context.init_random(42);
305        let _r: bool = context.sample_bool(FooRng, 0.5);
306    }
307
308    #[test]
309    fn sample_weighted() {
310        let mut context = Context::new();
311        context.init_random(42);
312        let r: usize = context.sample_weighted(FooRng, &[0.1, 0.3, 0.4]);
313        assert!(r < 3);
314    }
315}