ixa/random/
sampling_algorithms.rs

1//! Algorithms for uniform random sampling from hash sets or iterators. These algorithms are written to be generic
2//! over the container type using zero-cost trait abstractions.
3use std::collections::{HashMap, HashSet};
4
5use crate::rand::seq::index::sample as choose_range;
6use crate::rand::Rng;
7
8/// The `len` capability, a zero-cost abstraction for types that have a known length.
9pub trait HasLen {
10    fn len(&self) -> usize;
11}
12
13/// The `iter` capability, a zero-cost abstraction for types that can be iterated over.
14pub trait HasIter {
15    type Item<'a>
16    where
17        Self: 'a;
18    type Iter<'a>: Iterator<Item = Self::Item<'a>>
19    where
20        Self: 'a;
21
22    fn iter(&self) -> Self::Iter<'_>;
23}
24
25macro_rules! impl_has_len {
26    ($ty:ident < $($gen:ident),* >) => {
27        impl<$($gen),*> HasLen for $ty<$($gen),*> {
28            fn len(&self) -> usize {
29                <$ty<$($gen),*>>::len(self)
30            }
31        }
32    };
33}
34
35macro_rules! impl_has_iter {
36    ($ty:ident < $($gen:ident),* >, $iter:ty, $item:ty) => {
37        impl<$($gen),*> HasIter for $ty<$($gen),*> {
38            type Item<'a> = $item where Self: 'a;
39            type Iter<'a> = $iter where Self: 'a;
40
41            fn iter(&self) -> Self::Iter<'_> {
42                <$ty<$($gen),*>>::iter(self)
43            }
44        }
45    };
46}
47
48// Vec<T>
49impl_has_len!(Vec<T>);
50// We implement `HasIter` manually for `Vec<T>` because its `iter` method is from `Deref<Target = [T]>`.
51impl<T> HasIter for Vec<T> {
52    type Item<'a>
53        = &'a T
54    where
55        Self: 'a;
56    type Iter<'a>
57        = std::slice::Iter<'a, T>
58    where
59        Self: 'a;
60
61    fn iter(&self) -> Self::Iter<'_> {
62        <[T]>::iter(self)
63    }
64}
65
66// HashSet<T, H>
67impl_has_len!(HashSet<T, H>);
68impl_has_iter!(HashSet<T, H>, std::collections::hash_set::Iter<'a, T>, &'a T);
69
70// HashMap<K, V, H>
71impl_has_len!(HashMap<K, V, H>);
72impl_has_iter!(HashMap<K, V, H>, std::collections::hash_map::Iter<'a, K, V>, (&'a K, &'a V));
73
74/// Sample a random element uniformly from a container of known length.
75///
76/// We do not assume the container is randomly indexable, only that it can be iterated over. The value is cloned.
77/// This algorithm is used when the property is indexed, and thus we know the length of the result set.
78pub fn sample_single_from_known_length<'a, Container, R, T>(
79    rng: &mut R,
80    set: &'a Container,
81) -> Option<T>
82where
83    R: Rng,
84    Container: HasLen + HasIter<Item<'a> = &'a T>,
85    T: Clone + 'static,
86{
87    let len = set.len();
88    if len == 0 {
89        return None;
90    }
91    // This little trick with `u32` makes this function 30% faster.
92    let index = rng.random_range(0..len as u32) as usize;
93    // The set need not be randomly indexable, so we have to use the `nth` method.
94    set.iter().nth(index).cloned()
95}
96
97/// Sample a random element uniformly from a container of unknown length.
98///
99/// We do not assume the container is randomly indexable, only that it can be iterated over. The value is cloned.
100///
101/// This function implements "Algorithm L" from KIM-HUNG LI
102/// Reservoir-Sampling Algorithms of Time Complexity O(n(1 + log(N/n)))
103/// <https://dl.acm.org/doi/pdf/10.1145/198429.198435>
104///
105/// This algorithm is significantly slower than the "known length" algorithm (factor
106/// of 10^4). The reservoir algorithm from [`rand`](crate::rand) reduces to the "known length"
107/// algorithm when the iterator is an [`ExactSizeIterator`](std::iter::ExactSizeIterator), or more precisely,
108/// when `iterator.size_hint()` returns `(k, Some(k))` for some `k`. Otherwise,
109/// this algorithm is much faster than the [`rand`](crate::rand) implementation (factor of 100).
110// ToDo(RobertJacobsonCDC): This function will take an iterator once the `iter_query_results` API is ready.
111pub fn sample_single_l_reservoir<'a, Container, R, T>(rng: &mut R, set: &'a Container) -> Option<T>
112where
113    R: Rng,
114    Container: HasIter<Item<'a> = &'a T>,
115    T: Clone + 'static,
116{
117    let mut chosen_item: Option<T> = None; // the currently selected element
118    let mut weight: f64 = rng.random_range(0.0..1.0); // controls skip distance distribution
119    let mut position: usize = 0; // current index in data
120    let mut next_pick_position: usize = 1; // index of the next item to pick
121
122    set.iter().for_each(|item| {
123        position += 1;
124        if position == next_pick_position {
125            chosen_item = Some(item.clone());
126            next_pick_position +=
127                (f64::ln(rng.random_range(0.0..1.0)) / f64::ln(1.0 - weight)).floor() as usize + 1;
128            weight *= rng.random_range(0.0..1.0);
129        }
130    });
131
132    chosen_item
133}
134
135/// Sample multiple random elements uniformly without replacement from a container of known length.
136/// This function assumes `set.len() >= requested`.
137///
138/// We do not assume the container is randomly indexable, only that it can be iterated over. The values are cloned.
139///
140/// This algorithm can be used when the property is indexed, and thus we know the length of the result set.
141/// For very small `requested` values (<=5), this algorithm is faster than reservoir because it doesn't
142/// iterate over the entire set.
143pub fn sample_multiple_from_known_length<'a, Container, R, T>(
144    rng: &mut R,
145    set: &'a Container,
146    requested: usize,
147) -> Vec<T>
148where
149    R: Rng,
150    Container: HasLen + HasIter<Item<'a> = &'a T>,
151    T: Clone + 'static,
152{
153    let mut indexes = Vec::with_capacity(requested);
154    indexes.extend(choose_range(rng, set.len(), requested));
155    indexes.sort_unstable();
156    let mut index_iterator = indexes.into_iter();
157    let mut next_idx = index_iterator.next().unwrap();
158    let mut selected = Vec::with_capacity(requested);
159
160    for (idx, item) in set.iter().enumerate() {
161        if idx == next_idx {
162            selected.push(item.clone());
163            if let Some(i) = index_iterator.next() {
164                next_idx = i;
165            } else {
166                break;
167            }
168        }
169    }
170
171    selected
172}
173
174/// Sample multiple random elements uniformly without replacement from a container of known length. If
175/// more samples are requested than are in the set, the function returns as many items as it can.
176///
177/// We do not assume the container is randomly indexable, only that it can be iterated over. The values are cloned.
178///
179/// This function implements "Algorithm L" from KIM-HUNG LI
180/// Reservoir-Sampling Algorithms of Time Complexity O(n(1 + log(N/n)))
181/// <https://dl.acm.org/doi/pdf/10.1145/198429.198435>
182///
183/// This algorithm is significantly faster than the reservoir algorithm in `rand` and is
184/// on par with the "known length" algorithm for large `requested` values.
185// ToDo(RobertJacobsonCDC): This function will take an iterator once the `iter_query_results` API is ready.
186pub fn sample_multiple_l_reservoir<'a, Container, R, T>(
187    rng: &mut R,
188    set: &'a Container,
189    requested: usize,
190) -> Vec<T>
191where
192    R: Rng,
193    Container: HasIter<Item<'a> = &'a T>,
194    T: Clone + 'static,
195{
196    let mut weight: f64 = rng.random_range(0.0..1.0); // controls skip distance distribution
197    weight = weight.powf(1.0 / requested as f64);
198    let mut position: usize = 0; // current index in data
199    let mut next_pick_position: usize = 1; // index of the next item to pick
200    let mut reservoir = Vec::with_capacity(requested); // the sample reservoir
201
202    set.iter().for_each(|item| {
203        position += 1;
204        if position == next_pick_position {
205            if reservoir.len() == requested {
206                let to_remove = rng.random_range(0..reservoir.len());
207                reservoir.swap_remove(to_remove);
208            }
209            reservoir.push(item.clone());
210
211            if reservoir.len() == requested {
212                next_pick_position += (f64::ln(rng.random_range(0.0..1.0)) / f64::ln(1.0 - weight))
213                    .floor() as usize
214                    + 1;
215                let uniform_random: f64 = rng.random_range(0.0..1.0);
216                weight *= uniform_random.powf(1.0 / requested as f64);
217            } else {
218                next_pick_position += 1;
219            }
220        }
221    });
222
223    reservoir
224}
225
226#[cfg(test)]
227mod tests {
228    use rand::rngs::StdRng;
229    use rand::SeedableRng;
230
231    use super::*;
232    #[test]
233    fn test_sample_multiple_l_reservoir_basic() {
234        let data: Vec<u32> = (0..1000).collect();
235        let requested = 100;
236        let seed: u64 = 42;
237        let mut rng = StdRng::seed_from_u64(seed);
238        let sample = sample_multiple_l_reservoir(&mut rng, &data, requested);
239
240        // Correct sample size
241        assert_eq!(sample.len(), requested);
242
243        // All sampled values are within the valid range
244        assert!(sample.iter().all(|v| *v < 1000));
245
246        // The sample should not have duplicates
247        let unique: HashSet<_> = sample.iter().collect();
248        assert_eq!(unique.len(), sample.len());
249    }
250
251    // Verifies that the reservoir sampling algorithm produces uniformly distributed
252    // samples by running it 1000 times and checking that the resulting chi-square
253    // statistics follow the expected chi-square(9) distribution. Note that this
254    // test is only approximately correct, reasonable only when `requested` is small
255    // relative to `population`, because `sample_multiple_l_reservoir` samples
256    // without replacement, while the chi-squared test assumes independent samples.
257    #[test]
258    fn test_sample_multiple_l_reservoir_uniformity() {
259        let population: u32 = 10000;
260        let data: Vec<u32> = (0..population).collect();
261        let requested = 100;
262        let num_runs = 1000;
263        let mut chi_squares = Vec::with_capacity(num_runs);
264
265        for run in 0..num_runs {
266            let mut rng = StdRng::seed_from_u64(42 + run as u64);
267            let sample = sample_multiple_l_reservoir(&mut rng, &data, requested);
268
269            // Partition range 0..population into 10 equal-width bins
270            let mut counts = [0usize; 10];
271            for &value in &sample {
272                let bin = (value as usize) / (population as usize / 10);
273                counts[bin] += 1;
274            }
275
276            // Expected count per bin for uniform sampling
277            let expected = requested as f64 / 10.0; // = 10.0
278
279            // Compute chi-square statistic
280            let chi_square: f64 = counts
281                .iter()
282                .map(|&obs| {
283                    let diff = (obs as f64) - expected;
284                    diff * diff / expected
285                })
286                .sum();
287
288            chi_squares.push(chi_square);
289        }
290
291        // Now test that chi_squares follow a chi-square distribution with df=9
292        // We use quantiles of the chi-square(9) distribution to create bins
293        // and check if the observed counts match the expected uniform distribution
294
295        // Quantiles of chi-square distribution with df=9 at deciles (10 bins)
296        // These values define the bin boundaries such that each bin should contain
297        // 10% of the observations if they truly follow chi-square(9).
298        // Generate with Mathematica:
299        //     Table[Quantile[ChiSquareDistribution[9], p/10], {p, 0, 10}]//N
300        let quantiles = [
301            0.0,           // 0th percentile (minimum)
302            4.16816,       // 10th percentile
303            5.38005,       // 20th percentile
304            6.39331,       // 30th percentile
305            7.35703,       // 40th percentile
306            8.34283,       // 50th percentile (median)
307            9.41364,       // 60th percentile
308            10.6564,       // 70th percentile
309            12.2421,       // 80th percentile
310            14.6837,       // 90th percentile
311            f64::INFINITY, // 100th percentile (maximum)
312        ];
313
314        let num_bins = quantiles.len() - 1;
315        let mut chi_square_counts = vec![0usize; num_bins];
316
317        for &chi_sq in &chi_squares {
318            // Find which bin this chi-square value falls into
319            for i in 0..num_bins {
320                if chi_sq >= quantiles[i] && chi_sq < quantiles[i + 1] {
321                    chi_square_counts[i] += 1;
322                    break;
323                }
324            }
325        }
326
327        // Each bin should contain approximately num_runs / num_bins observations
328        let expected_per_bin = num_runs as f64 / num_bins as f64;
329        let chi_square_of_chi_squares: f64 = chi_square_counts
330            .iter()
331            .map(|&obs| {
332                let diff = (obs as f64) - expected_per_bin;
333                diff * diff / expected_per_bin
334            })
335            .sum();
336
337        // Degrees of freedom = (#bins - 1) = 9
338        // Critical χ²₀.₉₉₉ for df=9 is 27.877
339        let critical = 27.877;
340
341        println!(
342            "χ² = {}, counts = {:?}",
343            chi_square_of_chi_squares, chi_square_counts
344        );
345
346        assert!(
347            chi_square_of_chi_squares < critical,
348            "Chi-square statistics fail to follow chi-square(9) distribution: χ² = {}, counts = {:?}",
349            chi_square_of_chi_squares,
350            chi_square_counts
351        );
352    }
353}