1#[cfg(feature = "profiling")]
2use std::sync::{Mutex, MutexGuard, OnceLock};
3use std::time::{Duration, Instant};
4
5use super::computed_statistic::ComputableType;
6use super::Span;
7#[cfg(feature = "profiling")]
8use super::{
9 ComputedStatistic, ComputedValue, CustomStatisticComputer, CustomStatisticPrinter,
10 TOTAL_MEASURED,
11};
12use crate::HashMap;
13
14#[cfg(feature = "profiling")]
15static PROFILING_DATA: OnceLock<Mutex<ProfilingData>> = OnceLock::new();
16
17#[cfg(feature = "profiling")]
19pub(super) fn profiling_data() -> MutexGuard<'static, ProfilingData> {
20 PROFILING_DATA
21 .get_or_init(|| Mutex::new(ProfilingData::new()))
22 .lock()
23 .unwrap()
24}
25
26#[derive(Default)]
27pub struct ProfilingData {
28 pub start_time: Option<Instant>,
29 pub counts: HashMap<&'static str, usize>,
30 pub spans: HashMap<&'static str, (Duration, usize)>,
33 #[cfg(feature = "profiling")]
40 pub(super) open_span_count: usize,
41 #[cfg(feature = "profiling")]
42 pub(super) coverage: Option<Instant>,
43 #[cfg(feature = "profiling")]
46 pub(super) computed_statistics: Vec<Option<ComputedStatistic>>,
47}
48
49#[cfg(feature = "profiling")]
50impl ProfilingData {
51 fn new() -> Self {
53 Self::default()
54 }
55
56 pub(super) fn increment_named_count(&mut self, key: &'static str) {
57 self.init_start_time();
58 self.counts.entry(key).and_modify(|v| *v += 1).or_insert(1);
59 }
60
61 pub(super) fn get_named_count(&self, key: &'static str) -> Option<usize> {
62 self.counts.get(&key).copied()
63 }
64
65 fn init_start_time(&mut self) {
66 if self.start_time.is_none() {
67 self.start_time = Some(Instant::now());
68 }
69 }
70
71 fn open_span(&mut self, label: &'static str) -> Span {
72 self.init_start_time();
73 if self.open_span_count == 0 {
74 self.coverage = Some(Instant::now());
76 }
77 self.open_span_count += 1;
78 Span::new(label)
79 }
80
81 pub(super) fn close_span(&mut self, span: &Span) {
83 if self.open_span_count > 0 {
84 self.open_span_count -= 1;
85 if self.open_span_count == 0 {
86 let coverage = self.coverage.take().unwrap();
89 self.close_span_without_coverage(TOTAL_MEASURED, coverage.elapsed());
90 }
91 }
92 self.close_span_without_coverage(span.label, span.start_time.elapsed());
94 }
95
96 fn close_span_without_coverage(&mut self, label: &'static str, elapsed: Duration) {
98 self.spans
99 .entry(label)
100 .and_modify(|(time, count)| {
101 *time += elapsed;
102 *count += 1;
103 })
104 .or_insert((elapsed, 1));
105 }
106
107 pub(super) fn get_named_counts_table(&self) -> Vec<(String, usize, f64)> {
110 let elapsed = match self.start_time {
111 Some(start_time) => start_time.elapsed().as_secs_f64(),
112 None => 0.0,
113 };
114 let mut rows = vec![];
115
116 for (key, count) in &self.counts {
118 #[allow(clippy::cast_precision_loss)]
119 let rate = (*count as f64) / elapsed; rows.push(((*key).to_string(), *count, rate));
122 }
123
124 rows
125 }
126
127 pub(super) fn get_named_spans_table(&self) -> Vec<(String, usize, Duration, f64)> {
130 let elapsed = match self.start_time {
131 Some(start_time) => start_time.elapsed().as_secs_f64(),
132 None => 0.0,
133 };
134
135 let mut rows = vec![];
136
137 for (&label, &(duration, count)) in self.spans.iter().filter(|(k, _)| *k != &TOTAL_MEASURED)
139 {
140 rows.push((
141 label.to_string(),
142 count,
143 duration,
144 duration.as_secs_f64() / elapsed * 100.0,
145 ));
146 }
147
148 if let Some(&(duration, count)) = self.spans.get(&TOTAL_MEASURED) {
150 rows.push((
151 TOTAL_MEASURED.to_string(),
152 count,
153 duration,
154 duration.as_secs_f64() / elapsed * 100.0,
155 ));
156 }
157
158 rows
159 }
160
161 pub(super) fn add_computed_statistic<T: ComputableType>(
162 &mut self,
163 label: &'static str,
164 description: &'static str,
165 computer: CustomStatisticComputer<T>,
166 printer: CustomStatisticPrinter<T>,
167 ) {
168 let computed_stat = ComputedStatistic {
169 label,
170 description,
171 value: None,
172 functions: T::new_functions(computer, printer),
173 };
174 self.computed_statistics.push(Some(computed_stat));
175 }
176}
177
178#[cfg(feature = "profiling")]
179pub fn increment_named_count(key: &'static str) {
180 let mut container = profiling_data();
181 container.increment_named_count(key);
182}
183
184#[cfg(not(feature = "profiling"))]
185pub fn increment_named_count(_key: &'static str) {}
186
187#[cfg(feature = "profiling")]
188pub fn open_span(label: &'static str) -> Span {
189 let mut container = profiling_data();
190 container.open_span(label)
191}
192
193#[cfg(not(feature = "profiling"))]
194pub fn open_span(label: &'static str) -> Span {
195 Span::new(label)
196}
197
198pub fn close_span(_span: Span) {
201 }
204
205#[cfg(all(test, feature = "profiling"))]
206mod tests {
207 use std::time::Duration;
208
209 use super::*;
210 use crate::profiling::{get_profiling_data, increment_named_count};
211
212 #[test]
213 fn test_span_basic() {
214 {
215 let _span = open_span("test_operation_basic");
216 std::thread::sleep(Duration::from_millis(10));
217 }
218
219 let data = get_profiling_data();
220 assert!(data.spans.contains_key("test_operation_basic"));
221 let (duration, count) = data.spans.get("test_operation_basic").unwrap();
222 assert_eq!(*count, 1);
223 assert!(*duration >= Duration::from_millis(10));
224 }
225
226 #[test]
227 fn test_span_multiple_calls() {
228 for _ in 0..5 {
229 let _span = open_span("repeated_operation_multi_test");
230 std::thread::sleep(Duration::from_millis(5));
231 }
232
233 let data = get_profiling_data();
234 let (duration, count) = data.spans.get("repeated_operation_multi_test").unwrap();
235 assert!(*count >= 4, "expected at least 4 drops, got {}", count);
236 assert!(*duration >= Duration::from_millis(15));
237 }
238
239 #[test]
240 fn test_span_explicit_close() {
241 let span = open_span("explicit_close_test");
242 std::thread::sleep(Duration::from_millis(10));
243 close_span(span);
244
245 let data = get_profiling_data();
246 assert!(data.spans.contains_key("explicit_close_test"));
247 }
248
249 #[test]
250 fn test_span_nesting() {
251 {
252 let _outer = open_span("outer_nesting_test");
253 std::thread::sleep(Duration::from_millis(5));
254 {
255 let _inner = open_span("inner_nesting_test");
256 std::thread::sleep(Duration::from_millis(5));
257 }
258 std::thread::sleep(Duration::from_millis(5));
259 }
260
261 let data = get_profiling_data();
262 assert!(data.spans.contains_key("outer_nesting_test"));
263 assert!(data.spans.contains_key("inner_nesting_test"));
264
265 let (outer_duration, _) = data.spans.get("outer_nesting_test").unwrap();
266 let (inner_duration, _) = data.spans.get("inner_nesting_test").unwrap();
267
268 assert!(*outer_duration > *inner_duration);
269 assert!(*outer_duration >= Duration::from_millis(15));
270 assert!(*inner_duration >= Duration::from_millis(5));
271 }
272
273 #[test]
274 fn test_total_measured_span() {
275 {
276 let _span1 = open_span("operation1_total_measured");
277 std::thread::sleep(Duration::from_millis(10));
278 }
279
280 std::thread::sleep(Duration::from_millis(5));
281
282 {
283 let _span2 = open_span("operation2_total_measured");
284 std::thread::sleep(Duration::from_millis(10));
285 }
286
287 let data = get_profiling_data();
288
289 assert!(data.spans.contains_key("operation1_total_measured"));
291 assert!(data.spans.contains_key("operation2_total_measured"));
292
293 let (duration1, _) = data.spans.get("operation1_total_measured").unwrap();
294 let (duration2, _) = data.spans.get("operation2_total_measured").unwrap();
295
296 assert!(*duration1 >= Duration::from_millis(10));
297 assert!(*duration2 >= Duration::from_millis(10));
298 }
299
300 #[test]
301 fn test_get_named_counts_table() {
302 let container_start = {
304 let data = get_profiling_data();
305 data.start_time
306 };
307 increment_named_count("event_a_counts_table_test");
308 increment_named_count("event_a_counts_table_test");
309 increment_named_count("event_b_counts_table_test");
310
311 std::thread::sleep(Duration::from_millis(100));
313
314 let elapsed = if let Some(start_time) = container_start {
316 start_time.elapsed().as_secs_f64()
317 } else {
318 0.1
322 };
323
324 let data = get_profiling_data();
325 let table = data.get_named_counts_table();
326
327 let event_a = table
329 .iter()
330 .find(|(label, _, _)| label == "event_a_counts_table_test");
331 assert!(event_a.is_some());
332 let (_, count, rate) = event_a.unwrap();
333 assert_eq!(*count, 2);
334 let expected_rate = 2.0 / elapsed;
336 println!(
337 "Rate: {}, Expected: {}, Elapsed: {}",
338 rate, expected_rate, elapsed
339 );
340 assert!(*rate > expected_rate * 0.9 && *rate < expected_rate * 1.1);
342
343 let event_b = table
344 .iter()
345 .find(|(label, _, _)| label == "event_b_counts_table_test");
346 assert!(event_b.is_some());
347 let (_, count, _) = event_b.unwrap();
348 assert_eq!(*count, 1);
349 }
350
351 #[test]
352 fn test_get_named_spans_table() {
353 let container_start = {
355 let data = get_profiling_data();
356 data.start_time
357 };
358
359 {
360 let _span = open_span("test_span_table");
361 std::thread::sleep(Duration::from_millis(100));
362 }
363
364 std::thread::sleep(Duration::from_millis(100));
365
366 let data = get_profiling_data();
367 let table = data.get_named_spans_table();
368
369 assert!(table.len() >= 2);
370
371 let test_span = table
372 .iter()
373 .find(|(label, _, _, _)| label == "test_span_table");
374 assert!(test_span.is_some());
375
376 let last = table.last().unwrap();
377 assert_eq!(last.0, "Total Measured");
378
379 let (_, _, _, percent) = test_span.unwrap();
380 let elapsed = if let Some(start_time) = container_start {
382 start_time.elapsed().as_secs_f64()
383 } else {
384 0.2
386 };
387 let (duration, _) = data.spans.get("test_span_table").unwrap();
388 let expected_percent = duration.as_secs_f64() / elapsed * 100.0;
389 assert!((*percent - expected_percent).abs() < 5.0);
391 }
392}