1use std::cell::Ref;
22
23use log::warn;
24use rand::Rng;
25
26use crate::entity::entity_set::source_set::{SourceIterator, SourceSet};
27use crate::entity::{Entity, EntityId, PopulationIterator};
28use crate::hashing::IndexSet;
29use crate::random::{
30 sample_multiple_from_known_length, sample_multiple_l_reservoir, sample_single_l_reservoir,
31};
32
33pub struct EntitySetIterator<'c, E: Entity> {
35 source: SourceIterator<'c, E>,
36 sources: Vec<SourceSet<'c, E>>,
37}
38
39impl<'c, E: Entity> EntitySetIterator<'c, E> {
40 pub fn empty() -> EntitySetIterator<'c, E> {
43 EntitySetIterator {
44 source: SourceIterator::Empty,
45 sources: vec![],
46 }
47 }
48
49 pub(crate) fn from_population_iterator(iter: PopulationIterator<E>) -> Self {
52 EntitySetIterator {
53 source: SourceIterator::WholePopulation(iter),
54 sources: vec![],
55 }
56 }
57
58 pub(crate) fn from_sources(mut sources: Vec<SourceSet<'c, E>>) -> Self {
61 if sources.is_empty() {
62 return Self::empty();
63 }
64
65 sources.sort_unstable_by_key(|x| x.upper_len());
66 let source = sources.remove(0).into_iter();
67 EntitySetIterator { source, sources }
68 }
69
70 pub fn from_index_set(set: Ref<'c, IndexSet<EntityId<E>>>) -> EntitySetIterator<'c, E> {
71 EntitySetIterator {
72 source: SourceSet::IndexSet(set).into_iter(),
73 sources: vec![],
74 }
75 }
76
77 pub fn sample_entity<R>(mut self, rng: &mut R) -> Option<EntityId<E>>
80 where
81 R: Rng,
82 {
83 let (lower, upper) = self.size_hint();
85 if Some(lower) == upper {
86 if lower == 0 {
87 warn!("Requested a sample entity from an empty population");
88 return None;
89 }
90 let index = rng.random_range(0..lower as u32);
92 return self.nth(index as usize);
93 }
94
95 sample_single_l_reservoir(rng, self)
97 }
98
99 pub fn sample_entities<R>(self, rng: &mut R, requested: usize) -> Vec<EntityId<E>>
103 where
104 R: Rng,
105 {
106 match self.size_hint() {
107 (lower, Some(upper)) if lower == upper => {
108 if lower == 0 {
109 warn!("Requested a sample of entities from an empty population");
110 return vec![];
111 }
112 sample_multiple_from_known_length(rng, self, requested)
113 }
114 _ => sample_multiple_l_reservoir(rng, self, requested),
115 }
116 }
117}
118
119impl<'a, E: Entity> Iterator for EntitySetIterator<'a, E> {
120 type Item = EntityId<E>;
121
122 fn next(&mut self) -> Option<Self::Item> {
123 'outer: for entity_id in self.source.by_ref() {
127 for source in &self.sources {
129 if !source.contains(entity_id) {
130 continue 'outer;
131 }
132 }
133
134 return Some(entity_id);
136 }
137
138 None
139 }
140
141 fn size_hint(&self) -> (usize, Option<usize>) {
142 let (lower, upper) = self.source.size_hint();
143 if self.sources.is_empty() {
144 (lower, upper)
145 } else {
146 (0, upper)
149 }
150 }
151
152 fn count(self) -> usize {
153 if self.sources.is_empty() {
154 self.source.count()
156 } else {
157 self.fold(0, |n, _| n + 1)
158 }
159 }
160
161 fn nth(&mut self, n: usize) -> Option<Self::Item> {
162 if self.sources.is_empty() {
163 self.source.nth(n)
166 } else {
167 for _ in 0..n {
170 self.next()?;
171 }
172 self.next()
173 }
174 }
175}
176impl<'c, E: Entity> std::iter::FusedIterator for EntitySetIterator<'c, E> {}
177
178#[cfg(test)]
179mod tests {
180 use indexmap::IndexSet;
206
207 use crate::prelude::*;
208 use crate::{define_derived_property, define_property};
209
210 define_entity!(Person);
211
212 define_property!(struct ExplicitProp(u8), Person);
216
217 define_property!(struct ConstantProp(u8), Person, default_const = ConstantProp(42));
219
220 define_derived_property!(struct DerivedProp(bool), Person, [ExplicitProp], |explicit| {
222 DerivedProp(explicit.0 % 2 == 0)
223 });
224
225 define_property!(struct ConstantProp2(u16), Person, default_const = ConstantProp2(100));
227 define_property!(struct ExplicitProp2(bool), Person);
228
229 define_property!(struct Age(u8), Person, default_const = Age(0));
230 define_property!(struct Alive(bool), Person, default_const = Alive(true));
231
232 define_derived_property!(
233 enum AgeGroupRisk {
234 NewBorn,
235 General,
236 OldAdult,
237 },
238 Person,
239 [Age],
240 [],
241 |age| {
242 if age.0 <= 1 {
243 AgeGroupRisk::NewBorn
244 } else if age.0 <= 65 {
245 AgeGroupRisk::General
246 } else {
247 AgeGroupRisk::OldAdult
248 }
249 }
250 );
251
252 fn setup_test_population(context: &mut Context, size: usize) -> Vec<EntityId<Person>> {
254 let mut people = Vec::new();
255 for i in 0..size {
256 let person = context
257 .add_entity((ExplicitProp((i % 20) as u8), ExplicitProp2(i % 2 == 0)))
258 .unwrap();
259 people.push(person);
260 }
261 people
262 }
263
264 #[test]
268 fn test_explicit_non_default_not_indexed_initial_source_yes() {
269 let mut context = Context::new();
270 setup_test_population(&mut context, 100);
271
272 let results = context
273 .query_result_iterator((ExplicitProp(5),))
274 .collect::<Vec<_>>();
275
276 assert_eq!(results.len(), 5); for person in results {
278 assert_eq!(
279 context.get_property::<_, ExplicitProp>(person),
280 ExplicitProp(5)
281 );
282 }
283 }
284
285 #[test]
287 fn test_explicit_non_default_indexed_initial_source_yes() {
288 let mut context = Context::new();
289 context.index_property::<Person, ExplicitProp>();
290 setup_test_population(&mut context, 100);
291
292 let results = context
293 .query_result_iterator((ExplicitProp(7),))
294 .collect::<Vec<_>>();
295
296 assert_eq!(results.len(), 5); for person in results {
298 assert_eq!(
299 context.get_property::<_, ExplicitProp>(person),
300 ExplicitProp(7)
301 );
302 }
303 }
304
305 #[test]
307 fn test_constant_default_not_indexed_initial_source_yes() {
308 let mut context = Context::new();
309 for _ in 0..50 {
311 context
312 .add_entity((ExplicitProp(1), ExplicitProp2(false)))
313 .unwrap();
314 }
315
316 let results = context
317 .query_result_iterator((ConstantProp(42), ExplicitProp2(false)))
318 .collect::<Vec<_>>();
319
320 assert_eq!(results.len(), 50);
321 for person in results {
322 assert_eq!(
323 context.get_property::<_, ConstantProp>(person),
324 ConstantProp(42)
325 );
326 }
327 }
328
329 #[test]
331 fn test_constant_default_indexed_initial_source_yes() {
332 let mut context = Context::new();
333 context.index_property::<Person, ConstantProp>();
334
335 for _ in 0..50 {
336 context
337 .add_entity((ExplicitProp(1), ExplicitProp2(false)))
338 .unwrap();
339 }
340
341 let results = context
342 .query_result_iterator((ConstantProp(42),))
343 .collect::<Vec<_>>();
344 assert_eq!(results.len(), 50);
345 }
346
347 #[test]
349 fn test_constant_non_default_not_indexed_initial_source_yes() {
350 let mut context = Context::new();
351
352 for i in 0..50 {
353 if i < 10 {
354 context
355 .add_entity((ExplicitProp(1), ExplicitProp2(false), ConstantProp(99)))
356 .unwrap();
357 } else {
358 context
359 .add_entity((ExplicitProp(1), ExplicitProp2(false)))
360 .unwrap();
361 }
362 }
363
364 let results = context
365 .query_result_iterator((ConstantProp(99),))
366 .collect::<Vec<_>>();
367
368 assert_eq!(results.len(), 10);
369 for person in results {
370 assert_eq!(
371 context.get_property::<_, ConstantProp>(person),
372 ConstantProp(99)
373 );
374 }
375 }
376
377 #[test]
379 fn test_constant_non_default_indexed_initial_source_yes() {
380 let mut context = Context::new();
381 context.index_property::<Person, ConstantProp>();
382
383 for i in 0..50 {
384 if i < 10 {
385 context
386 .add_entity((ExplicitProp(1), ExplicitProp2(false), ConstantProp(99)))
387 .unwrap();
388 } else {
389 context
390 .add_entity((ExplicitProp(1), ExplicitProp2(false)))
391 .unwrap();
392 }
393 }
394
395 let results = context
396 .query_result_iterator((ConstantProp(99),))
397 .collect::<Vec<_>>();
398
399 assert_eq!(results.len(), 10);
400 }
401
402 #[test]
404 fn test_derived_not_indexed_initial_source_yes() {
405 let mut context = Context::new();
406
407 for i in 0..100 {
408 context
409 .add_entity((ExplicitProp(i as u8), ExplicitProp2(false)))
410 .unwrap();
411 }
412
413 let results = context
414 .query_result_iterator((DerivedProp(true),))
415 .collect::<Vec<_>>();
416
417 assert_eq!(results.len(), 50);
419 for person in results {
420 assert_eq!(
421 context.get_property::<Person, DerivedProp>(person),
422 DerivedProp(true)
423 );
424 }
425 }
426
427 #[test]
429 fn test_derived_indexed_initial_source_yes() {
430 let mut context = Context::new();
431 context.index_property::<Person, DerivedProp>();
432
433 for i in 0..100 {
434 context
435 .add_entity((ExplicitProp(i as u8), ExplicitProp2(false)))
436 .unwrap();
437 }
438
439 let results = context
440 .query_result_iterator((DerivedProp(false),))
441 .collect::<Vec<_>>();
442
443 assert_eq!(results.len(), 50);
445 for person in results {
446 assert_eq!(
447 context.get_property::<Person, DerivedProp>(person),
448 DerivedProp(false)
449 );
450 }
451 }
452
453 #[test]
457 fn test_explicit_non_default_not_indexed_initial_source_no() {
458 let mut context = Context::new();
459 context.index_property::<Person, ExplicitProp2>(); for i in 0..100 {
462 context
463 .add_entity((ExplicitProp((i % 20) as u8), ExplicitProp2(i % 2 == 0)))
464 .unwrap();
465 }
466
467 let results = context.query_result_iterator(()).collect::<Vec<_>>();
468 for person in results {
469 let explicit_prop = context.get_property::<Person, ExplicitProp>(person);
470 let explicit_prop2 = context.get_property::<Person, ExplicitProp2>(person);
471 println!("({:?} {:?} {:?})", person, explicit_prop, explicit_prop2);
472 }
473
474 let results = context
476 .query_result_iterator((ExplicitProp(5), ExplicitProp2(false)))
477 .collect::<Vec<_>>();
478
479 let expected = results.len();
483 assert!(expected > 0);
484 for person in results {
485 assert_eq!(
486 context.get_property::<Person, ExplicitProp>(person),
487 ExplicitProp(5)
488 );
489 assert_eq!(
490 context.get_property::<Person, ExplicitProp2>(person),
491 ExplicitProp2(false)
492 );
493 }
494 }
495
496 #[test]
498 fn test_explicit_non_default_indexed_initial_source_no() {
499 let mut context = Context::new();
500 context.index_property::<Person, ExplicitProp>();
501 context.index_property::<Person, ConstantProp2>(); for i in 0..100 {
504 if i < 10 {
505 context
506 .add_entity((
507 ExplicitProp(7),
508 ExplicitProp2(false),
509 ConstantProp2(200), ))
511 .unwrap();
512 } else {
513 context
514 .add_entity((ExplicitProp((i % 20) as u8), ExplicitProp2(false)))
515 .unwrap();
516 }
517 }
518
519 let results = context
520 .query_result_iterator((ExplicitProp(7), ConstantProp2(200)))
521 .collect::<Vec<_>>();
522
523 assert_eq!(results.len(), 10);
524 }
525
526 #[test]
528 fn test_constant_default_not_indexed_initial_source_no() {
529 let mut context = Context::new();
530 context.index_property::<Person, ExplicitProp>(); for i in 0..100 {
533 if i < 5 {
534 context
535 .add_entity((ExplicitProp(99), ExplicitProp2(false)))
536 .unwrap(); } else {
538 context
539 .add_entity((ExplicitProp((i % 20) as u8), ExplicitProp2(false)))
540 .unwrap();
541 }
542 }
543
544 let results = context
545 .query_result_iterator((ExplicitProp(99), ConstantProp(42)))
546 .collect::<Vec<_>>();
547
548 assert_eq!(results.len(), 5);
549 for person in results {
550 assert_eq!(
551 context.get_property::<Person, ConstantProp>(person),
552 ConstantProp(42)
553 );
554 }
555 }
556
557 #[test]
559 fn test_constant_non_default_indexed_initial_source_no() {
560 let mut context = Context::new();
561 context.index_property::<Person, ConstantProp>();
562 context.index_property::<Person, ExplicitProp2>();
563
564 for i in 0..100 {
565 if i < 10 {
566 context
567 .add_entity((ConstantProp(99), ExplicitProp(0), ExplicitProp2(true)))
568 .unwrap();
569 } else {
570 context
571 .add_entity((ExplicitProp(0), ExplicitProp2(false)))
572 .unwrap();
573 }
574 }
575
576 let results = context
577 .query_result_iterator((ConstantProp(99), ExplicitProp2(true)))
578 .collect::<Vec<_>>();
579
580 assert_eq!(results.len(), 10);
581 }
582
583 #[test]
585 fn test_derived_not_indexed_initial_source_no() {
586 let mut context = Context::new();
587 context.index_property::<Person, ExplicitProp2>();
588
589 for i in 0..100 {
590 context
591 .add_entity((ExplicitProp(i as u8), ExplicitProp2(i < 50)))
592 .unwrap();
593 }
594
595 let results = context
596 .query_result_iterator((ExplicitProp2(true), DerivedProp(true)))
597 .collect::<Vec<_>>();
598
599 assert_eq!(results.len(), 25);
602 }
603
604 #[test]
606 fn test_derived_indexed_initial_source_no() {
607 let mut context = Context::new();
608 context.index_property::<Person, DerivedProp>();
609 context.index_property::<Person, ExplicitProp2>();
610
611 for i in 0..100 {
612 context
613 .add_entity((ExplicitProp(i as u8), ExplicitProp2(i < 30)))
614 .unwrap();
615 }
616
617 let results = context
618 .query_result_iterator((ExplicitProp2(true), DerivedProp(false)))
619 .collect::<Vec<_>>();
620
621 assert_eq!(results.len(), 15);
624 }
625
626 #[test]
629 fn test_multiple_query_result_iterators() {
630 let mut context = Context::new();
631 context.index_property::<Person, Age>();
632
633 for age in 0..100 {
634 context
635 .add_entity((
636 Age(age),
637 ExplicitProp(age.wrapping_mul(7) % 100),
638 ExplicitProp2(false),
639 ))
640 .unwrap();
641 }
642 for age in 0..100 {
643 context
644 .add_entity((
645 Age(age),
646 ExplicitProp(age.wrapping_mul(14) % 100),
647 ExplicitProp2(false),
648 ))
649 .unwrap();
650 }
651
652 let results = context.query_result_iterator((Age(25),));
655 let more_results = context.query_result_iterator((Age(25), ExplicitProp(75)));
656
657 let collected_results = results.collect::<IndexSet<_>>();
658 let other_collected_results = more_results.collect::<IndexSet<_>>();
659 let intersection_count = collected_results
660 .intersection(&other_collected_results)
661 .count();
662 assert_eq!(intersection_count, 1);
663 }
664}