1use std::any::{Any, TypeId};
9
10use rand::Rng;
11
12use crate::context::Context;
13use crate::error::IxaError;
14use crate::people::PersonId;
15use crate::random::{ContextRandomExt, RngId};
16use crate::{define_data_plugin, ContextBase, HashMap};
17
18#[derive(Copy, Clone, Debug, PartialEq)]
19pub struct Edge<T: Sized> {
22 pub person: PersonId,
24 pub neighbor: PersonId,
26 pub weight: f32,
28 pub inner: T,
30}
31
32pub trait EdgeType {
33 type Value: Sized + Default + Copy;
34}
35
36#[derive(Default)]
37struct PersonNetwork {
38 neighbors: HashMap<TypeId, Box<dyn Any>>,
40}
41
42struct NetworkData {
43 network: Vec<PersonNetwork>,
44}
45
46impl NetworkData {
47 fn new() -> Self {
48 NetworkData {
49 network: Vec::new(),
50 }
51 }
52
53 fn add_edge<T: EdgeType + 'static>(
54 &mut self,
55 person: PersonId,
56 neighbor: PersonId,
57 weight: f32,
58 inner: T::Value,
59 ) -> Result<(), IxaError> {
60 if person == neighbor {
61 return Err(IxaError::IxaError(String::from("Cannot make edge to self")));
62 }
63
64 if weight.is_infinite() || weight.is_nan() || weight.is_sign_negative() {
65 return Err(IxaError::IxaError(String::from("Invalid weight")));
66 }
67
68 if person.0 >= self.network.len() {
70 self.network.resize_with(person.0 + 1, Default::default);
71 }
72
73 let entry = self.network[person.0]
74 .neighbors
75 .entry(TypeId::of::<T>())
76 .or_insert_with(|| Box::new(Vec::<Edge<T::Value>>::new()));
77 let edges: &mut Vec<Edge<T::Value>> = entry.downcast_mut().expect("Type mismatch");
78
79 for edge in edges.iter_mut() {
80 if edge.neighbor == neighbor {
81 return Err(IxaError::IxaError(String::from("Edge already exists")));
82 }
83 }
84
85 edges.push(Edge {
86 person,
87 neighbor,
88 weight,
89 inner,
90 });
91 Ok(())
92 }
93
94 fn remove_edge<T: EdgeType + 'static>(
95 &mut self,
96 person: PersonId,
97 neighbor: PersonId,
98 ) -> Result<(), IxaError> {
99 if person.0 >= self.network.len() {
100 return Err(IxaError::IxaError(String::from("Edge does not exist")));
101 }
102
103 let entry = match self.network[person.0].neighbors.get_mut(&TypeId::of::<T>()) {
104 None => {
105 return Err(IxaError::IxaError(String::from("Edge does not exist")));
106 }
107 Some(entry) => entry,
108 };
109
110 let edges: &mut Vec<Edge<T::Value>> = entry.downcast_mut().expect("Type mismatch");
111 for index in 0..edges.len() {
112 if edges[index].neighbor == neighbor {
113 edges.remove(index);
114 return Ok(());
115 }
116 }
117
118 Err(IxaError::IxaError(String::from("Edge does not exist")))
119 }
120
121 fn get_edge<T: EdgeType + 'static>(
122 &self,
123 person: PersonId,
124 neighbor: PersonId,
125 ) -> Option<&Edge<T::Value>> {
126 if person.0 >= self.network.len() {
127 return None;
128 }
129
130 let entry = self.network[person.0].neighbors.get(&TypeId::of::<T>())?;
131 let edges: &Vec<Edge<T::Value>> = entry.downcast_ref().expect("Type mismatch");
132 edges.iter().find(|&edge| edge.neighbor == neighbor)
133 }
134
135 fn get_edges<T: EdgeType + 'static>(&self, person: PersonId) -> Vec<Edge<T::Value>> {
136 if person.0 >= self.network.len() {
137 return Vec::new();
138 }
139
140 let entry = self.network[person.0].neighbors.get(&TypeId::of::<T>());
141 if entry.is_none() {
142 return Vec::new();
143 }
144
145 let edges: &Vec<Edge<T::Value>> = entry.unwrap().downcast_ref().expect("Type mismatch");
146 edges.clone()
147 }
148
149 fn find_people_by_degree<T: EdgeType + 'static>(&self, degree: usize) -> Vec<PersonId> {
150 let mut result = Vec::new();
151
152 for person_id in 0..self.network.len() {
153 let entry = self.network[person_id].neighbors.get(&TypeId::of::<T>());
154 if entry.is_none() {
155 continue;
156 }
157 let edges: &Vec<Edge<T::Value>> = entry.unwrap().downcast_ref().expect("Type mismatch");
158 if edges.len() == degree {
159 result.push(PersonId(person_id));
160 }
161 }
162 result
163 }
164}
165
166define_data_plugin!(NetworkPlugin, NetworkData, NetworkData::new());
167
168pub trait ContextNetworkExt: ContextBase + ContextRandomExt {
170 fn add_edge<T: EdgeType + 'static>(
182 &mut self,
183 person: PersonId,
184 neighbor: PersonId,
185 weight: f32,
186 inner: T::Value,
187 ) -> Result<(), IxaError> {
188 let data_container = self.get_data_mut(NetworkPlugin);
189 data_container.add_edge::<T>(person, neighbor, weight, inner)
190 }
191
192 fn add_edge_bidi<T: EdgeType + 'static>(
206 &mut self,
207 person1: PersonId,
208 person2: PersonId,
209 weight: f32,
210 inner: T::Value,
211 ) -> Result<(), IxaError> {
212 let data_container = self.get_data_mut(NetworkPlugin);
213 data_container.add_edge::<T>(person1, person2, weight, inner)?;
214 data_container.add_edge::<T>(person2, person1, weight, inner)
215 }
216
217 fn remove_edge<T: EdgeType + 'static>(
223 &mut self,
224 person: PersonId,
225 neighbor: PersonId,
226 ) -> Result<(), IxaError> {
227 let data_container = self.get_data_mut(NetworkPlugin);
228 data_container.remove_edge::<T>(person, neighbor)
229 }
230
231 fn get_edge<T: EdgeType + 'static>(
234 &self,
235 person: PersonId,
236 neighbor: PersonId,
237 ) -> Option<&Edge<T::Value>> {
238 self.get_data(NetworkPlugin).get_edge::<T>(person, neighbor)
239 }
240
241 fn get_edges<T: EdgeType + 'static>(&self, person: PersonId) -> Vec<Edge<T::Value>> {
243 self.get_data(NetworkPlugin).get_edges::<T>(person)
244 }
245
246 fn get_matching_edges<T: EdgeType + 'static>(
253 &self,
254 person: PersonId,
255 filter: impl Fn(&Context, &Edge<T::Value>) -> bool + 'static,
256 ) -> Vec<Edge<T::Value>>;
257
258 fn find_people_by_degree<T: EdgeType + 'static>(&self, degree: usize) -> Vec<PersonId> {
260 self.get_data(NetworkPlugin)
261 .find_people_by_degree::<T>(degree)
262 }
263
264 fn select_random_edge<T: EdgeType + 'static, R: RngId + 'static>(
270 &self,
271 rng_id: R,
272 person_id: PersonId,
273 ) -> Result<Edge<T::Value>, IxaError>
274 where
275 R::RngType: Rng,
276 {
277 let edges = self.get_edges::<T>(person_id);
278 if edges.is_empty() {
279 return Err(IxaError::IxaError(String::from(
280 "Can't sample from empty list",
281 )));
282 }
283
284 let weights: Vec<_> = edges.iter().map(|x| x.weight).collect();
285 let index = self.sample_weighted(rng_id, &weights);
286 Ok(edges[index])
287 }
288}
289impl ContextNetworkExt for Context {
290 fn get_matching_edges<T: EdgeType + 'static>(
291 &self,
292 person: PersonId,
293 filter: impl Fn(&Context, &Edge<T::Value>) -> bool + 'static,
294 ) -> Vec<Edge<T::Value>> {
295 let edges = self.get_edges::<T>(person);
296 let mut result = Vec::new();
297 for edge in &edges {
298 if filter(self, edge) {
299 result.push(*edge);
300 }
301 }
302 result
303 }
304}
305
306#[cfg(test)]
307#[allow(clippy::float_cmp)]
308mod test_inner {
310 use super::{Edge, NetworkData};
311 use crate::define_edge_type;
312 use crate::error::IxaError;
313 use crate::people::PersonId;
314
315 define_edge_type!(EdgeType1, ());
316 define_edge_type!(EdgeType2, ());
317 define_edge_type!(EdgeType3, bool);
318
319 #[test]
320 fn add_edge() {
321 let mut nd = NetworkData::new();
322
323 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
324 .unwrap();
325 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
326 assert_eq!(edge.weight, 0.01);
327 }
328
329 #[test]
330 fn add_edge_with_inner() {
331 let mut nd = NetworkData::new();
332
333 nd.add_edge::<EdgeType3>(PersonId(1), PersonId(2), 0.01, true)
334 .unwrap();
335 let edge = nd.get_edge::<EdgeType3>(PersonId(1), PersonId(2)).unwrap();
336 assert_eq!(edge.weight, 0.01);
337 assert!(edge.inner);
338 }
339
340 #[test]
341 fn add_two_edges() {
342 let mut nd = NetworkData::new();
343
344 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
345 .unwrap();
346 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(3), 0.02, ())
347 .unwrap();
348 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
349 assert_eq!(edge.weight, 0.01);
350 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(3)).unwrap();
351 assert_eq!(edge.weight, 0.02);
352
353 let edges = nd.get_edges::<EdgeType1>(PersonId(1));
354 assert_eq!(
355 edges,
356 vec![
357 Edge {
358 person: PersonId(1),
359 neighbor: PersonId(2),
360 weight: 0.01,
361 inner: ()
362 },
363 Edge {
364 person: PersonId(1),
365 neighbor: PersonId(3),
366 weight: 0.02,
367 inner: ()
368 }
369 ]
370 );
371 }
372
373 #[test]
374 fn add_two_edge_types() {
375 let mut nd = NetworkData::new();
376
377 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
378 .unwrap();
379 nd.add_edge::<EdgeType2>(PersonId(1), PersonId(2), 0.02, ())
380 .unwrap();
381 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
382 assert_eq!(edge.weight, 0.01);
383 let edge = nd.get_edge::<EdgeType2>(PersonId(1), PersonId(2)).unwrap();
384 assert_eq!(edge.weight, 0.02);
385
386 let edges = nd.get_edges::<EdgeType1>(PersonId(1));
387 assert_eq!(
388 edges,
389 vec![Edge {
390 person: PersonId(1),
391 neighbor: PersonId(2),
392 weight: 0.01,
393 inner: ()
394 }]
395 );
396 }
397
398 #[test]
399 fn add_edge_twice_fails() {
400 let mut nd = NetworkData::new();
401
402 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
403 .unwrap();
404 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
405 assert_eq!(edge.weight, 0.01);
406
407 assert!(matches!(
408 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.02, ()),
409 Err(IxaError::IxaError(_))
410 ));
411 }
412
413 #[test]
414 fn add_remove_add_edge() {
415 let mut nd = NetworkData::new();
416
417 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
418 .unwrap();
419 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
420 assert_eq!(edge.weight, 0.01);
421
422 nd.remove_edge::<EdgeType1>(PersonId(1), PersonId(2))
423 .unwrap();
424 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2));
425 assert!(edge.is_none());
426
427 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.02, ())
428 .unwrap();
429 let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
430 assert_eq!(edge.weight, 0.02);
431 }
432
433 #[test]
434 fn remove_nonexistent_edge() {
435 let mut nd = NetworkData::new();
436 assert!(matches!(
437 nd.remove_edge::<EdgeType1>(PersonId(1), PersonId(2)),
438 Err(IxaError::IxaError(_))
439 ));
440 }
441
442 #[test]
443 fn add_edge_to_self() {
444 let mut nd = NetworkData::new();
445
446 let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(1), 0.01, ());
447 assert!(matches!(result, Err(IxaError::IxaError(_))));
448 }
449
450 #[test]
451 fn add_edge_bogus_weight() {
452 let mut nd = NetworkData::new();
453
454 let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), -1.0, ());
455 assert!(matches!(result, Err(IxaError::IxaError(_))));
456
457 let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), f32::NAN, ());
458 assert!(matches!(result, Err(IxaError::IxaError(_))));
459
460 let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), f32::INFINITY, ());
461 assert!(matches!(result, Err(IxaError::IxaError(_))));
462 }
463
464 #[test]
465 fn find_people_by_degree() {
466 let mut nd = NetworkData::new();
467
468 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.0, ())
469 .unwrap();
470 nd.add_edge::<EdgeType1>(PersonId(1), PersonId(3), 0.0, ())
471 .unwrap();
472 nd.add_edge::<EdgeType1>(PersonId(2), PersonId(3), 0.0, ())
473 .unwrap();
474 nd.add_edge::<EdgeType1>(PersonId(3), PersonId(2), 0.0, ())
475 .unwrap();
476
477 let matches = nd.find_people_by_degree::<EdgeType1>(2);
478 assert_eq!(matches, vec![PersonId(1)]);
479 let matches = nd.find_people_by_degree::<EdgeType1>(1);
480 assert_eq!(matches, vec![PersonId(2), PersonId(3)]);
481 }
482}
483
484#[cfg(test)]
485#[allow(clippy::float_cmp)]
486mod test_api {
488 use crate::context::Context;
489 use crate::error::IxaError;
490 use crate::network::{ContextNetworkExt, Edge};
491 use crate::people::{define_person_property, ContextPeopleExt, PersonId};
492 use crate::random::ContextRandomExt;
493 use crate::{define_edge_type, define_rng};
494
495 define_edge_type!(EdgeType1, u32);
496 define_person_property!(Age, u8);
497
498 fn setup() -> (Context, PersonId, PersonId) {
499 let mut context = Context::new();
500 let person1 = context.add_person((Age, 1)).unwrap();
501 let person2 = context.add_person((Age, 2)).unwrap();
502
503 (context, person1, person2)
504 }
505
506 #[test]
507 fn add_edge() {
508 let (mut context, person1, person2) = setup();
509
510 context
511 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
512 .unwrap();
513 assert_eq!(
514 context
515 .get_edge::<EdgeType1>(person1, person2)
516 .unwrap()
517 .weight,
518 0.01
519 );
520 assert_eq!(
521 context.get_edges::<EdgeType1>(person1),
522 vec![Edge {
523 person: person1,
524 neighbor: person2,
525 weight: 0.01,
526 inner: 1
527 }]
528 );
529 }
530
531 #[test]
532 fn remove_edge() {
533 let (mut context, person1, person2) = setup();
534 assert!(matches!(
537 context.remove_edge::<EdgeType1>(person1, person2),
538 Err(IxaError::IxaError(_))
539 ));
540 context
541 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
542 .unwrap();
543 context.remove_edge::<EdgeType1>(person1, person2).unwrap();
544 assert!(context.get_edge::<EdgeType1>(person1, person2).is_none());
545 assert_eq!(context.get_edges::<EdgeType1>(person1).len(), 0);
546 }
547
548 #[test]
549 fn add_edge_bidi() {
550 let (mut context, person1, person2) = setup();
551
552 context
553 .add_edge_bidi::<EdgeType1>(person1, person2, 0.01, 1)
554 .unwrap();
555 assert_eq!(
556 context
557 .get_edge::<EdgeType1>(person1, person2)
558 .unwrap()
559 .weight,
560 0.01
561 );
562 assert_eq!(
563 context
564 .get_edge::<EdgeType1>(person2, person1)
565 .unwrap()
566 .weight,
567 0.01
568 );
569 }
570
571 #[test]
572 fn add_edge_different_weights() {
573 let (mut context, person1, person2) = setup();
574
575 context
576 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
577 .unwrap();
578 context
579 .add_edge::<EdgeType1>(person2, person1, 0.02, 1)
580 .unwrap();
581 assert_eq!(
582 context
583 .get_edge::<EdgeType1>(person1, person2)
584 .unwrap()
585 .weight,
586 0.01
587 );
588 assert_eq!(
589 context
590 .get_edge::<EdgeType1>(person2, person1)
591 .unwrap()
592 .weight,
593 0.02
594 );
595 }
596
597 #[test]
598 fn get_matching_edges_weight() {
599 let (mut context, person1, person2) = setup();
600 let person3 = context.add_person((Age, 3)).unwrap();
601
602 context
603 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
604 .unwrap();
605 context
606 .add_edge::<EdgeType1>(person1, person3, 0.03, 1)
607 .unwrap();
608 let edges =
609 context.get_matching_edges::<EdgeType1>(person1, |_context, edge| edge.weight > 0.01);
610 assert_eq!(edges.len(), 1);
611 assert_eq!(edges[0].person, person1);
612 assert_eq!(edges[0].neighbor, person3);
613 }
614
615 #[test]
616 fn get_matching_edges_inner() {
617 let (mut context, person1, person2) = setup();
618 let person3 = context.add_person((Age, 3)).unwrap();
619
620 context
621 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
622 .unwrap();
623 context
624 .add_edge::<EdgeType1>(person1, person3, 0.03, 3)
625 .unwrap();
626 let edges =
627 context.get_matching_edges::<EdgeType1>(person1, |_context, edge| edge.inner == 3);
628 assert_eq!(edges.len(), 1);
629 assert_eq!(edges[0].person, person1);
630 assert_eq!(edges[0].neighbor, person3);
631 }
632
633 #[test]
634 fn get_matching_edges_person_property() {
635 let (mut context, person1, person2) = setup();
636 let person3 = context.add_person((Age, 3)).unwrap();
637
638 context
639 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
640 .unwrap();
641 context
642 .add_edge::<EdgeType1>(person1, person3, 0.03, 3)
643 .unwrap();
644 let edges = context.get_matching_edges::<EdgeType1>(person1, |context, edge| {
645 context.match_person(edge.neighbor, (Age, 3))
646 });
647 assert_eq!(edges.len(), 1);
648 assert_eq!(edges[0].person, person1);
649 assert_eq!(edges[0].neighbor, person3);
650 }
651
652 #[test]
653 fn select_random_edge() {
654 define_rng!(NetworkTestRng);
655
656 let (mut context, person1, person2) = setup();
657 let person3 = context.add_person((Age, 3)).unwrap();
658 context.init_random(42);
659
660 context
661 .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
662 .unwrap();
663 context
664 .add_edge::<EdgeType1>(person1, person3, 10_000_000.0, 3)
665 .unwrap();
666
667 let edge = context
668 .select_random_edge::<EdgeType1, _>(NetworkTestRng, person1)
669 .unwrap();
670 assert_eq!(edge.person, person1);
671 assert_eq!(edge.neighbor, person3);
672 }
673}