ixa/random/
context_ext.rs

1use std::any::TypeId;
2use std::cell::RefMut;
3
4use log::trace;
5
6use crate::hashing::hash_str;
7use crate::rand::distr::uniform::{SampleRange, SampleUniform};
8use crate::rand::distr::weighted::{Weight, WeightedIndex};
9use crate::rand::distr::Distribution;
10use crate::rand::{Rng, SeedableRng};
11use crate::random::{RngHolder, RngPlugin};
12use crate::{Context, ContextBase, RngId};
13
14/// Gets a mutable reference to the random number generator associated with the given
15/// [`RngId`]. If the Rng has not been used before, one will be created with the base seed
16/// you defined in `init`. Note that this will panic if `init` was not called yet.
17fn get_rng<R: RngId + 'static>(context: &impl ContextBase) -> RefMut<R::RngType> {
18    let data_container = context.get_data(RngPlugin);
19
20    let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap();
21    RefMut::map(rng_holders, |holders| {
22        holders
23            .entry(TypeId::of::<R>())
24            // Create a new rng holder if it doesn't exist yet
25            .or_insert_with(|| {
26                trace!(
27                    "creating new RNG (seed={}) for type id {:?}",
28                    data_container.base_seed,
29                    TypeId::of::<R>()
30                );
31                let base_seed = data_container.base_seed;
32                let seed_offset = hash_str(R::get_name());
33                RngHolder {
34                    rng: Box::new(R::RngType::seed_from_u64(
35                        base_seed.wrapping_add(seed_offset),
36                    )),
37                }
38            })
39            .rng
40            .downcast_mut::<R::RngType>()
41            .unwrap()
42    })
43}
44
45// This is a trait extension on Context for
46// random number generation functionality.
47pub trait ContextRandomExt: ContextBase {
48    /// Initializes the `RngPlugin` data container to store rngs as well as a base
49    /// seed. Note that rngs are created lazily when `get_rng` is called.
50    fn init_random(&mut self, base_seed: u64) {
51        trace!("initializing random module");
52        let data_container = self.get_data_mut(RngPlugin);
53        data_container.base_seed = base_seed;
54
55        // Clear any existing Rngs to ensure they get re-seeded when `get_rng` is called
56        let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap();
57        rng_map.clear();
58    }
59
60    /// Gets a random sample from the random number generator associated with the given
61    /// [`RngId`] by applying the specified sampler function. If the Rng has not been used
62    /// before, one will be created with the base seed you defined in `set_base_random_seed`.
63    /// Note that this will panic if `set_base_random_seed` was not called yet.
64    fn sample<R: RngId + 'static, T>(
65        &self,
66        _rng_type: R,
67        sampler: impl FnOnce(&mut R::RngType) -> T,
68    ) -> T {
69        let mut rng = get_rng::<R>(self);
70        sampler(&mut rng)
71    }
72
73    /// Gets a random sample from the specified distribution using a random number generator
74    /// associated with the given [`RngId`]. If the Rng has not been used before, one will be
75    /// created with the base seed you defined in `set_base_random_seed`.
76    /// Note that this will panic if `set_base_random_seed` was not called yet.
77    fn sample_distr<R: RngId + 'static, T>(
78        &self,
79        _rng_type: R,
80        distribution: impl Distribution<T>,
81    ) -> T
82    where
83        R::RngType: Rng,
84    {
85        let mut rng = get_rng::<R>(self);
86        distribution.sample::<R::RngType>(&mut rng)
87    }
88
89    /// Gets a random sample within the range provided by `range`
90    /// using the generator associated with the given [`RngId`].
91    /// Note that this will panic if `set_base_random_seed` was not called yet.
92    fn sample_range<R: RngId + 'static, S, T>(&self, rng_id: R, range: S) -> T
93    where
94        R::RngType: Rng,
95        S: SampleRange<T>,
96        T: SampleUniform,
97    {
98        self.sample(rng_id, |rng| rng.random_range(range))
99    }
100
101    /// Gets a random boolean value which is true with probability `p`
102    /// using the generator associated with the given [`RngId`].
103    /// Note that this will panic if `set_base_random_seed` was not called yet.
104    fn sample_bool<R: RngId + 'static>(&self, rng_id: R, p: f64) -> bool
105    where
106        R::RngType: Rng,
107    {
108        self.sample(rng_id, |rng| rng.random_bool(p))
109    }
110
111    /// Draws a random entry out of the list provided in `weights`
112    /// with the given weights using the generator associated with the
113    /// given [`RngId`].  Note that this will panic if
114    /// `set_base_random_seed` was not called yet.
115    fn sample_weighted<R: RngId + 'static, T>(&self, _rng_id: R, weights: &[T]) -> usize
116    where
117        R::RngType: Rng,
118        T: Clone
119            + Default
120            + SampleUniform
121            + for<'a> std::ops::AddAssign<&'a T>
122            + PartialOrd
123            + Weight,
124    {
125        let index = WeightedIndex::new(weights).unwrap();
126        let mut rng = get_rng::<R>(self);
127        index.sample(&mut *rng)
128    }
129}
130
131impl ContextRandomExt for Context {}
132
133#[cfg(test)]
134mod test {
135    use crate::context::Context;
136    use crate::rand::distr::weighted::WeightedIndex;
137    use crate::rand::distr::Distribution;
138    use crate::rand::RngCore;
139    use crate::random::context_ext::ContextRandomExt;
140    use crate::{define_data_plugin, define_rng};
141
142    define_rng!(FooRng);
143    define_rng!(BarRng);
144
145    #[test]
146    fn get_rng_basic() {
147        let mut context = Context::new();
148        context.init_random(42);
149
150        assert_ne!(
151            context.sample(FooRng, RngCore::next_u64),
152            context.sample(FooRng, RngCore::next_u64)
153        );
154    }
155
156    #[test]
157    fn multiple_rng_types() {
158        let mut context = Context::new();
159        context.init_random(42);
160
161        assert_ne!(
162            context.sample(FooRng, RngCore::next_u64),
163            context.sample(BarRng, RngCore::next_u64)
164        );
165    }
166
167    #[test]
168    fn reset_seed() {
169        let mut context = Context::new();
170        context.init_random(42);
171
172        let run_0 = context.sample(FooRng, RngCore::next_u64);
173        let run_1 = context.sample(FooRng, RngCore::next_u64);
174
175        // Reset with same seed, ensure we get the same values
176        context.init_random(42);
177        assert_eq!(run_0, context.sample(FooRng, RngCore::next_u64));
178        assert_eq!(run_1, context.sample(FooRng, RngCore::next_u64));
179
180        // Reset with different seed, ensure we get different values
181        context.init_random(88);
182        assert_ne!(run_0, context.sample(FooRng, RngCore::next_u64));
183        assert_ne!(run_1, context.sample(FooRng, RngCore::next_u64));
184    }
185
186    define_data_plugin!(
187        SamplerData,
188        WeightedIndex<f64>,
189        WeightedIndex::new(vec![1.0]).unwrap()
190    );
191
192    #[test]
193    fn sampler_function_closure_capture() {
194        let mut context = Context::new();
195        context.init_random(42);
196
197        // Initialize weighted sampler. Zero is selected with probability 1/3, one with a
198        // probability of 2/3.
199        *context.get_data_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap();
200
201        let parameters = context.get_data(SamplerData);
202        let n_samples = 3000;
203        let mut zero_counter = 0;
204        for _ in 0..n_samples {
205            let sample = context.sample(FooRng, |rng| parameters.sample(rng));
206            if sample == 0 {
207                zero_counter += 1;
208            }
209        }
210        // The expected value of `zero_counter` is 1000.
211        assert!((zero_counter - 1000_i32).abs() < 100);
212    }
213
214    #[test]
215    fn sample_distribution() {
216        let mut context = Context::new();
217        context.init_random(42);
218
219        // Initialize weighted sampler. Zero is selected with probability 1/3, one with a
220        // probability of 2/3.
221        *context.get_data_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap();
222
223        let parameters = context.get_data(SamplerData);
224        let n_samples = 3000;
225        let mut zero_counter = 0;
226        for _ in 0..n_samples {
227            let sample = context.sample_distr(FooRng, parameters);
228            if sample == 0 {
229                zero_counter += 1;
230            }
231        }
232        // The expected value of `zero_counter` is 1000.
233        assert!((zero_counter - 1000_i32).abs() < 100);
234    }
235
236    #[test]
237    fn sample_range() {
238        let mut context = Context::new();
239        context.init_random(42);
240        let result = context.sample_range(FooRng, 0..10);
241        assert!((0..10).contains(&result));
242    }
243
244    #[test]
245    fn sample_bool() {
246        let mut context = Context::new();
247        context.init_random(42);
248        let _r: bool = context.sample_bool(FooRng, 0.5);
249    }
250
251    #[test]
252    fn sample_weighted() {
253        let mut context = Context::new();
254        context.init_random(42);
255        let r: usize = context.sample_weighted(FooRng, &[0.1, 0.3, 0.4]);
256        assert!(r < 3);
257    }
258}