ixa/
web_api.rs

1use crate::context::Context;
2use crate::error::IxaError;
3use crate::external_api::{
4    breakpoint, global_properties, halt, next, people, population, r#continue, run_ext_api, time,
5    EmptyArgs,
6};
7use crate::rand::RngCore;
8use crate::{define_data_plugin, PluginContext};
9use crate::{HashMap, HashMapExt};
10use axum::extract::{Json, Path, State};
11use axum::response::Redirect;
12use axum::routing::get;
13use axum::{http::StatusCode, routing::post, Router};
14use serde_json::json;
15use std::thread;
16use tokio::sync::mpsc;
17use tokio::sync::oneshot;
18use tower_http::services::{ServeDir, ServeFile};
19
20pub type WebApiHandler =
21    dyn Fn(&mut Context, serde_json::Value) -> Result<serde_json::Value, IxaError>;
22
23fn register_api_handler<
24    T: crate::external_api::ExtApi<Args = A>,
25    A: serde::de::DeserializeOwned,
26>(
27    dc: &mut ApiData,
28    name: &str,
29) {
30    dc.handlers.insert(
31        name.to_string(),
32        Box::new(
33            |context, args_json| -> Result<serde_json::Value, IxaError> {
34                let args: A = serde_json::from_value(args_json)?;
35                let retval: T::Retval = run_ext_api::<T>(context, &args)?;
36                Ok(serde_json::to_value(retval)?)
37            },
38        ),
39    );
40}
41
42struct ApiData {
43    receiver: mpsc::UnboundedReceiver<ApiRequest>,
44    handlers: HashMap<String, Box<WebApiHandler>>,
45}
46
47/// This wrapper method allows simultaneous mutable access to the `ApiPlugin` and `context` at the
48/// same time.
49pub(crate) fn handle_web_api_with_plugin(context: &mut Context) {
50    // We temporarily swap out the `ApiPlugin` so we can have simultaneous mutable access to
51    // it and to `context`. We swap it back in at the end of the function.
52    let mut data_container = context.get_data_mut(ApiPlugin).take().unwrap();
53
54    handle_web_api(context, &mut data_container);
55
56    // Restore the `ApiPlugin`
57    let saved_data_container = context.get_data_mut(ApiPlugin);
58    *saved_data_container = Some(data_container);
59}
60
61define_data_plugin!(ApiPlugin, Option<ApiData>, None);
62
63// Input to the API handler.
64struct ApiRequest {
65    cmd: String,
66    arguments: serde_json::Value,
67    // This channel is used to send the response.
68    rx: oneshot::Sender<ApiResponse>,
69}
70
71// Output of the API handler.
72struct ApiResponse {
73    code: StatusCode,
74    response: serde_json::Value,
75}
76
77#[derive(Clone)]
78struct ApiEndpointServer {
79    sender: mpsc::UnboundedSender<ApiRequest>,
80}
81
82async fn process_cmd(
83    State(state): State<ApiEndpointServer>,
84    Path(path): Path<String>,
85    Json(payload): Json<serde_json::Value>,
86) -> (StatusCode, Json<serde_json::Value>) {
87    let (tx, rx) = oneshot::channel::<ApiResponse>();
88    let _ = state.sender.send(ApiRequest {
89        cmd: path,
90        arguments: payload,
91        rx: tx,
92    });
93
94    match rx.await {
95        Ok(response) => (response.code, Json(response.response)),
96        _ => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({}))),
97    }
98}
99
100#[tokio::main]
101async fn serve(
102    sender: mpsc::UnboundedSender<ApiRequest>,
103    port: u16,
104    prefix: &str,
105    ready: oneshot::Sender<Result<String, IxaError>>,
106) {
107    let state = ApiEndpointServer { sender };
108
109    // run our app with Axum, listening on `port`
110    let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}")).await;
111    if listener.is_err() {
112        ready
113            .send(Err(IxaError::IxaError(format!("Could not bind to {port}"))))
114            .unwrap();
115        return;
116    }
117
118    // build our application with a route
119    let path = format!("{}/{}", env!("CARGO_MANIFEST_DIR"), "static/");
120    let static_assets_path = std::path::Path::new(&path);
121    let home_path = format!("/{prefix}/static/index.html");
122    let app = Router::new()
123        .route(&format!("/{prefix}/cmd/{{command}}"), post(process_cmd))
124        .route(
125            &format!("/{prefix}/"),
126            get(|| async move { Redirect::temporary(&home_path) }),
127        )
128        .nest_service(
129            &format!("/{prefix}/static/"),
130            ServeDir::new(static_assets_path),
131        )
132        .nest_service(
133            "/favicon.ico",
134            ServeFile::new_with_mime(
135                static_assets_path.join(std::path::Path::new("favicon.ico")),
136                &mime::IMAGE_PNG,
137            ),
138        )
139        .with_state(state);
140
141    // Notify the caller that we are ready.
142    ready
143        .send(Ok(format!("http://127.0.0.1:{port}/{prefix}/")))
144        .unwrap();
145    axum::serve(listener.unwrap(), app).await.unwrap();
146}
147
148/// Starts the Web API, pausing execution until instructed
149/// to continue.
150fn handle_web_api(context: &mut Context, api: &mut ApiData) {
151    while let Some(req) = api.receiver.blocking_recv() {
152        if req.cmd == "continue" {
153            let _ = req.rx.send(ApiResponse {
154                code: StatusCode::OK,
155                response: json!({}),
156            });
157            break;
158        }
159
160        let handler = api.handlers.get(&req.cmd);
161        if handler.is_none() {
162            let _ = req.rx.send(ApiResponse {
163                code: StatusCode::NOT_FOUND,
164                response: json!({
165                    "error" : format!("No command {}", req.cmd)
166                }),
167            });
168            continue;
169        }
170
171        let handler = handler.unwrap();
172        match handler(context, req.arguments.clone()) {
173            Err(err) => {
174                let _ = req.rx.send(ApiResponse {
175                    code: StatusCode::BAD_REQUEST,
176                    response: json!({
177                        "error" : err.to_string()
178                    }),
179                });
180                continue;
181            }
182            Ok(response) => {
183                let _ = req.rx.send(ApiResponse {
184                    code: StatusCode::OK,
185                    response,
186                });
187            }
188        }
189
190        // Special case the functions which require exiting
191        // the loop.
192        if req.cmd == "continue" {
193            return;
194        }
195    }
196}
197
198pub trait ContextWebApiExt: PluginContext {
199    /// Set up the Web API and start the Web server.
200    ///
201    /// # Errors
202    /// `IxaError` on failure to bind to `port`
203    fn setup_web_api(&mut self, port: u16) -> Result<String, IxaError> {
204        // TODO(cym4@cdc.gov): Check on the limits here.
205        let (api_to_ctx_send, api_to_ctx_recv) = mpsc::unbounded_channel::<ApiRequest>();
206
207        let data_container = self.get_data_mut(ApiPlugin);
208        if data_container.is_some() {
209            return Err(IxaError::IxaError(String::from(
210                "HTTP API already initialized",
211            )));
212        }
213
214        // Start the API server
215        let mut random: [u8; 16] = [0; 16];
216        let mut rng = rand::rng();
217        rng.fill_bytes(&mut random);
218        let secret = uuid::Builder::from_random_bytes(random)
219            .into_uuid()
220            .to_string();
221
222        let (ready_tx, ready_rx) = oneshot::channel::<Result<String, IxaError>>();
223        thread::spawn(move || serve(api_to_ctx_send, port, &secret, ready_tx));
224        let url = ready_rx.blocking_recv().unwrap()?;
225
226        let mut api_data = ApiData {
227            receiver: api_to_ctx_recv,
228            handlers: HashMap::new(),
229        };
230
231        register_api_handler::<breakpoint::Api, breakpoint::Args>(&mut api_data, "breakpoint");
232        register_api_handler::<r#continue::Api, EmptyArgs>(&mut api_data, "continue");
233        register_api_handler::<global_properties::Api, global_properties::Args>(
234            &mut api_data,
235            "global",
236        );
237        register_api_handler::<halt::Api, EmptyArgs>(&mut api_data, "halt");
238        register_api_handler::<next::Api, EmptyArgs>(&mut api_data, "next");
239        register_api_handler::<people::Api, people::Args>(&mut api_data, "people");
240        register_api_handler::<population::Api, EmptyArgs>(&mut api_data, "population");
241        register_api_handler::<time::Api, EmptyArgs>(&mut api_data, "time");
242        // Record the data container.
243        *data_container = Some(api_data);
244
245        Ok(url)
246    }
247
248    /// Schedule the simulation to pause at time t and listen for
249    /// requests from the Web API.
250    fn schedule_web_api(&mut self, t: f64) {
251        self.add_plan(t, handle_web_api_with_plugin);
252    }
253
254    /// Add an API point.
255    /// # Errors
256    /// `IxaError` when the Web API has not been set up yet.
257    fn add_web_api_handler(
258        &mut self,
259        name: &str,
260        handler: impl Fn(&mut Context, serde_json::Value) -> Result<serde_json::Value, IxaError>
261            + 'static,
262    ) -> Result<(), IxaError> {
263        let data_container = self.get_data_mut(ApiPlugin);
264
265        match data_container {
266            Some(dc) => {
267                dc.handlers.insert(name.to_string(), Box::new(handler));
268                Ok(())
269            }
270            None => Err(IxaError::IxaError(String::from("Web API not yet set up"))),
271        }
272    }
273}
274impl ContextWebApiExt for Context {}
275
276#[cfg(test)]
277mod tests {
278    use super::ContextWebApiExt;
279    use crate::people::define_person_property;
280    use crate::{define_global_property, ContextGlobalPropertiesExt};
281    use crate::{Context, ContextPeopleExt};
282    use reqwest::StatusCode;
283    use serde::Serialize;
284    use serde_json::json;
285    use std::thread;
286
287    define_global_property!(WebApiTestGlobal, String);
288    define_person_property!(Age, u8);
289    fn setup() -> (String, Context) {
290        let mut context = Context::new();
291        let url = context.setup_web_api(33339).unwrap();
292        context.schedule_web_api(0.0);
293        context
294            .set_global_property_value(WebApiTestGlobal, "foobar".to_string())
295            .unwrap();
296        context.add_person((Age, 1)).unwrap();
297        context.add_person((Age, 2)).unwrap();
298        context
299            .add_web_api_handler("external", |_context, args| Ok(args))
300            .unwrap();
301        (url, context)
302    }
303
304    // Continue the simulation. Note that we don't wait for a response
305    // because there is a race condition between sending the final
306    // response and program termination.
307    fn send_continue(url: &str) {
308        let client = reqwest::blocking::Client::new();
309        client
310            .post(format!("{url}cmd/continue"))
311            .json(&{})
312            .send()
313            .unwrap();
314    }
315
316    // Send a request and check the response.
317    fn send_request<T: Serialize + ?Sized>(url: &str, cmd: &str, req: &T) -> serde_json::Value {
318        let client = reqwest::blocking::Client::new();
319        let response = client
320            .post(format!("{url}cmd/{cmd}"))
321            .json(req)
322            .send()
323            .unwrap();
324        let status = response.status();
325        let response = response.json().unwrap();
326        println!("{response:?}");
327        assert_eq!(status, StatusCode::OK);
328        response
329    }
330
331    // Send a request and check the response.
332    fn send_request_text(url: &str, cmd: &str, req: String) -> reqwest::blocking::Response {
333        let client = reqwest::blocking::Client::new();
334        client
335            .post(format!("{url}cmd/{cmd}"))
336            .header("Content-Type", "application/json")
337            .body(req)
338            .send()
339            .unwrap()
340    }
341
342    // We do all of the tests in one test block to avoid having to
343    // start a lot of servers with different ports and having
344    // to manage that. This may not be ideal, but we're doing it for now.
345    // TODO(cym4@cdc.gov): Consider using some kind of static
346    // object to isolate the test cases.
347    #[allow(clippy::too_many_lines)]
348    #[test]
349    fn web_api_test() {
350        #[derive(Serialize)]
351        struct PopulationResponse {
352            population: usize,
353        }
354
355        // TODO(cym4@cdc.gov): If this thread fails
356        // then the test will stall instead of
357        // erroring out, but there's nothing that
358        // should fail here.
359        let (tx, rx) = std::sync::mpsc::channel::<String>();
360        let ctx_thread = thread::spawn(move || {
361            let (url, mut context) = setup();
362            let _ = tx.send(url);
363            context.execute();
364        });
365
366        let url = rx.recv().unwrap();
367        // Test the population API point.
368        let res = send_request(&url, "population", &json!({}));
369        assert_eq!(json!(&PopulationResponse { population: 2 }), res);
370
371        // Test the time API point.
372        let res = send_request(&url, "time", &json!({}));
373        assert_eq!(
374            json!(
375                { "time": 0.0 }
376            ),
377            res
378        );
379
380        // Test the global property list point. We can't do
381        // exact match because the return is every defined
382        // global property anywhere in the code.
383        let res = send_request(
384            &url,
385            "global",
386            &json!({
387                "Global": "List"
388            }),
389        );
390        let list = res.get("List").unwrap().as_array().unwrap();
391        let mut found = false;
392        for prop in list {
393            let prop_val = prop.as_str().unwrap();
394            if prop_val == "ixa.WebApiTestGlobal" {
395                found = true;
396                break;
397            }
398        }
399        assert!(found);
400
401        // Test the global property get API point.
402        let res = send_request(
403            &url,
404            "global",
405            &json!({
406                "Global": {
407                    "Get" : {
408                        "property" : "ixa.WebApiTestGlobal"
409                    }
410                }
411            }),
412        );
413        // The extra quotes here are because we internally JSONify.
414        // TODO(cym4@cdc.gov): Should we fix this internally?
415        assert_eq!(
416            res,
417            json!({
418                "Value": "\"foobar\""
419            })
420        );
421
422        // Next
423        let res = send_request(&url, "next", &json!({}));
424        assert_eq!(res, json!("Ok"));
425
426        // We test breakpoint commands as a group.
427        // Breakpoint set
428        let res = send_request(
429            &url,
430            "breakpoint",
431            &json!({ "Breakpoint" : { "Set" : { "time": 1.0, "console": false} } }),
432        );
433        assert_eq!(res, json!("Ok"));
434
435        let res = send_request(
436            &url,
437            "breakpoint",
438            &json!({ "Breakpoint" : { "Set" : { "time": 2.0, "console": false} } }),
439        );
440        assert_eq!(res, json!("Ok"));
441
442        let res = send_request(
443            &url,
444            "breakpoint",
445            &json!({ "Breakpoint" : { "Delete" : { "id": 0, "all": false} } }),
446        );
447        assert_eq!(res, json!("Ok"));
448
449        // Breakpoint list
450        let res = send_request(&url, "breakpoint", &json!({"Breakpoint": "List"}));
451        assert_eq!(
452            res,
453            json!({"List" : [
454                "1: t=2 (First)"
455            ]}
456            )
457        );
458
459        let res = send_request(
460            &url,
461            "breakpoint",
462            &json!({ "Breakpoint" : { "Delete" : { "all": true, } } }),
463        );
464        assert_eq!(res, json!("Ok"));
465
466        // Check list again
467        let res = send_request(&url, "breakpoint", &json!({"Breakpoint": "List"}));
468        assert_eq!(
469            res,
470            json!({"List" : [/* empty list */ ]}
471            )
472        );
473
474        let res = send_request(&url, "breakpoint", &json!({ "Breakpoint" : "Disable" }));
475        assert_eq!(res, json!("Ok"));
476
477        let res = send_request(&url, "breakpoint", &json!({ "Breakpoint" : "Enable" }));
478        assert_eq!(res, json!("Ok"));
479
480        // Person properties API.
481        let res = send_request(
482            &url,
483            "people",
484            &json!({
485                "People" : {
486                    "Get" : {
487                        "person_id": 0,
488                        "property" : "Age"
489                    }
490                }
491            }),
492        );
493        assert_eq!(
494            res,
495            json!({"Properties" : [
496                ( "Age",  "1" )
497            ]}
498            )
499        );
500
501        // List properties.
502        let res = send_request(
503            &url,
504            "people",
505            &json!({
506                "People" : "Properties"
507            }),
508        );
509        assert_eq!(
510            res,
511            json!({"PropertyNames" : [
512                "Age"
513            ]}
514            )
515        );
516
517        // Tabulate API.
518        let res = send_request(
519            &url,
520            "people",
521            &json!({
522                "People" : {
523                    "Tabulate" : {
524                        "properties": ["Age"]
525                    }
526                }
527            }),
528        );
529
530        // This is a hack to deal with these arriving in
531        // arbitrary order.
532        assert!(
533            (res == json!({"Tabulated" : [
534                [{ "Age" :  "1" }, 1],
535                [{ "Age" :  "2" }, 1]
536            ]})) || (res
537                == json!({"Tabulated" : [
538                    [{ "Age" :  "2" }, 1],
539                    [{ "Age" :  "1" }, 1]
540                ]})),
541        );
542
543        // Valid JSON but wrong type.
544        let res = send_request_text(
545            &url,
546            "breakpoint",
547            String::from("{\"Set\": {\"time\" : \"invalid\"}}"),
548        );
549        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
550
551        // Invalid JSON.
552        let res = send_request_text(&url, "next", String::from("{]"));
553        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
554
555        // A generic externally added API handler
556        let res = send_request(&url, "external", &json!({"External": [1]}));
557        assert_eq!(res, json!({"External": [1]}));
558
559        // Test continue and make sure that the context
560        // exits.
561        send_continue(&url);
562        let _ = ctx_thread.join();
563    }
564}