ixa/profiling/
computed_statistic.rs

1use std::fmt::Display;
2
3use serde::ser::{Serialize, Serializer};
4
5#[cfg(feature = "profiling")]
6use super::profiling_data;
7use super::ProfilingData;
8
9pub type CustomStatisticComputer<T> = Box<dyn (Fn(&ProfilingData) -> Option<T>) + Send + Sync>;
10pub type CustomStatisticPrinter<T> = Box<dyn Fn(T) + Send + Sync>;
11
12pub(super) enum ComputedStatisticFunctions {
13    USize {
14        computer: CustomStatisticComputer<usize>,
15        printer: CustomStatisticPrinter<usize>,
16    },
17    Int {
18        computer: CustomStatisticComputer<i64>,
19        printer: CustomStatisticPrinter<i64>,
20    },
21    Float {
22        computer: CustomStatisticComputer<f64>,
23        printer: CustomStatisticPrinter<f64>,
24    },
25}
26
27impl ComputedStatisticFunctions {
28    /// A type erased way to compute a statistic.
29    pub(super) fn compute(&self, container: &ProfilingData) -> Option<ComputedValue> {
30        match self {
31            ComputedStatisticFunctions::USize { computer, .. } => {
32                computer(container).map(ComputedValue::USize)
33            }
34            ComputedStatisticFunctions::Int { computer, .. } => {
35                computer(container).map(ComputedValue::Int)
36            }
37            ComputedStatisticFunctions::Float { computer, .. } => {
38                computer(container).map(ComputedValue::Float)
39            }
40        }
41    }
42
43    /// A type erased way to print a statistic.
44    pub(super) fn print(&self, value: ComputedValue) {
45        match value {
46            ComputedValue::USize(value) => {
47                let ComputedStatisticFunctions::USize { printer, .. } = self else {
48                    unreachable!()
49                };
50                (printer)(value);
51            }
52            ComputedValue::Int(value) => {
53                let ComputedStatisticFunctions::Int { printer, .. } = self else {
54                    unreachable!()
55                };
56                (printer)(value);
57            }
58            ComputedValue::Float(value) => {
59                let ComputedStatisticFunctions::Float { printer, .. } = self else {
60                    unreachable!()
61                };
62                (printer)(value);
63            }
64        }
65    }
66}
67
68pub(super) struct ComputedStatistic {
69    /// The label used for the statistic in the JSON report.
70    pub label: &'static str,
71    /// Description of the statistic. Used in the JSON report.
72    pub description: &'static str,
73    /// The computed value of the statistic.
74    pub value: Option<ComputedValue>,
75    /// The two functions used to compute the statistic and to print it to the console.
76    pub functions: ComputedStatisticFunctions,
77}
78
79// This trick makes it so client code can _use_ `ComputableType` but not _implement_ it.
80mod sealed {
81    pub(super) trait SealedComputableType {}
82}
83#[allow(private_bounds)]
84pub trait ComputableType: sealed::SealedComputableType
85where
86    Self: Sized,
87{
88    // This method is only callable from within this crate.
89    #[allow(private_interfaces)]
90    fn new_functions(
91        computer: CustomStatisticComputer<Self>,
92        printer: CustomStatisticPrinter<Self>,
93    ) -> ComputedStatisticFunctions;
94}
95impl sealed::SealedComputableType for usize {}
96impl ComputableType for usize {
97    #[allow(private_interfaces)]
98    fn new_functions(
99        computer: CustomStatisticComputer<Self>,
100        printer: CustomStatisticPrinter<Self>,
101    ) -> ComputedStatisticFunctions {
102        ComputedStatisticFunctions::USize { computer, printer }
103    }
104}
105impl sealed::SealedComputableType for i64 {}
106impl ComputableType for i64 {
107    #[allow(private_interfaces)]
108    fn new_functions(
109        computer: CustomStatisticComputer<Self>,
110        printer: CustomStatisticPrinter<Self>,
111    ) -> ComputedStatisticFunctions {
112        ComputedStatisticFunctions::Int { computer, printer }
113    }
114}
115impl sealed::SealedComputableType for f64 {}
116impl ComputableType for f64 {
117    #[allow(private_interfaces)]
118    fn new_functions(
119        computer: CustomStatisticComputer<Self>,
120        printer: CustomStatisticPrinter<Self>,
121    ) -> ComputedStatisticFunctions {
122        ComputedStatisticFunctions::Float { computer, printer }
123    }
124}
125
126/// The computed value of a statistic. The "computer" returns a value of this type.
127#[derive(Copy, Clone, PartialEq, Debug)]
128pub(super) enum ComputedValue {
129    USize(usize),
130    Int(i64),
131    Float(f64),
132}
133
134impl Serialize for ComputedValue {
135    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
136        match self {
137            ComputedValue::USize(v) => serializer.serialize_u64(*v as u64),
138            ComputedValue::Int(v) => serializer.serialize_i64(*v),
139            ComputedValue::Float(v) => serializer.serialize_f64(*v),
140        }
141    }
142}
143
144impl Display for ComputedValue {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        match self {
147            ComputedValue::USize(value) => {
148                write!(f, "{}", value)
149            }
150
151            ComputedValue::Int(value) => {
152                write!(f, "{}", value)
153            }
154
155            ComputedValue::Float(value) => {
156                write!(f, "{}", value)
157            }
158        }
159    }
160}
161
162#[cfg(feature = "profiling")]
163pub fn add_computed_statistic<T: ComputableType>(
164    label: &'static str,
165    description: &'static str,
166    computer: CustomStatisticComputer<T>,
167    printer: CustomStatisticPrinter<T>,
168) {
169    let mut container = profiling_data();
170    container.add_computed_statistic(label, description, computer, printer);
171}
172#[cfg(not(feature = "profiling"))]
173pub fn add_computed_statistic<T: ComputableType>(
174    _label: &'static str,
175    _description: &'static str,
176    _computer: CustomStatisticComputer<T>,
177    _printer: CustomStatisticPrinter<T>,
178) {
179}
180
181#[cfg(all(test, feature = "profiling"))]
182mod tests {
183    use std::sync::atomic::{AtomicBool, Ordering};
184
185    use super::*;
186    use crate::profiling::{get_profiling_data, increment_named_count};
187
188    #[test]
189    fn test_computed_statistic_usize() {
190        increment_named_count("comp_stat_events_usize_test");
191        increment_named_count("comp_stat_events_usize_test");
192        increment_named_count("comp_stat_events_usize_test");
193
194        add_computed_statistic::<usize>(
195            "comp_stat_total_events",
196            "Total number of events",
197            Box::new(|data| data.get_named_count("comp_stat_events_usize_test")),
198            Box::new(|value| println!("Total events: {}", value)),
199        );
200
201        let data = get_profiling_data();
202
203        let stat = data
204            .computed_statistics
205            .iter()
206            .find_map(|s| {
207                s.as_ref()
208                    .filter(|stat| stat.label == "comp_stat_total_events")
209            })
210            .expect("total_events statistic not found");
211        let computed = stat.functions.compute(&data);
212        assert_eq!(computed, Some(ComputedValue::USize(3)));
213    }
214
215    #[test]
216    fn test_computed_statistic_i64() {
217        increment_named_count("comp_stat_positive_i64_test");
218        increment_named_count("comp_stat_positive_i64_test");
219        increment_named_count("comp_stat_negative_i64_test");
220
221        add_computed_statistic::<i64>(
222            "comp_stat_difference",
223            "Difference between positive and negative",
224            Box::new(|data| {
225                let pos = data
226                    .get_named_count("comp_stat_positive_i64_test")
227                    .unwrap_or(0) as i64;
228                let neg = data
229                    .get_named_count("comp_stat_negative_i64_test")
230                    .unwrap_or(0) as i64;
231                Some(pos - neg)
232            }),
233            Box::new(|value| println!("Difference: {}", value)),
234        );
235
236        let data = get_profiling_data();
237        let stat = data
238            .computed_statistics
239            .iter()
240            .find_map(|s| {
241                s.as_ref()
242                    .filter(|stat| stat.label == "comp_stat_difference")
243            })
244            .expect("difference statistic not found");
245        let computed = stat.functions.compute(&data);
246        assert_eq!(computed, Some(ComputedValue::Int(1)));
247    }
248
249    #[test]
250    fn test_computed_statistic_f64() {
251        {
252            let mut data = get_profiling_data();
253            *data
254                .counts
255                .entry("comp_stat_successes_f64_test")
256                .or_insert(0) += 3;
257            *data.counts.entry("comp_stat_total_f64_test").or_insert(0) += 4;
258            data.add_computed_statistic::<f64>(
259                "comp_stat_success_rate",
260                "Success rate as percentage",
261                Box::new(|data| {
262                    let successes = data.get_named_count("comp_stat_successes_f64_test")? as f64;
263                    let total = data.get_named_count("comp_stat_total_f64_test")? as f64;
264                    Some(successes / total * 100.0)
265                }),
266                Box::new(|value| println!("Success rate: {:.2}%", value)),
267            );
268
269            let stat = data
270                .computed_statistics
271                .iter()
272                .find_map(|s| {
273                    s.as_ref()
274                        .filter(|stat| stat.label == "comp_stat_success_rate")
275                })
276                .expect("comp_stat_success_rate statistic not found");
277            let computed = stat.functions.compute(&data);
278            if let Some(ComputedValue::Float(value)) = computed {
279                assert!((value - 75.0).abs() < 0.01);
280            } else {
281                panic!("Expected Float value, got {:?}", computed);
282            }
283        }
284    }
285
286    #[test]
287    fn test_computed_statistic_returns_none() {
288        add_computed_statistic::<usize>(
289            "comp_stat_missing_data",
290            "Statistic with missing data",
291            Box::new(|data| data.get_named_count("comp_stat_nonexistent")),
292            Box::new(|value| println!("Value: {}", value)),
293        );
294
295        let data = get_profiling_data();
296        let stat = data
297            .computed_statistics
298            .iter()
299            .find_map(|s| {
300                s.as_ref()
301                    .filter(|stat| stat.label == "comp_stat_missing_data")
302            })
303            .expect("comp_stat_missing_data statistic not found");
304        let computed = stat.functions.compute(&data);
305        assert_eq!(computed, None);
306    }
307
308    #[test]
309    fn test_computed_value_display() {
310        let usize_val = ComputedValue::USize(42);
311        assert_eq!(format!("{}", usize_val), "42");
312
313        let int_val = ComputedValue::Int(-100);
314        assert_eq!(format!("{}", int_val), "-100");
315
316        let float_val = ComputedValue::Float(std::f64::consts::PI);
317        assert_eq!(format!("{}", float_val), "3.141592653589793");
318    }
319
320    #[test]
321    fn test_computed_statistic_print_functions() {
322        static PRINTED: AtomicBool = AtomicBool::new(false);
323
324        // Reset the static variable
325        PRINTED.store(false, Ordering::SeqCst);
326
327        increment_named_count("comp_stat_test_print_func");
328
329        add_computed_statistic::<usize>(
330            "comp_stat_test_stat",
331            "Test statistic",
332            Box::new(|data| data.get_named_count("comp_stat_test_print_func")),
333            Box::new(|_value| {
334                PRINTED.store(true, Ordering::SeqCst);
335            }),
336        );
337
338        let data = get_profiling_data();
339        let stat = data
340            .computed_statistics
341            .iter()
342            .find_map(|s| {
343                s.as_ref()
344                    .filter(|stat| stat.label == "comp_stat_test_stat")
345            })
346            .expect("test_stat statistic not found");
347        let value = stat.functions.compute(&data).unwrap();
348        stat.functions.print(value);
349
350        assert!(PRINTED.load(Ordering::SeqCst));
351    }
352}