1use super::{EntitySetIterator, SourceSet};
15use crate::entity::{Entity, EntityId};
16
17pub struct EntitySet<'a, E: Entity>(EntitySetInner<'a, E>);
19
20pub(super) enum EntitySetInner<'a, E: Entity> {
22 Source(SourceSet<'a, E>),
23 Union(Box<EntitySet<'a, E>>, Box<EntitySet<'a, E>>),
24 Intersection(Vec<EntitySet<'a, E>>),
25 Difference(Box<EntitySet<'a, E>>, Box<EntitySet<'a, E>>),
26}
27
28impl<'a, E: Entity> Default for EntitySet<'a, E> {
29 fn default() -> Self {
30 Self::empty()
31 }
32}
33
34impl<'a, E: Entity> EntitySet<'a, E> {
35 pub(super) fn into_inner(self) -> EntitySetInner<'a, E> {
36 self.0
37 }
38
39 pub(super) fn is_source_leaf(&self) -> bool {
40 matches!(self.0, EntitySetInner::Source(_))
41 }
42
43 pub(super) fn into_source_leaf(self) -> Option<SourceSet<'a, E>> {
44 match self.0 {
45 EntitySetInner::Source(source) => Some(source),
46 _ => None,
47 }
48 }
49
50 pub fn empty() -> Self {
52 EntitySet(EntitySetInner::Source(SourceSet::Empty))
53 }
54
55 pub(crate) fn from_source(source: SourceSet<'a, E>) -> Self {
57 EntitySet(EntitySetInner::Source(source))
58 }
59
60 pub(crate) fn from_intersection_sources(mut sources: Vec<SourceSet<'a, E>>) -> Self {
61 match sources.len() {
62 0 => return Self::empty(),
63 1 => return Self::from_source(sources.pop().unwrap()),
64 _ => {}
65 }
66
67 sources.sort_unstable_by_key(SourceSet::sort_key);
70
71 let sets = sources.into_iter().map(Self::from_source).collect();
72
73 EntitySet(EntitySetInner::Intersection(sets))
74 }
75
76 pub fn union(self, other: Self) -> Self {
77 if self.is_empty() {
79 return other;
80 }
81 if other.is_empty() {
82 return self;
83 }
84 if self.structurally_eq(&other) {
86 return self;
87 }
88 if self.is_universal() {
90 return self;
91 }
92 if other.is_universal() {
93 return other;
94 }
95 if let Some(e) = self.as_singleton() {
97 if other.contains(e) {
98 return other;
99 }
100 }
101 if let Some(e) = other.as_singleton() {
102 if self.contains(e) {
103 return self;
104 }
105 }
106
107 let (left, right) = if self.sort_key() >= other.sort_key() {
109 (self, other)
110 } else {
111 (other, self)
112 };
113 EntitySet(EntitySetInner::Union(Box::new(left), Box::new(right)))
114 }
115
116 pub fn intersection(self, other: Self) -> Self {
117 if self.is_empty() || other.is_empty() {
119 return Self::empty();
120 }
121 if self.structurally_eq(&other) {
123 return self;
124 }
125 if self.is_universal() {
127 return other;
128 }
129 if other.is_universal() {
130 return self;
131 }
132 if let Some(e) = self.as_singleton() {
135 return if other.contains(e) {
136 self
137 } else {
138 Self::empty()
139 };
140 }
141 if let Some(e) = other.as_singleton() {
142 return if self.contains(e) {
143 other
144 } else {
145 Self::empty()
146 };
147 }
148
149 let mut sets = match self {
150 EntitySet(EntitySetInner::Intersection(sets)) => sets,
151 _ => vec![self],
152 };
153
154 sets.push(other);
155 sets.sort_unstable_by_key(EntitySet::sort_key);
158 EntitySet(EntitySetInner::Intersection(sets))
159 }
160
161 pub fn difference(self, other: Self) -> Self {
162 if other.is_empty() {
164 return self;
165 }
166 if self.is_empty() {
168 return Self::empty();
169 }
170 if self.structurally_eq(&other) {
172 return Self::empty();
173 }
174 if other.is_universal() {
176 return Self::empty();
177 }
178 if let Some(e) = self.as_singleton() {
181 return if other.contains(e) {
182 Self::empty()
183 } else {
184 self
185 };
186 }
187 EntitySet(EntitySetInner::Difference(Box::new(self), Box::new(other)))
188 }
189
190 pub fn contains(&self, entity_id: EntityId<E>) -> bool {
192 match self {
193 EntitySet(EntitySetInner::Source(source)) => source.contains(entity_id),
194 EntitySet(EntitySetInner::Union(a, b)) => {
195 a.contains(entity_id) || b.contains(entity_id)
196 }
197 EntitySet(EntitySetInner::Intersection(sets)) => {
198 sets.iter().all(|set| set.contains(entity_id))
199 }
200 EntitySet(EntitySetInner::Difference(a, b)) => {
201 a.contains(entity_id) && !b.contains(entity_id)
202 }
203 }
204 }
205
206 pub fn to_owned_vec(self) -> Vec<EntityId<E>> {
208 self.into_iter().collect()
209 }
210
211 pub fn try_len(&self) -> Option<usize> {
216 match self {
217 EntitySet(EntitySetInner::Source(source)) => source.try_len(),
218 _ => None,
219 }
220 }
221
222 fn is_empty(&self) -> bool {
227 matches!(self, EntitySet(EntitySetInner::Source(SourceSet::Empty)))
228 }
229 fn is_universal(&self) -> bool {
231 matches!(
232 self,
233 EntitySet(EntitySetInner::Source(SourceSet::Population(_)))
234 )
235 }
236 fn as_singleton(&self) -> Option<EntityId<E>> {
238 match self {
239 EntitySet(EntitySetInner::Source(SourceSet::Entity(e))) => Some(*e),
240 _ => None,
241 }
242 }
243
244 fn sort_key(&self) -> (usize, u8) {
245 match self {
246 EntitySet(EntitySetInner::Source(source)) => source.sort_key(),
247 EntitySet(EntitySetInner::Union(left, right)) => {
248 let (left_upper, left_hint) = left.sort_key();
250 let (right_upper, right_hint) = right.sort_key();
251 (
252 left_upper.saturating_add(right_upper),
253 left_hint.min(right_hint),
254 )
255 }
256 EntitySet(EntitySetInner::Intersection(sets)) => {
257 let mut upper = usize::MAX;
258 let mut hint = 0u8;
259 for set in sets {
260 let (set_upper, set_hint) = set.sort_key();
261 upper = upper.min(set_upper);
262 hint = hint.saturating_add(set_hint);
263 }
264 if upper == usize::MAX {
265 upper = 0;
266 }
267 (upper, hint)
268 }
269 EntitySet(EntitySetInner::Difference(left, right)) => {
270 let (left_upper, left_hint) = left.sort_key();
271 let (_, right_hint) = right.sort_key();
272 (left_upper, left_hint.saturating_add(right_hint))
273 }
274 }
275 }
276
277 fn structurally_eq(&self, other: &Self) -> bool {
279 match (self, other) {
280 (EntitySet(EntitySetInner::Source(a)), EntitySet(EntitySetInner::Source(b))) => a == b,
281 (
282 EntitySet(EntitySetInner::Union(a1, a2)),
283 EntitySet(EntitySetInner::Union(b1, b2)),
284 )
285 | (
286 EntitySet(EntitySetInner::Difference(a1, a2)),
287 EntitySet(EntitySetInner::Difference(b1, b2)),
288 ) => a1.structurally_eq(b1) && a2.structurally_eq(b2),
289 (
290 EntitySet(EntitySetInner::Intersection(a_sets)),
291 EntitySet(EntitySetInner::Intersection(b_sets)),
292 ) => {
293 a_sets.len() == b_sets.len()
294 && a_sets
295 .iter()
296 .zip(b_sets.iter())
297 .all(|(a_set, b_set)| a_set.structurally_eq(b_set))
298 }
299 _ => false,
300 }
301 }
302}
303
304impl<'a, E: Entity> IntoIterator for EntitySet<'a, E> {
305 type Item = EntityId<E>;
306 type IntoIter = EntitySetIterator<'a, E>;
307
308 fn into_iter(self) -> Self::IntoIter {
309 EntitySetIterator::new(self)
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use std::cell::RefCell;
316
317 use super::*;
318 use crate::entity::ContextEntitiesExt;
319 use crate::hashing::IndexSet;
320 use crate::{define_entity, define_property, Context};
321
322 define_entity!(Person);
323 define_property!(struct Age(u8), Person);
324
325 fn finite_set(ids: &[usize]) -> RefCell<IndexSet<EntityId<Person>>> {
326 RefCell::new(
327 ids.iter()
328 .copied()
329 .map(EntityId::<Person>::new)
330 .collect::<IndexSet<_>>(),
331 )
332 }
333
334 fn as_entity_set(set: &RefCell<IndexSet<EntityId<Person>>>) -> EntitySet<Person> {
335 EntitySet::from_source(SourceSet::IndexSet(set.borrow()))
336 }
337
338 #[test]
339 fn from_source_empty_is_empty() {
340 let es = EntitySet::<Person>::from_source(SourceSet::Empty);
341 assert_eq!(es.sort_key().0, 0);
342 for value in 0..10 {
343 assert!(!es.contains(EntityId::<Person>::new(value)));
344 }
345 }
346
347 #[test]
348 fn from_source_entity_and_population() {
349 let entity =
350 EntitySet::from_source(SourceSet::<Person>::Entity(EntityId::<Person>::new(5)));
351 assert!(entity.contains(EntityId::<Person>::new(5)));
352 assert!(!entity.contains(EntityId::<Person>::new(4)));
353 assert_eq!(entity.sort_key().0, 1);
354
355 let population = EntitySet::from_source(SourceSet::<Person>::Population(3));
356 assert!(population.contains(EntityId::<Person>::new(0)));
357 assert!(population.contains(EntityId::<Person>::new(2)));
358 assert!(!population.contains(EntityId::<Person>::new(3)));
359 assert_eq!(population.sort_key().0, 3);
360 }
361
362 #[test]
363 fn union_algebraic_reductions() {
364 let a = finite_set(&[1, 2, 3]);
365 let e = EntitySet::<Person>::empty();
366 let u = EntitySet::from_source(SourceSet::<Person>::Population(10));
367
368 let a_union_empty = as_entity_set(&a).union(e);
369 assert!(a_union_empty.contains(EntityId::<Person>::new(1)));
370 assert!(!a_union_empty.contains(EntityId::<Person>::new(4)));
371
372 let u_union_a = u.union(as_entity_set(&a));
373 assert!(matches!(
374 u_union_a,
375 EntitySet(EntitySetInner::Source(SourceSet::Population(10)))
376 ));
377 }
378
379 #[test]
380 fn union_entity_absorption() {
381 let a = finite_set(&[1, 2, 3]);
382 let absorbed =
383 EntitySet::from_source(SourceSet::<Person>::Entity(EntityId::<Person>::new(2)))
384 .union(as_entity_set(&a));
385 assert!(absorbed.contains(EntityId::<Person>::new(1)));
386 assert!(absorbed.contains(EntityId::<Person>::new(2)));
387 assert!(absorbed.contains(EntityId::<Person>::new(3)));
388
389 let b = finite_set(&[1, 2, 3]);
390 let not_absorbed =
391 EntitySet::from_source(SourceSet::<Person>::Entity(EntityId::<Person>::new(8)))
392 .union(as_entity_set(&b));
393 assert!(not_absorbed.contains(EntityId::<Person>::new(8)));
394 assert!(not_absorbed.contains(EntityId::<Person>::new(1)));
395 }
396
397 #[test]
398 fn intersection_algebraic_reductions() {
399 let a = finite_set(&[1, 2, 3]);
400 let u = EntitySet::from_source(SourceSet::<Person>::Population(10));
401
402 let a_inter_u = as_entity_set(&a).intersection(u);
403 assert!(a_inter_u.contains(EntityId::<Person>::new(1)));
404 assert!(a_inter_u.contains(EntityId::<Person>::new(2)));
405 assert!(!a_inter_u.contains(EntityId::<Person>::new(9)));
406
407 let b = finite_set(&[1, 2, 3]);
408 let present =
409 EntitySet::from_source(SourceSet::<Person>::Entity(EntityId::<Person>::new(2)))
410 .intersection(as_entity_set(&b));
411 assert!(matches!(
412 present,
413 EntitySet(EntitySetInner::Source(SourceSet::Entity(_)))
414 ));
415
416 let c = finite_set(&[1, 2, 3]);
417 let absent =
418 EntitySet::from_source(SourceSet::<Person>::Entity(EntityId::<Person>::new(7)))
419 .intersection(as_entity_set(&c));
420 assert!(!absent.contains(EntityId::<Person>::new(7)));
421 }
422
423 #[test]
424 fn difference_algebraic_reductions() {
425 let a = finite_set(&[1, 2, 3]);
426
427 let minus_empty = as_entity_set(&a).difference(EntitySet::empty());
428 assert!(minus_empty.contains(EntityId::<Person>::new(1)));
429 assert!(!minus_empty.contains(EntityId::<Person>::new(9)));
430
431 let minus_universe =
432 as_entity_set(&a)
433 .difference(EntitySet::from_source(SourceSet::<Person>::Population(10)));
434 for value in 0..10 {
435 assert!(!minus_universe.contains(EntityId::<Person>::new(value)));
436 }
437
438 let b = finite_set(&[1, 2, 3]);
439 let singleton_absent =
440 EntitySet::from_source(SourceSet::<Person>::Entity(EntityId::<Person>::new(8)))
441 .difference(as_entity_set(&b));
442 assert!(singleton_absent.contains(EntityId::<Person>::new(8)));
443
444 let c = finite_set(&[1, 2, 3]);
445 let singleton_present =
446 EntitySet::from_source(SourceSet::<Person>::Entity(EntityId::<Person>::new(2)))
447 .difference(as_entity_set(&c));
448 assert!(!singleton_present.contains(EntityId::<Person>::new(2)));
449 }
450
451 #[test]
452 fn difference_is_not_commutative() {
453 let a = finite_set(&[1, 2, 3]);
454 let b = finite_set(&[2, 3, 4]);
455
456 let d1 = as_entity_set(&a).difference(as_entity_set(&b));
457 let c = finite_set(&[2, 3, 4]);
458 let d = finite_set(&[1, 2, 3]);
459 let d2 = as_entity_set(&c).difference(as_entity_set(&d));
460
461 assert!(d1.contains(EntityId::<Person>::new(1)));
462 assert!(!d1.contains(EntityId::<Person>::new(4)));
463 assert!(d2.contains(EntityId::<Person>::new(4)));
464 assert!(!d2.contains(EntityId::<Person>::new(1)));
465 }
466
467 #[test]
468 fn sort_key_rules() {
469 let a = finite_set(&[1, 2]);
470 let b = finite_set(&[2, 3, 4]);
471
472 let union = as_entity_set(&a).union(as_entity_set(&b));
473 assert_eq!(union.sort_key(), (a.borrow().len() + b.borrow().len(), 3));
474
475 let intersection = as_entity_set(&a).intersection(as_entity_set(&b));
476 assert_eq!(
477 intersection.sort_key(),
478 (a.borrow().len().min(b.borrow().len()), 6)
479 );
480
481 let difference = as_entity_set(&a).difference(as_entity_set(&b));
482 assert_eq!(difference.sort_key(), (a.borrow().len(), 6));
483 }
484
485 #[test]
486 fn compound_expressions_membership() {
487 let a = finite_set(&[1, 2, 3, 4]);
488 let b = finite_set(&[3, 4, 5]);
489 let c = finite_set(&[10, 20]);
490 let d = finite_set(&[20]);
491
492 let union_of_intersections = as_entity_set(&a)
493 .intersection(as_entity_set(&b))
494 .union(as_entity_set(&c).intersection(as_entity_set(&d)));
495 assert!(union_of_intersections.contains(EntityId::<Person>::new(3)));
496 assert!(union_of_intersections.contains(EntityId::<Person>::new(4)));
497 assert!(union_of_intersections.contains(EntityId::<Person>::new(20)));
498 assert!(!union_of_intersections.contains(EntityId::<Person>::new(5)));
499
500 let a2 = finite_set(&[1, 2, 3]);
501 let b2 = finite_set(&[3, 4, 5]);
502 let a3 = finite_set(&[1, 2, 3]);
503 let law = as_entity_set(&a3).intersection(as_entity_set(&a2).union(as_entity_set(&b2)));
504 assert!(law.contains(EntityId::<Person>::new(1)));
505 assert!(law.contains(EntityId::<Person>::new(2)));
506 assert!(law.contains(EntityId::<Person>::new(3)));
507 assert!(!law.contains(EntityId::<Person>::new(4)));
508 }
509
510 #[test]
511 fn population_zero_is_empty() {
512 let es = EntitySet::from_source(SourceSet::<Person>::Population(0));
513 assert_eq!(es.sort_key().0, 0);
514 assert!(!es.contains(EntityId::<Person>::new(0)));
515 }
516
517 #[test]
518 fn try_len_known_only_for_non_property_sources() {
519 let empty = EntitySet::<Person>::from_source(SourceSet::Empty);
520 assert_eq!(empty.try_len(), Some(0));
521
522 let singleton = EntitySet::<Person>::from_source(SourceSet::Entity(EntityId::new(42)));
523 assert_eq!(singleton.try_len(), Some(1));
524
525 let population = EntitySet::<Person>::from_source(SourceSet::Population(5));
526 assert_eq!(population.try_len(), Some(5));
527
528 let index_data = RefCell::new(
529 [EntityId::new(1), EntityId::new(2), EntityId::new(3)]
530 .into_iter()
531 .collect::<IndexSet<_>>(),
532 );
533 let indexed = EntitySet::<Person>::from_source(SourceSet::IndexSet(index_data.borrow()));
534 assert_eq!(indexed.try_len(), Some(3));
535
536 let mut context = Context::new();
537 context.add_entity((Age(10),)).unwrap();
538 let property_source = SourceSet::<Person>::new(Age(10), &context).unwrap();
539 assert!(matches!(property_source, SourceSet::PropertySet(_)));
540 let property_set = EntitySet::<Person>::from_source(property_source);
541 assert_eq!(property_set.try_len(), None);
542
543 let composed = EntitySet::<Person>::from_source(SourceSet::Population(3))
544 .difference(EntitySet::from_source(SourceSet::Entity(EntityId::new(1))));
545 assert_eq!(composed.try_len(), None);
546 }
547}