ixa/random/
context_ext.rs

1use crate::rand::distr::uniform::{SampleRange, SampleUniform};
2use crate::rand::distr::weighted::{Weight, WeightedIndex};
3use crate::rand::distr::Distribution;
4use crate::rand::{Rng, SeedableRng};
5use crate::{
6    hashing::hash_str,
7    random::{RngHolder, RngPlugin},
8    Context, ContextBase, RngId,
9};
10use log::trace;
11use std::{any::TypeId, cell::RefMut};
12
13/// Gets a mutable reference to the random number generator associated with the given
14/// `RngId`. If the Rng has not been used before, one will be created with the base seed
15/// you defined in `init`. Note that this will panic if `init` was not called yet.
16fn get_rng<R: RngId + 'static>(context: &impl ContextBase) -> RefMut<R::RngType> {
17    let data_container = context.get_data(RngPlugin);
18
19    let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap();
20    RefMut::map(rng_holders, |holders| {
21        holders
22            .entry(TypeId::of::<R>())
23            // Create a new rng holder if it doesn't exist yet
24            .or_insert_with(|| {
25                trace!(
26                    "creating new RNG (seed={}) for type id {:?}",
27                    data_container.base_seed,
28                    TypeId::of::<R>()
29                );
30                let base_seed = data_container.base_seed;
31                let seed_offset = hash_str(R::get_name());
32                RngHolder {
33                    rng: Box::new(R::RngType::seed_from_u64(
34                        base_seed.wrapping_add(seed_offset),
35                    )),
36                }
37            })
38            .rng
39            .downcast_mut::<R::RngType>()
40            .unwrap()
41    })
42}
43
44// This is a trait extension on Context for
45// random number generation functionality.
46pub trait ContextRandomExt: ContextBase {
47    /// Initializes the `RngPlugin` data container to store rngs as well as a base
48    /// seed. Note that rngs are created lazily when `get_rng` is called.
49    fn init_random(&mut self, base_seed: u64) {
50        trace!("initializing random module");
51        let data_container = self.get_data_mut(RngPlugin);
52        data_container.base_seed = base_seed;
53
54        // Clear any existing Rngs to ensure they get re-seeded when `get_rng` is called
55        let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap();
56        rng_map.clear();
57    }
58
59    /// Gets a random sample from the random number generator associated with the given
60    /// `RngId` by applying the specified sampler function. If the Rng has not been used
61    /// before, one will be created with the base seed you defined in `set_base_random_seed`.
62    /// Note that this will panic if `set_base_random_seed` was not called yet.
63    fn sample<R: RngId + 'static, T>(
64        &self,
65        _rng_type: R,
66        sampler: impl FnOnce(&mut R::RngType) -> T,
67    ) -> T {
68        let mut rng = get_rng::<R>(self);
69        sampler(&mut rng)
70    }
71
72    /// Gets a random sample from the specified distribution using a random number generator
73    /// associated with the given `RngId`. If the Rng has not been used before, one will be
74    /// created with the base seed you defined in `set_base_random_seed`.
75    /// Note that this will panic if `set_base_random_seed` was not called yet.
76    fn sample_distr<R: RngId + 'static, T>(
77        &self,
78        _rng_type: R,
79        distribution: impl Distribution<T>,
80    ) -> T
81    where
82        R::RngType: Rng,
83    {
84        let mut rng = get_rng::<R>(self);
85        distribution.sample::<R::RngType>(&mut rng)
86    }
87
88    /// Gets a random sample within the range provided by `range`
89    /// using the generator associated with the given `RngId`.
90    /// Note that this will panic if `set_base_random_seed` was not called yet.
91    fn sample_range<R: RngId + 'static, S, T>(&self, rng_id: R, range: S) -> T
92    where
93        R::RngType: Rng,
94        S: SampleRange<T>,
95        T: SampleUniform,
96    {
97        self.sample(rng_id, |rng| rng.random_range(range))
98    }
99
100    /// Gets a random boolean value which is true with probability `p`
101    /// using the generator associated with the given `RngId`.
102    /// Note that this will panic if `set_base_random_seed` was not called yet.
103    fn sample_bool<R: RngId + 'static>(&self, rng_id: R, p: f64) -> bool
104    where
105        R::RngType: Rng,
106    {
107        self.sample(rng_id, |rng| rng.random_bool(p))
108    }
109
110    /// Draws a random entry out of the list provided in `weights`
111    /// with the given weights using the generator associated with the
112    /// given `RngId`.  Note that this will panic if
113    /// `set_base_random_seed` was not called yet.
114    fn sample_weighted<R: RngId + 'static, T>(&self, _rng_id: R, weights: &[T]) -> usize
115    where
116        R::RngType: Rng,
117        T: Clone
118            + Default
119            + SampleUniform
120            + for<'a> std::ops::AddAssign<&'a T>
121            + PartialOrd
122            + Weight,
123    {
124        let index = WeightedIndex::new(weights).unwrap();
125        let mut rng = get_rng::<R>(self);
126        index.sample(&mut *rng)
127    }
128}
129
130impl ContextRandomExt for Context {}
131
132#[cfg(test)]
133mod test {
134    use crate::context::Context;
135    use crate::rand::distr::{weighted::WeightedIndex, Distribution};
136    use crate::rand::RngCore;
137    use crate::random::context_ext::ContextRandomExt;
138    use crate::{define_data_plugin, define_rng};
139
140    define_rng!(FooRng);
141    define_rng!(BarRng);
142
143    #[test]
144    fn get_rng_basic() {
145        let mut context = Context::new();
146        context.init_random(42);
147
148        assert_ne!(
149            context.sample(FooRng, RngCore::next_u64),
150            context.sample(FooRng, RngCore::next_u64)
151        );
152    }
153
154    #[test]
155    fn multiple_rng_types() {
156        let mut context = Context::new();
157        context.init_random(42);
158
159        assert_ne!(
160            context.sample(FooRng, RngCore::next_u64),
161            context.sample(BarRng, RngCore::next_u64)
162        );
163    }
164
165    #[test]
166    fn reset_seed() {
167        let mut context = Context::new();
168        context.init_random(42);
169
170        let run_0 = context.sample(FooRng, RngCore::next_u64);
171        let run_1 = context.sample(FooRng, RngCore::next_u64);
172
173        // Reset with same seed, ensure we get the same values
174        context.init_random(42);
175        assert_eq!(run_0, context.sample(FooRng, RngCore::next_u64));
176        assert_eq!(run_1, context.sample(FooRng, RngCore::next_u64));
177
178        // Reset with different seed, ensure we get different values
179        context.init_random(88);
180        assert_ne!(run_0, context.sample(FooRng, RngCore::next_u64));
181        assert_ne!(run_1, context.sample(FooRng, RngCore::next_u64));
182    }
183
184    define_data_plugin!(
185        SamplerData,
186        WeightedIndex<f64>,
187        WeightedIndex::new(vec![1.0]).unwrap()
188    );
189
190    #[test]
191    fn sampler_function_closure_capture() {
192        let mut context = Context::new();
193        context.init_random(42);
194
195        // Initialize weighted sampler. Zero is selected with probability 1/3, one with a
196        // probability of 2/3.
197        *context.get_data_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap();
198
199        let parameters = context.get_data(SamplerData);
200        let n_samples = 3000;
201        let mut zero_counter = 0;
202        for _ in 0..n_samples {
203            let sample = context.sample(FooRng, |rng| parameters.sample(rng));
204            if sample == 0 {
205                zero_counter += 1;
206            }
207        }
208        // The expected value of `zero_counter` is 1000.
209        assert!((zero_counter - 1000_i32).abs() < 100);
210    }
211
212    #[test]
213    fn sample_distribution() {
214        let mut context = Context::new();
215        context.init_random(42);
216
217        // Initialize weighted sampler. Zero is selected with probability 1/3, one with a
218        // probability of 2/3.
219        *context.get_data_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap();
220
221        let parameters = context.get_data(SamplerData);
222        let n_samples = 3000;
223        let mut zero_counter = 0;
224        for _ in 0..n_samples {
225            let sample = context.sample_distr(FooRng, parameters);
226            if sample == 0 {
227                zero_counter += 1;
228            }
229        }
230        // The expected value of `zero_counter` is 1000.
231        assert!((zero_counter - 1000_i32).abs() < 100);
232    }
233
234    #[test]
235    fn sample_range() {
236        let mut context = Context::new();
237        context.init_random(42);
238        let result = context.sample_range(FooRng, 0..10);
239        assert!((0..10).contains(&result));
240    }
241
242    #[test]
243    fn sample_bool() {
244        let mut context = Context::new();
245        context.init_random(42);
246        let _r: bool = context.sample_bool(FooRng, 0.5);
247    }
248
249    #[test]
250    fn sample_weighted() {
251        let mut context = Context::new();
252        context.init_random(42);
253        let r: usize = context.sample_weighted(FooRng, &[0.1, 0.3, 0.4]);
254        assert!(r < 3);
255    }
256}