ixa/
random.rs

1use crate::context::Context;
2use crate::hashing::hash_str;
3use crate::{define_data_plugin, HashMap, HashMapExt, PluginContext};
4use log::trace;
5use rand::distributions::uniform::{SampleRange, SampleUniform};
6use rand::distributions::WeightedIndex;
7use rand::prelude::Distribution;
8use rand::{Rng, SeedableRng};
9use std::any::{Any, TypeId};
10use std::cell::{RefCell, RefMut};
11
12/// Use this to define a unique type which will be used as a key to retrieve
13/// an independent rng instance when calling `.get_rng`.
14#[macro_export]
15macro_rules! define_rng {
16    ($random_id:ident) => {
17        #[derive(Copy, Clone)]
18        struct $random_id;
19
20        impl $crate::random::RngId for $random_id {
21            type RngType = $crate::rand::rngs::SmallRng;
22
23            fn get_name() -> &'static str {
24                stringify!($random_id)
25            }
26        }
27
28        // This ensures that you can't define two RngIds with the same name
29        $crate::paste::paste! {
30            #[doc(hidden)]
31            #[no_mangle]
32            #[allow(non_upper_case_globals)]
33            pub static [<rng_name_duplication_guard_ $random_id>]: () = ();
34        }
35    };
36}
37pub use define_rng;
38
39pub trait RngId: Copy + Clone {
40    type RngType: SeedableRng;
41    fn get_name() -> &'static str;
42}
43
44// This is a wrapper which allows for future support for different types of
45// random number generators (anything that implements SeedableRng is valid).
46struct RngHolder {
47    rng: Box<dyn Any>,
48}
49
50struct RngData {
51    base_seed: u64,
52    rng_holders: RefCell<HashMap<TypeId, RngHolder>>,
53}
54
55// Registers a data container which stores:
56// * base_seed: A base seed for all rngs
57// * rng_holders: A map of rngs, keyed by their RngId. Note that this is
58//   stored in a RefCell to allow for mutable borrow without requiring a
59//   mutable borrow of the Context itself.
60define_data_plugin!(
61    RngPlugin,
62    RngData,
63    RngData {
64        base_seed: 0,
65        rng_holders: RefCell::new(HashMap::new()),
66    }
67);
68
69/// Gets a mutable reference to the random number generator associated with the given
70/// `RngId`. If the Rng has not been used before, one will be created with the base seed
71/// you defined in `init`. Note that this will panic if `init` was not called yet.
72fn get_rng<R: RngId + 'static>(context: &impl PluginContext) -> RefMut<R::RngType> {
73    let data_container = context.get_data(RngPlugin);
74
75    let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap();
76    RefMut::map(rng_holders, |holders| {
77        holders
78            .entry(TypeId::of::<R>())
79            // Create a new rng holder if it doesn't exist yet
80            .or_insert_with(|| {
81                trace!(
82                    "creating new RNG (seed={}) for type id {:?}",
83                    data_container.base_seed,
84                    TypeId::of::<R>()
85                );
86                let base_seed = data_container.base_seed;
87                let seed_offset = hash_str(R::get_name());
88                RngHolder {
89                    rng: Box::new(R::RngType::seed_from_u64(
90                        base_seed.wrapping_add(seed_offset),
91                    )),
92                }
93            })
94            .rng
95            .downcast_mut::<R::RngType>()
96            .unwrap()
97    })
98}
99
100// This is a trait extension on Context for
101// random number generation functionality.
102pub trait ContextRandomExt: PluginContext {
103    /// Initializes the `RngPlugin` data container to store rngs as well as a base
104    /// seed. Note that rngs are created lazily when `get_rng` is called.
105    fn init_random(&mut self, base_seed: u64) {
106        trace!("initializing random module");
107        let data_container = self.get_data_mut(RngPlugin);
108        data_container.base_seed = base_seed;
109
110        // Clear any existing Rngs to ensure they get re-seeded when `get_rng` is called
111        let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap();
112        rng_map.clear();
113    }
114
115    /// Gets a random sample from the random number generator associated with the given
116    /// `RngId` by applying the specified sampler function. If the Rng has not been used
117    /// before, one will be created with the base seed you defined in `set_base_random_seed`.
118    /// Note that this will panic if `set_base_random_seed` was not called yet.
119    fn sample<R: RngId + 'static, T>(
120        &self,
121        _rng_type: R,
122        sampler: impl FnOnce(&mut R::RngType) -> T,
123    ) -> T {
124        let mut rng = get_rng::<R>(self);
125        sampler(&mut rng)
126    }
127
128    /// Gets a random sample from the specified distribution using a random number generator
129    /// associated with the given `RngId`. If the Rng has not been used before, one will be
130    /// created with the base seed you defined in `set_base_random_seed`.
131    /// Note that this will panic if `set_base_random_seed` was not called yet.
132    fn sample_distr<R: RngId + 'static, T>(
133        &self,
134        _rng_type: R,
135        distribution: impl Distribution<T>,
136    ) -> T
137    where
138        R::RngType: Rng,
139    {
140        let mut rng = get_rng::<R>(self);
141        distribution.sample::<R::RngType>(&mut rng)
142    }
143
144    /// Gets a random sample within the range provided by `range`
145    /// using the generator associated with the given `RngId`.
146    /// Note that this will panic if `set_base_random_seed` was not called yet.
147    fn sample_range<R: RngId + 'static, S, T>(&self, rng_id: R, range: S) -> T
148    where
149        R::RngType: Rng,
150        S: SampleRange<T>,
151        T: SampleUniform,
152    {
153        self.sample(rng_id, |rng| rng.gen_range(range))
154    }
155
156    /// Gets a random boolean value which is true with probability `p`
157    /// using the generator associated with the given `RngId`.
158    /// Note that this will panic if `set_base_random_seed` was not called yet.
159    fn sample_bool<R: RngId + 'static>(&self, rng_id: R, p: f64) -> bool
160    where
161        R::RngType: Rng,
162    {
163        self.sample(rng_id, |rng| rng.gen_bool(p))
164    }
165
166    /// Draws a random entry out of the list provided in `weights`
167    /// with the given weights using the generator associated with the
168    /// given `RngId`.  Note that this will panic if
169    /// `set_base_random_seed` was not called yet.
170    fn sample_weighted<R: RngId + 'static, T>(&self, _rng_id: R, weights: &[T]) -> usize
171    where
172        R::RngType: Rng,
173        T: Clone + Default + SampleUniform + for<'a> std::ops::AddAssign<&'a T> + PartialOrd,
174    {
175        let index = WeightedIndex::new(weights).unwrap();
176        let mut rng = get_rng::<R>(self);
177        index.sample(&mut *rng)
178    }
179}
180impl ContextRandomExt for Context {}
181
182#[cfg(test)]
183mod test {
184    use crate::context::Context;
185    use crate::define_data_plugin;
186    use crate::random::ContextRandomExt;
187    use rand::RngCore;
188    use rand::{distributions::WeightedIndex, prelude::Distribution};
189
190    define_rng!(FooRng);
191    define_rng!(BarRng);
192
193    #[test]
194    fn get_rng_basic() {
195        let mut context = Context::new();
196        context.init_random(42);
197
198        assert_ne!(
199            context.sample(FooRng, RngCore::next_u64),
200            context.sample(FooRng, RngCore::next_u64)
201        );
202    }
203
204    #[test]
205    fn multiple_rng_types() {
206        let mut context = Context::new();
207        context.init_random(42);
208
209        assert_ne!(
210            context.sample(FooRng, RngCore::next_u64),
211            context.sample(BarRng, RngCore::next_u64)
212        );
213    }
214
215    #[test]
216    fn reset_seed() {
217        let mut context = Context::new();
218        context.init_random(42);
219
220        let run_0 = context.sample(FooRng, RngCore::next_u64);
221        let run_1 = context.sample(FooRng, RngCore::next_u64);
222
223        // Reset with same seed, ensure we get the same values
224        context.init_random(42);
225        assert_eq!(run_0, context.sample(FooRng, RngCore::next_u64));
226        assert_eq!(run_1, context.sample(FooRng, RngCore::next_u64));
227
228        // Reset with different seed, ensure we get different values
229        context.init_random(88);
230        assert_ne!(run_0, context.sample(FooRng, RngCore::next_u64));
231        assert_ne!(run_1, context.sample(FooRng, RngCore::next_u64));
232    }
233
234    define_data_plugin!(
235        SamplerData,
236        WeightedIndex<f64>,
237        WeightedIndex::new(vec![1.0]).unwrap()
238    );
239
240    #[test]
241    fn sampler_function_closure_capture() {
242        let mut context = Context::new();
243        context.init_random(42);
244
245        // Initialize weighted sampler. Zero is selected with probability 1/3, one with a
246        // probability of 2/3.
247        *context.get_data_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap();
248
249        let parameters = context.get_data(SamplerData);
250        let n_samples = 3000;
251        let mut zero_counter = 0;
252        for _ in 0..n_samples {
253            let sample = context.sample(FooRng, |rng| parameters.sample(rng));
254            if sample == 0 {
255                zero_counter += 1;
256            }
257        }
258        // The expected value of `zero_counter` is 1000.
259        assert!((zero_counter - 1000_i32).abs() < 100);
260    }
261
262    #[test]
263    fn sample_distribution() {
264        let mut context = Context::new();
265        context.init_random(42);
266
267        // Initialize weighted sampler. Zero is selected with probability 1/3, one with a
268        // probability of 2/3.
269        *context.get_data_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap();
270
271        let parameters = context.get_data(SamplerData);
272        let n_samples = 3000;
273        let mut zero_counter = 0;
274        for _ in 0..n_samples {
275            let sample = context.sample_distr(FooRng, parameters);
276            if sample == 0 {
277                zero_counter += 1;
278            }
279        }
280        // The expected value of `zero_counter` is 1000.
281        assert!((zero_counter - 1000_i32).abs() < 100);
282    }
283
284    #[test]
285    fn sample_range() {
286        let mut context = Context::new();
287        context.init_random(42);
288        let result = context.sample_range(FooRng, 0..10);
289        assert!((0..10).contains(&result));
290    }
291
292    #[test]
293    fn sample_bool() {
294        let mut context = Context::new();
295        context.init_random(42);
296        let _r: bool = context.sample_bool(FooRng, 0.5);
297    }
298
299    #[test]
300    fn sample_weighted() {
301        let mut context = Context::new();
302        context.init_random(42);
303        let r: usize = context.sample_weighted(FooRng, &[0.1, 0.3, 0.4]);
304        assert!(r < 3);
305    }
306}