1use crate::{
9 context::Context, define_data_plugin, error::IxaError, people::PersonId,
10 random::ContextRandomExt, random::RngId, HashMap, PluginContext,
11};
12use rand::Rng;
13use std::any::{Any, TypeId};
14
15#[derive(Copy, Clone, Debug, PartialEq)]
16pub struct Edge<T: Sized> {
19 pub person: PersonId,
21 pub neighbor: PersonId,
23 pub weight: f32,
25 pub inner: T,
27}
28
29pub trait EdgeType {
30 type Value: Sized + Default + Copy;
31}
32
33#[derive(Default)]
34struct PersonNetwork {
35 neighbors: HashMap<TypeId, Box<dyn Any>>,
37}
38
39struct NetworkData {
40 network: Vec<PersonNetwork>,
41}
42
43impl NetworkData {
44 fn new() -> Self {
45 NetworkData {
46 network: Vec::new(),
47 }
48 }
49
50 fn add_edge<T: EdgeType + 'static>(
51 &mut self,
52 person: PersonId,
53 neighbor: PersonId,
54 weight: f32,
55 inner: T::Value,
56 ) -> Result<(), IxaError> {
57 if person == neighbor {
58 return Err(IxaError::IxaError(String::from("Cannot make edge to self")));
59 }
60
61 if weight.is_infinite() || weight.is_nan() || weight.is_sign_negative() {
62 return Err(IxaError::IxaError(String::from("Invalid weight")));
63 }
64
65 if person.0 >= self.network.len() {
67 self.network.resize_with(person.0 + 1, Default::default);
68 }
69
70 let entry = self.network[person.0]
71 .neighbors
72 .entry(TypeId::of::<T>())
73 .or_insert_with(|| Box::new(Vec::<Edge<T::Value>>::new()));
74 let edges: &mut Vec<Edge<T::Value>> = entry.downcast_mut().expect("Type mismatch");
75
76 for edge in edges.iter_mut() {
77 if edge.neighbor == neighbor {
78 return Err(IxaError::IxaError(String::from("Edge already exists")));
79 }
80 }
81
82 edges.push(Edge {
83 person,
84 neighbor,
85 weight,
86 inner,
87 });
88 Ok(())
89 }
90
91 fn remove_edge<T: EdgeType + 'static>(
92 &mut self,
93 person: PersonId,
94 neighbor: PersonId,
95 ) -> Result<(), IxaError> {
96 if person.0 >= self.network.len() {
97 return Err(IxaError::IxaError(String::from("Edge does not exist")));
98 }
99
100 let entry = match self.network[person.0].neighbors.get_mut(&TypeId::of::<T>()) {
101 None => {
102 return Err(IxaError::IxaError(String::from("Edge does not exist")));
103 }
104 Some(entry) => entry,
105 };
106
107 let edges: &mut Vec<Edge<T::Value>> = entry.downcast_mut().expect("Type mismatch");
108 for index in 0..edges.len() {
109 if edges[index].neighbor == neighbor {
110 edges.remove(index);
111 return Ok(());
112 }
113 }
114
115 Err(IxaError::IxaError(String::from("Edge does not exist")))
116 }
117
118 fn get_edge<T: EdgeType + 'static>(
119 &self,
120 person: PersonId,
121 neighbor: PersonId,
122 ) -> Option<&Edge<T::Value>> {
123 if person.0 >= self.network.len() {
124 return None;
125 }
126
127 let entry = self.network[person.0].neighbors.get(&TypeId::of::<T>())?;
128 let edges: &Vec<Edge<T::Value>> = entry.downcast_ref().expect("Type mismatch");
129 edges.iter().find(|&edge| edge.neighbor == neighbor)
130 }
131
132 fn get_edges<T: EdgeType + 'static>(&self, person: PersonId) -> Vec<Edge<T::Value>> {
133 if person.0 >= self.network.len() {
134 return Vec::new();
135 }
136
137 let entry = self.network[person.0].neighbors.get(&TypeId::of::<T>());
138 if entry.is_none() {
139 return Vec::new();
140 }
141
142 let edges: &Vec<Edge<T::Value>> = entry.unwrap().downcast_ref().expect("Type mismatch");
143 edges.clone()
144 }
145
146 fn find_people_by_degree<T: EdgeType + 'static>(&self, degree: usize) -> Vec<PersonId> {
147 let mut result = Vec::new();
148
149 for person_id in 0..self.network.len() {
150 let entry = self.network[person_id].neighbors.get(&TypeId::of::<T>());
151 if entry.is_none() {
152 continue;
153 }
154 let edges: &Vec<Edge<T::Value>> = entry.unwrap().downcast_ref().expect("Type mismatch");
155 if edges.len() == degree {
156 result.push(PersonId(person_id));
157 }
158 }
159 result
160 }
161}
162
163#[allow(unused_macros)]
168#[macro_export]
169macro_rules! define_edge_type {
170 ($edge_type:ident, $value:ty) => {
171 #[derive(Debug, Copy, Clone)]
172 pub struct $edge_type;
173
174 impl $crate::network::EdgeType for $edge_type {
175 type Value = $value;
176 }
177 };
178}
179
180define_data_plugin!(NetworkPlugin, NetworkData, NetworkData::new());
181
182pub trait ContextNetworkExt: PluginContext + ContextRandomExt {
184 fn add_edge<T: EdgeType + 'static>(
196 &mut self,
197 person: PersonId,
198 neighbor: PersonId,
199 weight: f32,
200 inner: T::Value,
201 ) -> Result<(), IxaError> {
202 let data_container = self.get_data_mut(NetworkPlugin);
203 data_container.add_edge::<T>(person, neighbor, weight, inner)
204 }
205
206 fn add_edge_bidi<T: EdgeType + 'static>(
220 &mut self,
221 person1: PersonId,
222 person2: PersonId,
223 weight: f32,
224 inner: T::Value,
225 ) -> Result<(), IxaError> {
226 let data_container = self.get_data_mut(NetworkPlugin);
227 data_container.add_edge::<T>(person1, person2, weight, inner)?;
228 data_container.add_edge::<T>(person2, person1, weight, inner)
229 }
230
231 fn remove_edge<T: EdgeType + 'static>(
237 &mut self,
238 person: PersonId,
239 neighbor: PersonId,
240 ) -> Result<(), IxaError> {
241 let data_container = self.get_data_mut(NetworkPlugin);
242 data_container.remove_edge::<T>(person, neighbor)
243 }
244
245 fn get_edge<T: EdgeType + 'static>(
248 &self,
249 person: PersonId,
250 neighbor: PersonId,
251 ) -> Option<&Edge<T::Value>> {
252 self.get_data(NetworkPlugin).get_edge::<T>(person, neighbor)
253 }
254
255 fn get_edges<T: EdgeType + 'static>(&self, person: PersonId) -> Vec<Edge<T::Value>> {
257 self.get_data(NetworkPlugin).get_edges::<T>(person)
258 }
259
260 fn get_matching_edges<T: EdgeType + 'static>(
267 &self,
268 person: PersonId,
269 filter: impl Fn(&Context, &Edge<T::Value>) -> bool + 'static,
270 ) -> Vec<Edge<T::Value>>;
271
272 fn find_people_by_degree<T: EdgeType + 'static>(&self, degree: usize) -> Vec<PersonId> {
274 self.get_data(NetworkPlugin)
275 .find_people_by_degree::<T>(degree)
276 }
277
278 fn select_random_edge<T: EdgeType + 'static, R: RngId + 'static>(
284 &self,
285 rng_id: R,
286 person_id: PersonId,
287 ) -> Result<Edge<T::Value>, IxaError>
288 where
289 R::RngType: Rng,
290 {
291 let edges = self.get_edges::<T>(person_id);
292 if edges.is_empty() {
293 return Err(IxaError::IxaError(String::from(
294 "Can't sample from empty list",
295 )));
296 }
297
298 let weights: Vec<_> = edges.iter().map(|x| x.weight).collect();
299 let index = self.sample_weighted(rng_id, &weights);
300 Ok(edges[index])
301 }
302}
303impl ContextNetworkExt for Context {
304 fn get_matching_edges<T: EdgeType + 'static>(
305 &self,
306 person: PersonId,
307 filter: impl Fn(&Context, &Edge<T::Value>) -> bool + 'static,
308 ) -> Vec<Edge<T::Value>> {
309 let edges = self.get_edges::<T>(person);
310 let mut result = Vec::new();
311 for edge in &edges {
312 if filter(self, edge) {
313 result.push(*edge);
314 }
315 }
316 result
317 }
318}
319
320#[cfg(test)]
321#[allow(clippy::float_cmp)]
322mod test_inner {
324 use super::{Edge, NetworkData};
325 use crate::error::IxaError;
326 use crate::people::PersonId;
327
328 define_edge_type!(EdgeType1, ());
329 define_edge_type!(EdgeType2, ());
330 define_edge_type!(EdgeType3, bool);
331
332 #[test]
333 fn add_edge() {
334 let mut nd = NetworkData::new();
335
336 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
337 .unwrap();
338 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
339 assert_eq!(edge.weight, 0.01);
340 }
341
342 #[test]
343 fn add_edge_with_inner() {
344 let mut nd = NetworkData::new();
345
346 nd.add_edge::<EdgeType3>(PersonId(1), PersonId(2), 0.01, true)
347 .unwrap();
348 let edge = nd.get_edge::<EdgeType3>(PersonId(1), PersonId(2)).unwrap();
349 assert_eq!(edge.weight, 0.01);
350 assert!(edge.inner);
351 }
352
353 #[test]
354 fn add_two_edges() {
355 let mut nd = NetworkData::new();
356
357 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
358 .unwrap();
359 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(3), 0.02, ())
360 .unwrap();
361 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
362 assert_eq!(edge.weight, 0.01);
363 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(3)).unwrap();
364 assert_eq!(edge.weight, 0.02);
365
366 let edges = nd.get_edges::<EdgeType1>(PersonId(1));
367 assert_eq!(
368 edges,
369 vec![
370 Edge {
371 person: PersonId(1),
372 neighbor: PersonId(2),
373 weight: 0.01,
374 inner: ()
375 },
376 Edge {
377 person: PersonId(1),
378 neighbor: PersonId(3),
379 weight: 0.02,
380 inner: ()
381 }
382 ]
383 );
384 }
385
386 #[test]
387 fn add_two_edge_types() {
388 let mut nd = NetworkData::new();
389
390 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
391 .unwrap();
392 nd.add_edge::<EdgeType2>(PersonId(1), PersonId(2), 0.02, ())
393 .unwrap();
394 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
395 assert_eq!(edge.weight, 0.01);
396 let edge = nd.get_edge::<EdgeType2>(PersonId(1), PersonId(2)).unwrap();
397 assert_eq!(edge.weight, 0.02);
398
399 let edges = nd.get_edges::<EdgeType1>(PersonId(1));
400 assert_eq!(
401 edges,
402 vec![Edge {
403 person: PersonId(1),
404 neighbor: PersonId(2),
405 weight: 0.01,
406 inner: ()
407 }]
408 );
409 }
410
411 #[test]
412 fn add_edge_twice_fails() {
413 let mut nd = NetworkData::new();
414
415 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
416 .unwrap();
417 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
418 assert_eq!(edge.weight, 0.01);
419
420 assert!(matches!(
421 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.02, ()),
422 Err(IxaError::IxaError(_))
423 ));
424 }
425
426 #[test]
427 fn add_remove_add_edge() {
428 let mut nd = NetworkData::new();
429
430 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
431 .unwrap();
432 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
433 assert_eq!(edge.weight, 0.01);
434
435 nd.remove_edge::<EdgeType1>(PersonId(1), PersonId(2))
436 .unwrap();
437 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2));
438 assert!(edge.is_none());
439
440 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.02, ())
441 .unwrap();
442 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
443 assert_eq!(edge.weight, 0.02);
444 }
445
446 #[test]
447 fn remove_nonexistent_edge() {
448 let mut nd = NetworkData::new();
449 assert!(matches!(
450 nd.remove_edge::<EdgeType1>(PersonId(1), PersonId(2)),
451 Err(IxaError::IxaError(_))
452 ));
453 }
454
455 #[test]
456 fn add_edge_to_self() {
457 let mut nd = NetworkData::new();
458
459 let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(1), 0.01, ());
460 assert!(matches!(result, Err(IxaError::IxaError(_))));
461 }
462
463 #[test]
464 fn add_edge_bogus_weight() {
465 let mut nd = NetworkData::new();
466
467 let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), -1.0, ());
468 assert!(matches!(result, Err(IxaError::IxaError(_))));
469
470 let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), f32::NAN, ());
471 assert!(matches!(result, Err(IxaError::IxaError(_))));
472
473 let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), f32::INFINITY, ());
474 assert!(matches!(result, Err(IxaError::IxaError(_))));
475 }
476
477 #[test]
478 fn find_people_by_degree() {
479 let mut nd = NetworkData::new();
480
481 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.0, ())
482 .unwrap();
483 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(3), 0.0, ())
484 .unwrap();
485 nd.add_edge::<EdgeType1>(PersonId(2), PersonId(3), 0.0, ())
486 .unwrap();
487 nd.add_edge::<EdgeType1>(PersonId(3), PersonId(2), 0.0, ())
488 .unwrap();
489
490 let matches = nd.find_people_by_degree::<EdgeType1>(2);
491 assert_eq!(matches, vec![PersonId(1)]);
492 let matches = nd.find_people_by_degree::<EdgeType1>(1);
493 assert_eq!(matches, vec![PersonId(2), PersonId(3)]);
494 }
495}
496
497#[cfg(test)]
498#[allow(clippy::float_cmp)]
499mod test_api {
501 use crate::context::Context;
502 use crate::define_rng;
503 use crate::error::IxaError;
504 use crate::network::{ContextNetworkExt, Edge};
505 use crate::people::{define_person_property, ContextPeopleExt, PersonId};
506 use crate::random::ContextRandomExt;
507
508 define_edge_type!(EdgeType1, u32);
509 define_person_property!(Age, u8);
510
511 fn setup() -> (Context, PersonId, PersonId) {
512 let mut context = Context::new();
513 let person1 = context.add_person((Age, 1)).unwrap();
514 let person2 = context.add_person((Age, 2)).unwrap();
515
516 (context, person1, person2)
517 }
518
519 #[test]
520 fn add_edge() {
521 let (mut context, person1, person2) = setup();
522
523 context
524 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
525 .unwrap();
526 assert_eq!(
527 context
528 .get_edge::<EdgeType1>(person1, person2)
529 .unwrap()
530 .weight,
531 0.01
532 );
533 assert_eq!(
534 context.get_edges::<EdgeType1>(person1),
535 vec![Edge {
536 person: person1,
537 neighbor: person2,
538 weight: 0.01,
539 inner: 1
540 }]
541 );
542 }
543
544 #[test]
545 fn remove_edge() {
546 let (mut context, person1, person2) = setup();
547 assert!(matches!(
550 context.remove_edge::<EdgeType1>(person1, person2),
551 Err(IxaError::IxaError(_))
552 ));
553 context
554 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
555 .unwrap();
556 context.remove_edge::<EdgeType1>(person1, person2).unwrap();
557 assert!(context.get_edge::<EdgeType1>(person1, person2).is_none());
558 assert_eq!(context.get_edges::<EdgeType1>(person1).len(), 0);
559 }
560
561 #[test]
562 fn add_edge_bidi() {
563 let (mut context, person1, person2) = setup();
564
565 context
566 .add_edge_bidi::<EdgeType1>(person1, person2, 0.01, 1)
567 .unwrap();
568 assert_eq!(
569 context
570 .get_edge::<EdgeType1>(person1, person2)
571 .unwrap()
572 .weight,
573 0.01
574 );
575 assert_eq!(
576 context
577 .get_edge::<EdgeType1>(person2, person1)
578 .unwrap()
579 .weight,
580 0.01
581 );
582 }
583
584 #[test]
585 fn add_edge_different_weights() {
586 let (mut context, person1, person2) = setup();
587
588 context
589 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
590 .unwrap();
591 context
592 .add_edge::<EdgeType1>(person2, person1, 0.02, 1)
593 .unwrap();
594 assert_eq!(
595 context
596 .get_edge::<EdgeType1>(person1, person2)
597 .unwrap()
598 .weight,
599 0.01
600 );
601 assert_eq!(
602 context
603 .get_edge::<EdgeType1>(person2, person1)
604 .unwrap()
605 .weight,
606 0.02
607 );
608 }
609
610 #[test]
611 fn get_matching_edges_weight() {
612 let (mut context, person1, person2) = setup();
613 let person3 = context.add_person((Age, 3)).unwrap();
614
615 context
616 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
617 .unwrap();
618 context
619 .add_edge::<EdgeType1>(person1, person3, 0.03, 1)
620 .unwrap();
621 let edges =
622 context.get_matching_edges::<EdgeType1>(person1, |_context, edge| edge.weight > 0.01);
623 assert_eq!(edges.len(), 1);
624 assert_eq!(edges[0].person, person1);
625 assert_eq!(edges[0].neighbor, person3);
626 }
627
628 #[test]
629 fn get_matching_edges_inner() {
630 let (mut context, person1, person2) = setup();
631 let person3 = context.add_person((Age, 3)).unwrap();
632
633 context
634 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
635 .unwrap();
636 context
637 .add_edge::<EdgeType1>(person1, person3, 0.03, 3)
638 .unwrap();
639 let edges =
640 context.get_matching_edges::<EdgeType1>(person1, |_context, edge| edge.inner == 3);
641 assert_eq!(edges.len(), 1);
642 assert_eq!(edges[0].person, person1);
643 assert_eq!(edges[0].neighbor, person3);
644 }
645
646 #[test]
647 fn get_matching_edges_person_property() {
648 let (mut context, person1, person2) = setup();
649 let person3 = context.add_person((Age, 3)).unwrap();
650
651 context
652 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
653 .unwrap();
654 context
655 .add_edge::<EdgeType1>(person1, person3, 0.03, 3)
656 .unwrap();
657 let edges = context.get_matching_edges::<EdgeType1>(person1, |context, edge| {
658 context.match_person(edge.neighbor, (Age, 3))
659 });
660 assert_eq!(edges.len(), 1);
661 assert_eq!(edges[0].person, person1);
662 assert_eq!(edges[0].neighbor, person3);
663 }
664
665 #[test]
666 fn select_random_edge() {
667 define_rng!(NetworkTestRng);
668
669 let (mut context, person1, person2) = setup();
670 let person3 = context.add_person((Age, 3)).unwrap();
671 context.init_random(42);
672
673 context
674 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
675 .unwrap();
676 context
677 .add_edge::<EdgeType1>(person1, person3, 10_000_000.0, 3)
678 .unwrap();
679
680 let edge = context
681 .select_random_edge::<EdgeType1, _>(NetworkTestRng, person1)
682 .unwrap();
683 assert_eq!(edge.person, person1);
684 assert_eq!(edge.neighbor, person3);
685 }
686}