1pub mod edge;
10mod network;
11mod network_store;
12
13use std::any::Any;
14use std::cell::OnceCell;
15
16pub use edge::{Edge, EdgeType};
17use network::{AdjacencyList, Network};
18use network_store::NetworkStore;
19use rand::Rng;
20
21use crate::context::Context;
22use crate::entity::entity_store::get_registered_entity_count;
23use crate::entity::{Entity, EntityId};
24use crate::error::IxaError;
25use crate::random::{ContextRandomExt, RngId};
26use crate::{define_data_plugin, ContextBase};
27
28pub struct NetworkData {
29 network_stores: Vec<OnceCell<Box<dyn Any>>>,
31}
32
33impl Default for NetworkData {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl NetworkData {
40 #[must_use]
41 pub fn new() -> Self {
42 let entity_count = get_registered_entity_count();
43 let network_stores = (0..entity_count)
44 .map(|_| OnceCell::new())
45 .collect::<Vec<_>>();
46
47 NetworkData { network_stores }
48 }
49
50 pub fn add_edge<E: Entity, ET: EdgeType<E>>(
53 &mut self,
54 entity_id: EntityId<E>,
55 neighbor: EntityId<E>,
56 weight: f32,
57 inner: ET,
58 ) -> Result<(), IxaError> {
59 if entity_id == neighbor {
60 return Err(IxaError::IxaError(String::from("Cannot make edge to self")));
61 }
62 if weight.is_infinite() || weight.is_nan() || weight.is_sign_negative() {
63 return Err(IxaError::IxaError(String::from("Invalid weight")));
64 }
65
66 let edge = Edge {
67 neighbor,
68 weight,
69 inner,
70 };
71 let network = self.get_network_mut::<E, ET>();
72
73 network.add_edge(entity_id, edge)
74 }
75
76 pub fn remove_edge<E: Entity, ET: EdgeType<E>>(
79 &mut self,
80 entity_id: EntityId<E>,
81 neighbor: EntityId<E>,
82 ) -> Option<Edge<E, ET>> {
83 let network = self.get_network_mut::<E, ET>();
84 network.remove_edge(entity_id, neighbor)
85 }
86
87 #[must_use]
90 pub fn get_edge<E: Entity, ET: EdgeType<E>>(
91 &self,
92 entity_id: EntityId<E>,
93 neighbor: EntityId<E>,
94 ) -> Option<&Edge<E, ET>> {
95 let network = self.get_network::<E, ET>();
96 network.get_edge(entity_id, neighbor)
97 }
98
99 #[must_use]
102 pub fn get_edges<E: Entity, ET: EdgeType<E>>(
103 &self,
104 entity_id: EntityId<E>,
105 ) -> AdjacencyList<E, ET> {
106 let network = self.get_network::<E, ET>();
107 network.get_list_cloned(entity_id)
108 }
109
110 #[must_use]
113 pub fn get_edges_ref<E: Entity, ET: EdgeType<E>>(
114 &self,
115 entity_id: EntityId<E>,
116 ) -> Option<&AdjacencyList<E, ET>> {
117 let network = self.get_network::<E, ET>();
118 network.get_list(entity_id)
119 }
120
121 #[must_use]
123 pub fn find_entities_by_degree<E: Entity, ET: EdgeType<E>>(
124 &self,
125 degree: usize,
126 ) -> Vec<EntityId<E>> {
127 let network = self.get_network::<E, ET>();
128 network.find_entities_by_degree(degree)
129 }
130
131 #[must_use]
132 fn get_network<E: Entity, ET: EdgeType<E>>(&self) -> &Network<E, ET> {
133 self.network_stores
134 .get(E::id())
135 .unwrap_or_else(|| {
136 panic!(
137 "internal error: NetworkStore for Entity {} not found",
138 E::name()
139 )
140 })
141 .get_or_init(NetworkStore::<E>::new_boxed)
142 .downcast_ref::<NetworkStore<E>>()
143 .unwrap_or_else(|| {
144 panic!(
145 "internal error: found wrong NetworkStore type when accessing Entity {}",
146 E::name()
147 )
148 })
149 .get::<ET>()
150 }
151
152 #[must_use]
153 fn get_network_mut<E: Entity, ET: EdgeType<E>>(&mut self) -> &mut Network<E, ET> {
154 let network_store = self.network_stores.get_mut(E::id()).unwrap_or_else(|| {
155 panic!(
156 "internal error: NetworkStore for Entity {} not found",
157 E::name()
158 )
159 });
160
161 if network_store.get().is_none() {
163 network_store.set(NetworkStore::<E>::new_boxed()).unwrap();
164 }
165
166 network_store
168 .get_mut()
169 .unwrap()
170 .downcast_mut::<NetworkStore<E>>()
171 .unwrap_or_else(|| {
172 panic!(
173 "internal error: found wrong NetworkStore type when accessing Entity {}",
174 E::name()
175 )
176 })
177 .get_mut::<ET>()
178 }
179}
180
181define_data_plugin!(NetworkPlugin, NetworkData, NetworkData::new());
182
183pub trait ContextNetworkExt: ContextBase + ContextRandomExt {
185 fn add_edge<E: Entity, ET: EdgeType<E>>(
196 &mut self,
197 entity_id: EntityId<E>,
198 neighbor: EntityId<E>,
199 weight: f32,
200 inner: ET,
201 ) -> Result<(), IxaError> {
202 let data_container = self.get_data_mut(NetworkPlugin);
203 data_container.add_edge::<E, ET>(entity_id, neighbor, weight, inner)
204 }
205
206 fn add_edge_bidi<E: Entity, ET: EdgeType<E>>(
218 &mut self,
219 entity1: EntityId<E>,
220 entity2: EntityId<E>,
221 weight: f32,
222 inner: ET,
223 ) -> Result<(), IxaError> {
224 let data_container = self.get_data_mut(NetworkPlugin);
225 data_container.add_edge::<E, ET>(entity1, entity2, weight, inner.clone())?;
226 data_container.add_edge::<E, ET>(entity2, entity1, weight, inner)
227 }
228
229 fn remove_edge<E: Entity, ET: EdgeType<E>>(
232 &mut self,
233 entity_id: EntityId<E>,
234 neighbor: EntityId<E>,
235 ) -> Option<Edge<E, ET>> {
236 let data_container = self.get_data_mut(NetworkPlugin);
237 data_container.remove_edge::<E, ET>(entity_id, neighbor)
238 }
239
240 #[must_use]
242 fn get_edge<E: Entity, ET: EdgeType<E>>(
243 &self,
244 entity_id: EntityId<E>,
245 neighbor: EntityId<E>,
246 ) -> Option<&Edge<E, ET>> {
247 self.get_data(NetworkPlugin)
248 .get_edge::<E, ET>(entity_id, neighbor)
249 }
250
251 #[must_use]
253 fn get_edges<E: Entity, ET: EdgeType<E>>(
254 &self,
255 entity_id: EntityId<E>,
256 ) -> AdjacencyList<E, ET> {
257 self.get_data(NetworkPlugin).get_edges::<E, ET>(entity_id)
258 }
259
260 #[must_use]
265 fn get_matching_edges<E: Entity, ET: EdgeType<E>>(
266 &self,
267 entity_id: EntityId<E>,
268 filter: impl Fn(&Self, &Edge<E, ET>) -> bool,
269 ) -> AdjacencyList<E, ET> {
270 let network_data = self.get_data(NetworkPlugin);
271 let empty = vec![];
272 let edges = network_data
273 .get_edges_ref::<E, ET>(entity_id)
274 .unwrap_or(&empty);
275 edges
276 .iter()
277 .filter(|&edge| filter(self, edge))
278 .cloned()
279 .collect()
280 }
281
282 #[must_use]
284 fn find_entities_by_degree<E: Entity, ET: EdgeType<E>>(
285 &self,
286 degree: usize,
287 ) -> Vec<EntityId<E>> {
288 self.get_data(NetworkPlugin)
289 .find_entities_by_degree::<E, ET>(degree)
290 }
291
292 fn select_random_edge<E: Entity, ET: EdgeType<E>, R: RngId + 'static>(
297 &self,
298 rng_id: R,
299 entity_id: EntityId<E>,
300 ) -> Result<Edge<E, ET>, IxaError>
301 where
302 R::RngType: Rng,
303 {
304 let edges = self.get_edges::<E, ET>(entity_id);
305 if edges.is_empty() {
306 return Err(IxaError::IxaError(String::from(
307 "Can't sample from empty list",
308 )));
309 }
310
311 let weights: Vec<_> = edges.iter().map(|x| x.weight).collect();
312 let index = self.sample_weighted(rng_id, &weights);
313 Ok(edges[index].clone())
314 }
315}
316
317impl ContextNetworkExt for Context {}
318
319#[cfg(test)]
320#[allow(clippy::float_cmp)]
321mod test_inner {
323 use super::NetworkData;
324 use crate::error::IxaError;
325 use crate::network::edge::Edge;
326 use crate::{define_edge_type, define_entity};
327
328 define_entity!(Person);
329
330 define_edge_type!(struct EdgeType1, Person);
331 define_edge_type!(struct EdgeType2, Person);
332 define_edge_type!(struct EdgeType3(pub bool), Person);
333
334 #[test]
335 fn add_edge() {
336 let mut nd = NetworkData::new();
337
338 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), 0.01, EdgeType1)
339 .unwrap();
340 let edge = nd
341 .get_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2))
342 .unwrap();
343 assert_eq!(edge.weight, 0.01);
344 }
345
346 #[test]
347 fn add_edge_with_inner() {
348 let mut nd = NetworkData::new();
349
350 nd.add_edge::<Person, EdgeType3>(PersonId::new(1), PersonId::new(2), 0.01, EdgeType3(true))
351 .unwrap();
352 let edge = nd
353 .get_edge::<Person, EdgeType3>(PersonId::new(1), PersonId::new(2))
354 .unwrap();
355 assert_eq!(edge.weight, 0.01);
356 assert_eq!(edge.inner, EdgeType3(true));
357 }
358
359 #[test]
360 fn add_two_edges() {
361 let mut nd = NetworkData::new();
362
363 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), 0.01, EdgeType1)
364 .unwrap();
365 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(3), 0.02, EdgeType1)
366 .unwrap();
367 let edge = nd
368 .get_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2))
369 .unwrap();
370 assert_eq!(edge.weight, 0.01);
371 let edge = nd
372 .get_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(3))
373 .unwrap();
374 assert_eq!(edge.weight, 0.02);
375
376 let edges = nd.get_edges::<Person, EdgeType1>(PersonId::new(1));
377 assert_eq!(
378 edges,
379 vec![
380 Edge {
381 neighbor: PersonId::new(2),
382 weight: 0.01,
383 inner: EdgeType1
384 },
385 Edge {
386 neighbor: PersonId::new(3),
387 weight: 0.02,
388 inner: EdgeType1
389 }
390 ]
391 );
392 }
393
394 #[test]
395 fn add_two_edge_types() {
396 let mut nd = NetworkData::new();
397
398 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), 0.01, EdgeType1)
399 .unwrap();
400 nd.add_edge::<Person, EdgeType2>(PersonId::new(1), PersonId::new(2), 0.02, EdgeType2)
401 .unwrap();
402 let edge = nd
403 .get_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2))
404 .unwrap();
405 assert_eq!(edge.weight, 0.01);
406 let edge = nd
407 .get_edge::<Person, EdgeType2>(PersonId::new(1), PersonId::new(2))
408 .unwrap();
409 assert_eq!(edge.weight, 0.02);
410
411 let edges = nd.get_edges::<Person, EdgeType1>(PersonId::new(1));
412 assert_eq!(
413 edges,
414 vec![Edge {
415 neighbor: PersonId::new(2),
416 weight: 0.01,
417 inner: EdgeType1
418 }]
419 );
420 }
421
422 #[test]
423 fn add_edge_twice_fails() {
424 let mut nd = NetworkData::new();
425
426 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), 0.01, EdgeType1)
427 .unwrap();
428 let edge = nd
429 .get_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2))
430 .unwrap();
431 assert_eq!(edge.weight, 0.01);
432
433 assert!(matches!(
434 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), 0.02, EdgeType1),
435 Err(IxaError::IxaError(_))
436 ));
437 }
438
439 #[test]
440 fn add_remove_add_edge() {
441 let mut nd = NetworkData::new();
442
443 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), 0.01, EdgeType1)
444 .unwrap();
445 let edge = nd
446 .get_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2))
447 .unwrap();
448 assert_eq!(edge.weight, 0.01);
449
450 nd.remove_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2))
451 .unwrap();
452 let edge = nd.get_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2));
453 assert!(edge.is_none());
454
455 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), 0.02, EdgeType1)
456 .unwrap();
457 let edge = nd
458 .get_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2))
459 .unwrap();
460 assert_eq!(edge.weight, 0.02);
461 }
462
463 #[test]
464 fn remove_nonexistent_edge() {
465 let mut nd = NetworkData::new();
466 assert!(nd
467 .remove_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2))
468 .is_none());
469 }
470
471 #[test]
472 fn add_edge_to_self() {
473 let mut nd = NetworkData::new();
474
475 let result =
476 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(1), 0.01, EdgeType1);
477 assert!(matches!(result, Err(IxaError::IxaError(_))));
478 }
479
480 #[test]
481 fn add_edge_bogus_weight() {
482 let mut nd = NetworkData::new();
483
484 let result =
485 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), -1.0, EdgeType1);
486 assert!(matches!(result, Err(IxaError::IxaError(_))));
487
488 let result = nd.add_edge::<Person, EdgeType1>(
489 PersonId::new(1),
490 PersonId::new(2),
491 f32::NAN,
492 EdgeType1,
493 );
494 assert!(matches!(result, Err(IxaError::IxaError(_))));
495
496 let result = nd.add_edge::<Person, EdgeType1>(
497 PersonId::new(1),
498 PersonId::new(2),
499 f32::INFINITY,
500 EdgeType1,
501 );
502 assert!(matches!(result, Err(IxaError::IxaError(_))));
503 }
504
505 #[test]
506 fn find_people_by_degree() {
507 let mut nd = NetworkData::new();
508
509 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(2), 0.0, EdgeType1)
510 .unwrap();
511 nd.add_edge::<Person, EdgeType1>(PersonId::new(1), PersonId::new(3), 0.0, EdgeType1)
512 .unwrap();
513 nd.add_edge::<Person, EdgeType1>(PersonId::new(2), PersonId::new(3), 0.0, EdgeType1)
514 .unwrap();
515 nd.add_edge::<Person, EdgeType1>(PersonId::new(3), PersonId::new(2), 0.0, EdgeType1)
516 .unwrap();
517
518 let matches = nd.find_entities_by_degree::<Person, EdgeType1>(2);
519 assert_eq!(matches, vec![PersonId::new(1)]);
520 let matches = nd.find_entities_by_degree::<Person, EdgeType1>(1);
521 assert_eq!(matches, vec![PersonId::new(2), PersonId::new(3)]);
522 }
523}
524
525#[cfg(test)]
526#[allow(clippy::float_cmp)]
527mod test_api {
529 use crate::context::Context;
530 use crate::network::edge::Edge;
531 use crate::network::ContextNetworkExt;
532 use crate::prelude::*;
533 use crate::random::ContextRandomExt;
534 use crate::{define_edge_type, define_entity, define_property, define_rng};
535
536 define_entity!(Person);
537
538 define_edge_type!(struct EdgeType1(pub u32), Person);
539 define_property!(struct Age(u8), Person);
540
541 fn setup() -> (Context, PersonId, PersonId) {
542 let mut context = Context::new();
543 let person1 = context.add_entity((Age(1),)).unwrap();
544 let person2 = context.add_entity((Age(2),)).unwrap();
545
546 (context, person1, person2)
547 }
548
549 #[test]
550 fn add_edge() {
551 let (mut context, person1, person2) = setup();
552
553 context
554 .add_edge::<Person, EdgeType1>(person1, person2, 0.01, EdgeType1(1))
555 .unwrap();
556 assert_eq!(
557 context
558 .get_edge::<Person, EdgeType1>(person1, person2)
559 .unwrap()
560 .weight,
561 0.01
562 );
563 assert_eq!(
564 context.get_edges::<Person, EdgeType1>(person1),
565 vec![Edge {
566 neighbor: person2,
567 weight: 0.01,
568 inner: EdgeType1(1)
569 }]
570 );
571 }
572
573 #[test]
574 fn remove_edge() {
575 let (mut context, person1, person2) = setup();
576 assert!(context
577 .remove_edge::<Person, EdgeType1>(person1, person2)
578 .is_none());
579 context
580 .add_edge::<Person, EdgeType1>(person1, person2, 0.01, EdgeType1(1))
581 .unwrap();
582 assert!(context
583 .remove_edge::<Person, EdgeType1>(person1, person2)
584 .is_some());
585 assert!(context
586 .get_edge::<Person, EdgeType1>(person1, person2)
587 .is_none());
588 assert_eq!(context.get_edges::<Person, EdgeType1>(person1).len(), 0);
589 }
590
591 #[test]
592 fn add_edge_bidi() {
593 let (mut context, person1, person2) = setup();
594
595 context
596 .add_edge_bidi::<Person, EdgeType1>(person1, person2, 0.01, EdgeType1(1))
597 .unwrap();
598 assert_eq!(
599 context
600 .get_edge::<Person, EdgeType1>(person1, person2)
601 .unwrap()
602 .weight,
603 0.01
604 );
605 assert_eq!(
606 context
607 .get_edge::<Person, EdgeType1>(person2, person1)
608 .unwrap()
609 .weight,
610 0.01
611 );
612 }
613
614 #[test]
615 fn add_edge_different_weights() {
616 let (mut context, person1, person2) = setup();
617
618 context
619 .add_edge::<Person, EdgeType1>(person1, person2, 0.01, EdgeType1(1))
620 .unwrap();
621 context
622 .add_edge::<Person, EdgeType1>(person2, person1, 0.02, EdgeType1(1))
623 .unwrap();
624 assert_eq!(
625 context
626 .get_edge::<Person, EdgeType1>(person1, person2)
627 .unwrap()
628 .weight,
629 0.01
630 );
631 assert_eq!(
632 context
633 .get_edge::<Person, EdgeType1>(person2, person1)
634 .unwrap()
635 .weight,
636 0.02
637 );
638 }
639
640 #[test]
641 fn get_matching_edges_weight() {
642 let (mut context, person1, person2) = setup();
643 let person3 = context.add_entity((Age(3),)).unwrap();
644
645 context
646 .add_edge::<Person, EdgeType1>(person1, person2, 0.01, EdgeType1(1))
647 .unwrap();
648 context
649 .add_edge::<Person, EdgeType1>(person1, person3, 0.03, EdgeType1(1))
650 .unwrap();
651 let edges = context
652 .get_matching_edges::<Person, EdgeType1>(person1, |_context, edge| edge.weight > 0.01);
653 assert_eq!(edges.len(), 1);
654 assert_eq!(edges[0].neighbor, person3);
655 }
656
657 #[test]
658 fn get_matching_edges_inner() {
659 let (mut context, person1, person2) = setup();
660 let person3 = context.add_entity((Age(3),)).unwrap();
661
662 context
663 .add_edge::<Person, EdgeType1>(person1, person2, 0.01, EdgeType1(1))
664 .unwrap();
665 context
666 .add_edge::<Person, EdgeType1>(person1, person3, 0.03, EdgeType1(3))
667 .unwrap();
668 let edges = context
669 .get_matching_edges::<Person, EdgeType1>(person1, |_context, edge| edge.inner.0 == 3);
670 assert_eq!(edges.len(), 1);
671 assert_eq!(edges[0].neighbor, person3);
672 }
673
674 #[test]
675 fn get_matching_edges_person_property() {
676 let (mut context, person1, person2) = setup();
677 let person3 = context.add_entity((Age(3),)).unwrap();
678
679 context
680 .add_edge::<Person, EdgeType1>(person1, person2, 0.01, EdgeType1(1))
681 .unwrap();
682 context
683 .add_edge::<Person, EdgeType1>(person1, person3, 0.03, EdgeType1(3))
684 .unwrap();
685 let edges = context.get_matching_edges::<Person, EdgeType1>(person1, |context, edge| {
686 context.match_entity(edge.neighbor, (Age(3),))
687 });
688 assert_eq!(edges.len(), 1);
689 assert_eq!(edges[0].neighbor, person3);
690 }
691
692 #[test]
693 fn select_random_edge() {
694 define_rng!(NetworkTestRng);
695
696 let (mut context, person1, person2) = setup();
697 let person3 = context.add_entity((Age(3),)).unwrap();
698 context.init_random(42);
699
700 context
701 .add_edge::<Person, EdgeType1>(person1, person2, 0.01, EdgeType1(1))
702 .unwrap();
703 context
704 .add_edge::<Person, EdgeType1>(person1, person3, 10_000_000.0, EdgeType1(3))
705 .unwrap();
706
707 let edge = context
708 .select_random_edge::<Person, EdgeType1, _>(NetworkTestRng, person1)
709 .unwrap();
710 assert_eq!(edge.neighbor, person3);
711 assert_eq!(edge.inner, EdgeType1(3));
712 }
713}