ixa/random/sampling_algorithms.rs
1//! Algorithms for uniform random sampling from hash sets or iterators. These algorithms are written to be generic
2//! over the container type using zero-cost trait abstractions.
3use std::collections::{HashMap, HashSet};
4
5use crate::rand::seq::index::sample as choose_range;
6use crate::rand::Rng;
7
8/// The `len` capability, a zero-cost abstraction for types that have a known length.
9pub trait HasLen {
10 fn len(&self) -> usize;
11}
12
13/// The `iter` capability, a zero-cost abstraction for types that can be iterated over.
14pub trait HasIter {
15 type Item<'a>
16 where
17 Self: 'a;
18 type Iter<'a>: Iterator<Item = Self::Item<'a>>
19 where
20 Self: 'a;
21
22 fn iter(&self) -> Self::Iter<'_>;
23}
24
25macro_rules! impl_has_len {
26 ($ty:ident < $($gen:ident),* >) => {
27 impl<$($gen),*> HasLen for $ty<$($gen),*> {
28 fn len(&self) -> usize {
29 <$ty<$($gen),*>>::len(self)
30 }
31 }
32 };
33}
34
35macro_rules! impl_has_iter {
36 ($ty:ident < $($gen:ident),* >, $iter:ty, $item:ty) => {
37 impl<$($gen),*> HasIter for $ty<$($gen),*> {
38 type Item<'a> = $item where Self: 'a;
39 type Iter<'a> = $iter where Self: 'a;
40
41 fn iter(&self) -> Self::Iter<'_> {
42 <$ty<$($gen),*>>::iter(self)
43 }
44 }
45 };
46}
47
48// Vec<T>
49impl_has_len!(Vec<T>);
50// We implement `HasIter` manually for `Vec<T>` because its `iter` method is from `Deref<Target = [T]>`.
51impl<T> HasIter for Vec<T> {
52 type Item<'a>
53 = &'a T
54 where
55 Self: 'a;
56 type Iter<'a>
57 = std::slice::Iter<'a, T>
58 where
59 Self: 'a;
60
61 fn iter(&self) -> Self::Iter<'_> {
62 <[T]>::iter(self)
63 }
64}
65
66// HashSet<T, H>
67impl_has_len!(HashSet<T, H>);
68impl_has_iter!(HashSet<T, H>, std::collections::hash_set::Iter<'a, T>, &'a T);
69
70// HashMap<K, V, H>
71impl_has_len!(HashMap<K, V, H>);
72impl_has_iter!(HashMap<K, V, H>, std::collections::hash_map::Iter<'a, K, V>, (&'a K, &'a V));
73
74/// Sample a random element uniformly from a container of known length.
75///
76/// We do not assume the container is randomly indexable, only that it can be iterated over. The value is cloned.
77/// This algorithm is used when the property is indexed, and thus we know the length of the result set.
78pub fn sample_single_from_known_length<'a, Container, R, T>(
79 rng: &mut R,
80 set: &'a Container,
81) -> Option<T>
82where
83 R: Rng,
84 Container: HasLen + HasIter<Item<'a> = &'a T>,
85 T: Clone + 'static,
86{
87 let len = set.len();
88 if len == 0 {
89 return None;
90 }
91 // This little trick with `u32` makes this function 30% faster.
92 let index = rng.random_range(0..len as u32) as usize;
93 // The set need not be randomly indexable, so we have to use the `nth` method.
94 set.iter().nth(index).cloned()
95}
96
97/// Sample a random element uniformly from a container of unknown length.
98///
99/// We do not assume the container is randomly indexable, only that it can be iterated over. The value is cloned.
100///
101/// This function implements "Algorithm L" from KIM-HUNG LI
102/// Reservoir-Sampling Algorithms of Time Complexity O(n(1 + log(N/n)))
103/// <https://dl.acm.org/doi/pdf/10.1145/198429.198435>
104///
105/// This algorithm is significantly slower than the "known length" algorithm (factor
106/// of 10^4). The reservoir algorithm from [`rand`](crate::rand) reduces to the "known length"
107/// algorithm when the iterator is an [`ExactSizeIterator`](std::iter::ExactSizeIterator), or more precisely,
108/// when `iterator.size_hint()` returns `(k, Some(k))` for some `k`. Otherwise,
109/// this algorithm is much faster than the [`rand`](crate::rand) implementation (factor of 100).
110// ToDo(RobertJacobsonCDC): This function will take an iterator once the `iter_query_results` API is ready.
111pub fn sample_single_l_reservoir<'a, Container, R, T>(rng: &mut R, set: &'a Container) -> Option<T>
112where
113 R: Rng,
114 Container: HasIter<Item<'a> = &'a T>,
115 T: Clone + 'static,
116{
117 let mut chosen_item: Option<T> = None; // the currently selected element
118 let mut weight: f64 = rng.random_range(0.0..1.0); // controls skip distance distribution
119 let mut position: usize = 0; // current index in data
120 let mut next_pick_position: usize = 1; // index of the next item to pick
121
122 set.iter().for_each(|item| {
123 position += 1;
124 if position == next_pick_position {
125 chosen_item = Some(item.clone());
126 next_pick_position +=
127 (f64::ln(rng.random_range(0.0..1.0)) / f64::ln(1.0 - weight)).floor() as usize + 1;
128 weight *= rng.random_range(0.0..1.0);
129 }
130 });
131
132 chosen_item
133}
134
135/// Sample multiple random elements uniformly without replacement from a container of known length.
136/// This function assumes `set.len() >= requested`.
137///
138/// We do not assume the container is randomly indexable, only that it can be iterated over. The values are cloned.
139///
140/// This algorithm can be used when the property is indexed, and thus we know the length of the result set.
141/// For very small `requested` values (<=5), this algorithm is faster than reservoir because it doesn't
142/// iterate over the entire set.
143pub fn sample_multiple_from_known_length<'a, Container, R, T>(
144 rng: &mut R,
145 set: &'a Container,
146 requested: usize,
147) -> Vec<T>
148where
149 R: Rng,
150 Container: HasLen + HasIter<Item<'a> = &'a T>,
151 T: Clone + 'static,
152{
153 let mut indexes = Vec::with_capacity(requested);
154 indexes.extend(choose_range(rng, set.len(), requested));
155 indexes.sort_unstable();
156 let mut index_iterator = indexes.into_iter();
157 let mut next_idx = index_iterator.next().unwrap();
158 let mut selected = Vec::with_capacity(requested);
159
160 for (idx, item) in set.iter().enumerate() {
161 if idx == next_idx {
162 selected.push(item.clone());
163 if let Some(i) = index_iterator.next() {
164 next_idx = i;
165 } else {
166 break;
167 }
168 }
169 }
170
171 selected
172}
173
174/// Sample multiple random elements uniformly without replacement from a container of known length. If
175/// more samples are requested than are in the set, the function returns as many items as it can.
176///
177/// We do not assume the container is randomly indexable, only that it can be iterated over. The values are cloned.
178///
179/// This function implements "Algorithm L" from KIM-HUNG LI
180/// Reservoir-Sampling Algorithms of Time Complexity O(n(1 + log(N/n)))
181/// <https://dl.acm.org/doi/pdf/10.1145/198429.198435>
182///
183/// This algorithm is significantly faster than the reservoir algorithm in `rand` and is
184/// on par with the "known length" algorithm for large `requested` values.
185// ToDo(RobertJacobsonCDC): This function will take an iterator once the `iter_query_results` API is ready.
186pub fn sample_multiple_l_reservoir<'a, Container, R, T>(
187 rng: &mut R,
188 set: &'a Container,
189 requested: usize,
190) -> Vec<T>
191where
192 R: Rng,
193 Container: HasIter<Item<'a> = &'a T>,
194 T: Clone + 'static,
195{
196 let mut weight: f64 = rng.random_range(0.0..1.0); // controls skip distance distribution
197 weight = weight.powf(1.0 / requested as f64);
198 let mut position: usize = 0; // current index in data
199 let mut next_pick_position: usize = 1; // index of the next item to pick
200 let mut reservoir = Vec::with_capacity(requested); // the sample reservoir
201
202 set.iter().for_each(|item| {
203 position += 1;
204 if position == next_pick_position {
205 if reservoir.len() == requested {
206 let to_remove = rng.random_range(0..reservoir.len());
207 reservoir.swap_remove(to_remove);
208 }
209 reservoir.push(item.clone());
210
211 if reservoir.len() == requested {
212 next_pick_position += (f64::ln(rng.random_range(0.0..1.0)) / f64::ln(1.0 - weight))
213 .floor() as usize
214 + 1;
215 let uniform_random: f64 = rng.random_range(0.0..1.0);
216 weight *= uniform_random.powf(1.0 / requested as f64);
217 } else {
218 next_pick_position += 1;
219 }
220 }
221 });
222
223 reservoir
224}
225
226#[cfg(test)]
227mod tests {
228 use rand::rngs::StdRng;
229 use rand::SeedableRng;
230
231 use super::*;
232 #[test]
233 fn test_sample_multiple_l_reservoir_basic() {
234 let data: Vec<u32> = (0..1000).collect();
235 let requested = 100;
236 let seed: u64 = 42;
237 let mut rng = StdRng::seed_from_u64(seed);
238 let sample = sample_multiple_l_reservoir(&mut rng, &data, requested);
239
240 // Correct sample size
241 assert_eq!(sample.len(), requested);
242
243 // All sampled values are within the valid range
244 assert!(sample.iter().all(|v| *v < 1000));
245
246 // The sample should not have duplicates
247 let unique: HashSet<_> = sample.iter().collect();
248 assert_eq!(unique.len(), sample.len());
249 }
250
251 // Verifies that the reservoir sampling algorithm produces uniformly distributed
252 // samples by running it 1000 times and checking that the resulting chi-square
253 // statistics follow the expected chi-square(9) distribution. Note that this
254 // test is only approximately correct, reasonable only when `requested` is small
255 // relative to `population`, because `sample_multiple_l_reservoir` samples
256 // without replacement, while the chi-squared test assumes independent samples.
257 #[test]
258 fn test_sample_multiple_l_reservoir_uniformity() {
259 let population: u32 = 10000;
260 let data: Vec<u32> = (0..population).collect();
261 let requested = 100;
262 let num_runs = 1000;
263 let mut chi_squares = Vec::with_capacity(num_runs);
264
265 for run in 0..num_runs {
266 let mut rng = StdRng::seed_from_u64(42 + run as u64);
267 let sample = sample_multiple_l_reservoir(&mut rng, &data, requested);
268
269 // Partition range 0..population into 10 equal-width bins
270 let mut counts = [0usize; 10];
271 for &value in &sample {
272 let bin = (value as usize) / (population as usize / 10);
273 counts[bin] += 1;
274 }
275
276 // Expected count per bin for uniform sampling
277 let expected = requested as f64 / 10.0; // = 10.0
278
279 // Compute chi-square statistic
280 let chi_square: f64 = counts
281 .iter()
282 .map(|&obs| {
283 let diff = (obs as f64) - expected;
284 diff * diff / expected
285 })
286 .sum();
287
288 chi_squares.push(chi_square);
289 }
290
291 // Now test that chi_squares follow a chi-square distribution with df=9
292 // We use quantiles of the chi-square(9) distribution to create bins
293 // and check if the observed counts match the expected uniform distribution
294
295 // Quantiles of chi-square distribution with df=9 at deciles (10 bins)
296 // These values define the bin boundaries such that each bin should contain
297 // 10% of the observations if they truly follow chi-square(9).
298 // Generate with Mathematica:
299 // Table[Quantile[ChiSquareDistribution[9], p/10], {p, 0, 10}]//N
300 let quantiles = [
301 0.0, // 0th percentile (minimum)
302 4.16816, // 10th percentile
303 5.38005, // 20th percentile
304 6.39331, // 30th percentile
305 7.35703, // 40th percentile
306 8.34283, // 50th percentile (median)
307 9.41364, // 60th percentile
308 10.6564, // 70th percentile
309 12.2421, // 80th percentile
310 14.6837, // 90th percentile
311 f64::INFINITY, // 100th percentile (maximum)
312 ];
313
314 let num_bins = quantiles.len() - 1;
315 let mut chi_square_counts = vec![0usize; num_bins];
316
317 for &chi_sq in &chi_squares {
318 // Find which bin this chi-square value falls into
319 for i in 0..num_bins {
320 if chi_sq >= quantiles[i] && chi_sq < quantiles[i + 1] {
321 chi_square_counts[i] += 1;
322 break;
323 }
324 }
325 }
326
327 // Each bin should contain approximately num_runs / num_bins observations
328 let expected_per_bin = num_runs as f64 / num_bins as f64;
329 let chi_square_of_chi_squares: f64 = chi_square_counts
330 .iter()
331 .map(|&obs| {
332 let diff = (obs as f64) - expected_per_bin;
333 diff * diff / expected_per_bin
334 })
335 .sum();
336
337 // Degrees of freedom = (#bins - 1) = 9
338 // Critical χ²₀.₉₉₉ for df=9 is 27.877
339 let critical = 27.877;
340
341 println!(
342 "χ² = {}, counts = {:?}",
343 chi_square_of_chi_squares, chi_square_counts
344 );
345
346 assert!(
347 chi_square_of_chi_squares < critical,
348 "Chi-square statistics fail to follow chi-square(9) distribution: χ² = {}, counts = {:?}",
349 chi_square_of_chi_squares,
350 chi_square_counts
351 );
352 }
353}