ixa/
web_api.rs

1use std::thread;
2
3use axum::extract::{Json, Path, State};
4use axum::http::StatusCode;
5use axum::response::Redirect;
6use axum::routing::{get, post};
7use axum::Router;
8use serde_json::json;
9use tokio::sync::{mpsc, oneshot};
10use tower_http::services::{ServeDir, ServeFile};
11
12use crate::context::Context;
13use crate::error::IxaError;
14use crate::external_api::{
15    breakpoint, global_properties, halt, next, people, population, r#continue, run_ext_api, time,
16    EmptyArgs,
17};
18use crate::rand::RngCore;
19use crate::{define_data_plugin, HashMap, HashMapExt, PluginContext};
20
21pub type WebApiHandler =
22    dyn Fn(&mut Context, serde_json::Value) -> Result<serde_json::Value, IxaError>;
23
24fn register_api_handler<
25    T: crate::external_api::ExtApi<Args = A>,
26    A: serde::de::DeserializeOwned,
27>(
28    dc: &mut ApiData,
29    name: &str,
30) {
31    dc.handlers.insert(
32        name.to_string(),
33        Box::new(
34            |context, args_json| -> Result<serde_json::Value, IxaError> {
35                let args: A = serde_json::from_value(args_json)?;
36                let retval: T::Retval = run_ext_api::<T>(context, &args)?;
37                Ok(serde_json::to_value(retval)?)
38            },
39        ),
40    );
41}
42
43struct ApiData {
44    receiver: mpsc::UnboundedReceiver<ApiRequest>,
45    handlers: HashMap<String, Box<WebApiHandler>>,
46}
47
48/// This wrapper method allows simultaneous mutable access to the [`ApiPlugin`] and `context` at the
49/// same time.
50pub(crate) fn handle_web_api_with_plugin(context: &mut Context) {
51    // We temporarily swap out the `ApiPlugin` so we can have simultaneous mutable access to
52    // it and to `context`. We swap it back in at the end of the function.
53    let mut data_container = context.get_data_mut(ApiPlugin).take().unwrap();
54
55    handle_web_api(context, &mut data_container);
56
57    // Restore the `ApiPlugin`
58    let saved_data_container = context.get_data_mut(ApiPlugin);
59    *saved_data_container = Some(data_container);
60}
61
62define_data_plugin!(ApiPlugin, Option<ApiData>, None);
63
64// Input to the API handler.
65struct ApiRequest {
66    cmd: String,
67    arguments: serde_json::Value,
68    // This channel is used to send the response.
69    rx: oneshot::Sender<ApiResponse>,
70}
71
72// Output of the API handler.
73struct ApiResponse {
74    code: StatusCode,
75    response: serde_json::Value,
76}
77
78#[derive(Clone)]
79struct ApiEndpointServer {
80    sender: mpsc::UnboundedSender<ApiRequest>,
81}
82
83async fn process_cmd(
84    State(state): State<ApiEndpointServer>,
85    Path(path): Path<String>,
86    Json(payload): Json<serde_json::Value>,
87) -> (StatusCode, Json<serde_json::Value>) {
88    let (tx, rx) = oneshot::channel::<ApiResponse>();
89    let _ = state.sender.send(ApiRequest {
90        cmd: path,
91        arguments: payload,
92        rx: tx,
93    });
94
95    match rx.await {
96        Ok(response) => (response.code, Json(response.response)),
97        _ => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({}))),
98    }
99}
100
101#[tokio::main]
102async fn serve(
103    sender: mpsc::UnboundedSender<ApiRequest>,
104    port: u16,
105    prefix: &str,
106    ready: oneshot::Sender<Result<String, IxaError>>,
107) {
108    let state = ApiEndpointServer { sender };
109
110    // run our app with Axum, listening on `port`
111    let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}")).await;
112    if listener.is_err() {
113        ready
114            .send(Err(IxaError::IxaError(format!("Could not bind to {port}"))))
115            .unwrap();
116        return;
117    }
118
119    // build our application with a route
120    let path = format!("{}/{}", env!("CARGO_MANIFEST_DIR"), "static/");
121    let static_assets_path = std::path::Path::new(&path);
122    let home_path = format!("/{prefix}/static/index.html");
123    let app = Router::new()
124        .route(&format!("/{prefix}/cmd/{{command}}"), post(process_cmd))
125        .route(
126            &format!("/{prefix}/"),
127            get(|| async move { Redirect::temporary(&home_path) }),
128        )
129        .nest_service(
130            &format!("/{prefix}/static/"),
131            ServeDir::new(static_assets_path),
132        )
133        .nest_service(
134            "/favicon.ico",
135            ServeFile::new_with_mime(
136                static_assets_path.join(std::path::Path::new("favicon.ico")),
137                &mime::IMAGE_PNG,
138            ),
139        )
140        .with_state(state);
141
142    // Notify the caller that we are ready.
143    ready
144        .send(Ok(format!("http://127.0.0.1:{port}/{prefix}/")))
145        .unwrap();
146    axum::serve(listener.unwrap(), app).await.unwrap();
147}
148
149/// Starts the Web API, pausing execution until instructed
150/// to continue.
151fn handle_web_api(context: &mut Context, api: &mut ApiData) {
152    while let Some(req) = api.receiver.blocking_recv() {
153        if req.cmd == "continue" {
154            let _ = req.rx.send(ApiResponse {
155                code: StatusCode::OK,
156                response: json!({}),
157            });
158            break;
159        }
160
161        let handler = api.handlers.get(&req.cmd);
162        if handler.is_none() {
163            let _ = req.rx.send(ApiResponse {
164                code: StatusCode::NOT_FOUND,
165                response: json!({
166                    "error" : format!("No command {}", req.cmd)
167                }),
168            });
169            continue;
170        }
171
172        let handler = handler.unwrap();
173        match handler(context, req.arguments.clone()) {
174            Err(err) => {
175                let _ = req.rx.send(ApiResponse {
176                    code: StatusCode::BAD_REQUEST,
177                    response: json!({
178                        "error" : err.to_string()
179                    }),
180                });
181                continue;
182            }
183            Ok(response) => {
184                let _ = req.rx.send(ApiResponse {
185                    code: StatusCode::OK,
186                    response,
187                });
188            }
189        }
190
191        // Special case the functions which require exiting
192        // the loop.
193        if req.cmd == "continue" {
194            return;
195        }
196    }
197}
198
199pub trait ContextWebApiExt: PluginContext {
200    /// Set up the Web API and start the Web server.
201    ///
202    /// # Errors
203    /// [`IxaError`] on failure to bind to `port`
204    fn setup_web_api(&mut self, port: u16) -> Result<String, IxaError> {
205        // TODO(cym4@cdc.gov): Check on the limits here.
206        let (api_to_ctx_send, api_to_ctx_recv) = mpsc::unbounded_channel::<ApiRequest>();
207
208        let data_container = self.get_data_mut(ApiPlugin);
209        if data_container.is_some() {
210            return Err(IxaError::IxaError(String::from(
211                "HTTP API already initialized",
212            )));
213        }
214
215        // Start the API server
216        let mut random: [u8; 16] = [0; 16];
217        let mut rng = rand::rng();
218        rng.fill_bytes(&mut random);
219        let secret = uuid::Builder::from_random_bytes(random)
220            .into_uuid()
221            .to_string();
222
223        let (ready_tx, ready_rx) = oneshot::channel::<Result<String, IxaError>>();
224        thread::spawn(move || serve(api_to_ctx_send, port, &secret, ready_tx));
225        let url = ready_rx.blocking_recv().unwrap()?;
226
227        let mut api_data = ApiData {
228            receiver: api_to_ctx_recv,
229            handlers: HashMap::new(),
230        };
231
232        register_api_handler::<breakpoint::Api, breakpoint::Args>(&mut api_data, "breakpoint");
233        register_api_handler::<r#continue::Api, EmptyArgs>(&mut api_data, "continue");
234        register_api_handler::<global_properties::Api, global_properties::Args>(
235            &mut api_data,
236            "global",
237        );
238        register_api_handler::<halt::Api, EmptyArgs>(&mut api_data, "halt");
239        register_api_handler::<next::Api, EmptyArgs>(&mut api_data, "next");
240        register_api_handler::<people::Api, people::Args>(&mut api_data, "people");
241        register_api_handler::<population::Api, EmptyArgs>(&mut api_data, "population");
242        register_api_handler::<time::Api, EmptyArgs>(&mut api_data, "time");
243        // Record the data container.
244        *data_container = Some(api_data);
245
246        Ok(url)
247    }
248
249    /// Schedule the simulation to pause at time t and listen for
250    /// requests from the Web API.
251    fn schedule_web_api(&mut self, t: f64) {
252        self.add_plan(t, handle_web_api_with_plugin);
253    }
254
255    /// Add an API point.
256    /// # Errors
257    /// [`IxaError`] when the Web API has not been set up yet.
258    fn add_web_api_handler(
259        &mut self,
260        name: &str,
261        handler: impl Fn(&mut Context, serde_json::Value) -> Result<serde_json::Value, IxaError>
262            + 'static,
263    ) -> Result<(), IxaError> {
264        let data_container = self.get_data_mut(ApiPlugin);
265
266        match data_container {
267            Some(dc) => {
268                dc.handlers.insert(name.to_string(), Box::new(handler));
269                Ok(())
270            }
271            None => Err(IxaError::IxaError(String::from("Web API not yet set up"))),
272        }
273    }
274}
275impl ContextWebApiExt for Context {}
276
277#[cfg(test)]
278mod tests {
279    use std::thread;
280
281    use reqwest::StatusCode;
282    use serde::Serialize;
283    use serde_json::json;
284
285    use super::ContextWebApiExt;
286    use crate::people::define_person_property;
287    use crate::{define_global_property, Context, ContextGlobalPropertiesExt, ContextPeopleExt};
288
289    define_global_property!(WebApiTestGlobal, String);
290    define_person_property!(Age, u8);
291    fn setup() -> (String, Context) {
292        let mut context = Context::new();
293        let url = context.setup_web_api(33339).unwrap();
294        context.schedule_web_api(0.0);
295        context
296            .set_global_property_value(WebApiTestGlobal, "foobar".to_string())
297            .unwrap();
298        context.add_person((Age, 1)).unwrap();
299        context.add_person((Age, 2)).unwrap();
300        context
301            .add_web_api_handler("external", |_context, args| Ok(args))
302            .unwrap();
303        (url, context)
304    }
305
306    // Continue the simulation. Note that we don't wait for a response
307    // because there is a race condition between sending the final
308    // response and program termination.
309    fn send_continue(url: &str) {
310        let client = reqwest::blocking::Client::new();
311        client
312            .post(format!("{url}cmd/continue"))
313            .json(&{})
314            .send()
315            .unwrap();
316    }
317
318    // Send a request and check the response.
319    fn send_request<T: Serialize + ?Sized>(url: &str, cmd: &str, req: &T) -> serde_json::Value {
320        let client = reqwest::blocking::Client::new();
321        let response = client
322            .post(format!("{url}cmd/{cmd}"))
323            .json(req)
324            .send()
325            .unwrap();
326        let status = response.status();
327        let response = response.json().unwrap();
328        println!("{response:?}");
329        assert_eq!(status, StatusCode::OK);
330        response
331    }
332
333    // Send a request and check the response.
334    fn send_request_text(url: &str, cmd: &str, req: String) -> reqwest::blocking::Response {
335        let client = reqwest::blocking::Client::new();
336        client
337            .post(format!("{url}cmd/{cmd}"))
338            .header("Content-Type", "application/json")
339            .body(req)
340            .send()
341            .unwrap()
342    }
343
344    // We do all of the tests in one test block to avoid having to
345    // start a lot of servers with different ports and having
346    // to manage that. This may not be ideal, but we're doing it for now.
347    // TODO(cym4@cdc.gov): Consider using some kind of static
348    // object to isolate the test cases.
349    #[allow(clippy::too_many_lines)]
350    #[test]
351    fn web_api_test() {
352        #[derive(Serialize)]
353        struct PopulationResponse {
354            population: usize,
355        }
356
357        // TODO(cym4@cdc.gov): If this thread fails
358        // then the test will stall instead of
359        // erroring out, but there's nothing that
360        // should fail here.
361        let (tx, rx) = std::sync::mpsc::channel::<String>();
362        let ctx_thread = thread::spawn(move || {
363            let (url, mut context) = setup();
364            let _ = tx.send(url);
365            context.execute();
366        });
367
368        let url = rx.recv().unwrap();
369        // Test the population API point.
370        let res = send_request(&url, "population", &json!({}));
371        assert_eq!(json!(&PopulationResponse { population: 2 }), res);
372
373        // Test the time API point.
374        let res = send_request(&url, "time", &json!({}));
375        assert_eq!(
376            json!(
377                { "time": 0.0 }
378            ),
379            res
380        );
381
382        // Test the global property list point. We can't do
383        // exact match because the return is every defined
384        // global property anywhere in the code.
385        let res = send_request(
386            &url,
387            "global",
388            &json!({
389                "Global": "List"
390            }),
391        );
392        let list = res.get("List").unwrap().as_array().unwrap();
393        let mut found = false;
394        for prop in list {
395            let prop_val = prop.as_str().unwrap();
396            if prop_val == "ixa.WebApiTestGlobal" {
397                found = true;
398                break;
399            }
400        }
401        assert!(found);
402
403        // Test the global property get API point.
404        let res = send_request(
405            &url,
406            "global",
407            &json!({
408                "Global": {
409                    "Get" : {
410                        "property" : "ixa.WebApiTestGlobal"
411                    }
412                }
413            }),
414        );
415        // The extra quotes here are because we internally JSONify.
416        // TODO(cym4@cdc.gov): Should we fix this internally?
417        assert_eq!(
418            res,
419            json!({
420                "Value": "\"foobar\""
421            })
422        );
423
424        // Next
425        let res = send_request(&url, "next", &json!({}));
426        assert_eq!(res, json!("Ok"));
427
428        // We test breakpoint commands as a group.
429        // Breakpoint set
430        let res = send_request(
431            &url,
432            "breakpoint",
433            &json!({ "Breakpoint" : { "Set" : { "time": 1.0, "console": false} } }),
434        );
435        assert_eq!(res, json!("Ok"));
436
437        let res = send_request(
438            &url,
439            "breakpoint",
440            &json!({ "Breakpoint" : { "Set" : { "time": 2.0, "console": false} } }),
441        );
442        assert_eq!(res, json!("Ok"));
443
444        let res = send_request(
445            &url,
446            "breakpoint",
447            &json!({ "Breakpoint" : { "Delete" : { "id": 0, "all": false} } }),
448        );
449        assert_eq!(res, json!("Ok"));
450
451        // Breakpoint list
452        let res = send_request(&url, "breakpoint", &json!({"Breakpoint": "List"}));
453        assert_eq!(
454            res,
455            json!({"List" : [
456                "1: t=2 (First)"
457            ]}
458            )
459        );
460
461        let res = send_request(
462            &url,
463            "breakpoint",
464            &json!({ "Breakpoint" : { "Delete" : { "all": true, } } }),
465        );
466        assert_eq!(res, json!("Ok"));
467
468        // Check list again
469        let res = send_request(&url, "breakpoint", &json!({"Breakpoint": "List"}));
470        assert_eq!(
471            res,
472            json!({"List" : [/* empty list */ ]}
473            )
474        );
475
476        let res = send_request(&url, "breakpoint", &json!({ "Breakpoint" : "Disable" }));
477        assert_eq!(res, json!("Ok"));
478
479        let res = send_request(&url, "breakpoint", &json!({ "Breakpoint" : "Enable" }));
480        assert_eq!(res, json!("Ok"));
481
482        // Person properties API.
483        let res = send_request(
484            &url,
485            "people",
486            &json!({
487                "People" : {
488                    "Get" : {
489                        "person_id": 0,
490                        "property" : "Age"
491                    }
492                }
493            }),
494        );
495        assert_eq!(
496            res,
497            json!({"Properties" : [
498                ( "Age",  "1" )
499            ]}
500            )
501        );
502
503        // List properties.
504        let res = send_request(
505            &url,
506            "people",
507            &json!({
508                "People" : "Properties"
509            }),
510        );
511        assert_eq!(
512            res,
513            json!({"PropertyNames" : [
514                "Age"
515            ]}
516            )
517        );
518
519        // Tabulate API.
520        let res = send_request(
521            &url,
522            "people",
523            &json!({
524                "People" : {
525                    "Tabulate" : {
526                        "properties": ["Age"]
527                    }
528                }
529            }),
530        );
531
532        // This is a hack to deal with these arriving in
533        // arbitrary order.
534        assert!(
535            (res == json!({"Tabulated" : [
536                [{ "Age" :  "1" }, 1],
537                [{ "Age" :  "2" }, 1]
538            ]})) || (res
539                == json!({"Tabulated" : [
540                    [{ "Age" :  "2" }, 1],
541                    [{ "Age" :  "1" }, 1]
542                ]})),
543        );
544
545        // Valid JSON but wrong type.
546        let res = send_request_text(
547            &url,
548            "breakpoint",
549            String::from("{\"Set\": {\"time\" : \"invalid\"}}"),
550        );
551        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
552
553        // Invalid JSON.
554        let res = send_request_text(&url, "next", String::from("{]"));
555        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
556
557        // A generic externally added API handler
558        let res = send_request(&url, "external", &json!({"External": [1]}));
559        assert_eq!(res, json!({"External": [1]}));
560
561        // Test continue and make sure that the context
562        // exits.
563        send_continue(&url);
564        let _ = ctx_thread.join();
565    }
566}