ixa/random/
sampling_algorithms.rs

1//! Algorithms for uniform random sampling from hash sets or iterators. These algorithms are written to be generic
2//! over the container type using zero-cost trait abstractions.
3use crate::rand::{seq::index::sample as choose_range, Rng};
4use std::collections::{HashMap, HashSet};
5
6/// The `len` capability, a zero-cost abstraction for types that have a known length.
7pub trait HasLen {
8    fn len(&self) -> usize;
9}
10
11/// The `iter` capability, a zero-cost abstraction for types that can be iterated over.
12pub trait HasIter {
13    type Item<'a>
14    where
15        Self: 'a;
16    type Iter<'a>: Iterator<Item = Self::Item<'a>>
17    where
18        Self: 'a;
19
20    fn iter(&self) -> Self::Iter<'_>;
21}
22
23macro_rules! impl_has_len {
24    ($ty:ident < $($gen:ident),* >) => {
25        impl<$($gen),*> HasLen for $ty<$($gen),*> {
26            fn len(&self) -> usize {
27                <$ty<$($gen),*>>::len(self)
28            }
29        }
30    };
31}
32
33macro_rules! impl_has_iter {
34    ($ty:ident < $($gen:ident),* >, $iter:ty, $item:ty) => {
35        impl<$($gen),*> HasIter for $ty<$($gen),*> {
36            type Item<'a> = $item where Self: 'a;
37            type Iter<'a> = $iter where Self: 'a;
38
39            fn iter(&self) -> Self::Iter<'_> {
40                <$ty<$($gen),*>>::iter(self)
41            }
42        }
43    };
44}
45
46// Vec<T>
47impl_has_len!(Vec<T>);
48// We implement `HasIter` manually for `Vec<T>` because its `iter` method is from `Deref<Target = [T]>`.
49impl<T> HasIter for Vec<T> {
50    type Item<'a>
51        = &'a T
52    where
53        Self: 'a;
54    type Iter<'a>
55        = std::slice::Iter<'a, T>
56    where
57        Self: 'a;
58
59    fn iter(&self) -> Self::Iter<'_> {
60        <[T]>::iter(self)
61    }
62}
63
64// HashSet<T, H>
65impl_has_len!(HashSet<T, H>);
66impl_has_iter!(HashSet<T, H>, std::collections::hash_set::Iter<'a, T>, &'a T);
67
68// HashMap<K, V, H>
69impl_has_len!(HashMap<K, V, H>);
70impl_has_iter!(HashMap<K, V, H>, std::collections::hash_map::Iter<'a, K, V>, (&'a K, &'a V));
71
72/// Sample a random element uniformly from a container of known length.
73///
74/// We do not assume the container is randomly indexable, only that it can be iterated over. The value is cloned.
75/// This algorithm is used when the property is indexed, and thus we know the length of the result set.
76pub fn sample_single_from_known_length<'a, Container, R, T>(
77    rng: &mut R,
78    set: &'a Container,
79) -> Option<T>
80where
81    R: Rng,
82    Container: HasLen + HasIter<Item<'a> = &'a T>,
83    T: Clone + 'static,
84{
85    let len = set.len();
86    if len == 0 {
87        return None;
88    }
89    // This little trick with `u32` makes this function 30% faster.
90    let index = rng.random_range(0..len as u32) as usize;
91    // The set need not be randomly indexable, so we have to use the `nth` method.
92    set.iter().nth(index).cloned()
93}
94
95/// Sample a random element uniformly from a container of unknown length.
96///
97/// We do not assume the container is randomly indexable, only that it can be iterated over. The value is cloned.
98///
99/// This function implements "Algorithm L" from KIM-HUNG LI
100/// Reservoir-Sampling Algorithms of Time Complexity O(n(1 + log(N/n)))
101/// <https://dl.acm.org/doi/pdf/10.1145/198429.198435>
102///
103/// This algorithm is significantly slower than the "known length" algorithm (factor
104/// of 10^4). The reservoir algorithm from `rand` reduces to the "known length`
105/// algorithm when the iterator is an `ExactSizeIterator`, or more precisely,
106/// when `iterator.size_hint()` returns `(k, Some(k))` for some `k`. Otherwise,
107/// this algorithm is much faster than the `rand` implementation (factor of 100).
108// ToDo(RobertJacobsonCDC): This function will take an iterator once the `iter_query_results` API is ready.
109pub fn sample_single_l_reservoir<'a, Container, R, T>(rng: &mut R, set: &'a Container) -> Option<T>
110where
111    R: Rng,
112    Container: HasIter<Item<'a> = &'a T>,
113    T: Clone + 'static,
114{
115    let mut chosen_item: Option<T> = None; // the currently selected element
116    let mut weight: f64 = rng.random_range(0.0..1.0); // controls skip distance distribution
117    let mut position: usize = 0; // current index in data
118    let mut next_pick_position: usize = 1; // index of the next item to pick
119
120    set.iter().for_each(|item| {
121        position += 1;
122        if position == next_pick_position {
123            chosen_item = Some(item.clone());
124            next_pick_position +=
125                (f64::ln(rng.random_range(0.0..1.0)) / f64::ln(1.0 - weight)).floor() as usize + 1;
126            weight *= rng.random_range(0.0..1.0);
127        }
128    });
129
130    chosen_item
131}
132
133/// Sample multiple random elements uniformly without replacement from a container of known length.
134/// This function assumes `set.len() >= requested`.
135///
136/// We do not assume the container is randomly indexable, only that it can be iterated over. The values are cloned.
137///
138/// This algorithm can be used when the property is indexed, and thus we know the length of the result set.
139/// For very small `requested` values (<=5), this algorithm is faster than reservoir because it doesn't
140/// iterate over the entire set.
141pub fn sample_multiple_from_known_length<'a, Container, R, T>(
142    rng: &mut R,
143    set: &'a Container,
144    requested: usize,
145) -> Vec<T>
146where
147    R: Rng,
148    Container: HasLen + HasIter<Item<'a> = &'a T>,
149    T: Clone + 'static,
150{
151    let mut indexes = Vec::with_capacity(requested);
152    indexes.extend(choose_range(rng, set.len(), requested));
153    indexes.sort_unstable();
154    let mut index_iterator = indexes.into_iter();
155    let mut next_idx = index_iterator.next().unwrap();
156    let mut selected = Vec::with_capacity(requested);
157
158    for (idx, item) in set.iter().enumerate() {
159        if idx == next_idx {
160            selected.push(item.clone());
161            if let Some(i) = index_iterator.next() {
162                next_idx = i;
163            } else {
164                break;
165            }
166        }
167    }
168
169    selected
170}
171
172/// Sample multiple random elements uniformly without replacement from a container of known length. If
173/// more samples are requested than are in the set, the function returns as many items as it can.
174///
175/// We do not assume the container is randomly indexable, only that it can be iterated over. The values are cloned.
176///
177/// This function implements "Algorithm L" from KIM-HUNG LI
178/// Reservoir-Sampling Algorithms of Time Complexity O(n(1 + log(N/n)))
179/// <https://dl.acm.org/doi/pdf/10.1145/198429.198435>
180///
181/// This algorithm is significantly faster than the reservoir algorithm in `rand` and is
182/// on par with the "known length" algorithm for large `requested` values.
183// ToDo(RobertJacobsonCDC): This function will take an iterator once the `iter_query_results` API is ready.
184pub fn sample_multiple_l_reservoir<'a, Container, R, T>(
185    rng: &mut R,
186    set: &'a Container,
187    requested: usize,
188) -> Vec<T>
189where
190    R: Rng,
191    Container: HasLen + HasIter<Item<'a> = &'a T>,
192    T: Clone + 'static,
193{
194    let mut weight: f64 = rng.random_range(0.0..1.0); // controls skip distance distribution
195    let mut position: usize = 0; // current index in data
196    let mut next_pick_position: usize = 1; // index of the next item to pick
197    let mut reservoir = Vec::with_capacity(requested); // the sample reservoir
198
199    set.iter().for_each(|item| {
200        position += 1;
201        if position == next_pick_position {
202            if reservoir.len() == requested {
203                let to_remove = rng.random_range(0..reservoir.len());
204                reservoir.swap_remove(to_remove);
205            }
206            reservoir.push(item.clone());
207
208            if reservoir.len() == requested {
209                next_pick_position += (f64::ln(rng.random_range(0.0..1.0)) / f64::ln(1.0 - weight))
210                    .floor() as usize
211                    + 1;
212                weight *= rng.random_range(0.0..1.0);
213            } else {
214                next_pick_position += 1;
215            }
216        }
217    });
218
219    reservoir
220}