ixa/random/
sampling_algorithms.rs

1//! Algorithms for uniform random sampling from hash sets or iterators. These algorithms are written to be generic
2//! over the container type using zero-cost trait abstractions.
3
4use crate::rand::seq::index::sample as choose_range;
5use crate::rand::Rng;
6
7/// Samples one element uniformly at random from an iterator whose length is known at runtime.
8///
9/// The caller must ensure that `(len, Some(len)) == iter.size_hint()`, i.e. the iterator
10/// reports its exact length via `size_hint`. We do not require `ExactSizeIterator`
11/// because that is a compile-time guarantee, whereas our requirement is a runtime condition.
12///
13/// The implementation selects a random index and uses `Iterator::nth`. For iterators
14/// with O(1) `nth` (e.g., randomly indexable structures), this is very efficient.
15/// The selected value is cloned.
16///
17/// The iterator need only support iteration; random indexing is not required.
18/// This function is intended for use when the result set is indexed and its length is known.
19pub fn sample_single_from_known_length<I, R, T>(rng: &mut R, mut iter: I) -> Option<T>
20where
21    R: Rng,
22    I: Iterator<Item = T>,
23{
24    // It is the caller's responsibility to ensure that `(len, Some(len)) == iter.size_hint()`.
25    let (length, _) = iter.size_hint();
26    if length == 0 {
27        return None;
28    }
29    // This little trick with `u32` makes this function 30% faster.
30    let index = rng.random_range(0..length as u32) as usize;
31    // The set need not be randomly indexable, so we have to use the `nth` method.
32    iter.nth(index)
33}
34
35/// Sample a random element uniformly from an iterator of unknown length.
36///
37/// We do not assume the container is randomly indexable, only that it can be iterated over.
38///
39/// This function implements "Algorithm L" from KIM-HUNG LI
40/// Reservoir-Sampling Algorithms of Time Complexity O(n(1 + log(N/n)))
41/// <https://dl.acm.org/doi/pdf/10.1145/198429.198435>
42///
43/// This algorithm is significantly slower than the "known length" algorithm (factor
44/// of 10^4). The reservoir algorithm from [`rand`](crate::rand) reduces to the "known length"
45/// algorithm when `iterator.size_hint()` returns `(k, Some(k))` for some `k`. Otherwise,
46/// this algorithm is much faster than the [`rand`](crate::rand)  implementation (factor of 100).
47pub fn sample_single_l_reservoir<I, R, T>(rng: &mut R, iterable: I) -> Option<T>
48where
49    R: Rng,
50    I: IntoIterator<Item = T>,
51{
52    let mut iter = iterable.into_iter();
53    let mut weight: f64 = rng.random(); // controls skip distance distribution
54    let mut log_one_minus_weight = (-weight).ln_1p();
55    let mut chosen_item: T = iter.next()?; // the currently selected element
56
57    // Number of elements to skip before the next candidate to consider for the reservoir.
58    // `iter.nth(skip)` skips `skip` elements and returns the next one.
59    let mut skip = (rng.random::<f64>().ln() / log_one_minus_weight).floor() as usize;
60    weight *= rng.random::<f64>();
61    log_one_minus_weight = (-weight).ln_1p();
62
63    loop {
64        match iter.nth(skip) {
65            Some(item) => {
66                chosen_item = item;
67                skip = (rng.random::<f64>().ln() / log_one_minus_weight).floor() as usize;
68                weight *= rng.random::<f64>();
69                log_one_minus_weight = (-weight).ln_1p();
70            }
71            None => return Some(chosen_item),
72        }
73    }
74}
75
76/// Count elements and sample one element uniformly from an iterator of unknown
77/// length.
78///
79/// Returns `(count, sample)` where `count` is the total number of items observed
80/// and `sample` is `None` iff `count == 0`.
81///
82/// This uses single-item reservoir sampling while tracking total count.
83pub fn count_and_sample_single_l_reservoir<I, R, T>(rng: &mut R, iterable: I) -> (usize, Option<T>)
84where
85    R: Rng,
86    I: IntoIterator<Item = T>,
87{
88    let mut count = 0usize;
89    let mut chosen_item: Option<T> = None;
90
91    for item in iterable {
92        count += 1;
93        if rng.random_range(0..count as u64) == 0 {
94            chosen_item = Some(item);
95        }
96    }
97
98    (count, chosen_item)
99}
100
101/// Samples `requested` elements uniformly at random without replacement from an iterator
102/// whose length is known at runtime. Requires `len >= requested`.
103///
104/// The caller must ensure that `(len, Some(len)) == iter.size_hint()`, i.e. the iterator
105/// reports its exact length via `size_hint`. We do not require `ExactSizeIterator`
106/// because that is a compile-time guarantee, whereas our requirement is a runtime condition.
107///
108/// The implementation selects random indices and uses `Iterator::nth`. For iterators
109/// with O(1) `nth` (e.g., randomly indexable structures), this is very efficient.
110/// Selected values are cloned.
111///
112/// This strategy is particularly effective for small `requested` (≤ 5), since it
113/// avoids iterating over the entire set and is typically faster than reservoir sampling.
114pub fn sample_multiple_from_known_length<I, R, T>(rng: &mut R, iter: I, requested: usize) -> Vec<T>
115where
116    R: Rng,
117    I: IntoIterator<Item = T>,
118{
119    let mut iter = iter.into_iter();
120    // It is the caller's responsibility to ensure that `(length, Some(length)) == iter.size_hint()`.
121    let (length, _) = iter.size_hint();
122
123    let mut indexes = Vec::with_capacity(requested);
124    indexes.extend(choose_range(rng, length, requested));
125    indexes.sort_unstable();
126
127    let mut selected = Vec::with_capacity(requested);
128    let mut consumed: usize = 0; // number of elements consumed from the iterator so far
129
130    // `iter.nth(n)` skips `n` elements and returns the next one, so to reach
131    // index `idx` we skip `idx - consumed` where `consumed` tracks how many
132    // elements have already been consumed.
133    for idx in indexes {
134        if let Some(item) = iter.nth(idx - consumed) {
135            selected.push(item);
136        }
137        consumed = idx + 1;
138    }
139
140    selected
141}
142
143/// Sample multiple random elements uniformly without replacement from a container of unknown length. If
144/// more samples are requested than are in the set, the function returns as many items as it can.
145///
146/// The implementation uses `Iterator::nth`. Randomly indexable structures will have a O(1) `nth`
147/// implementation and will be very efficient. The values are cloned.
148///
149/// This function implements "Algorithm L" from KIM-HUNG LI
150/// Reservoir-Sampling Algorithms of Time Complexity O(n(1 + log(N/n)))
151/// <https://dl.acm.org/doi/pdf/10.1145/198429.198435>
152///
153/// This algorithm is significantly faster than the reservoir algorithm in `rand` and is
154/// on par with the "known length" algorithm for large `requested` values.
155pub fn sample_multiple_l_reservoir<I, R, T>(rng: &mut R, iter: I, requested: usize) -> Vec<T>
156where
157    R: Rng,
158    I: IntoIterator<Item = T>,
159{
160    if requested == 0 {
161        return Vec::new();
162    }
163
164    let requested_recip = 1.0 / requested as f64;
165    let mut weight: f64 = rng.random(); // controls skip distance distribution
166    weight = weight.powf(requested_recip);
167    let mut log_one_minus_weight = (-weight).ln_1p();
168    let mut iter = iter.into_iter();
169    let mut reservoir: Vec<T> = iter.by_ref().take(requested).collect(); // the sample reservoir
170
171    if reservoir.len() < requested {
172        return reservoir;
173    }
174
175    // Number of elements to skip before the next candidate to consider for the reservoir.
176    // `iter.nth(skip)` skips `skip` elements and returns the next one.
177    let mut skip = (rng.random::<f64>().ln() / log_one_minus_weight).floor() as usize;
178    let uniform_random: f64 = rng.random();
179    weight *= uniform_random.powf(requested_recip);
180    log_one_minus_weight = (-weight).ln_1p();
181
182    loop {
183        match iter.nth(skip) {
184            Some(item) => {
185                let to_remove = rng.random_range(0..reservoir.len());
186                reservoir.swap_remove(to_remove);
187                reservoir.push(item);
188
189                skip = (rng.random::<f64>().ln() / log_one_minus_weight).floor() as usize;
190                let uniform_random: f64 = rng.random();
191                weight *= uniform_random.powf(requested_recip);
192                log_one_minus_weight = (-weight).ln_1p();
193            }
194            None => return reservoir,
195        }
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use rand::rngs::StdRng;
202    use rand::SeedableRng;
203
204    use super::*;
205    use crate::hashing::{HashSet, HashSetExt};
206
207    #[test]
208    fn test_sample_single_l_reservoir_basic() {
209        let data: Vec<u32> = (0..1000).collect();
210        let seed: u64 = 42;
211        let mut rng = StdRng::seed_from_u64(seed);
212        let sample = sample_single_l_reservoir(&mut rng, data);
213
214        // Should return Some value
215        assert!(sample.is_some());
216
217        // Value should be in valid range
218        let value = sample.unwrap();
219        assert!(value < 1000);
220    }
221
222    #[test]
223    fn test_sample_single_l_reservoir_empty() {
224        let data: Vec<u32> = Vec::new();
225        let mut rng = StdRng::seed_from_u64(42);
226        let sample = sample_single_l_reservoir(&mut rng, data);
227
228        // Should return None for empty container
229        assert!(sample.is_none());
230    }
231
232    #[test]
233    fn test_sample_single_l_reservoir_single_element() {
234        let data: Vec<u32> = vec![42];
235        let mut rng = StdRng::seed_from_u64(1);
236        let sample = sample_single_l_reservoir(&mut rng, data);
237
238        // Should return the only element
239        assert_eq!(sample, Some(42));
240    }
241
242    #[test]
243    fn test_sample_single_l_reservoir_uniformity() {
244        let population: u32 = 1000;
245        let data: Vec<u32> = (0..population).collect();
246        let num_runs = 10000;
247        let num_bins = 10;
248        let mut counts = vec![0usize; num_bins];
249
250        for run in 0..num_runs {
251            let mut rng = StdRng::seed_from_u64(42 + run as u64);
252            let sample = sample_single_l_reservoir(&mut rng, data.iter().cloned());
253
254            if let Some(value) = sample {
255                let bin = (value as usize) / (population as usize / num_bins);
256                counts[bin] += 1;
257            }
258        }
259
260        // Expected count per bin for uniform sampling
261        let expected = num_runs as f64 / num_bins as f64;
262
263        // Compute chi-square statistic
264        let chi_square: f64 = counts
265            .iter()
266            .map(|&obs| {
267                let diff = (obs as f64) - expected;
268                diff * diff / expected
269            })
270            .sum();
271
272        // Degrees of freedom = num_bins - 1 = 9
273        // Critical χ²₀.₉₉₉ for df=9 is 27.877
274        let critical = 27.877;
275
276        println!("χ² = {}, counts = {:?}", chi_square, counts);
277
278        assert!(
279            chi_square < critical,
280            "Single sample fails uniformity test: χ² = {}, counts = {:?}",
281            chi_square,
282            counts
283        );
284    }
285
286    #[test]
287    fn test_sample_single_l_reservoir_hashset() {
288        let mut data = HashSet::new();
289        for i in 0..100 {
290            data.insert(i);
291        }
292
293        let mut rng = StdRng::seed_from_u64(42);
294        let sample = sample_single_l_reservoir(&mut rng, &data);
295
296        assert!(sample.is_some());
297        let value = sample.unwrap();
298        assert!(data.contains(value));
299    }
300
301    #[test]
302    fn test_count_and_sample_single_l_reservoir_empty() {
303        let data: Vec<u32> = Vec::new();
304        let mut rng = StdRng::seed_from_u64(42);
305        let (count, sample) = count_and_sample_single_l_reservoir(&mut rng, data);
306        assert_eq!(count, 0);
307        assert!(sample.is_none());
308    }
309
310    #[test]
311    fn test_count_and_sample_single_l_reservoir_count_matches() {
312        let data: Vec<u32> = (0..1000).collect();
313        let mut rng = StdRng::seed_from_u64(42);
314        let (count, sample) = count_and_sample_single_l_reservoir(&mut rng, data);
315        assert_eq!(count, 1000);
316        assert!(sample.is_some());
317    }
318
319    #[test]
320    fn test_count_and_sample_single_l_reservoir_single_element() {
321        let data: Vec<u32> = vec![7];
322        let mut rng = StdRng::seed_from_u64(42);
323        let (count, sample) = count_and_sample_single_l_reservoir(&mut rng, data);
324        assert_eq!(count, 1);
325        assert_eq!(sample, Some(7));
326    }
327
328    #[test]
329    fn test_sample_multiple_l_reservoir_basic() {
330        let data: Vec<u32> = (0..1000).collect();
331        let requested = 100;
332        let seed: u64 = 42;
333        let mut rng = StdRng::seed_from_u64(seed);
334        let sample = sample_multiple_l_reservoir(&mut rng, data, requested);
335
336        // Correct sample size
337        assert_eq!(sample.len(), requested);
338
339        // All sampled values are within the valid range
340        assert!(sample.iter().all(|v| *v < 1000));
341
342        // The sample should not have duplicates
343        let unique: HashSet<_> = sample.iter().collect();
344        assert_eq!(unique.len(), sample.len());
345    }
346
347    #[test]
348    fn test_sample_multiple_l_reservoir_empty() {
349        let data: Vec<u32> = Vec::new();
350        let mut rng = StdRng::seed_from_u64(42);
351        let sample = sample_multiple_l_reservoir(&mut rng, &data, 10);
352
353        // Should return empty vector for empty container
354        assert_eq!(sample.len(), 0);
355    }
356
357    #[test]
358    fn test_sample_multiple_l_reservoir_zero_requested() {
359        let data: Vec<u32> = (0..100).collect();
360        let mut rng = StdRng::seed_from_u64(42);
361        let sample = sample_multiple_l_reservoir(&mut rng, &data, 0);
362
363        // Should return empty vector when 0 requested
364        assert_eq!(sample.len(), 0);
365    }
366
367    #[test]
368    fn test_sample_multiple_l_reservoir_requested_exceeds_population() {
369        let data: Vec<u32> = (0..50).collect();
370        let requested = 100;
371        let mut rng = StdRng::seed_from_u64(42);
372        let sample = sample_multiple_l_reservoir(&mut rng, data, requested);
373
374        // Should return all available items when requested > population
375        assert_eq!(sample.len(), 50);
376
377        // All elements should be unique
378        let unique: HashSet<_> = sample.iter().collect();
379        assert_eq!(unique.len(), 50);
380
381        // All elements should be from the original data
382        assert!(sample.iter().all(|v| *v < 50));
383    }
384
385    #[test]
386    fn test_sample_multiple_l_reservoir_exact_population() {
387        let data: Vec<u32> = (0..100).collect();
388        let mut rng = StdRng::seed_from_u64(42);
389        let sample = sample_multiple_l_reservoir(&mut rng, data, 100);
390
391        // Should return all elements when requested == population
392        assert_eq!(sample.len(), 100);
393
394        let unique: HashSet<_> = sample.iter().collect();
395        assert_eq!(unique.len(), 100);
396    }
397
398    #[test]
399    fn test_sample_multiple_l_reservoir_single_element() {
400        let data: Vec<u32> = vec![42];
401        let mut rng = StdRng::seed_from_u64(1);
402        let sample = sample_multiple_l_reservoir(&mut rng, data, 1);
403
404        assert_eq!(sample.len(), 1);
405        assert_eq!(sample[0], 42);
406    }
407
408    #[test]
409    fn test_sample_multiple_l_reservoir_hashset() {
410        let mut data = HashSet::new();
411        for i in 0..100 {
412            data.insert(i);
413        }
414
415        let mut rng = StdRng::seed_from_u64(42);
416        let sample = sample_multiple_l_reservoir(&mut rng, &data, 10);
417
418        assert_eq!(sample.len(), 10);
419
420        // All sampled values should be in the original set
421        assert!(sample.iter().all(|v| data.contains(v)));
422
423        // No duplicates
424        let unique: HashSet<_> = sample.iter().collect();
425        assert_eq!(unique.len(), 10);
426    }
427
428    #[test]
429    fn test_sample_multiple_l_reservoir_small_sample() {
430        let data: Vec<u32> = (0..1000).collect();
431        let requested = 5;
432        let mut rng = StdRng::seed_from_u64(42);
433        let sample = sample_multiple_l_reservoir(&mut rng, &data, requested);
434
435        assert_eq!(sample.len(), requested);
436
437        // No duplicates
438        let unique: HashSet<_> = sample.iter().collect();
439        assert_eq!(unique.len(), requested);
440    }
441
442    #[test]
443    fn test_sample_multiple_l_reservoir_large_sample() {
444        let data: Vec<u32> = (0..1000).collect();
445        let requested = 900;
446        let mut rng = StdRng::seed_from_u64(42);
447        let sample = sample_multiple_l_reservoir(&mut rng, &data, requested);
448
449        assert_eq!(sample.len(), requested);
450
451        // No duplicates
452        let unique: HashSet<_> = sample.iter().collect();
453        assert_eq!(unique.len(), requested);
454    }
455
456    // Verifies that the reservoir sampling algorithm produces uniformly distributed
457    // samples by running it 1000 times and checking that the resulting chi-square
458    // statistics follow the expected chi-square(9) distribution. Note that this
459    // test is only approximately correct, reasonable only when `requested` is small
460    // relative to `population`, because `sample_multiple_l_reservoir` samples
461    // without replacement, while the chi-squared test assumes independent samples.
462    #[test]
463    fn test_sample_multiple_l_reservoir_uniformity() {
464        let population: u32 = 10000;
465        let data: Vec<u32> = (0..population).collect();
466        let requested = 100;
467        let num_runs = 1000;
468        let mut chi_squares = Vec::with_capacity(num_runs);
469
470        for run in 0..num_runs {
471            let mut rng = StdRng::seed_from_u64(42 + run as u64);
472            let sample = sample_multiple_l_reservoir(&mut rng, data.iter().cloned(), requested);
473
474            // Partition range 0..population into 10 equal-width bins
475            let mut counts = [0usize; 10];
476            for &value in &sample {
477                let bin = (value as usize) / (population as usize / 10);
478                counts[bin] += 1;
479            }
480
481            // Expected count per bin for uniform sampling
482            let expected = requested as f64 / 10.0; // = 10.0
483
484            // Compute chi-square statistic
485            let chi_square: f64 = counts
486                .iter()
487                .map(|&obs| {
488                    let diff = (obs as f64) - expected;
489                    diff * diff / expected
490                })
491                .sum();
492
493            chi_squares.push(chi_square);
494        }
495
496        // Now test that chi_squares follow a chi-square distribution with df=9
497        // We use quantiles of the chi-square(9) distribution to create bins
498        // and check if the observed counts match the expected uniform distribution
499
500        // Quantiles of chi-square distribution with df=9 at deciles (10 bins)
501        // These values define the bin boundaries such that each bin should contain
502        // 10% of the observations if they truly follow chi-square(9).
503        // Generate with Mathematica:
504        //     Table[Quantile[ChiSquareDistribution[9], p/10], {p, 0, 10}]//N
505        let quantiles = [
506            0.0,           // 0th percentile (minimum)
507            4.16816,       // 10th percentile
508            5.38005,       // 20th percentile
509            6.39331,       // 30th percentile
510            7.35703,       // 40th percentile
511            8.34283,       // 50th percentile (median)
512            9.41364,       // 60th percentile
513            10.6564,       // 70th percentile
514            12.2421,       // 80th percentile
515            14.6837,       // 90th percentile
516            f64::INFINITY, // 100th percentile (maximum)
517        ];
518
519        let num_bins = quantiles.len() - 1;
520        let mut chi_square_counts = vec![0usize; num_bins];
521
522        for &chi_sq in &chi_squares {
523            // Find which bin this chi-square value falls into
524            for i in 0..num_bins {
525                if chi_sq >= quantiles[i] && chi_sq < quantiles[i + 1] {
526                    chi_square_counts[i] += 1;
527                    break;
528                }
529            }
530        }
531
532        // Each bin should contain approximately num_runs / num_bins observations
533        let expected_per_bin = num_runs as f64 / num_bins as f64;
534        let chi_square_of_chi_squares: f64 = chi_square_counts
535            .iter()
536            .map(|&obs| {
537                let diff = (obs as f64) - expected_per_bin;
538                diff * diff / expected_per_bin
539            })
540            .sum();
541
542        // Degrees of freedom = (#bins - 1) = 9
543        // Critical χ²₀.₉₉₉ for df=9 is 27.877
544        let critical = 27.877;
545
546        println!(
547            "χ² = {}, counts = {:?}",
548            chi_square_of_chi_squares, chi_square_counts
549        );
550
551        assert!(
552            chi_square_of_chi_squares < critical,
553            "Chi-square statistics fail to follow chi-square(9) distribution: χ² = {}, counts = {:?}",
554            chi_square_of_chi_squares,
555            chi_square_counts
556        );
557    }
558
559    // Test that each element has equal probability of being selected
560    #[test]
561    fn test_sample_multiple_l_reservoir_element_probability() {
562        let population: u32 = 100;
563        let data: Vec<u32> = (0..population).collect();
564        let requested = 10;
565        let num_runs = 10000;
566        let mut selection_counts = vec![0usize; population as usize];
567
568        for run in 0..num_runs {
569            let mut rng = StdRng::seed_from_u64(42 + run as u64);
570            let sample = sample_multiple_l_reservoir(&mut rng, data.iter().cloned(), requested);
571
572            for &value in &sample {
573                selection_counts[value as usize] += 1;
574            }
575        }
576
577        // Each element should be selected with probability requested/population
578        // Expected count per element
579        let expected = (num_runs * requested) as f64 / population as f64;
580
581        // Compute chi-square statistic
582        let chi_square: f64 = selection_counts
583            .iter()
584            .map(|&obs| {
585                let diff = (obs as f64) - expected;
586                diff * diff / expected
587            })
588            .sum();
589
590        // Degrees of freedom = population - 1 = 99.
591        // Critical value uses p = 0.999 (alpha = 0.001): χ²_{0.999, 99} ≈ 148.23
592        // from the inverse chi-square CDF.
593        let critical = 148.23;
594
595        println!(
596            "χ² = {}, expected = {}, min = {}, max = {}",
597            chi_square,
598            expected,
599            selection_counts.iter().min().unwrap(),
600            selection_counts.iter().max().unwrap()
601        );
602
603        assert!(
604            chi_square < critical,
605            "Element selection probabilities are not uniform: χ² = {}",
606            chi_square
607        );
608    }
609
610    // Test reproducibility with same seed
611    #[test]
612    fn test_sample_multiple_l_reservoir_reproducibility() {
613        let data: Vec<u32> = (0..1000).collect();
614        let test_sizes = [1, 2, 5, 10, 100, 500];
615
616        for &requested in &test_sizes {
617            let seed: u64 = 12345;
618
619            let mut rng1 = StdRng::seed_from_u64(seed);
620            let sample1 = sample_multiple_l_reservoir(&mut rng1, &data, requested);
621
622            let mut rng2 = StdRng::seed_from_u64(seed);
623            let sample2 = sample_multiple_l_reservoir(&mut rng2, &data, requested);
624
625            // Verify correct sample size
626            assert_eq!(
627                sample1.len(),
628                requested,
629                "Sample size {} doesn't match requested size {}",
630                sample1.len(),
631                requested
632            );
633            assert_eq!(
634                sample2.len(),
635                requested,
636                "Sample size {} doesn't match requested size {}",
637                sample2.len(),
638                requested
639            );
640
641            // Same seed should produce identical samples
642            assert_eq!(
643                sample1, sample2,
644                "Reproducibility failed for requested={}",
645                requested
646            );
647        }
648    }
649
650    #[test]
651    fn test_sample_single_l_reservoir_reproducibility() {
652        let data: Vec<u32> = (0..1000).collect();
653        let seed: u64 = 12345;
654
655        let mut rng1 = StdRng::seed_from_u64(seed);
656        let sample1 = sample_single_l_reservoir(&mut rng1, &data);
657
658        let mut rng2 = StdRng::seed_from_u64(seed);
659        let sample2 = sample_single_l_reservoir(&mut rng2, &data);
660
661        // Same seed should produce identical samples
662        assert_eq!(sample1, sample2);
663    }
664}