ixa/random/
context_ext.rs1use crate::rand::distr::uniform::{SampleRange, SampleUniform};
2use crate::rand::distr::weighted::{Weight, WeightedIndex};
3use crate::rand::distr::Distribution;
4use crate::rand::{Rng, SeedableRng};
5use crate::{
6 hashing::hash_str,
7 random::{RngHolder, RngPlugin},
8 Context, ContextBase, RngId,
9};
10use log::trace;
11use std::{any::TypeId, cell::RefMut};
12
13fn get_rng<R: RngId + 'static>(context: &impl ContextBase) -> RefMut<R::RngType> {
17 let data_container = context.get_data(RngPlugin);
18
19 let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap();
20 RefMut::map(rng_holders, |holders| {
21 holders
22 .entry(TypeId::of::<R>())
23 .or_insert_with(|| {
25 trace!(
26 "creating new RNG (seed={}) for type id {:?}",
27 data_container.base_seed,
28 TypeId::of::<R>()
29 );
30 let base_seed = data_container.base_seed;
31 let seed_offset = hash_str(R::get_name());
32 RngHolder {
33 rng: Box::new(R::RngType::seed_from_u64(
34 base_seed.wrapping_add(seed_offset),
35 )),
36 }
37 })
38 .rng
39 .downcast_mut::<R::RngType>()
40 .unwrap()
41 })
42}
43
44pub trait ContextRandomExt: ContextBase {
47 fn init_random(&mut self, base_seed: u64) {
50 trace!("initializing random module");
51 let data_container = self.get_data_mut(RngPlugin);
52 data_container.base_seed = base_seed;
53
54 let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap();
56 rng_map.clear();
57 }
58
59 fn sample<R: RngId + 'static, T>(
64 &self,
65 _rng_type: R,
66 sampler: impl FnOnce(&mut R::RngType) -> T,
67 ) -> T {
68 let mut rng = get_rng::<R>(self);
69 sampler(&mut rng)
70 }
71
72 fn sample_distr<R: RngId + 'static, T>(
77 &self,
78 _rng_type: R,
79 distribution: impl Distribution<T>,
80 ) -> T
81 where
82 R::RngType: Rng,
83 {
84 let mut rng = get_rng::<R>(self);
85 distribution.sample::<R::RngType>(&mut rng)
86 }
87
88 fn sample_range<R: RngId + 'static, S, T>(&self, rng_id: R, range: S) -> T
92 where
93 R::RngType: Rng,
94 S: SampleRange<T>,
95 T: SampleUniform,
96 {
97 self.sample(rng_id, |rng| rng.random_range(range))
98 }
99
100 fn sample_bool<R: RngId + 'static>(&self, rng_id: R, p: f64) -> bool
104 where
105 R::RngType: Rng,
106 {
107 self.sample(rng_id, |rng| rng.random_bool(p))
108 }
109
110 fn sample_weighted<R: RngId + 'static, T>(&self, _rng_id: R, weights: &[T]) -> usize
115 where
116 R::RngType: Rng,
117 T: Clone
118 + Default
119 + SampleUniform
120 + for<'a> std::ops::AddAssign<&'a T>
121 + PartialOrd
122 + Weight,
123 {
124 let index = WeightedIndex::new(weights).unwrap();
125 let mut rng = get_rng::<R>(self);
126 index.sample(&mut *rng)
127 }
128}
129
130impl ContextRandomExt for Context {}
131
132#[cfg(test)]
133mod test {
134 use crate::context::Context;
135 use crate::rand::distr::{weighted::WeightedIndex, Distribution};
136 use crate::rand::RngCore;
137 use crate::random::context_ext::ContextRandomExt;
138 use crate::{define_data_plugin, define_rng};
139
140 define_rng!(FooRng);
141 define_rng!(BarRng);
142
143 #[test]
144 fn get_rng_basic() {
145 let mut context = Context::new();
146 context.init_random(42);
147
148 assert_ne!(
149 context.sample(FooRng, RngCore::next_u64),
150 context.sample(FooRng, RngCore::next_u64)
151 );
152 }
153
154 #[test]
155 fn multiple_rng_types() {
156 let mut context = Context::new();
157 context.init_random(42);
158
159 assert_ne!(
160 context.sample(FooRng, RngCore::next_u64),
161 context.sample(BarRng, RngCore::next_u64)
162 );
163 }
164
165 #[test]
166 fn reset_seed() {
167 let mut context = Context::new();
168 context.init_random(42);
169
170 let run_0 = context.sample(FooRng, RngCore::next_u64);
171 let run_1 = context.sample(FooRng, RngCore::next_u64);
172
173 context.init_random(42);
175 assert_eq!(run_0, context.sample(FooRng, RngCore::next_u64));
176 assert_eq!(run_1, context.sample(FooRng, RngCore::next_u64));
177
178 context.init_random(88);
180 assert_ne!(run_0, context.sample(FooRng, RngCore::next_u64));
181 assert_ne!(run_1, context.sample(FooRng, RngCore::next_u64));
182 }
183
184 define_data_plugin!(
185 SamplerData,
186 WeightedIndex<f64>,
187 WeightedIndex::new(vec![1.0]).unwrap()
188 );
189
190 #[test]
191 fn sampler_function_closure_capture() {
192 let mut context = Context::new();
193 context.init_random(42);
194
195 *context.get_data_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap();
198
199 let parameters = context.get_data(SamplerData);
200 let n_samples = 3000;
201 let mut zero_counter = 0;
202 for _ in 0..n_samples {
203 let sample = context.sample(FooRng, |rng| parameters.sample(rng));
204 if sample == 0 {
205 zero_counter += 1;
206 }
207 }
208 assert!((zero_counter - 1000_i32).abs() < 100);
210 }
211
212 #[test]
213 fn sample_distribution() {
214 let mut context = Context::new();
215 context.init_random(42);
216
217 *context.get_data_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap();
220
221 let parameters = context.get_data(SamplerData);
222 let n_samples = 3000;
223 let mut zero_counter = 0;
224 for _ in 0..n_samples {
225 let sample = context.sample_distr(FooRng, parameters);
226 if sample == 0 {
227 zero_counter += 1;
228 }
229 }
230 assert!((zero_counter - 1000_i32).abs() < 100);
232 }
233
234 #[test]
235 fn sample_range() {
236 let mut context = Context::new();
237 context.init_random(42);
238 let result = context.sample_range(FooRng, 0..10);
239 assert!((0..10).contains(&result));
240 }
241
242 #[test]
243 fn sample_bool() {
244 let mut context = Context::new();
245 context.init_random(42);
246 let _r: bool = context.sample_bool(FooRng, 0.5);
247 }
248
249 #[test]
250 fn sample_weighted() {
251 let mut context = Context::new();
252 context.init_random(42);
253 let r: usize = context.sample_weighted(FooRng, &[0.1, 0.3, 0.4]);
254 assert!(r < 3);
255 }
256}