1use std::thread;
2
3use axum::extract::{Json, Path, State};
4use axum::http::StatusCode;
5use axum::response::Redirect;
6use axum::routing::{get, post};
7use axum::Router;
8use serde_json::json;
9use tokio::sync::{mpsc, oneshot};
10use tower_http::services::{ServeDir, ServeFile};
11
12use crate::context::Context;
13use crate::error::IxaError;
14use crate::external_api::{
15 breakpoint, global_properties, halt, next, people, population, r#continue, run_ext_api, time,
16 EmptyArgs,
17};
18use crate::rand::RngCore;
19use crate::{define_data_plugin, HashMap, HashMapExt, PluginContext};
20
21pub type WebApiHandler =
22 dyn Fn(&mut Context, serde_json::Value) -> Result<serde_json::Value, IxaError>;
23
24fn register_api_handler<
25 T: crate::external_api::ExtApi<Args = A>,
26 A: serde::de::DeserializeOwned,
27>(
28 dc: &mut ApiData,
29 name: &str,
30) {
31 dc.handlers.insert(
32 name.to_string(),
33 Box::new(
34 |context, args_json| -> Result<serde_json::Value, IxaError> {
35 let args: A = serde_json::from_value(args_json)?;
36 let retval: T::Retval = run_ext_api::<T>(context, &args)?;
37 Ok(serde_json::to_value(retval)?)
38 },
39 ),
40 );
41}
42
43struct ApiData {
44 receiver: mpsc::UnboundedReceiver<ApiRequest>,
45 handlers: HashMap<String, Box<WebApiHandler>>,
46}
47
48pub(crate) fn handle_web_api_with_plugin(context: &mut Context) {
51 let mut data_container = context.get_data_mut(ApiPlugin).take().unwrap();
54
55 handle_web_api(context, &mut data_container);
56
57 let saved_data_container = context.get_data_mut(ApiPlugin);
59 *saved_data_container = Some(data_container);
60}
61
62define_data_plugin!(ApiPlugin, Option<ApiData>, None);
63
64struct ApiRequest {
66 cmd: String,
67 arguments: serde_json::Value,
68 rx: oneshot::Sender<ApiResponse>,
70}
71
72struct ApiResponse {
74 code: StatusCode,
75 response: serde_json::Value,
76}
77
78#[derive(Clone)]
79struct ApiEndpointServer {
80 sender: mpsc::UnboundedSender<ApiRequest>,
81}
82
83async fn process_cmd(
84 State(state): State<ApiEndpointServer>,
85 Path(path): Path<String>,
86 Json(payload): Json<serde_json::Value>,
87) -> (StatusCode, Json<serde_json::Value>) {
88 let (tx, rx) = oneshot::channel::<ApiResponse>();
89 let _ = state.sender.send(ApiRequest {
90 cmd: path,
91 arguments: payload,
92 rx: tx,
93 });
94
95 match rx.await {
96 Ok(response) => (response.code, Json(response.response)),
97 _ => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({}))),
98 }
99}
100
101#[tokio::main]
102async fn serve(
103 sender: mpsc::UnboundedSender<ApiRequest>,
104 port: u16,
105 prefix: &str,
106 ready: oneshot::Sender<Result<String, IxaError>>,
107) {
108 let state = ApiEndpointServer { sender };
109
110 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{port}")).await;
112 if listener.is_err() {
113 ready
114 .send(Err(IxaError::IxaError(format!("Could not bind to {port}"))))
115 .unwrap();
116 return;
117 }
118
119 let path = format!("{}/{}", env!("CARGO_MANIFEST_DIR"), "static/");
121 let static_assets_path = std::path::Path::new(&path);
122 let home_path = format!("/{prefix}/static/index.html");
123 let app = Router::new()
124 .route(&format!("/{prefix}/cmd/{{command}}"), post(process_cmd))
125 .route(
126 &format!("/{prefix}/"),
127 get(|| async move { Redirect::temporary(&home_path) }),
128 )
129 .nest_service(
130 &format!("/{prefix}/static/"),
131 ServeDir::new(static_assets_path),
132 )
133 .nest_service(
134 "/favicon.ico",
135 ServeFile::new_with_mime(
136 static_assets_path.join(std::path::Path::new("favicon.ico")),
137 &mime::IMAGE_PNG,
138 ),
139 )
140 .with_state(state);
141
142 ready
144 .send(Ok(format!("http://127.0.0.1:{port}/{prefix}/")))
145 .unwrap();
146 axum::serve(listener.unwrap(), app).await.unwrap();
147}
148
149fn handle_web_api(context: &mut Context, api: &mut ApiData) {
152 while let Some(req) = api.receiver.blocking_recv() {
153 if req.cmd == "continue" {
154 let _ = req.rx.send(ApiResponse {
155 code: StatusCode::OK,
156 response: json!({}),
157 });
158 break;
159 }
160
161 let handler = api.handlers.get(&req.cmd);
162 if handler.is_none() {
163 let _ = req.rx.send(ApiResponse {
164 code: StatusCode::NOT_FOUND,
165 response: json!({
166 "error" : format!("No command {}", req.cmd)
167 }),
168 });
169 continue;
170 }
171
172 let handler = handler.unwrap();
173 match handler(context, req.arguments.clone()) {
174 Err(err) => {
175 let _ = req.rx.send(ApiResponse {
176 code: StatusCode::BAD_REQUEST,
177 response: json!({
178 "error" : err.to_string()
179 }),
180 });
181 continue;
182 }
183 Ok(response) => {
184 let _ = req.rx.send(ApiResponse {
185 code: StatusCode::OK,
186 response,
187 });
188 }
189 }
190
191 if req.cmd == "continue" {
194 return;
195 }
196 }
197}
198
199pub trait ContextWebApiExt: PluginContext {
200 fn setup_web_api(&mut self, port: u16) -> Result<String, IxaError> {
205 let (api_to_ctx_send, api_to_ctx_recv) = mpsc::unbounded_channel::<ApiRequest>();
207
208 let data_container = self.get_data_mut(ApiPlugin);
209 if data_container.is_some() {
210 return Err(IxaError::IxaError(String::from(
211 "HTTP API already initialized",
212 )));
213 }
214
215 let mut random: [u8; 16] = [0; 16];
217 let mut rng = rand::rng();
218 rng.fill_bytes(&mut random);
219 let secret = uuid::Builder::from_random_bytes(random)
220 .into_uuid()
221 .to_string();
222
223 let (ready_tx, ready_rx) = oneshot::channel::<Result<String, IxaError>>();
224 thread::spawn(move || serve(api_to_ctx_send, port, &secret, ready_tx));
225 let url = ready_rx.blocking_recv().unwrap()?;
226
227 let mut api_data = ApiData {
228 receiver: api_to_ctx_recv,
229 handlers: HashMap::new(),
230 };
231
232 register_api_handler::<breakpoint::Api, breakpoint::Args>(&mut api_data, "breakpoint");
233 register_api_handler::<r#continue::Api, EmptyArgs>(&mut api_data, "continue");
234 register_api_handler::<global_properties::Api, global_properties::Args>(
235 &mut api_data,
236 "global",
237 );
238 register_api_handler::<halt::Api, EmptyArgs>(&mut api_data, "halt");
239 register_api_handler::<next::Api, EmptyArgs>(&mut api_data, "next");
240 register_api_handler::<people::Api, people::Args>(&mut api_data, "people");
241 register_api_handler::<population::Api, EmptyArgs>(&mut api_data, "population");
242 register_api_handler::<time::Api, EmptyArgs>(&mut api_data, "time");
243 *data_container = Some(api_data);
245
246 Ok(url)
247 }
248
249 fn schedule_web_api(&mut self, t: f64) {
252 self.add_plan(t, handle_web_api_with_plugin);
253 }
254
255 fn add_web_api_handler(
259 &mut self,
260 name: &str,
261 handler: impl Fn(&mut Context, serde_json::Value) -> Result<serde_json::Value, IxaError>
262 + 'static,
263 ) -> Result<(), IxaError> {
264 let data_container = self.get_data_mut(ApiPlugin);
265
266 match data_container {
267 Some(dc) => {
268 dc.handlers.insert(name.to_string(), Box::new(handler));
269 Ok(())
270 }
271 None => Err(IxaError::IxaError(String::from("Web API not yet set up"))),
272 }
273 }
274}
275impl ContextWebApiExt for Context {}
276
277#[cfg(test)]
278mod tests {
279 use std::thread;
280
281 use reqwest::StatusCode;
282 use serde::Serialize;
283 use serde_json::json;
284
285 use super::ContextWebApiExt;
286 use crate::people::define_person_property;
287 use crate::{define_global_property, Context, ContextGlobalPropertiesExt, ContextPeopleExt};
288
289 define_global_property!(WebApiTestGlobal, String);
290 define_person_property!(Age, u8);
291 fn setup() -> (String, Context) {
292 let mut context = Context::new();
293 let url = context.setup_web_api(33339).unwrap();
294 context.schedule_web_api(0.0);
295 context
296 .set_global_property_value(WebApiTestGlobal, "foobar".to_string())
297 .unwrap();
298 context.add_person((Age, 1)).unwrap();
299 context.add_person((Age, 2)).unwrap();
300 context
301 .add_web_api_handler("external", |_context, args| Ok(args))
302 .unwrap();
303 (url, context)
304 }
305
306 fn send_continue(url: &str) {
310 let client = reqwest::blocking::Client::new();
311 client
312 .post(format!("{url}cmd/continue"))
313 .json(&{})
314 .send()
315 .unwrap();
316 }
317
318 fn send_request<T: Serialize + ?Sized>(url: &str, cmd: &str, req: &T) -> serde_json::Value {
320 let client = reqwest::blocking::Client::new();
321 let response = client
322 .post(format!("{url}cmd/{cmd}"))
323 .json(req)
324 .send()
325 .unwrap();
326 let status = response.status();
327 let response = response.json().unwrap();
328 println!("{response:?}");
329 assert_eq!(status, StatusCode::OK);
330 response
331 }
332
333 fn send_request_text(url: &str, cmd: &str, req: String) -> reqwest::blocking::Response {
335 let client = reqwest::blocking::Client::new();
336 client
337 .post(format!("{url}cmd/{cmd}"))
338 .header("Content-Type", "application/json")
339 .body(req)
340 .send()
341 .unwrap()
342 }
343
344 #[allow(clippy::too_many_lines)]
350 #[test]
351 fn web_api_test() {
352 #[derive(Serialize)]
353 struct PopulationResponse {
354 population: usize,
355 }
356
357 let (tx, rx) = std::sync::mpsc::channel::<String>();
362 let ctx_thread = thread::spawn(move || {
363 let (url, mut context) = setup();
364 let _ = tx.send(url);
365 context.execute();
366 });
367
368 let url = rx.recv().unwrap();
369 let res = send_request(&url, "population", &json!({}));
371 assert_eq!(json!(&PopulationResponse { population: 2 }), res);
372
373 let res = send_request(&url, "time", &json!({}));
375 assert_eq!(
376 json!(
377 { "time": 0.0 }
378 ),
379 res
380 );
381
382 let res = send_request(
386 &url,
387 "global",
388 &json!({
389 "Global": "List"
390 }),
391 );
392 let list = res.get("List").unwrap().as_array().unwrap();
393 let mut found = false;
394 for prop in list {
395 let prop_val = prop.as_str().unwrap();
396 if prop_val == "ixa.WebApiTestGlobal" {
397 found = true;
398 break;
399 }
400 }
401 assert!(found);
402
403 let res = send_request(
405 &url,
406 "global",
407 &json!({
408 "Global": {
409 "Get" : {
410 "property" : "ixa.WebApiTestGlobal"
411 }
412 }
413 }),
414 );
415 assert_eq!(
418 res,
419 json!({
420 "Value": "\"foobar\""
421 })
422 );
423
424 let res = send_request(&url, "next", &json!({}));
426 assert_eq!(res, json!("Ok"));
427
428 let res = send_request(
431 &url,
432 "breakpoint",
433 &json!({ "Breakpoint" : { "Set" : { "time": 1.0, "console": false} } }),
434 );
435 assert_eq!(res, json!("Ok"));
436
437 let res = send_request(
438 &url,
439 "breakpoint",
440 &json!({ "Breakpoint" : { "Set" : { "time": 2.0, "console": false} } }),
441 );
442 assert_eq!(res, json!("Ok"));
443
444 let res = send_request(
445 &url,
446 "breakpoint",
447 &json!({ "Breakpoint" : { "Delete" : { "id": 0, "all": false} } }),
448 );
449 assert_eq!(res, json!("Ok"));
450
451 let res = send_request(&url, "breakpoint", &json!({"Breakpoint": "List"}));
453 assert_eq!(
454 res,
455 json!({"List" : [
456 "1: t=2 (First)"
457 ]}
458 )
459 );
460
461 let res = send_request(
462 &url,
463 "breakpoint",
464 &json!({ "Breakpoint" : { "Delete" : { "all": true, } } }),
465 );
466 assert_eq!(res, json!("Ok"));
467
468 let res = send_request(&url, "breakpoint", &json!({"Breakpoint": "List"}));
470 assert_eq!(
471 res,
472 json!({"List" : []}
473 )
474 );
475
476 let res = send_request(&url, "breakpoint", &json!({ "Breakpoint" : "Disable" }));
477 assert_eq!(res, json!("Ok"));
478
479 let res = send_request(&url, "breakpoint", &json!({ "Breakpoint" : "Enable" }));
480 assert_eq!(res, json!("Ok"));
481
482 let res = send_request(
484 &url,
485 "people",
486 &json!({
487 "People" : {
488 "Get" : {
489 "person_id": 0,
490 "property" : "Age"
491 }
492 }
493 }),
494 );
495 assert_eq!(
496 res,
497 json!({"Properties" : [
498 ( "Age", "1" )
499 ]}
500 )
501 );
502
503 let res = send_request(
505 &url,
506 "people",
507 &json!({
508 "People" : "Properties"
509 }),
510 );
511 assert_eq!(
512 res,
513 json!({"PropertyNames" : [
514 "Age"
515 ]}
516 )
517 );
518
519 let res = send_request(
521 &url,
522 "people",
523 &json!({
524 "People" : {
525 "Tabulate" : {
526 "properties": ["Age"]
527 }
528 }
529 }),
530 );
531
532 assert!(
535 (res == json!({"Tabulated" : [
536 [{ "Age" : "1" }, 1],
537 [{ "Age" : "2" }, 1]
538 ]})) || (res
539 == json!({"Tabulated" : [
540 [{ "Age" : "2" }, 1],
541 [{ "Age" : "1" }, 1]
542 ]})),
543 );
544
545 let res = send_request_text(
547 &url,
548 "breakpoint",
549 String::from("{\"Set\": {\"time\" : \"invalid\"}}"),
550 );
551 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
552
553 let res = send_request_text(&url, "next", String::from("{]"));
555 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
556
557 let res = send_request(&url, "external", &json!({"External": [1]}));
559 assert_eq!(res, json!({"External": [1]}));
560
561 send_continue(&url);
564 let _ = ctx_thread.join();
565 }
566}