ixa/
network.rs

1//! A module for modeling contact networks.
2//!
3//! A network is modeled as a directed graph.  Edges are typed in the
4//! usual fashion, i.e., keyed by a Rust type, and each person can have an
5//! arbitrary number of outgoing edges of a given type, with each edge
6//! having a weight. Edge types can also specify their own per-type
7//! data which will be stored along with the edge.
8use std::any::{Any, TypeId};
9
10use rand::Rng;
11
12use crate::context::Context;
13use crate::error::IxaError;
14use crate::people::PersonId;
15use crate::random::{ContextRandomExt, RngId};
16use crate::{define_data_plugin, ContextBase, HashMap};
17
18#[derive(Copy, Clone, Debug, PartialEq)]
19/// An edge in network graph. Edges are directed, so the
20/// source person is implicit.
21pub struct Edge<T: Sized> {
22    /// The person this edge comes from.
23    pub person: PersonId,
24    /// The person this edge points to.
25    pub neighbor: PersonId,
26    /// The weight associated with the edge.
27    pub weight: f32,
28    /// An inner value defined by type `T`.
29    pub inner: T,
30}
31
32pub trait EdgeType {
33    type Value: Sized + Default + Copy;
34}
35
36#[derive(Default)]
37struct PersonNetwork {
38    // A vector of vectors of NetworkEdge, indexed by edge type.
39    neighbors: HashMap<TypeId, Box<dyn Any>>,
40}
41
42struct NetworkData {
43    network: Vec<PersonNetwork>,
44}
45
46impl NetworkData {
47    fn new() -> Self {
48        NetworkData {
49            network: Vec::new(),
50        }
51    }
52
53    fn add_edge<T: EdgeType + 'static>(
54        &mut self,
55        person: PersonId,
56        neighbor: PersonId,
57        weight: f32,
58        inner: T::Value,
59    ) -> Result<(), IxaError> {
60        if person == neighbor {
61            return Err(IxaError::IxaError(String::from("Cannot make edge to self")));
62        }
63
64        if weight.is_infinite() || weight.is_nan() || weight.is_sign_negative() {
65            return Err(IxaError::IxaError(String::from("Invalid weight")));
66        }
67
68        // Make sure we have data for this person.
69        if person.0 >= self.network.len() {
70            self.network.resize_with(person.0 + 1, Default::default);
71        }
72
73        let entry = self.network[person.0]
74            .neighbors
75            .entry(TypeId::of::<T>())
76            .or_insert_with(|| Box::new(Vec::<Edge<T::Value>>::new()));
77        let edges: &mut Vec<Edge<T::Value>> = entry.downcast_mut().expect("Type mismatch");
78
79        for edge in edges.iter_mut() {
80            if edge.neighbor == neighbor {
81                return Err(IxaError::IxaError(String::from("Edge already exists")));
82            }
83        }
84
85        edges.push(Edge {
86            person,
87            neighbor,
88            weight,
89            inner,
90        });
91        Ok(())
92    }
93
94    fn remove_edge<T: EdgeType + 'static>(
95        &mut self,
96        person: PersonId,
97        neighbor: PersonId,
98    ) -> Result<(), IxaError> {
99        if person.0 >= self.network.len() {
100            return Err(IxaError::IxaError(String::from("Edge does not exist")));
101        }
102
103        let entry = match self.network[person.0].neighbors.get_mut(&TypeId::of::<T>()) {
104            None => {
105                return Err(IxaError::IxaError(String::from("Edge does not exist")));
106            }
107            Some(entry) => entry,
108        };
109
110        let edges: &mut Vec<Edge<T::Value>> = entry.downcast_mut().expect("Type mismatch");
111        for index in 0..edges.len() {
112            if edges[index].neighbor == neighbor {
113                edges.remove(index);
114                return Ok(());
115            }
116        }
117
118        Err(IxaError::IxaError(String::from("Edge does not exist")))
119    }
120
121    fn get_edge<T: EdgeType + 'static>(
122        &self,
123        person: PersonId,
124        neighbor: PersonId,
125    ) -> Option<&Edge<T::Value>> {
126        if person.0 >= self.network.len() {
127            return None;
128        }
129
130        let entry = self.network[person.0].neighbors.get(&TypeId::of::<T>())?;
131        let edges: &Vec<Edge<T::Value>> = entry.downcast_ref().expect("Type mismatch");
132        edges.iter().find(|&edge| edge.neighbor == neighbor)
133    }
134
135    fn get_edges<T: EdgeType + 'static>(&self, person: PersonId) -> Vec<Edge<T::Value>> {
136        if person.0 >= self.network.len() {
137            return Vec::new();
138        }
139
140        let entry = self.network[person.0].neighbors.get(&TypeId::of::<T>());
141        if entry.is_none() {
142            return Vec::new();
143        }
144
145        let edges: &Vec<Edge<T::Value>> = entry.unwrap().downcast_ref().expect("Type mismatch");
146        edges.clone()
147    }
148
149    fn find_people_by_degree<T: EdgeType + 'static>(&self, degree: usize) -> Vec<PersonId> {
150        let mut result = Vec::new();
151
152        for person_id in 0..self.network.len() {
153            let entry = self.network[person_id].neighbors.get(&TypeId::of::<T>());
154            if entry.is_none() {
155                continue;
156            }
157            let edges: &Vec<Edge<T::Value>> = entry.unwrap().downcast_ref().expect("Type mismatch");
158            if edges.len() == degree {
159                result.push(PersonId(person_id));
160            }
161        }
162        result
163    }
164}
165
166define_data_plugin!(NetworkPlugin, NetworkData, NetworkData::new());
167
168// Public API.
169pub trait ContextNetworkExt: ContextBase + ContextRandomExt {
170    /// Add an edge of type `T` between `person` and `neighbor` with a
171    /// given `weight`.  `inner` is a value of whatever type is
172    /// associated with `T`.
173    ///
174    /// # Errors
175    ///
176    /// Returns [`IxaError`] if:
177    ///
178    /// * `person` and `neighbor` are the same or an edge already
179    ///   exists between them.
180    /// * `weight` is invalid
181    fn add_edge<T: EdgeType + 'static>(
182        &mut self,
183        person: PersonId,
184        neighbor: PersonId,
185        weight: f32,
186        inner: T::Value,
187    ) -> Result<(), IxaError> {
188        let data_container = self.get_data_mut(NetworkPlugin);
189        data_container.add_edge::<T>(person, neighbor, weight, inner)
190    }
191
192    /// Add a pair of edges of type `T` between `person1` and
193    /// `neighbor2` with a given `weight`, one edge in each
194    /// direction. `inner` is a value of whatever type is associated
195    /// with `T`. This is syntactic sugar for calling [`add_edge`](Self::add_edge)
196    /// twice.
197    ///
198    /// # Errors
199    ///
200    /// Returns [`IxaError`] if:
201    ///
202    /// * `person` and `neighbor` are the same or an edge already
203    ///   exists between them.
204    /// * `weight` is invalid
205    fn add_edge_bidi<T: EdgeType + 'static>(
206        &mut self,
207        person1: PersonId,
208        person2: PersonId,
209        weight: f32,
210        inner: T::Value,
211    ) -> Result<(), IxaError> {
212        let data_container = self.get_data_mut(NetworkPlugin);
213        data_container.add_edge::<T>(person1, person2, weight, inner)?;
214        data_container.add_edge::<T>(person2, person1, weight, inner)
215    }
216
217    /// Remove an edge of type `T` between `person` and `neighbor`
218    /// if one exists.
219    ///
220    /// # Errors
221    /// Returns [`IxaError`] if no edge exists.
222    fn remove_edge<T: EdgeType + 'static>(
223        &mut self,
224        person: PersonId,
225        neighbor: PersonId,
226    ) -> Result<(), IxaError> {
227        let data_container = self.get_data_mut(NetworkPlugin);
228        data_container.remove_edge::<T>(person, neighbor)
229    }
230
231    /// Get an edge of type `T` between `person` and `neighbor`
232    /// if one exists.
233    fn get_edge<T: EdgeType + 'static>(
234        &self,
235        person: PersonId,
236        neighbor: PersonId,
237    ) -> Option<&Edge<T::Value>> {
238        self.get_data(NetworkPlugin).get_edge::<T>(person, neighbor)
239    }
240
241    /// Get all edges of type `T` from `person`.
242    fn get_edges<T: EdgeType + 'static>(&self, person: PersonId) -> Vec<Edge<T::Value>> {
243        self.get_data(NetworkPlugin).get_edges::<T>(person)
244    }
245
246    /// Get all edges of type `T` from `person` that match the predicate
247    /// provided in `filter`. Note that because `filter` has access to
248    /// both the edge, which contains the neighbor and [`Context`], it is
249    /// possible to filter on properties of the neighbor. The function
250    /// [`match_person`](crate::people::ContextPeopleExt::match_person) might be helpful here.
251    ///
252    fn get_matching_edges<T: EdgeType + 'static>(
253        &self,
254        person: PersonId,
255        filter: impl Fn(&Context, &Edge<T::Value>) -> bool + 'static,
256    ) -> Vec<Edge<T::Value>>;
257
258    /// Find all people who have an edge of type `T` and degree `degree`.
259    fn find_people_by_degree<T: EdgeType + 'static>(&self, degree: usize) -> Vec<PersonId> {
260        self.get_data(NetworkPlugin)
261            .find_people_by_degree::<T>(degree)
262    }
263
264    /// Select a random edge out of the list of outgoing edges of type
265    /// `T` from `person_id`, weighted by the edge weights.
266    ///
267    /// # Errors
268    /// Returns [`IxaError`] if there are no edges.
269    fn select_random_edge<T: EdgeType + 'static, R: RngId + 'static>(
270        &self,
271        rng_id: R,
272        person_id: PersonId,
273    ) -> Result<Edge<T::Value>, IxaError>
274    where
275        R::RngType: Rng,
276    {
277        let edges = self.get_edges::<T>(person_id);
278        if edges.is_empty() {
279            return Err(IxaError::IxaError(String::from(
280                "Can't sample from empty list",
281            )));
282        }
283
284        let weights: Vec<_> = edges.iter().map(|x| x.weight).collect();
285        let index = self.sample_weighted(rng_id, &weights);
286        Ok(edges[index])
287    }
288}
289impl ContextNetworkExt for Context {
290    fn get_matching_edges<T: EdgeType + 'static>(
291        &self,
292        person: PersonId,
293        filter: impl Fn(&Context, &Edge<T::Value>) -> bool + 'static,
294    ) -> Vec<Edge<T::Value>> {
295        let edges = self.get_edges::<T>(person);
296        let mut result = Vec::new();
297        for edge in &edges {
298            if filter(self, edge) {
299                result.push(*edge);
300            }
301        }
302        result
303    }
304}
305
306#[cfg(test)]
307#[allow(clippy::float_cmp)]
308// Tests for the inner core.
309mod test_inner {
310    use super::{Edge, NetworkData};
311    use crate::define_edge_type;
312    use crate::error::IxaError;
313    use crate::people::PersonId;
314
315    define_edge_type!(EdgeType1, ());
316    define_edge_type!(EdgeType2, ());
317    define_edge_type!(EdgeType3, bool);
318
319    #[test]
320    fn add_edge() {
321        let mut nd = NetworkData::new();
322
323        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
324            .unwrap();
325        let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
326        assert_eq!(edge.weight, 0.01);
327    }
328
329    #[test]
330    fn add_edge_with_inner() {
331        let mut nd = NetworkData::new();
332
333        nd.add_edge::<EdgeType3>(PersonId(1), PersonId(2), 0.01, true)
334            .unwrap();
335        let edge = nd.get_edge::<EdgeType3>(PersonId(1), PersonId(2)).unwrap();
336        assert_eq!(edge.weight, 0.01);
337        assert!(edge.inner);
338    }
339
340    #[test]
341    fn add_two_edges() {
342        let mut nd = NetworkData::new();
343
344        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
345            .unwrap();
346        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(3), 0.02, ())
347            .unwrap();
348        let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
349        assert_eq!(edge.weight, 0.01);
350        let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(3)).unwrap();
351        assert_eq!(edge.weight, 0.02);
352
353        let edges = nd.get_edges::<EdgeType1>(PersonId(1));
354        assert_eq!(
355            edges,
356            vec![
357                Edge {
358                    person: PersonId(1),
359                    neighbor: PersonId(2),
360                    weight: 0.01,
361                    inner: ()
362                },
363                Edge {
364                    person: PersonId(1),
365                    neighbor: PersonId(3),
366                    weight: 0.02,
367                    inner: ()
368                }
369            ]
370        );
371    }
372
373    #[test]
374    fn add_two_edge_types() {
375        let mut nd = NetworkData::new();
376
377        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
378            .unwrap();
379        nd.add_edge::<EdgeType2>(PersonId(1), PersonId(2), 0.02, ())
380            .unwrap();
381        let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
382        assert_eq!(edge.weight, 0.01);
383        let edge = nd.get_edge::<EdgeType2>(PersonId(1), PersonId(2)).unwrap();
384        assert_eq!(edge.weight, 0.02);
385
386        let edges = nd.get_edges::<EdgeType1>(PersonId(1));
387        assert_eq!(
388            edges,
389            vec![Edge {
390                person: PersonId(1),
391                neighbor: PersonId(2),
392                weight: 0.01,
393                inner: ()
394            }]
395        );
396    }
397
398    #[test]
399    fn add_edge_twice_fails() {
400        let mut nd = NetworkData::new();
401
402        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
403            .unwrap();
404        let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
405        assert_eq!(edge.weight, 0.01);
406
407        assert!(matches!(
408            nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.02, ()),
409            Err(IxaError::IxaError(_))
410        ));
411    }
412
413    #[test]
414    fn add_remove_add_edge() {
415        let mut nd = NetworkData::new();
416
417        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
418            .unwrap();
419        let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
420        assert_eq!(edge.weight, 0.01);
421
422        nd.remove_edge::<EdgeType1>(PersonId(1), PersonId(2))
423            .unwrap();
424        let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2));
425        assert!(edge.is_none());
426
427        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.02, ())
428            .unwrap();
429        let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
430        assert_eq!(edge.weight, 0.02);
431    }
432
433    #[test]
434    fn remove_nonexistent_edge() {
435        let mut nd = NetworkData::new();
436        assert!(matches!(
437            nd.remove_edge::<EdgeType1>(PersonId(1), PersonId(2)),
438            Err(IxaError::IxaError(_))
439        ));
440    }
441
442    #[test]
443    fn add_edge_to_self() {
444        let mut nd = NetworkData::new();
445
446        let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(1), 0.01, ());
447        assert!(matches!(result, Err(IxaError::IxaError(_))));
448    }
449
450    #[test]
451    fn add_edge_bogus_weight() {
452        let mut nd = NetworkData::new();
453
454        let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), -1.0, ());
455        assert!(matches!(result, Err(IxaError::IxaError(_))));
456
457        let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), f32::NAN, ());
458        assert!(matches!(result, Err(IxaError::IxaError(_))));
459
460        let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), f32::INFINITY, ());
461        assert!(matches!(result, Err(IxaError::IxaError(_))));
462    }
463
464    #[test]
465    fn find_people_by_degree() {
466        let mut nd = NetworkData::new();
467
468        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.0, ())
469            .unwrap();
470        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(3), 0.0, ())
471            .unwrap();
472        nd.add_edge::<EdgeType1>(PersonId(2), PersonId(3), 0.0, ())
473            .unwrap();
474        nd.add_edge::<EdgeType1>(PersonId(3), PersonId(2), 0.0, ())
475            .unwrap();
476
477        let matches = nd.find_people_by_degree::<EdgeType1>(2);
478        assert_eq!(matches, vec![PersonId(1)]);
479        let matches = nd.find_people_by_degree::<EdgeType1>(1);
480        assert_eq!(matches, vec![PersonId(2), PersonId(3)]);
481    }
482}
483
484#[cfg(test)]
485#[allow(clippy::float_cmp)]
486// Tests for the API.
487mod test_api {
488    use crate::context::Context;
489    use crate::error::IxaError;
490    use crate::network::{ContextNetworkExt, Edge};
491    use crate::people::{define_person_property, ContextPeopleExt, PersonId};
492    use crate::random::ContextRandomExt;
493    use crate::{define_edge_type, define_rng};
494
495    define_edge_type!(EdgeType1, u32);
496    define_person_property!(Age, u8);
497
498    fn setup() -> (Context, PersonId, PersonId) {
499        let mut context = Context::new();
500        let person1 = context.add_person((Age, 1)).unwrap();
501        let person2 = context.add_person((Age, 2)).unwrap();
502
503        (context, person1, person2)
504    }
505
506    #[test]
507    fn add_edge() {
508        let (mut context, person1, person2) = setup();
509
510        context
511            .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
512            .unwrap();
513        assert_eq!(
514            context
515                .get_edge::<EdgeType1>(person1, person2)
516                .unwrap()
517                .weight,
518            0.01
519        );
520        assert_eq!(
521            context.get_edges::<EdgeType1>(person1),
522            vec![Edge {
523                person: person1,
524                neighbor: person2,
525                weight: 0.01,
526                inner: 1
527            }]
528        );
529    }
530
531    #[test]
532    fn remove_edge() {
533        let (mut context, person1, person2) = setup();
534        // Check that we get an error if nothing has been added.
535
536        assert!(matches!(
537            context.remove_edge::<EdgeType1>(person1, person2),
538            Err(IxaError::IxaError(_))
539        ));
540        context
541            .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
542            .unwrap();
543        context.remove_edge::<EdgeType1>(person1, person2).unwrap();
544        assert!(context.get_edge::<EdgeType1>(person1, person2).is_none());
545        assert_eq!(context.get_edges::<EdgeType1>(person1).len(), 0);
546    }
547
548    #[test]
549    fn add_edge_bidi() {
550        let (mut context, person1, person2) = setup();
551
552        context
553            .add_edge_bidi::<EdgeType1>(person1, person2, 0.01, 1)
554            .unwrap();
555        assert_eq!(
556            context
557                .get_edge::<EdgeType1>(person1, person2)
558                .unwrap()
559                .weight,
560            0.01
561        );
562        assert_eq!(
563            context
564                .get_edge::<EdgeType1>(person2, person1)
565                .unwrap()
566                .weight,
567            0.01
568        );
569    }
570
571    #[test]
572    fn add_edge_different_weights() {
573        let (mut context, person1, person2) = setup();
574
575        context
576            .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
577            .unwrap();
578        context
579            .add_edge::<EdgeType1>(person2, person1, 0.02, 1)
580            .unwrap();
581        assert_eq!(
582            context
583                .get_edge::<EdgeType1>(person1, person2)
584                .unwrap()
585                .weight,
586            0.01
587        );
588        assert_eq!(
589            context
590                .get_edge::<EdgeType1>(person2, person1)
591                .unwrap()
592                .weight,
593            0.02
594        );
595    }
596
597    #[test]
598    fn get_matching_edges_weight() {
599        let (mut context, person1, person2) = setup();
600        let person3 = context.add_person((Age, 3)).unwrap();
601
602        context
603            .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
604            .unwrap();
605        context
606            .add_edge::<EdgeType1>(person1, person3, 0.03, 1)
607            .unwrap();
608        let edges =
609            context.get_matching_edges::<EdgeType1>(person1, |_context, edge| edge.weight > 0.01);
610        assert_eq!(edges.len(), 1);
611        assert_eq!(edges[0].person, person1);
612        assert_eq!(edges[0].neighbor, person3);
613    }
614
615    #[test]
616    fn get_matching_edges_inner() {
617        let (mut context, person1, person2) = setup();
618        let person3 = context.add_person((Age, 3)).unwrap();
619
620        context
621            .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
622            .unwrap();
623        context
624            .add_edge::<EdgeType1>(person1, person3, 0.03, 3)
625            .unwrap();
626        let edges =
627            context.get_matching_edges::<EdgeType1>(person1, |_context, edge| edge.inner == 3);
628        assert_eq!(edges.len(), 1);
629        assert_eq!(edges[0].person, person1);
630        assert_eq!(edges[0].neighbor, person3);
631    }
632
633    #[test]
634    fn get_matching_edges_person_property() {
635        let (mut context, person1, person2) = setup();
636        let person3 = context.add_person((Age, 3)).unwrap();
637
638        context
639            .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
640            .unwrap();
641        context
642            .add_edge::<EdgeType1>(person1, person3, 0.03, 3)
643            .unwrap();
644        let edges = context.get_matching_edges::<EdgeType1>(person1, |context, edge| {
645            context.match_person(edge.neighbor, (Age, 3))
646        });
647        assert_eq!(edges.len(), 1);
648        assert_eq!(edges[0].person, person1);
649        assert_eq!(edges[0].neighbor, person3);
650    }
651
652    #[test]
653    fn select_random_edge() {
654        define_rng!(NetworkTestRng);
655
656        let (mut context, person1, person2) = setup();
657        let person3 = context.add_person((Age, 3)).unwrap();
658        context.init_random(42);
659
660        context
661            .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
662            .unwrap();
663        context
664            .add_edge::<EdgeType1>(person1, person3, 10_000_000.0, 3)
665            .unwrap();
666
667        let edge = context
668            .select_random_edge::<EdgeType1, _>(NetworkTestRng, person1)
669            .unwrap();
670        assert_eq!(edge.person, person1);
671        assert_eq!(edge.neighbor, person3);
672    }
673}