ixa/random/
sampling_algorithms.rs

1//! Algorithms for uniform random sampling from hash sets or iterators. These algorithms are written to be generic
2//! over the container type using zero-cost trait abstractions.
3
4use crate::rand::seq::index::sample as choose_range;
5use crate::rand::Rng;
6
7/// Samples one element uniformly at random from an iterator whose length is known at runtime.
8///
9/// The caller must ensure that `(len, Some(len)) == iter.size_hint()`, i.e. the iterator
10/// reports its exact length via `size_hint`. We do not require `ExactSizeIterator`
11/// because that is a compile-time guarantee, whereas our requirement is a runtime condition.
12///
13/// The implementation selects a random index and uses `Iterator::nth`. For iterators
14/// with O(1) `nth` (e.g., randomly indexable structures), this is very efficient.
15/// The selected value is cloned.
16///
17/// The iterator need only support iteration; random indexing is not required.
18/// This function is intended for use when the result set is indexed and its length is known.
19pub fn sample_single_from_known_length<I, R, T>(rng: &mut R, mut iter: I) -> Option<T>
20where
21    R: Rng,
22    I: Iterator<Item = T>,
23{
24    // It is the caller's responsibility to ensure that `(len, Some(len)) == iter.size_hint()`.
25    let (length, _) = iter.size_hint();
26    if length == 0 {
27        return None;
28    }
29    // This little trick with `u32` makes this function 30% faster.
30    let index = rng.random_range(0..length as u32) as usize;
31    // The set need not be randomly indexable, so we have to use the `nth` method.
32    iter.nth(index)
33}
34
35/// Sample a random element uniformly from an iterator of unknown length.
36///
37/// We do not assume the container is randomly indexable, only that it can be iterated over.
38///
39/// This function implements "Algorithm L" from KIM-HUNG LI
40/// Reservoir-Sampling Algorithms of Time Complexity O(n(1 + log(N/n)))
41/// <https://dl.acm.org/doi/pdf/10.1145/198429.198435>
42///
43/// This algorithm is significantly slower than the "known length" algorithm (factor
44/// of 10^4). The reservoir algorithm from [`rand`](crate::rand) reduces to the "known length"
45/// algorithm when `iterator.size_hint()` returns `(k, Some(k))` for some `k`. Otherwise,
46/// this algorithm is much faster than the [`rand`](crate::rand)  implementation (factor of 100).
47pub fn sample_single_l_reservoir<I, R, T>(rng: &mut R, iterable: I) -> Option<T>
48where
49    R: Rng,
50    I: IntoIterator<Item = T>,
51{
52    let mut iter = iterable.into_iter();
53    let mut weight: f64 = rng.random_range(0.0..1.0); // controls skip distance distribution
54    let mut chosen_item: T = iter.next()?; // the currently selected element
55
56    // Number of elements to skip before the next candidate to consider for the reservoir.
57    // `iter.nth(skip)` skips `skip` elements and returns the next one.
58    let mut skip = (f64::ln(rng.random_range(0.0..1.0)) / f64::ln(1.0 - weight)).floor() as usize;
59    weight *= rng.random_range(0.0..1.0);
60
61    loop {
62        match iter.nth(skip) {
63            Some(item) => {
64                chosen_item = item;
65                skip =
66                    (f64::ln(rng.random_range(0.0..1.0)) / f64::ln(1.0 - weight)).floor() as usize;
67                weight *= rng.random_range(0.0..1.0);
68            }
69            None => return Some(chosen_item),
70        }
71    }
72}
73
74/// Samples `requested` elements uniformly at random without replacement from an iterator
75/// whose length is known at runtime. Requires `len >= requested`.
76///
77/// The caller must ensure that `(len, Some(len)) == iter.size_hint()`, i.e. the iterator
78/// reports its exact length via `size_hint`. We do not require `ExactSizeIterator`
79/// because that is a compile-time guarantee, whereas our requirement is a runtime condition.
80///
81/// The implementation selects random indices and uses `Iterator::nth`. For iterators
82/// with O(1) `nth` (e.g., randomly indexable structures), this is very efficient.
83/// Selected values are cloned.
84///
85/// This strategy is particularly effective for small `requested` (≤ 5), since it
86/// avoids iterating over the entire set and is typically faster than reservoir sampling.
87pub fn sample_multiple_from_known_length<I, R, T>(rng: &mut R, iter: I, requested: usize) -> Vec<T>
88where
89    R: Rng,
90    I: IntoIterator<Item = T>,
91{
92    let mut iter = iter.into_iter();
93    // It is the caller's responsibility to ensure that `(length, Some(length)) == iter.size_hint()`.
94    let (length, _) = iter.size_hint();
95
96    let mut indexes = Vec::with_capacity(requested);
97    indexes.extend(choose_range(rng, length, requested));
98    indexes.sort_unstable();
99
100    let mut selected = Vec::with_capacity(requested);
101    let mut consumed: usize = 0; // number of elements consumed from the iterator so far
102
103    // `iter.nth(n)` skips `n` elements and returns the next one, so to reach
104    // index `idx` we skip `idx - consumed` where `consumed` tracks how many
105    // elements have already been consumed.
106    for idx in indexes {
107        if let Some(item) = iter.nth(idx - consumed) {
108            selected.push(item);
109        }
110        consumed = idx + 1;
111    }
112
113    selected
114}
115
116/// Sample multiple random elements uniformly without replacement from a container of unknown length. If
117/// more samples are requested than are in the set, the function returns as many items as it can.
118///
119/// The implementation uses `Iterator::nth`. Randomly indexable structures will have a O(1) `nth`
120/// implementation and will be very efficient. The values are cloned.
121///
122/// This function implements "Algorithm L" from KIM-HUNG LI
123/// Reservoir-Sampling Algorithms of Time Complexity O(n(1 + log(N/n)))
124/// <https://dl.acm.org/doi/pdf/10.1145/198429.198435>
125///
126/// This algorithm is significantly faster than the reservoir algorithm in `rand` and is
127/// on par with the "known length" algorithm for large `requested` values.
128pub fn sample_multiple_l_reservoir<I, R, T>(rng: &mut R, iter: I, requested: usize) -> Vec<T>
129where
130    R: Rng,
131    I: IntoIterator<Item = T>,
132{
133    if requested == 0 {
134        return Vec::new();
135    }
136
137    let mut weight: f64 = rng.random_range(0.0..1.0); // controls skip distance distribution
138    weight = weight.powf(1.0 / requested as f64);
139    let mut iter = iter.into_iter();
140    let mut reservoir: Vec<T> = iter.by_ref().take(requested).collect(); // the sample reservoir
141
142    if reservoir.len() < requested {
143        return reservoir;
144    }
145
146    // Number of elements to skip before the next candidate to consider for the reservoir.
147    // `iter.nth(skip)` skips `skip` elements and returns the next one.
148    let mut skip = (f64::ln(rng.random_range(0.0..1.0)) / f64::ln(1.0 - weight)).floor() as usize;
149    let uniform_random: f64 = rng.random_range(0.0..1.0);
150    weight *= uniform_random.powf(1.0 / requested as f64);
151
152    loop {
153        match iter.nth(skip) {
154            Some(item) => {
155                let to_remove = rng.random_range(0..reservoir.len());
156                reservoir.swap_remove(to_remove);
157                reservoir.push(item);
158
159                skip =
160                    (f64::ln(rng.random_range(0.0..1.0)) / f64::ln(1.0 - weight)).floor() as usize;
161                let uniform_random: f64 = rng.random_range(0.0..1.0);
162                weight *= uniform_random.powf(1.0 / requested as f64);
163            }
164            None => return reservoir,
165        }
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use rand::rngs::StdRng;
172    use rand::SeedableRng;
173
174    use super::*;
175    use crate::hashing::{HashSet, HashSetExt};
176
177    #[test]
178    fn test_sample_single_l_reservoir_basic() {
179        let data: Vec<u32> = (0..1000).collect();
180        let seed: u64 = 42;
181        let mut rng = StdRng::seed_from_u64(seed);
182        let sample = sample_single_l_reservoir(&mut rng, data);
183
184        // Should return Some value
185        assert!(sample.is_some());
186
187        // Value should be in valid range
188        let value = sample.unwrap();
189        assert!(value < 1000);
190    }
191
192    #[test]
193    fn test_sample_single_l_reservoir_empty() {
194        let data: Vec<u32> = Vec::new();
195        let mut rng = StdRng::seed_from_u64(42);
196        let sample = sample_single_l_reservoir(&mut rng, data);
197
198        // Should return None for empty container
199        assert!(sample.is_none());
200    }
201
202    #[test]
203    fn test_sample_single_l_reservoir_single_element() {
204        let data: Vec<u32> = vec![42];
205        let mut rng = StdRng::seed_from_u64(1);
206        let sample = sample_single_l_reservoir(&mut rng, data);
207
208        // Should return the only element
209        assert_eq!(sample, Some(42));
210    }
211
212    #[test]
213    fn test_sample_single_l_reservoir_uniformity() {
214        let population: u32 = 1000;
215        let data: Vec<u32> = (0..population).collect();
216        let num_runs = 10000;
217        let num_bins = 10;
218        let mut counts = vec![0usize; num_bins];
219
220        for run in 0..num_runs {
221            let mut rng = StdRng::seed_from_u64(42 + run as u64);
222            let sample = sample_single_l_reservoir(&mut rng, data.iter().cloned());
223
224            if let Some(value) = sample {
225                let bin = (value as usize) / (population as usize / num_bins);
226                counts[bin] += 1;
227            }
228        }
229
230        // Expected count per bin for uniform sampling
231        let expected = num_runs as f64 / num_bins as f64;
232
233        // Compute chi-square statistic
234        let chi_square: f64 = counts
235            .iter()
236            .map(|&obs| {
237                let diff = (obs as f64) - expected;
238                diff * diff / expected
239            })
240            .sum();
241
242        // Degrees of freedom = num_bins - 1 = 9
243        // Critical χ²₀.₉₉₉ for df=9 is 27.877
244        let critical = 27.877;
245
246        println!("χ² = {}, counts = {:?}", chi_square, counts);
247
248        assert!(
249            chi_square < critical,
250            "Single sample fails uniformity test: χ² = {}, counts = {:?}",
251            chi_square,
252            counts
253        );
254    }
255
256    #[test]
257    fn test_sample_single_l_reservoir_hashset() {
258        let mut data = HashSet::new();
259        for i in 0..100 {
260            data.insert(i);
261        }
262
263        let mut rng = StdRng::seed_from_u64(42);
264        let sample = sample_single_l_reservoir(&mut rng, &data);
265
266        assert!(sample.is_some());
267        let value = sample.unwrap();
268        assert!(data.contains(value));
269    }
270
271    #[test]
272    fn test_sample_multiple_l_reservoir_basic() {
273        let data: Vec<u32> = (0..1000).collect();
274        let requested = 100;
275        let seed: u64 = 42;
276        let mut rng = StdRng::seed_from_u64(seed);
277        let sample = sample_multiple_l_reservoir(&mut rng, data, requested);
278
279        // Correct sample size
280        assert_eq!(sample.len(), requested);
281
282        // All sampled values are within the valid range
283        assert!(sample.iter().all(|v| *v < 1000));
284
285        // The sample should not have duplicates
286        let unique: HashSet<_> = sample.iter().collect();
287        assert_eq!(unique.len(), sample.len());
288    }
289
290    #[test]
291    fn test_sample_multiple_l_reservoir_empty() {
292        let data: Vec<u32> = Vec::new();
293        let mut rng = StdRng::seed_from_u64(42);
294        let sample = sample_multiple_l_reservoir(&mut rng, &data, 10);
295
296        // Should return empty vector for empty container
297        assert_eq!(sample.len(), 0);
298    }
299
300    #[test]
301    fn test_sample_multiple_l_reservoir_zero_requested() {
302        let data: Vec<u32> = (0..100).collect();
303        let mut rng = StdRng::seed_from_u64(42);
304        let sample = sample_multiple_l_reservoir(&mut rng, &data, 0);
305
306        // Should return empty vector when 0 requested
307        assert_eq!(sample.len(), 0);
308    }
309
310    #[test]
311    fn test_sample_multiple_l_reservoir_requested_exceeds_population() {
312        let data: Vec<u32> = (0..50).collect();
313        let requested = 100;
314        let mut rng = StdRng::seed_from_u64(42);
315        let sample = sample_multiple_l_reservoir(&mut rng, data, requested);
316
317        // Should return all available items when requested > population
318        assert_eq!(sample.len(), 50);
319
320        // All elements should be unique
321        let unique: HashSet<_> = sample.iter().collect();
322        assert_eq!(unique.len(), 50);
323
324        // All elements should be from the original data
325        assert!(sample.iter().all(|v| *v < 50));
326    }
327
328    #[test]
329    fn test_sample_multiple_l_reservoir_exact_population() {
330        let data: Vec<u32> = (0..100).collect();
331        let mut rng = StdRng::seed_from_u64(42);
332        let sample = sample_multiple_l_reservoir(&mut rng, data, 100);
333
334        // Should return all elements when requested == population
335        assert_eq!(sample.len(), 100);
336
337        let unique: HashSet<_> = sample.iter().collect();
338        assert_eq!(unique.len(), 100);
339    }
340
341    #[test]
342    fn test_sample_multiple_l_reservoir_single_element() {
343        let data: Vec<u32> = vec![42];
344        let mut rng = StdRng::seed_from_u64(1);
345        let sample = sample_multiple_l_reservoir(&mut rng, data, 1);
346
347        assert_eq!(sample.len(), 1);
348        assert_eq!(sample[0], 42);
349    }
350
351    #[test]
352    fn test_sample_multiple_l_reservoir_hashset() {
353        let mut data = HashSet::new();
354        for i in 0..100 {
355            data.insert(i);
356        }
357
358        let mut rng = StdRng::seed_from_u64(42);
359        let sample = sample_multiple_l_reservoir(&mut rng, &data, 10);
360
361        assert_eq!(sample.len(), 10);
362
363        // All sampled values should be in the original set
364        assert!(sample.iter().all(|v| data.contains(v)));
365
366        // No duplicates
367        let unique: HashSet<_> = sample.iter().collect();
368        assert_eq!(unique.len(), 10);
369    }
370
371    #[test]
372    fn test_sample_multiple_l_reservoir_small_sample() {
373        let data: Vec<u32> = (0..1000).collect();
374        let requested = 5;
375        let mut rng = StdRng::seed_from_u64(42);
376        let sample = sample_multiple_l_reservoir(&mut rng, &data, requested);
377
378        assert_eq!(sample.len(), requested);
379
380        // No duplicates
381        let unique: HashSet<_> = sample.iter().collect();
382        assert_eq!(unique.len(), requested);
383    }
384
385    #[test]
386    fn test_sample_multiple_l_reservoir_large_sample() {
387        let data: Vec<u32> = (0..1000).collect();
388        let requested = 900;
389        let mut rng = StdRng::seed_from_u64(42);
390        let sample = sample_multiple_l_reservoir(&mut rng, &data, requested);
391
392        assert_eq!(sample.len(), requested);
393
394        // No duplicates
395        let unique: HashSet<_> = sample.iter().collect();
396        assert_eq!(unique.len(), requested);
397    }
398
399    // Verifies that the reservoir sampling algorithm produces uniformly distributed
400    // samples by running it 1000 times and checking that the resulting chi-square
401    // statistics follow the expected chi-square(9) distribution. Note that this
402    // test is only approximately correct, reasonable only when `requested` is small
403    // relative to `population`, because `sample_multiple_l_reservoir` samples
404    // without replacement, while the chi-squared test assumes independent samples.
405    #[test]
406    fn test_sample_multiple_l_reservoir_uniformity() {
407        let population: u32 = 10000;
408        let data: Vec<u32> = (0..population).collect();
409        let requested = 100;
410        let num_runs = 1000;
411        let mut chi_squares = Vec::with_capacity(num_runs);
412
413        for run in 0..num_runs {
414            let mut rng = StdRng::seed_from_u64(42 + run as u64);
415            let sample = sample_multiple_l_reservoir(&mut rng, data.iter().cloned(), requested);
416
417            // Partition range 0..population into 10 equal-width bins
418            let mut counts = [0usize; 10];
419            for &value in &sample {
420                let bin = (value as usize) / (population as usize / 10);
421                counts[bin] += 1;
422            }
423
424            // Expected count per bin for uniform sampling
425            let expected = requested as f64 / 10.0; // = 10.0
426
427            // Compute chi-square statistic
428            let chi_square: f64 = counts
429                .iter()
430                .map(|&obs| {
431                    let diff = (obs as f64) - expected;
432                    diff * diff / expected
433                })
434                .sum();
435
436            chi_squares.push(chi_square);
437        }
438
439        // Now test that chi_squares follow a chi-square distribution with df=9
440        // We use quantiles of the chi-square(9) distribution to create bins
441        // and check if the observed counts match the expected uniform distribution
442
443        // Quantiles of chi-square distribution with df=9 at deciles (10 bins)
444        // These values define the bin boundaries such that each bin should contain
445        // 10% of the observations if they truly follow chi-square(9).
446        // Generate with Mathematica:
447        //     Table[Quantile[ChiSquareDistribution[9], p/10], {p, 0, 10}]//N
448        let quantiles = [
449            0.0,           // 0th percentile (minimum)
450            4.16816,       // 10th percentile
451            5.38005,       // 20th percentile
452            6.39331,       // 30th percentile
453            7.35703,       // 40th percentile
454            8.34283,       // 50th percentile (median)
455            9.41364,       // 60th percentile
456            10.6564,       // 70th percentile
457            12.2421,       // 80th percentile
458            14.6837,       // 90th percentile
459            f64::INFINITY, // 100th percentile (maximum)
460        ];
461
462        let num_bins = quantiles.len() - 1;
463        let mut chi_square_counts = vec![0usize; num_bins];
464
465        for &chi_sq in &chi_squares {
466            // Find which bin this chi-square value falls into
467            for i in 0..num_bins {
468                if chi_sq >= quantiles[i] && chi_sq < quantiles[i + 1] {
469                    chi_square_counts[i] += 1;
470                    break;
471                }
472            }
473        }
474
475        // Each bin should contain approximately num_runs / num_bins observations
476        let expected_per_bin = num_runs as f64 / num_bins as f64;
477        let chi_square_of_chi_squares: f64 = chi_square_counts
478            .iter()
479            .map(|&obs| {
480                let diff = (obs as f64) - expected_per_bin;
481                diff * diff / expected_per_bin
482            })
483            .sum();
484
485        // Degrees of freedom = (#bins - 1) = 9
486        // Critical χ²₀.₉₉₉ for df=9 is 27.877
487        let critical = 27.877;
488
489        println!(
490            "χ² = {}, counts = {:?}",
491            chi_square_of_chi_squares, chi_square_counts
492        );
493
494        assert!(
495            chi_square_of_chi_squares < critical,
496            "Chi-square statistics fail to follow chi-square(9) distribution: χ² = {}, counts = {:?}",
497            chi_square_of_chi_squares,
498            chi_square_counts
499        );
500    }
501
502    // Test that each element has equal probability of being selected
503    #[test]
504    fn test_sample_multiple_l_reservoir_element_probability() {
505        let population: u32 = 100;
506        let data: Vec<u32> = (0..population).collect();
507        let requested = 10;
508        let num_runs = 10000;
509        let mut selection_counts = vec![0usize; population as usize];
510
511        for run in 0..num_runs {
512            let mut rng = StdRng::seed_from_u64(42 + run as u64);
513            let sample = sample_multiple_l_reservoir(&mut rng, data.iter().cloned(), requested);
514
515            for &value in &sample {
516                selection_counts[value as usize] += 1;
517            }
518        }
519
520        // Each element should be selected with probability requested/population
521        // Expected count per element
522        let expected = (num_runs * requested) as f64 / population as f64;
523
524        // Compute chi-square statistic
525        let chi_square: f64 = selection_counts
526            .iter()
527            .map(|&obs| {
528                let diff = (obs as f64) - expected;
529                diff * diff / expected
530            })
531            .sum();
532
533        // Degrees of freedom = population - 1 = 99.
534        // Critical value uses p = 0.999 (alpha = 0.001): χ²_{0.999, 99} ≈ 148.23
535        // from the inverse chi-square CDF.
536        let critical = 148.23;
537
538        println!(
539            "χ² = {}, expected = {}, min = {}, max = {}",
540            chi_square,
541            expected,
542            selection_counts.iter().min().unwrap(),
543            selection_counts.iter().max().unwrap()
544        );
545
546        assert!(
547            chi_square < critical,
548            "Element selection probabilities are not uniform: χ² = {}",
549            chi_square
550        );
551    }
552
553    // Test reproducibility with same seed
554    #[test]
555    fn test_sample_multiple_l_reservoir_reproducibility() {
556        let data: Vec<u32> = (0..1000).collect();
557        let test_sizes = [1, 2, 5, 10, 100, 500];
558
559        for &requested in &test_sizes {
560            let seed: u64 = 12345;
561
562            let mut rng1 = StdRng::seed_from_u64(seed);
563            let sample1 = sample_multiple_l_reservoir(&mut rng1, &data, requested);
564
565            let mut rng2 = StdRng::seed_from_u64(seed);
566            let sample2 = sample_multiple_l_reservoir(&mut rng2, &data, requested);
567
568            // Verify correct sample size
569            assert_eq!(
570                sample1.len(),
571                requested,
572                "Sample size {} doesn't match requested size {}",
573                sample1.len(),
574                requested
575            );
576            assert_eq!(
577                sample2.len(),
578                requested,
579                "Sample size {} doesn't match requested size {}",
580                sample2.len(),
581                requested
582            );
583
584            // Same seed should produce identical samples
585            assert_eq!(
586                sample1, sample2,
587                "Reproducibility failed for requested={}",
588                requested
589            );
590        }
591    }
592
593    #[test]
594    fn test_sample_single_l_reservoir_reproducibility() {
595        let data: Vec<u32> = (0..1000).collect();
596        let seed: u64 = 12345;
597
598        let mut rng1 = StdRng::seed_from_u64(seed);
599        let sample1 = sample_single_l_reservoir(&mut rng1, &data);
600
601        let mut rng2 = StdRng::seed_from_u64(seed);
602        let sample2 = sample_single_l_reservoir(&mut rng2, &data);
603
604        // Same seed should produce identical samples
605        assert_eq!(sample1, sample2);
606    }
607}