1use crate::HashMap;
9use crate::{
10 context::Context, define_data_plugin, error::IxaError, people::PersonId,
11 random::ContextRandomExt, random::RngId,
12};
13use rand::Rng;
14use std::any::{Any, TypeId};
15
16#[derive(Copy, Clone, Debug, PartialEq)]
17pub struct Edge<T: Sized> {
20 pub person: PersonId,
22 pub neighbor: PersonId,
24 pub weight: f32,
26 pub inner: T,
28}
29
30pub trait EdgeType {
31 type Value: Sized + Default + Copy;
32}
33
34#[derive(Default)]
35struct PersonNetwork {
36 neighbors: HashMap<TypeId, Box<dyn Any>>,
38}
39
40struct NetworkData {
41 network: Vec<PersonNetwork>,
42}
43
44impl NetworkData {
45 fn new() -> Self {
46 NetworkData {
47 network: Vec::new(),
48 }
49 }
50
51 fn add_edge<T: EdgeType + 'static>(
52 &mut self,
53 person: PersonId,
54 neighbor: PersonId,
55 weight: f32,
56 inner: T::Value,
57 ) -> Result<(), IxaError> {
58 if person == neighbor {
59 return Err(IxaError::IxaError(String::from("Cannot make edge to self")));
60 }
61
62 if weight.is_infinite() || weight.is_nan() || weight.is_sign_negative() {
63 return Err(IxaError::IxaError(String::from("Invalid weight")));
64 }
65
66 if person.0 >= self.network.len() {
68 self.network.resize_with(person.0 + 1, Default::default);
69 }
70
71 let entry = self.network[person.0]
72 .neighbors
73 .entry(TypeId::of::<T>())
74 .or_insert_with(|| Box::new(Vec::<Edge<T::Value>>::new()));
75 let edges: &mut Vec<Edge<T::Value>> = entry.downcast_mut().expect("Type mismatch");
76
77 for edge in edges.iter_mut() {
78 if edge.neighbor == neighbor {
79 return Err(IxaError::IxaError(String::from("Edge already exists")));
80 }
81 }
82
83 edges.push(Edge {
84 person,
85 neighbor,
86 weight,
87 inner,
88 });
89 Ok(())
90 }
91
92 fn remove_edge<T: EdgeType + 'static>(
93 &mut self,
94 person: PersonId,
95 neighbor: PersonId,
96 ) -> Result<(), IxaError> {
97 if person.0 >= self.network.len() {
98 return Err(IxaError::IxaError(String::from("Edge does not exist")));
99 }
100
101 let entry = match self.network[person.0].neighbors.get_mut(&TypeId::of::<T>()) {
102 None => {
103 return Err(IxaError::IxaError(String::from("Edge does not exist")));
104 }
105 Some(entry) => entry,
106 };
107
108 let edges: &mut Vec<Edge<T::Value>> = entry.downcast_mut().expect("Type mismatch");
109 for index in 0..edges.len() {
110 if edges[index].neighbor == neighbor {
111 edges.remove(index);
112 return Ok(());
113 }
114 }
115
116 Err(IxaError::IxaError(String::from("Edge does not exist")))
117 }
118
119 fn get_edge<T: EdgeType + 'static>(
120 &self,
121 person: PersonId,
122 neighbor: PersonId,
123 ) -> Option<&Edge<T::Value>> {
124 if person.0 >= self.network.len() {
125 return None;
126 }
127
128 let entry = self.network[person.0].neighbors.get(&TypeId::of::<T>())?;
129 let edges: &Vec<Edge<T::Value>> = entry.downcast_ref().expect("Type mismatch");
130 edges.iter().find(|&edge| edge.neighbor == neighbor)
131 }
132
133 fn get_edges<T: EdgeType + 'static>(&self, person: PersonId) -> Vec<Edge<T::Value>> {
134 if person.0 >= self.network.len() {
135 return Vec::new();
136 }
137
138 let entry = self.network[person.0].neighbors.get(&TypeId::of::<T>());
139 if entry.is_none() {
140 return Vec::new();
141 }
142
143 let edges: &Vec<Edge<T::Value>> = entry.unwrap().downcast_ref().expect("Type mismatch");
144 edges.clone()
145 }
146
147 fn find_people_by_degree<T: EdgeType + 'static>(&self, degree: usize) -> Vec<PersonId> {
148 let mut result = Vec::new();
149
150 for person_id in 0..self.network.len() {
151 let entry = self.network[person_id].neighbors.get(&TypeId::of::<T>());
152 if entry.is_none() {
153 continue;
154 }
155 let edges: &Vec<Edge<T::Value>> = entry.unwrap().downcast_ref().expect("Type mismatch");
156 if edges.len() == degree {
157 result.push(PersonId(person_id));
158 }
159 }
160 result
161 }
162}
163
164#[allow(unused_macros)]
169#[macro_export]
170macro_rules! define_edge_type {
171 ($edge_type:ident, $value:ty) => {
172 #[derive(Debug, Copy, Clone)]
173 pub struct $edge_type;
174
175 impl $crate::network::EdgeType for $edge_type {
176 type Value = $value;
177 }
178 };
179}
180
181define_data_plugin!(NetworkPlugin, NetworkData, NetworkData::new());
182
183pub trait ContextNetworkExt {
184 fn add_edge<T: EdgeType + 'static>(
196 &mut self,
197 person: PersonId,
198 neighbor: PersonId,
199 weight: f32,
200 inner: T::Value,
201 ) -> Result<(), IxaError>;
202
203 fn add_edge_bidi<T: EdgeType + 'static>(
217 &mut self,
218 person1: PersonId,
219 person2: PersonId,
220 weight: f32,
221 inner: T::Value,
222 ) -> Result<(), IxaError>;
223
224 fn remove_edge<T: EdgeType + 'static>(
230 &mut self,
231 person: PersonId,
232 neighbor: PersonId,
233 ) -> Result<(), IxaError>;
234
235 fn get_edge<T: EdgeType + 'static>(
238 &self,
239 person: PersonId,
240 neighbor: PersonId,
241 ) -> Option<&Edge<T::Value>>;
242
243 fn get_edges<T: EdgeType + 'static>(&self, person: PersonId) -> Vec<Edge<T::Value>>;
245
246 fn get_matching_edges<T: EdgeType + 'static>(
253 &self,
254 person: PersonId,
255 filter: impl Fn(&Context, &Edge<T::Value>) -> bool + 'static,
256 ) -> Vec<Edge<T::Value>>;
257
258 fn find_people_by_degree<T: EdgeType + 'static>(&self, degree: usize) -> Vec<PersonId>;
260
261 fn select_random_edge<T: EdgeType + 'static, R: RngId + 'static>(
267 &self,
268 rng_id: R,
269 person_id: PersonId,
270 ) -> Result<Edge<T::Value>, IxaError>
271 where
272 R::RngType: Rng;
273}
274
275impl ContextNetworkExt for Context {
277 fn add_edge<T: EdgeType + 'static>(
278 &mut self,
279 person: PersonId,
280 neighbor: PersonId,
281 weight: f32,
282 inner: T::Value,
283 ) -> Result<(), IxaError> {
284 let data_container = self.get_data_container_mut(NetworkPlugin);
285 data_container.add_edge::<T>(person, neighbor, weight, inner)
286 }
287
288 fn add_edge_bidi<T: EdgeType + 'static>(
289 &mut self,
290 person1: PersonId,
291 person2: PersonId,
292 weight: f32,
293 inner: T::Value,
294 ) -> Result<(), IxaError> {
295 let data_container = self.get_data_container_mut(NetworkPlugin);
296 data_container.add_edge::<T>(person1, person2, weight, inner)?;
297 data_container.add_edge::<T>(person2, person1, weight, inner)
298 }
299
300 fn remove_edge<T: EdgeType + 'static>(
301 &mut self,
302 person: PersonId,
303 neighbor: PersonId,
304 ) -> Result<(), IxaError> {
305 let data_container = self.get_data_container(NetworkPlugin);
306 if data_container.is_none() {
308 return Err(IxaError::IxaError(String::from("Network not initialized")));
309 }
310 let data_container = self.get_data_container_mut(NetworkPlugin);
311 data_container.remove_edge::<T>(person, neighbor)
312 }
313
314 fn get_edge<T: EdgeType + 'static>(
315 &self,
316 person: PersonId,
317 neighbor: PersonId,
318 ) -> Option<&Edge<T::Value>> {
319 let data_container = self.get_data_container(NetworkPlugin);
320
321 match data_container {
322 None => None,
323 Some(data_container) => data_container.get_edge::<T>(person, neighbor),
324 }
325 }
326
327 fn get_edges<T: EdgeType + 'static>(&self, person: PersonId) -> Vec<Edge<T::Value>> {
328 let data_container = self.get_data_container(NetworkPlugin);
329
330 match data_container {
331 None => Vec::new(),
332 Some(data_container) => data_container.get_edges::<T>(person),
333 }
334 }
335
336 fn get_matching_edges<T: EdgeType + 'static>(
337 &self,
338 person: PersonId,
339 filter: impl Fn(&Context, &Edge<T::Value>) -> bool + 'static,
340 ) -> Vec<Edge<T::Value>> {
341 let edges = self.get_edges::<T>(person);
342 let mut result = Vec::new();
343 for edge in &edges {
344 if filter(self, edge) {
345 result.push(*edge);
346 }
347 }
348 result
349 }
350
351 fn find_people_by_degree<T: EdgeType + 'static>(&self, degree: usize) -> Vec<PersonId> {
352 let data_container = self.get_data_container(NetworkPlugin);
353
354 match data_container {
355 None => Vec::new(),
356 Some(data_container) => data_container.find_people_by_degree::<T>(degree),
357 }
358 }
359
360 fn select_random_edge<T: EdgeType + 'static, R: RngId + 'static>(
361 &self,
362 rng_id: R,
363 person_id: PersonId,
364 ) -> Result<Edge<T::Value>, IxaError>
365 where
366 R::RngType: Rng,
367 {
368 let edges = self.get_edges::<T>(person_id);
369 if edges.is_empty() {
370 return Err(IxaError::IxaError(String::from(
371 "Can't sample from empty list",
372 )));
373 }
374
375 let weights: Vec<_> = edges.iter().map(|x| x.weight).collect();
376 let index = self.sample_weighted(rng_id, &weights);
377 Ok(edges[index])
378 }
379}
380
381#[cfg(test)]
382#[allow(clippy::float_cmp)]
383mod test_inner {
385 use super::{Edge, NetworkData};
386 use crate::error::IxaError;
387 use crate::people::PersonId;
388
389 define_edge_type!(EdgeType1, ());
390 define_edge_type!(EdgeType2, ());
391 define_edge_type!(EdgeType3, bool);
392
393 #[test]
394 fn add_edge() {
395 let mut nd = NetworkData::new();
396
397 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
398 .unwrap();
399 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
400 assert_eq!(edge.weight, 0.01);
401 }
402
403 #[test]
404 fn add_edge_with_inner() {
405 let mut nd = NetworkData::new();
406
407 nd.add_edge::<EdgeType3>(PersonId(1), PersonId(2), 0.01, true)
408 .unwrap();
409 let edge = nd.get_edge::<EdgeType3>(PersonId(1), PersonId(2)).unwrap();
410 assert_eq!(edge.weight, 0.01);
411 assert!(edge.inner);
412 }
413
414 #[test]
415 fn add_two_edges() {
416 let mut nd = NetworkData::new();
417
418 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
419 .unwrap();
420 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(3), 0.02, ())
421 .unwrap();
422 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
423 assert_eq!(edge.weight, 0.01);
424 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(3)).unwrap();
425 assert_eq!(edge.weight, 0.02);
426
427 let edges = nd.get_edges::<EdgeType1>(PersonId(1));
428 assert_eq!(
429 edges,
430 vec![
431 Edge {
432 person: PersonId(1),
433 neighbor: PersonId(2),
434 weight: 0.01,
435 inner: ()
436 },
437 Edge {
438 person: PersonId(1),
439 neighbor: PersonId(3),
440 weight: 0.02,
441 inner: ()
442 }
443 ]
444 );
445 }
446
447 #[test]
448 fn add_two_edge_types() {
449 let mut nd = NetworkData::new();
450
451 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
452 .unwrap();
453 nd.add_edge::<EdgeType2>(PersonId(1), PersonId(2), 0.02, ())
454 .unwrap();
455 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
456 assert_eq!(edge.weight, 0.01);
457 let edge = nd.get_edge::<EdgeType2>(PersonId(1), PersonId(2)).unwrap();
458 assert_eq!(edge.weight, 0.02);
459
460 let edges = nd.get_edges::<EdgeType1>(PersonId(1));
461 assert_eq!(
462 edges,
463 vec![Edge {
464 person: PersonId(1),
465 neighbor: PersonId(2),
466 weight: 0.01,
467 inner: ()
468 }]
469 );
470 }
471
472 #[test]
473 fn add_edge_twice_fails() {
474 let mut nd = NetworkData::new();
475
476 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
477 .unwrap();
478 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
479 assert_eq!(edge.weight, 0.01);
480
481 assert!(matches!(
482 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.02, ()),
483 Err(IxaError::IxaError(_))
484 ));
485 }
486
487 #[test]
488 fn add_remove_add_edge() {
489 let mut nd = NetworkData::new();
490
491 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
492 .unwrap();
493 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
494 assert_eq!(edge.weight, 0.01);
495
496 nd.remove_edge::<EdgeType1>(PersonId(1), PersonId(2))
497 .unwrap();
498 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2));
499 assert!(edge.is_none());
500
501 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.02, ())
502 .unwrap();
503 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
504 assert_eq!(edge.weight, 0.02);
505 }
506
507 #[test]
508 fn remove_nonexistent_edge() {
509 let mut nd = NetworkData::new();
510 assert!(matches!(
511 nd.remove_edge::<EdgeType1>(PersonId(1), PersonId(2)),
512 Err(IxaError::IxaError(_))
513 ));
514 }
515
516 #[test]
517 fn add_edge_to_self() {
518 let mut nd = NetworkData::new();
519
520 let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(1), 0.01, ());
521 assert!(matches!(result, Err(IxaError::IxaError(_))));
522 }
523
524 #[test]
525 fn add_edge_bogus_weight() {
526 let mut nd = NetworkData::new();
527
528 let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), -1.0, ());
529 assert!(matches!(result, Err(IxaError::IxaError(_))));
530
531 let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), f32::NAN, ());
532 assert!(matches!(result, Err(IxaError::IxaError(_))));
533
534 let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), f32::INFINITY, ());
535 assert!(matches!(result, Err(IxaError::IxaError(_))));
536 }
537
538 #[test]
539 fn find_people_by_degree() {
540 let mut nd = NetworkData::new();
541
542 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.0, ())
543 .unwrap();
544 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(3), 0.0, ())
545 .unwrap();
546 nd.add_edge::<EdgeType1>(PersonId(2), PersonId(3), 0.0, ())
547 .unwrap();
548 nd.add_edge::<EdgeType1>(PersonId(3), PersonId(2), 0.0, ())
549 .unwrap();
550
551 let matches = nd.find_people_by_degree::<EdgeType1>(2);
552 assert_eq!(matches, vec![PersonId(1)]);
553 let matches = nd.find_people_by_degree::<EdgeType1>(1);
554 assert_eq!(matches, vec![PersonId(2), PersonId(3)]);
555 }
556}
557
558#[cfg(test)]
559#[allow(clippy::float_cmp)]
560mod test_api {
562 use crate::context::Context;
563 use crate::define_rng;
564 use crate::error::IxaError;
565 use crate::network::{ContextNetworkExt, Edge};
566 use crate::people::{define_person_property, ContextPeopleExt, PersonId};
567 use crate::random::ContextRandomExt;
568
569 define_edge_type!(EdgeType1, u32);
570 define_person_property!(Age, u8);
571
572 fn setup() -> (Context, PersonId, PersonId) {
573 let mut context = Context::new();
574 let person1 = context.add_person((Age, 1)).unwrap();
575 let person2 = context.add_person((Age, 2)).unwrap();
576
577 (context, person1, person2)
578 }
579
580 #[test]
581 fn add_edge() {
582 let (mut context, person1, person2) = setup();
583
584 context
585 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
586 .unwrap();
587 assert_eq!(
588 context
589 .get_edge::<EdgeType1>(person1, person2)
590 .unwrap()
591 .weight,
592 0.01
593 );
594 assert_eq!(
595 context.get_edges::<EdgeType1>(person1),
596 vec![Edge {
597 person: person1,
598 neighbor: person2,
599 weight: 0.01,
600 inner: 1
601 }]
602 );
603 }
604
605 #[test]
606 fn remove_edge() {
607 let (mut context, person1, person2) = setup();
608 assert!(matches!(
611 context.remove_edge::<EdgeType1>(person1, person2),
612 Err(IxaError::IxaError(_))
613 ));
614 context
615 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
616 .unwrap();
617 context.remove_edge::<EdgeType1>(person1, person2).unwrap();
618 assert!(context.get_edge::<EdgeType1>(person1, person2).is_none());
619 assert_eq!(context.get_edges::<EdgeType1>(person1).len(), 0);
620 }
621
622 #[test]
623 fn add_edge_bidi() {
624 let (mut context, person1, person2) = setup();
625
626 context
627 .add_edge_bidi::<EdgeType1>(person1, person2, 0.01, 1)
628 .unwrap();
629 assert_eq!(
630 context
631 .get_edge::<EdgeType1>(person1, person2)
632 .unwrap()
633 .weight,
634 0.01
635 );
636 assert_eq!(
637 context
638 .get_edge::<EdgeType1>(person2, person1)
639 .unwrap()
640 .weight,
641 0.01
642 );
643 }
644
645 #[test]
646 fn add_edge_different_weights() {
647 let (mut context, person1, person2) = setup();
648
649 context
650 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
651 .unwrap();
652 context
653 .add_edge::<EdgeType1>(person2, person1, 0.02, 1)
654 .unwrap();
655 assert_eq!(
656 context
657 .get_edge::<EdgeType1>(person1, person2)
658 .unwrap()
659 .weight,
660 0.01
661 );
662 assert_eq!(
663 context
664 .get_edge::<EdgeType1>(person2, person1)
665 .unwrap()
666 .weight,
667 0.02
668 );
669 }
670
671 #[test]
672 fn get_matching_edges_weight() {
673 let (mut context, person1, person2) = setup();
674 let person3 = context.add_person((Age, 3)).unwrap();
675
676 context
677 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
678 .unwrap();
679 context
680 .add_edge::<EdgeType1>(person1, person3, 0.03, 1)
681 .unwrap();
682 let edges =
683 context.get_matching_edges::<EdgeType1>(person1, |_context, edge| edge.weight > 0.01);
684 assert_eq!(edges.len(), 1);
685 assert_eq!(edges[0].person, person1);
686 assert_eq!(edges[0].neighbor, person3);
687 }
688
689 #[test]
690 fn get_matching_edges_inner() {
691 let (mut context, person1, person2) = setup();
692 let person3 = context.add_person((Age, 3)).unwrap();
693
694 context
695 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
696 .unwrap();
697 context
698 .add_edge::<EdgeType1>(person1, person3, 0.03, 3)
699 .unwrap();
700 let edges =
701 context.get_matching_edges::<EdgeType1>(person1, |_context, edge| edge.inner == 3);
702 assert_eq!(edges.len(), 1);
703 assert_eq!(edges[0].person, person1);
704 assert_eq!(edges[0].neighbor, person3);
705 }
706
707 #[test]
708 fn get_matching_edges_person_property() {
709 let (mut context, person1, person2) = setup();
710 let person3 = context.add_person((Age, 3)).unwrap();
711
712 context
713 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
714 .unwrap();
715 context
716 .add_edge::<EdgeType1>(person1, person3, 0.03, 3)
717 .unwrap();
718 let edges = context.get_matching_edges::<EdgeType1>(person1, |context, edge| {
719 context.match_person(edge.neighbor, (Age, 3))
720 });
721 assert_eq!(edges.len(), 1);
722 assert_eq!(edges[0].person, person1);
723 assert_eq!(edges[0].neighbor, person3);
724 }
725
726 #[test]
727 fn select_random_edge() {
728 define_rng!(NetworkTestRng);
729
730 let (mut context, person1, person2) = setup();
731 let person3 = context.add_person((Age, 3)).unwrap();
732 context.init_random(42);
733
734 context
735 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
736 .unwrap();
737 context
738 .add_edge::<EdgeType1>(person1, person3, 10_000_000.0, 3)
739 .unwrap();
740
741 let edge = context
742 .select_random_edge::<EdgeType1, _>(NetworkTestRng, person1)
743 .unwrap();
744 assert_eq!(edge.person, person1);
745 assert_eq!(edge.neighbor, person3);
746 }
747}