1pub mod edge;
10mod network;
11mod network_store;
12
13use std::any::Any;
14use std::cell::OnceCell;
15
16pub use edge::{Edge, EdgeType};
17use network::{AdjacencyList, Network};
18use network_store::NetworkStore;
19use rand::Rng;
20
21use crate::context::Context;
22use crate::entity::entity_store::get_registered_entity_count;
23use crate::entity::{Entity, EntityId};
24use crate::error::IxaError;
25use crate::random::{ContextRandomExt, RngId};
26use crate::{define_data_plugin, ContextBase};
27
28pub struct NetworkData {
29 network_stores: Vec<OnceCell<Box<dyn Any>>>,
31}
32
33impl Default for NetworkData {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl NetworkData {
40 #[must_use]
41 pub fn new() -> Self {
42 let entity_count = get_registered_entity_count();
43 let network_stores = (0..entity_count)
44 .map(|_| OnceCell::new())
45 .collect::<Vec<_>>();
46
47 NetworkData { network_stores }
48 }
49
50 pub fn add_edge<E: Entity, ET: EdgeType<E>>(
53 &mut self,
54 entity_id: EntityId<E>,
55 neighbor: EntityId<E>,
56 weight: f32,
57 inner: ET,
58 ) -> Result<(), IxaError> {
59 if entity_id == neighbor {
60 return Err(IxaError::CannotMakeEdgeToSelf);
61 }
62 if weight.is_infinite() || weight.is_nan() || weight.is_sign_negative() {
63 return Err(IxaError::InvalidWeight);
64 }
65
66 let edge = Edge {
67 neighbor,
68 weight,
69 inner,
70 };
71 let network = self.get_network_mut::<E, ET>();
72
73 network.add_edge(entity_id, edge)
74 }
75
76 pub fn remove_edge<E: Entity, ET: EdgeType<E>>(
79 &mut self,
80 entity_id: EntityId<E>,
81 neighbor: EntityId<E>,
82 ) -> Option<Edge<E, ET>> {
83 let network = self.get_network_mut::<E, ET>();
84 network.remove_edge(entity_id, neighbor)
85 }
86
87 #[must_use]
90 pub fn get_edge<E: Entity, ET: EdgeType<E>>(
91 &self,
92 entity_id: EntityId<E>,
93 neighbor: EntityId<E>,
94 ) -> Option<&Edge<E, ET>> {
95 let network = self.get_network::<E, ET>();
96 network.get_edge(entity_id, neighbor)
97 }
98
99 #[must_use]
102 pub fn get_edges<E: Entity, ET: EdgeType<E>>(
103 &self,
104 entity_id: EntityId<E>,
105 ) -> AdjacencyList<E, ET> {
106 let network = self.get_network::<E, ET>();
107 network.get_list_cloned(entity_id)
108 }
109
110 #[must_use]
113 pub fn get_edges_ref<E: Entity, ET: EdgeType<E>>(
114 &self,
115 entity_id: EntityId<E>,
116 ) -> Option<&AdjacencyList<E, ET>> {
117 let network = self.get_network::<E, ET>();
118 network.get_list(entity_id)
119 }
120
121 #[must_use]
123 pub fn find_entities_by_degree<E: Entity, ET: EdgeType<E>>(
124 &self,
125 degree: usize,
126 ) -> Vec<EntityId<E>> {
127 let network = self.get_network::<E, ET>();
128 network.find_entities_by_degree(degree)
129 }
130
131 #[must_use]
132 fn get_network<E: Entity, ET: EdgeType<E>>(&self) -> &Network<E, ET> {
133 self.network_stores
134 .get(E::id())
135 .unwrap_or_else(|| {
136 panic!(
137 "internal error: NetworkStore for Entity {} not found",
138 E::name()
139 )
140 })
141 .get_or_init(NetworkStore::<E>::new_boxed)
142 .downcast_ref::<NetworkStore<E>>()
143 .unwrap_or_else(|| {
144 panic!(
145 "internal error: found wrong NetworkStore type when accessing Entity {}",
146 E::name()
147 )
148 })
149 .get::<ET>()
150 }
151
152 #[must_use]
153 fn get_network_mut<E: Entity, ET: EdgeType<E>>(&mut self) -> &mut Network<E, ET> {
154 let network_store = self.network_stores.get_mut(E::id()).unwrap_or_else(|| {
155 panic!(
156 "internal error: NetworkStore for Entity {} not found",
157 E::name()
158 )
159 });
160
161 if network_store.get().is_none() {
163 network_store.set(NetworkStore::<E>::new_boxed()).unwrap();
164 }
165
166 network_store
168 .get_mut()
169 .unwrap()
170 .downcast_mut::<NetworkStore<E>>()
171 .unwrap_or_else(|| {
172 panic!(
173 "internal error: found wrong NetworkStore type when accessing Entity {}",
174 E::name()
175 )
176 })
177 .get_mut::<ET>()
178 }
179}
180
181define_data_plugin!(NetworkPlugin, NetworkData, NetworkData::new());
182
183pub trait ContextNetworkExt: ContextBase + ContextRandomExt {
185 fn add_edge<E: Entity, ET: EdgeType<E>>(
196 &mut self,
197 entity_id: EntityId<E>,
198 neighbor: EntityId<E>,
199 weight: f32,
200 inner: ET,
201 ) -> Result<(), IxaError> {
202 let data_container = self.get_data_mut(NetworkPlugin);
203 data_container.add_edge::<E, ET>(entity_id, neighbor, weight, inner)
204 }
205
206 fn add_edge_bidi<E: Entity, ET: EdgeType<E>>(
218 &mut self,
219 entity1: EntityId<E>,
220 entity2: EntityId<E>,
221 weight: f32,
222 inner: ET,
223 ) -> Result<(), IxaError> {
224 let data_container = self.get_data_mut(NetworkPlugin);
225 data_container.add_edge::<E, ET>(entity1, entity2, weight, inner.clone())?;
226 data_container.add_edge::<E, ET>(entity2, entity1, weight, inner)
227 }
228
229 fn remove_edge<E: Entity, ET: EdgeType<E>>(
232 &mut self,
233 entity_id: EntityId<E>,
234 neighbor: EntityId<E>,
235 ) -> Option<Edge<E, ET>> {
236 let data_container = self.get_data_mut(NetworkPlugin);
237 data_container.remove_edge::<E, ET>(entity_id, neighbor)
238 }
239
240 #[must_use]
242 fn get_edge<E: Entity, ET: EdgeType<E>>(
243 &self,
244 entity_id: EntityId<E>,
245 neighbor: EntityId<E>,
246 ) -> Option<&Edge<E, ET>> {
247 self.get_data(NetworkPlugin)
248 .get_edge::<E, ET>(entity_id, neighbor)
249 }
250
251 #[must_use]
253 fn get_edges<E: Entity, ET: EdgeType<E>>(
254 &self,
255 entity_id: EntityId<E>,
256 ) -> AdjacencyList<E, ET> {
257 self.get_data(NetworkPlugin).get_edges::<E, ET>(entity_id)
258 }
259
260 #[must_use]
265 fn get_matching_edges<E: Entity, ET: EdgeType<E>>(
266 &self,
267 entity_id: EntityId<E>,
268 filter: impl Fn(&Self, &Edge<E, ET>) -> bool,
269 ) -> AdjacencyList<E, ET> {
270 let network_data = self.get_data(NetworkPlugin);
271 let empty = vec![];
272 let edges = network_data
273 .get_edges_ref::<E, ET>(entity_id)
274 .unwrap_or(&empty);
275 edges
276 .iter()
277 .filter(|&edge| filter(self, edge))
278 .cloned()
279 .collect()
280 }
281
282 #[must_use]
284 fn find_entities_by_degree<E: Entity, ET: EdgeType<E>>(
285 &self,
286 degree: usize,
287 ) -> Vec<EntityId<E>> {
288 self.get_data(NetworkPlugin)
289 .find_entities_by_degree::<E, ET>(degree)
290 }
291
292 fn select_random_edge<E: Entity, ET: EdgeType<E>, R: RngId + 'static>(
297 &self,
298 rng_id: R,
299 entity_id: EntityId<E>,
300 ) -> Result<Edge<E, ET>, IxaError>
301 where
302 R::RngType: Rng,
303 {
304 let edges = self.get_edges::<E, ET>(entity_id);
305 if edges.is_empty() {
306 return Err(IxaError::CannotSampleFromEmptyList);
307 }
308
309 let weights: Vec<_> = edges.iter().map(|x| x.weight).collect();
310 let index = self.sample_weighted(rng_id, &weights);
311 Ok(edges[index].clone())
312 }
313}
314
315impl ContextNetworkExt for Context {}
316
317#[cfg(test)]
318#[allow(clippy::float_cmp)]
319mod test_inner {
321 use super::NetworkData;
322 use crate::error::IxaError;
323 use crate::network::edge::Edge;
324 use crate::{define_edge_type, define_entity};
325
326 define_entity!(Person);
327
328 define_edge_type!(struct EdgeType1, Person);
329 define_edge_type!(struct EdgeType2, Person);
330 define_edge_type!(struct EdgeType3(pub bool), Person);
331
332 #[test]
333 fn add_edge() {
334 let mut nd = NetworkData::new();
335
336 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), 0.01, EdgeType1)
337 .unwrap();
338 let edge = nd
339 .get_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2))
340 .unwrap();
341 assert_eq!(edge.weight, 0.01);
342 }
343
344 #[test]
345 fn add_edge_with_inner() {
346 let mut nd = NetworkData::new();
347
348 nd.add_edge::<Person, EdgeType3>(PersonId::new(1), PersonId::new(2), 0.01, EdgeType3(true))
349 .unwrap();
350 let edge = nd
351 .get_edge::<Person, EdgeType3>(PersonId::new(1), PersonId::new(2))
352 .unwrap();
353 assert_eq!(edge.weight, 0.01);
354 assert_eq!(edge.inner, EdgeType3(true));
355 }
356
357 #[test]
358 fn add_two_edges() {
359 let mut nd = NetworkData::new();
360
361 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), 0.01, EdgeType1)
362 .unwrap();
363 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(3), 0.02, EdgeType1)
364 .unwrap();
365 let edge = nd
366 .get_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2))
367 .unwrap();
368 assert_eq!(edge.weight, 0.01);
369 let edge = nd
370 .get_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(3))
371 .unwrap();
372 assert_eq!(edge.weight, 0.02);
373
374 let edges = nd.get_edges::<Person, EdgeType1>(PersonId::new(1));
375 assert_eq!(
376 edges,
377 vec![
378 Edge {
379 neighbor: PersonId::new(2),
380 weight: 0.01,
381 inner: EdgeType1
382 },
383 Edge {
384 neighbor: PersonId::new(3),
385 weight: 0.02,
386 inner: EdgeType1
387 }
388 ]
389 );
390 }
391
392 #[test]
393 fn add_two_edge_types() {
394 let mut nd = NetworkData::new();
395
396 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), 0.01, EdgeType1)
397 .unwrap();
398 nd.add_edge::<Person, EdgeType2>(PersonId::new(1), PersonId::new(2), 0.02, EdgeType2)
399 .unwrap();
400 let edge = nd
401 .get_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2))
402 .unwrap();
403 assert_eq!(edge.weight, 0.01);
404 let edge = nd
405 .get_edge::<Person, EdgeType2>(PersonId::new(1), PersonId::new(2))
406 .unwrap();
407 assert_eq!(edge.weight, 0.02);
408
409 let edges = nd.get_edges::<Person, EdgeType1>(PersonId::new(1));
410 assert_eq!(
411 edges,
412 vec![Edge {
413 neighbor: PersonId::new(2),
414 weight: 0.01,
415 inner: EdgeType1
416 }]
417 );
418 }
419
420 #[test]
421 fn add_edge_twice_fails() {
422 let mut nd = NetworkData::new();
423
424 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), 0.01, EdgeType1)
425 .unwrap();
426 let edge = nd
427 .get_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2))
428 .unwrap();
429 assert_eq!(edge.weight, 0.01);
430
431 assert!(matches!(
432 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), 0.02, EdgeType1),
433 Err(IxaError::EdgeAlreadyExists)
434 ));
435 }
436
437 #[test]
438 fn add_remove_add_edge() {
439 let mut nd = NetworkData::new();
440
441 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), 0.01, EdgeType1)
442 .unwrap();
443 let edge = nd
444 .get_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2))
445 .unwrap();
446 assert_eq!(edge.weight, 0.01);
447
448 nd.remove_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2))
449 .unwrap();
450 let edge = nd.get_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2));
451 assert!(edge.is_none());
452
453 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), 0.02, EdgeType1)
454 .unwrap();
455 let edge = nd
456 .get_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2))
457 .unwrap();
458 assert_eq!(edge.weight, 0.02);
459 }
460
461 #[test]
462 fn remove_nonexistent_edge() {
463 let mut nd = NetworkData::new();
464 assert!(nd
465 .remove_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2))
466 .is_none());
467 }
468
469 #[test]
470 fn add_edge_to_self() {
471 let mut nd = NetworkData::new();
472
473 let result =
474 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(1), 0.01, EdgeType1);
475 assert!(matches!(result, Err(IxaError::CannotMakeEdgeToSelf)));
476 }
477
478 #[test]
479 fn add_edge_bogus_weight() {
480 let mut nd = NetworkData::new();
481
482 let result =
483 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), -1.0, EdgeType1);
484 assert!(matches!(result, Err(IxaError::InvalidWeight)));
485
486 let result = nd.add_edge::<Person, EdgeType1>(
487 PersonId::new(1),
488 PersonId::new(2),
489 f32::NAN,
490 EdgeType1,
491 );
492 assert!(matches!(result, Err(IxaError::InvalidWeight)));
493
494 let result = nd.add_edge::<Person, EdgeType1>(
495 PersonId::new(1),
496 PersonId::new(2),
497 f32::INFINITY,
498 EdgeType1,
499 );
500 assert!(matches!(result, Err(IxaError::InvalidWeight)));
501 }
502
503 #[test]
504 fn find_people_by_degree() {
505 let mut nd = NetworkData::new();
506
507 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), 0.0, EdgeType1)
508 .unwrap();
509 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(3), 0.0, EdgeType1)
510 .unwrap();
511 nd.add_edge::<Person, EdgeType1>(PersonId::new(2), PersonId::new(3), 0.0, EdgeType1)
512 .unwrap();
513 nd.add_edge::<Person, EdgeType1>(PersonId::new(3), PersonId::new(2), 0.0, EdgeType1)
514 .unwrap();
515
516 let matches = nd.find_entities_by_degree::<Person, EdgeType1>(2);
517 assert_eq!(matches, vec![PersonId::new(1)]);
518 let matches = nd.find_entities_by_degree::<Person, EdgeType1>(1);
519 assert_eq!(matches, vec![PersonId::new(2), PersonId::new(3)]);
520 }
521}
522
523#[cfg(test)]
524#[allow(clippy::float_cmp)]
525mod test_api {
527 use crate::context::Context;
528 use crate::network::edge::Edge;
529 use crate::network::ContextNetworkExt;
530 use crate::prelude::*;
531 use crate::random::ContextRandomExt;
532 use crate::{define_edge_type, define_entity, define_property, define_rng};
533
534 define_entity!(Person);
535
536 define_edge_type!(struct EdgeType1(pub u32), Person);
537 define_property!(struct Age(u8), Person);
538
539 fn setup() -> (Context, PersonId, PersonId) {
540 let mut context = Context::new();
541 let person1 = context.add_entity((Age(1),)).unwrap();
542 let person2 = context.add_entity((Age(2),)).unwrap();
543
544 (context, person1, person2)
545 }
546
547 #[test]
548 fn add_edge() {
549 let (mut context, person1, person2) = setup();
550
551 context
552 .add_edge::<Person, EdgeType1>(person1, person2, 0.01, EdgeType1(1))
553 .unwrap();
554 assert_eq!(
555 context
556 .get_edge::<Person, EdgeType1>(person1, person2)
557 .unwrap()
558 .weight,
559 0.01
560 );
561 assert_eq!(
562 context.get_edges::<Person, EdgeType1>(person1),
563 vec![Edge {
564 neighbor: person2,
565 weight: 0.01,
566 inner: EdgeType1(1)
567 }]
568 );
569 }
570
571 #[test]
572 fn remove_edge() {
573 let (mut context, person1, person2) = setup();
574 assert!(context
575 .remove_edge::<Person, EdgeType1>(person1, person2)
576 .is_none());
577 context
578 .add_edge::<Person, EdgeType1>(person1, person2, 0.01, EdgeType1(1))
579 .unwrap();
580 assert!(context
581 .remove_edge::<Person, EdgeType1>(person1, person2)
582 .is_some());
583 assert!(context
584 .get_edge::<Person, EdgeType1>(person1, person2)
585 .is_none());
586 assert_eq!(context.get_edges::<Person, EdgeType1>(person1).len(), 0);
587 }
588
589 #[test]
590 fn add_edge_bidi() {
591 let (mut context, person1, person2) = setup();
592
593 context
594 .add_edge_bidi::<Person, EdgeType1>(person1, person2, 0.01, EdgeType1(1))
595 .unwrap();
596 assert_eq!(
597 context
598 .get_edge::<Person, EdgeType1>(person1, person2)
599 .unwrap()
600 .weight,
601 0.01
602 );
603 assert_eq!(
604 context
605 .get_edge::<Person, EdgeType1>(person2, person1)
606 .unwrap()
607 .weight,
608 0.01
609 );
610 }
611
612 #[test]
613 fn add_edge_different_weights() {
614 let (mut context, person1, person2) = setup();
615
616 context
617 .add_edge::<Person, EdgeType1>(person1, person2, 0.01, EdgeType1(1))
618 .unwrap();
619 context
620 .add_edge::<Person, EdgeType1>(person2, person1, 0.02, EdgeType1(1))
621 .unwrap();
622 assert_eq!(
623 context
624 .get_edge::<Person, EdgeType1>(person1, person2)
625 .unwrap()
626 .weight,
627 0.01
628 );
629 assert_eq!(
630 context
631 .get_edge::<Person, EdgeType1>(person2, person1)
632 .unwrap()
633 .weight,
634 0.02
635 );
636 }
637
638 #[test]
639 fn get_matching_edges_weight() {
640 let (mut context, person1, person2) = setup();
641 let person3 = context.add_entity((Age(3),)).unwrap();
642
643 context
644 .add_edge::<Person, EdgeType1>(person1, person2, 0.01, EdgeType1(1))
645 .unwrap();
646 context
647 .add_edge::<Person, EdgeType1>(person1, person3, 0.03, EdgeType1(1))
648 .unwrap();
649 let edges = context
650 .get_matching_edges::<Person, EdgeType1>(person1, |_context, edge| edge.weight > 0.01);
651 assert_eq!(edges.len(), 1);
652 assert_eq!(edges[0].neighbor, person3);
653 }
654
655 #[test]
656 fn get_matching_edges_inner() {
657 let (mut context, person1, person2) = setup();
658 let person3 = context.add_entity((Age(3),)).unwrap();
659
660 context
661 .add_edge::<Person, EdgeType1>(person1, person2, 0.01, EdgeType1(1))
662 .unwrap();
663 context
664 .add_edge::<Person, EdgeType1>(person1, person3, 0.03, EdgeType1(3))
665 .unwrap();
666 let edges = context
667 .get_matching_edges::<Person, EdgeType1>(person1, |_context, edge| edge.inner.0 == 3);
668 assert_eq!(edges.len(), 1);
669 assert_eq!(edges[0].neighbor, person3);
670 }
671
672 #[test]
673 fn get_matching_edges_person_property() {
674 let (mut context, person1, person2) = setup();
675 let person3 = context.add_entity((Age(3),)).unwrap();
676
677 context
678 .add_edge::<Person, EdgeType1>(person1, person2, 0.01, EdgeType1(1))
679 .unwrap();
680 context
681 .add_edge::<Person, EdgeType1>(person1, person3, 0.03, EdgeType1(3))
682 .unwrap();
683 let edges = context.get_matching_edges::<Person, EdgeType1>(person1, |context, edge| {
684 context.match_entity(edge.neighbor, (Age(3),))
685 });
686 assert_eq!(edges.len(), 1);
687 assert_eq!(edges[0].neighbor, person3);
688 }
689
690 #[test]
691 fn select_random_edge() {
692 define_rng!(NetworkTestRng);
693
694 let (mut context, person1, person2) = setup();
695 let person3 = context.add_entity((Age(3),)).unwrap();
696 context.init_random(42);
697
698 context
699 .add_edge::<Person, EdgeType1>(person1, person2, 0.01, EdgeType1(1))
700 .unwrap();
701 context
702 .add_edge::<Person, EdgeType1>(person1, person3, 10_000_000.0, EdgeType1(3))
703 .unwrap();
704
705 let edge = context
706 .select_random_edge::<Person, EdgeType1, _>(NetworkTestRng, person1)
707 .unwrap();
708 assert_eq!(edge.neighbor, person3);
709 assert_eq!(edge.inner, EdgeType1(3));
710 }
711}