ixa/
network.rs

1//! A module for modeling contact networks.
2//!
3//! A network is modeled as a directed graph.  Edges are typed in the
4//! usual fashion, i.e., keyed by a Rust type, and each person can have an
5//! arbitrary number of outgoing edges of a given type, with each edge
6//! having a weight. Edge types can also specify their own per-type
7//! data which will be stored along with the edge.
8use crate::{
9    context::Context, define_data_plugin, error::IxaError, people::PersonId,
10    random::ContextRandomExt, random::RngId, HashMap, PluginContext,
11};
12use rand::Rng;
13use std::any::{Any, TypeId};
14
15#[derive(Copy, Clone, Debug, PartialEq)]
16/// An edge in network graph. Edges are directed, so the
17/// source person is implicit.
18pub struct Edge<T: Sized> {
19    /// The person this edge comes from.
20    pub person: PersonId,
21    /// The person this edge points to.
22    pub neighbor: PersonId,
23    /// The weight associated with the edge.
24    pub weight: f32,
25    /// An inner value defined by type `T`.
26    pub inner: T,
27}
28
29pub trait EdgeType {
30    type Value: Sized + Default + Copy;
31}
32
33#[derive(Default)]
34struct PersonNetwork {
35    // A vector of vectors of NetworkEdge, indexed by edge type.
36    neighbors: HashMap<TypeId, Box<dyn Any>>,
37}
38
39struct NetworkData {
40    network: Vec<PersonNetwork>,
41}
42
43impl NetworkData {
44    fn new() -> Self {
45        NetworkData {
46            network: Vec::new(),
47        }
48    }
49
50    fn add_edge<T: EdgeType + 'static>(
51        &mut self,
52        person: PersonId,
53        neighbor: PersonId,
54        weight: f32,
55        inner: T::Value,
56    ) -> Result<(), IxaError> {
57        if person == neighbor {
58            return Err(IxaError::IxaError(String::from("Cannot make edge to self")));
59        }
60
61        if weight.is_infinite() || weight.is_nan() || weight.is_sign_negative() {
62            return Err(IxaError::IxaError(String::from("Invalid weight")));
63        }
64
65        // Make sure we have data for this person.
66        if person.0 >= self.network.len() {
67            self.network.resize_with(person.0 + 1, Default::default);
68        }
69
70        let entry = self.network[person.0]
71            .neighbors
72            .entry(TypeId::of::<T>())
73            .or_insert_with(|| Box::new(Vec::<Edge<T::Value>>::new()));
74        let edges: &mut Vec<Edge<T::Value>> = entry.downcast_mut().expect("Type mismatch");
75
76        for edge in edges.iter_mut() {
77            if edge.neighbor == neighbor {
78                return Err(IxaError::IxaError(String::from("Edge already exists")));
79            }
80        }
81
82        edges.push(Edge {
83            person,
84            neighbor,
85            weight,
86            inner,
87        });
88        Ok(())
89    }
90
91    fn remove_edge<T: EdgeType + 'static>(
92        &mut self,
93        person: PersonId,
94        neighbor: PersonId,
95    ) -> Result<(), IxaError> {
96        if person.0 >= self.network.len() {
97            return Err(IxaError::IxaError(String::from("Edge does not exist")));
98        }
99
100        let entry = match self.network[person.0].neighbors.get_mut(&TypeId::of::<T>()) {
101            None => {
102                return Err(IxaError::IxaError(String::from("Edge does not exist")));
103            }
104            Some(entry) => entry,
105        };
106
107        let edges: &mut Vec<Edge<T::Value>> = entry.downcast_mut().expect("Type mismatch");
108        for index in 0..edges.len() {
109            if edges[index].neighbor == neighbor {
110                edges.remove(index);
111                return Ok(());
112            }
113        }
114
115        Err(IxaError::IxaError(String::from("Edge does not exist")))
116    }
117
118    fn get_edge<T: EdgeType + 'static>(
119        &self,
120        person: PersonId,
121        neighbor: PersonId,
122    ) -> Option<&Edge<T::Value>> {
123        if person.0 >= self.network.len() {
124            return None;
125        }
126
127        let entry = self.network[person.0].neighbors.get(&TypeId::of::<T>())?;
128        let edges: &Vec<Edge<T::Value>> = entry.downcast_ref().expect("Type mismatch");
129        edges.iter().find(|&edge| edge.neighbor == neighbor)
130    }
131
132    fn get_edges<T: EdgeType + 'static>(&self, person: PersonId) -> Vec<Edge<T::Value>> {
133        if person.0 >= self.network.len() {
134            return Vec::new();
135        }
136
137        let entry = self.network[person.0].neighbors.get(&TypeId::of::<T>());
138        if entry.is_none() {
139            return Vec::new();
140        }
141
142        let edges: &Vec<Edge<T::Value>> = entry.unwrap().downcast_ref().expect("Type mismatch");
143        edges.clone()
144    }
145
146    fn find_people_by_degree<T: EdgeType + 'static>(&self, degree: usize) -> Vec<PersonId> {
147        let mut result = Vec::new();
148
149        for person_id in 0..self.network.len() {
150            let entry = self.network[person_id].neighbors.get(&TypeId::of::<T>());
151            if entry.is_none() {
152                continue;
153            }
154            let edges: &Vec<Edge<T::Value>> = entry.unwrap().downcast_ref().expect("Type mismatch");
155            if edges.len() == degree {
156                result.push(PersonId(person_id));
157            }
158        }
159        result
160    }
161}
162
163/// Define a new edge type for use with `network`.
164///
165/// Defines a new edge type of type `$edge_type`, with inner type `$value`.
166/// Use `()` for `$value` to have no inner type.
167#[allow(unused_macros)]
168#[macro_export]
169macro_rules! define_edge_type {
170    ($edge_type:ident, $value:ty) => {
171        #[derive(Debug, Copy, Clone)]
172        pub struct $edge_type;
173
174        impl $crate::network::EdgeType for $edge_type {
175            type Value = $value;
176        }
177    };
178}
179
180define_data_plugin!(NetworkPlugin, NetworkData, NetworkData::new());
181
182// Public API.
183pub trait ContextNetworkExt: PluginContext + ContextRandomExt {
184    /// Add an edge of type `T` between `person` and `neighbor` with a
185    /// given `weight`.  `inner` is a value of whatever type is
186    /// associated with `T`.
187    ///
188    /// # Errors
189    ///
190    /// Returns `IxaError` if:
191    ///
192    /// * `person` and `neighbor` are the same or an edge already
193    ///   exists between them.
194    /// * `weight` is invalid
195    fn add_edge<T: EdgeType + 'static>(
196        &mut self,
197        person: PersonId,
198        neighbor: PersonId,
199        weight: f32,
200        inner: T::Value,
201    ) -> Result<(), IxaError> {
202        let data_container = self.get_data_mut(NetworkPlugin);
203        data_container.add_edge::<T>(person, neighbor, weight, inner)
204    }
205
206    /// Add a pair of edges of type `T` between `person1` and
207    /// `neighbor2` with a given `weight`, one edge in each
208    /// direction. `inner` is a value of whatever type is associated
209    /// with `T`. This is syntactic sugar for calling `add_edge()`
210    /// twice.
211    ///
212    /// # Errors
213    ///
214    /// Returns `IxaError` if:
215    ///
216    /// * `person` and `neighbor` are the same or an edge already
217    ///   exists between them.
218    /// * `weight` is invalid
219    fn add_edge_bidi<T: EdgeType + 'static>(
220        &mut self,
221        person1: PersonId,
222        person2: PersonId,
223        weight: f32,
224        inner: T::Value,
225    ) -> Result<(), IxaError> {
226        let data_container = self.get_data_mut(NetworkPlugin);
227        data_container.add_edge::<T>(person1, person2, weight, inner)?;
228        data_container.add_edge::<T>(person2, person1, weight, inner)
229    }
230
231    /// Remove an edge of type `T` between `person` and `neighbor`
232    /// if one exists.
233    ///
234    /// # Errors
235    /// Returns `IxaError` if no edge exists.
236    fn remove_edge<T: EdgeType + 'static>(
237        &mut self,
238        person: PersonId,
239        neighbor: PersonId,
240    ) -> Result<(), IxaError> {
241        let data_container = self.get_data_mut(NetworkPlugin);
242        data_container.remove_edge::<T>(person, neighbor)
243    }
244
245    /// Get an edge of type `T` between `person` and `neighbor`
246    /// if one exists.
247    fn get_edge<T: EdgeType + 'static>(
248        &self,
249        person: PersonId,
250        neighbor: PersonId,
251    ) -> Option<&Edge<T::Value>> {
252        self.get_data(NetworkPlugin).get_edge::<T>(person, neighbor)
253    }
254
255    /// Get all edges of type `T` from `person`.
256    fn get_edges<T: EdgeType + 'static>(&self, person: PersonId) -> Vec<Edge<T::Value>> {
257        self.get_data(NetworkPlugin).get_edges::<T>(person)
258    }
259
260    /// Get all edges of type `T` from `person` that match the predicate
261    /// provided in `filter`. Note that because `filter` has access to
262    /// both the edge, which contains the neighbor and `Context`, it is
263    /// possible to filter on properties of the neighbor. The function
264    /// `context.matching_person()` might be helpful here.
265    ///
266    fn get_matching_edges<T: EdgeType + 'static>(
267        &self,
268        person: PersonId,
269        filter: impl Fn(&Context, &Edge<T::Value>) -> bool + 'static,
270    ) -> Vec<Edge<T::Value>>;
271
272    /// Find all people who have an edge of type `T` and degree `degree`.
273    fn find_people_by_degree<T: EdgeType + 'static>(&self, degree: usize) -> Vec<PersonId> {
274        self.get_data(NetworkPlugin)
275            .find_people_by_degree::<T>(degree)
276    }
277
278    /// Select a random edge out of the list of outgoing edges of type
279    /// `T` from `person_id`, weighted by the edge weights.
280    ///
281    /// # Errors
282    /// Returns `IxaError` if there are no edges.
283    fn select_random_edge<T: EdgeType + 'static, R: RngId + 'static>(
284        &self,
285        rng_id: R,
286        person_id: PersonId,
287    ) -> Result<Edge<T::Value>, IxaError>
288    where
289        R::RngType: Rng,
290    {
291        let edges = self.get_edges::<T>(person_id);
292        if edges.is_empty() {
293            return Err(IxaError::IxaError(String::from(
294                "Can't sample from empty list",
295            )));
296        }
297
298        let weights: Vec<_> = edges.iter().map(|x| x.weight).collect();
299        let index = self.sample_weighted(rng_id, &weights);
300        Ok(edges[index])
301    }
302}
303impl ContextNetworkExt for Context {
304    fn get_matching_edges<T: EdgeType + 'static>(
305        &self,
306        person: PersonId,
307        filter: impl Fn(&Context, &Edge<T::Value>) -> bool + 'static,
308    ) -> Vec<Edge<T::Value>> {
309        let edges = self.get_edges::<T>(person);
310        let mut result = Vec::new();
311        for edge in &edges {
312            if filter(self, edge) {
313                result.push(*edge);
314            }
315        }
316        result
317    }
318}
319
320#[cfg(test)]
321#[allow(clippy::float_cmp)]
322// Tests for the inner core.
323mod test_inner {
324    use super::{Edge, NetworkData};
325    use crate::error::IxaError;
326    use crate::people::PersonId;
327
328    define_edge_type!(EdgeType1, ());
329    define_edge_type!(EdgeType2, ());
330    define_edge_type!(EdgeType3, bool);
331
332    #[test]
333    fn add_edge() {
334        let mut nd = NetworkData::new();
335
336        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
337            .unwrap();
338        let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
339        assert_eq!(edge.weight, 0.01);
340    }
341
342    #[test]
343    fn add_edge_with_inner() {
344        let mut nd = NetworkData::new();
345
346        nd.add_edge::<EdgeType3>(PersonId(1), PersonId(2), 0.01, true)
347            .unwrap();
348        let edge = nd.get_edge::<EdgeType3>(PersonId(1), PersonId(2)).unwrap();
349        assert_eq!(edge.weight, 0.01);
350        assert!(edge.inner);
351    }
352
353    #[test]
354    fn add_two_edges() {
355        let mut nd = NetworkData::new();
356
357        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
358            .unwrap();
359        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(3), 0.02, ())
360            .unwrap();
361        let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
362        assert_eq!(edge.weight, 0.01);
363        let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(3)).unwrap();
364        assert_eq!(edge.weight, 0.02);
365
366        let edges = nd.get_edges::<EdgeType1>(PersonId(1));
367        assert_eq!(
368            edges,
369            vec![
370                Edge {
371                    person: PersonId(1),
372                    neighbor: PersonId(2),
373                    weight: 0.01,
374                    inner: ()
375                },
376                Edge {
377                    person: PersonId(1),
378                    neighbor: PersonId(3),
379                    weight: 0.02,
380                    inner: ()
381                }
382            ]
383        );
384    }
385
386    #[test]
387    fn add_two_edge_types() {
388        let mut nd = NetworkData::new();
389
390        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
391            .unwrap();
392        nd.add_edge::<EdgeType2>(PersonId(1), PersonId(2), 0.02, ())
393            .unwrap();
394        let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
395        assert_eq!(edge.weight, 0.01);
396        let edge = nd.get_edge::<EdgeType2>(PersonId(1), PersonId(2)).unwrap();
397        assert_eq!(edge.weight, 0.02);
398
399        let edges = nd.get_edges::<EdgeType1>(PersonId(1));
400        assert_eq!(
401            edges,
402            vec![Edge {
403                person: PersonId(1),
404                neighbor: PersonId(2),
405                weight: 0.01,
406                inner: ()
407            }]
408        );
409    }
410
411    #[test]
412    fn add_edge_twice_fails() {
413        let mut nd = NetworkData::new();
414
415        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
416            .unwrap();
417        let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
418        assert_eq!(edge.weight, 0.01);
419
420        assert!(matches!(
421            nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.02, ()),
422            Err(IxaError::IxaError(_))
423        ));
424    }
425
426    #[test]
427    fn add_remove_add_edge() {
428        let mut nd = NetworkData::new();
429
430        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.01, ())
431            .unwrap();
432        let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
433        assert_eq!(edge.weight, 0.01);
434
435        nd.remove_edge::<EdgeType1>(PersonId(1), PersonId(2))
436            .unwrap();
437        let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2));
438        assert!(edge.is_none());
439
440        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.02, ())
441            .unwrap();
442        let edge = nd.get_edge::<EdgeType1>(PersonId(1), PersonId(2)).unwrap();
443        assert_eq!(edge.weight, 0.02);
444    }
445
446    #[test]
447    fn remove_nonexistent_edge() {
448        let mut nd = NetworkData::new();
449        assert!(matches!(
450            nd.remove_edge::<EdgeType1>(PersonId(1), PersonId(2)),
451            Err(IxaError::IxaError(_))
452        ));
453    }
454
455    #[test]
456    fn add_edge_to_self() {
457        let mut nd = NetworkData::new();
458
459        let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(1), 0.01, ());
460        assert!(matches!(result, Err(IxaError::IxaError(_))));
461    }
462
463    #[test]
464    fn add_edge_bogus_weight() {
465        let mut nd = NetworkData::new();
466
467        let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), -1.0, ());
468        assert!(matches!(result, Err(IxaError::IxaError(_))));
469
470        let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), f32::NAN, ());
471        assert!(matches!(result, Err(IxaError::IxaError(_))));
472
473        let result = nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), f32::INFINITY, ());
474        assert!(matches!(result, Err(IxaError::IxaError(_))));
475    }
476
477    #[test]
478    fn find_people_by_degree() {
479        let mut nd = NetworkData::new();
480
481        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(2), 0.0, ())
482            .unwrap();
483        nd.add_edge::<EdgeType1>(PersonId(1), PersonId(3), 0.0, ())
484            .unwrap();
485        nd.add_edge::<EdgeType1>(PersonId(2), PersonId(3), 0.0, ())
486            .unwrap();
487        nd.add_edge::<EdgeType1>(PersonId(3), PersonId(2), 0.0, ())
488            .unwrap();
489
490        let matches = nd.find_people_by_degree::<EdgeType1>(2);
491        assert_eq!(matches, vec![PersonId(1)]);
492        let matches = nd.find_people_by_degree::<EdgeType1>(1);
493        assert_eq!(matches, vec![PersonId(2), PersonId(3)]);
494    }
495}
496
497#[cfg(test)]
498#[allow(clippy::float_cmp)]
499// Tests for the API.
500mod test_api {
501    use crate::context::Context;
502    use crate::define_rng;
503    use crate::error::IxaError;
504    use crate::network::{ContextNetworkExt, Edge};
505    use crate::people::{define_person_property, ContextPeopleExt, PersonId};
506    use crate::random::ContextRandomExt;
507
508    define_edge_type!(EdgeType1, u32);
509    define_person_property!(Age, u8);
510
511    fn setup() -> (Context, PersonId, PersonId) {
512        let mut context = Context::new();
513        let person1 = context.add_person((Age, 1)).unwrap();
514        let person2 = context.add_person((Age, 2)).unwrap();
515
516        (context, person1, person2)
517    }
518
519    #[test]
520    fn add_edge() {
521        let (mut context, person1, person2) = setup();
522
523        context
524            .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
525            .unwrap();
526        assert_eq!(
527            context
528                .get_edge::<EdgeType1>(person1, person2)
529                .unwrap()
530                .weight,
531            0.01
532        );
533        assert_eq!(
534            context.get_edges::<EdgeType1>(person1),
535            vec![Edge {
536                person: person1,
537                neighbor: person2,
538                weight: 0.01,
539                inner: 1
540            }]
541        );
542    }
543
544    #[test]
545    fn remove_edge() {
546        let (mut context, person1, person2) = setup();
547        // Check that we get an error if nothing has been added.
548
549        assert!(matches!(
550            context.remove_edge::<EdgeType1>(person1, person2),
551            Err(IxaError::IxaError(_))
552        ));
553        context
554            .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
555            .unwrap();
556        context.remove_edge::<EdgeType1>(person1, person2).unwrap();
557        assert!(context.get_edge::<EdgeType1>(person1, person2).is_none());
558        assert_eq!(context.get_edges::<EdgeType1>(person1).len(), 0);
559    }
560
561    #[test]
562    fn add_edge_bidi() {
563        let (mut context, person1, person2) = setup();
564
565        context
566            .add_edge_bidi::<EdgeType1>(person1, person2, 0.01, 1)
567            .unwrap();
568        assert_eq!(
569            context
570                .get_edge::<EdgeType1>(person1, person2)
571                .unwrap()
572                .weight,
573            0.01
574        );
575        assert_eq!(
576            context
577                .get_edge::<EdgeType1>(person2, person1)
578                .unwrap()
579                .weight,
580            0.01
581        );
582    }
583
584    #[test]
585    fn add_edge_different_weights() {
586        let (mut context, person1, person2) = setup();
587
588        context
589            .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
590            .unwrap();
591        context
592            .add_edge::<EdgeType1>(person2, person1, 0.02, 1)
593            .unwrap();
594        assert_eq!(
595            context
596                .get_edge::<EdgeType1>(person1, person2)
597                .unwrap()
598                .weight,
599            0.01
600        );
601        assert_eq!(
602            context
603                .get_edge::<EdgeType1>(person2, person1)
604                .unwrap()
605                .weight,
606            0.02
607        );
608    }
609
610    #[test]
611    fn get_matching_edges_weight() {
612        let (mut context, person1, person2) = setup();
613        let person3 = context.add_person((Age, 3)).unwrap();
614
615        context
616            .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
617            .unwrap();
618        context
619            .add_edge::<EdgeType1>(person1, person3, 0.03, 1)
620            .unwrap();
621        let edges =
622            context.get_matching_edges::<EdgeType1>(person1, |_context, edge| edge.weight > 0.01);
623        assert_eq!(edges.len(), 1);
624        assert_eq!(edges[0].person, person1);
625        assert_eq!(edges[0].neighbor, person3);
626    }
627
628    #[test]
629    fn get_matching_edges_inner() {
630        let (mut context, person1, person2) = setup();
631        let person3 = context.add_person((Age, 3)).unwrap();
632
633        context
634            .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
635            .unwrap();
636        context
637            .add_edge::<EdgeType1>(person1, person3, 0.03, 3)
638            .unwrap();
639        let edges =
640            context.get_matching_edges::<EdgeType1>(person1, |_context, edge| edge.inner == 3);
641        assert_eq!(edges.len(), 1);
642        assert_eq!(edges[0].person, person1);
643        assert_eq!(edges[0].neighbor, person3);
644    }
645
646    #[test]
647    fn get_matching_edges_person_property() {
648        let (mut context, person1, person2) = setup();
649        let person3 = context.add_person((Age, 3)).unwrap();
650
651        context
652            .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
653            .unwrap();
654        context
655            .add_edge::<EdgeType1>(person1, person3, 0.03, 3)
656            .unwrap();
657        let edges = context.get_matching_edges::<EdgeType1>(person1, |context, edge| {
658            context.match_person(edge.neighbor, (Age, 3))
659        });
660        assert_eq!(edges.len(), 1);
661        assert_eq!(edges[0].person, person1);
662        assert_eq!(edges[0].neighbor, person3);
663    }
664
665    #[test]
666    fn select_random_edge() {
667        define_rng!(NetworkTestRng);
668
669        let (mut context, person1, person2) = setup();
670        let person3 = context.add_person((Age, 3)).unwrap();
671        context.init_random(42);
672
673        context
674            .add_edge::<EdgeType1>(person1, person2, 0.01, 1)
675            .unwrap();
676        context
677            .add_edge::<EdgeType1>(person1, person3, 10_000_000.0, 3)
678            .unwrap();
679
680        let edge = context
681            .select_random_edge::<EdgeType1, _>(NetworkTestRng, person1)
682            .unwrap();
683        assert_eq!(edge.person, person1);
684        assert_eq!(edge.neighbor, person3);
685    }
686}