ixa/random/
context_ext.rs1use std::any::TypeId;
2use std::cell::RefMut;
3
4use log::trace;
5
6use crate::hashing::hash_str;
7use crate::rand::distr::uniform::{SampleRange, SampleUniform};
8use crate::rand::distr::weighted::{Weight, WeightedIndex};
9use crate::rand::distr::Distribution;
10use crate::rand::{Rng, SeedableRng};
11use crate::random::{RngHolder, RngPlugin};
12use crate::{Context, ContextBase, RngId};
13
14fn get_rng<R: RngId + 'static>(context: &impl ContextBase) -> RefMut<R::RngType> {
18 let data_container = context.get_data(RngPlugin);
19
20 let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap();
21 RefMut::map(rng_holders, |holders| {
22 holders
23 .entry(TypeId::of::<R>())
24 .or_insert_with(|| {
26 trace!(
27 "creating new RNG (seed={}) for type id {:?}",
28 data_container.base_seed,
29 TypeId::of::<R>()
30 );
31 let base_seed = data_container.base_seed;
32 let seed_offset = hash_str(R::get_name());
33 RngHolder {
34 rng: Box::new(R::RngType::seed_from_u64(
35 base_seed.wrapping_add(seed_offset),
36 )),
37 }
38 })
39 .rng
40 .downcast_mut::<R::RngType>()
41 .unwrap()
42 })
43}
44
45pub trait ContextRandomExt: ContextBase {
48 fn init_random(&mut self, base_seed: u64) {
51 trace!("initializing random module");
52 let data_container = self.get_data_mut(RngPlugin);
53 data_container.base_seed = base_seed;
54
55 let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap();
57 rng_map.clear();
58 }
59
60 fn sample<R: RngId + 'static, T>(
65 &self,
66 _rng_type: R,
67 sampler: impl FnOnce(&mut R::RngType) -> T,
68 ) -> T {
69 let mut rng = get_rng::<R>(self);
70 sampler(&mut rng)
71 }
72
73 fn sample_distr<R: RngId + 'static, T>(
78 &self,
79 _rng_type: R,
80 distribution: impl Distribution<T>,
81 ) -> T
82 where
83 R::RngType: Rng,
84 {
85 let mut rng = get_rng::<R>(self);
86 distribution.sample::<R::RngType>(&mut rng)
87 }
88
89 fn sample_range<R: RngId + 'static, S, T>(&self, rng_id: R, range: S) -> T
93 where
94 R::RngType: Rng,
95 S: SampleRange<T>,
96 T: SampleUniform,
97 {
98 self.sample(rng_id, |rng| rng.random_range(range))
99 }
100
101 fn sample_bool<R: RngId + 'static>(&self, rng_id: R, p: f64) -> bool
105 where
106 R::RngType: Rng,
107 {
108 self.sample(rng_id, |rng| rng.random_bool(p))
109 }
110
111 fn sample_weighted<R: RngId + 'static, T>(&self, _rng_id: R, weights: &[T]) -> usize
116 where
117 R::RngType: Rng,
118 T: Clone
119 + Default
120 + SampleUniform
121 + for<'a> std::ops::AddAssign<&'a T>
122 + PartialOrd
123 + Weight,
124 {
125 let index = WeightedIndex::new(weights).unwrap();
126 let mut rng = get_rng::<R>(self);
127 index.sample(&mut *rng)
128 }
129}
130
131impl ContextRandomExt for Context {}
132
133#[cfg(test)]
134mod test {
135 use crate::context::Context;
136 use crate::rand::distr::weighted::WeightedIndex;
137 use crate::rand::distr::Distribution;
138 use crate::rand::RngCore;
139 use crate::random::context_ext::ContextRandomExt;
140 use crate::{define_data_plugin, define_rng};
141
142 define_rng!(FooRng);
143 define_rng!(BarRng);
144
145 #[test]
146 fn get_rng_basic() {
147 let mut context = Context::new();
148 context.init_random(42);
149
150 assert_ne!(
151 context.sample(FooRng, RngCore::next_u64),
152 context.sample(FooRng, RngCore::next_u64)
153 );
154 }
155
156 #[test]
157 fn multiple_rng_types() {
158 let mut context = Context::new();
159 context.init_random(42);
160
161 assert_ne!(
162 context.sample(FooRng, RngCore::next_u64),
163 context.sample(BarRng, RngCore::next_u64)
164 );
165 }
166
167 #[test]
168 fn reset_seed() {
169 let mut context = Context::new();
170 context.init_random(42);
171
172 let run_0 = context.sample(FooRng, RngCore::next_u64);
173 let run_1 = context.sample(FooRng, RngCore::next_u64);
174
175 context.init_random(42);
177 assert_eq!(run_0, context.sample(FooRng, RngCore::next_u64));
178 assert_eq!(run_1, context.sample(FooRng, RngCore::next_u64));
179
180 context.init_random(88);
182 assert_ne!(run_0, context.sample(FooRng, RngCore::next_u64));
183 assert_ne!(run_1, context.sample(FooRng, RngCore::next_u64));
184 }
185
186 define_data_plugin!(
187 SamplerData,
188 WeightedIndex<f64>,
189 WeightedIndex::new(vec![1.0]).unwrap()
190 );
191
192 #[test]
193 fn sampler_function_closure_capture() {
194 let mut context = Context::new();
195 context.init_random(42);
196
197 *context.get_data_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap();
200
201 let parameters = context.get_data(SamplerData);
202 let n_samples = 3000;
203 let mut zero_counter = 0;
204 for _ in 0..n_samples {
205 let sample = context.sample(FooRng, |rng| parameters.sample(rng));
206 if sample == 0 {
207 zero_counter += 1;
208 }
209 }
210 assert!((zero_counter - 1000_i32).abs() < 100);
212 }
213
214 #[test]
215 fn sample_distribution() {
216 let mut context = Context::new();
217 context.init_random(42);
218
219 *context.get_data_mut(SamplerData) = WeightedIndex::new(vec![1.0, 2.0]).unwrap();
222
223 let parameters = context.get_data(SamplerData);
224 let n_samples = 3000;
225 let mut zero_counter = 0;
226 for _ in 0..n_samples {
227 let sample = context.sample_distr(FooRng, parameters);
228 if sample == 0 {
229 zero_counter += 1;
230 }
231 }
232 assert!((zero_counter - 1000_i32).abs() < 100);
234 }
235
236 #[test]
237 fn sample_range() {
238 let mut context = Context::new();
239 context.init_random(42);
240 let result = context.sample_range(FooRng, 0..10);
241 assert!((0..10).contains(&result));
242 }
243
244 #[test]
245 fn sample_bool() {
246 let mut context = Context::new();
247 context.init_random(42);
248 let _r: bool = context.sample_bool(FooRng, 0.5);
249 }
250
251 #[test]
252 fn sample_weighted() {
253 let mut context = Context::new();
254 context.init_random(42);
255 let r: usize = context.sample_weighted(FooRng, &[0.1, 0.3, 0.4]);
256 assert!(r < 3);
257 }
258}