use crate::client::Client; use crate::messages::command::CommandResponse; use api_core::command::{CommandHeader, IntoCommandDefinition}; use std::fmt::Display; use std::sync::Arc; use tokio::select; use tokio_util::sync::CancellationToken; pub struct CommandRegistry { client: Arc, } impl CommandRegistry { pub fn new(client: Arc) -> Self { Self { client } } pub fn register_handler( &self, command_name: impl Into, mut callback: F, ) -> CommandHandle where F: FnMut(CommandHeader, C) -> Result + Send + 'static, { let cancellation_token = CancellationToken::new(); let result = CommandHandle { cancellation_token: cancellation_token.clone(), }; let client = self.client.clone(); let command_definition = C::create(command_name.into()); tokio::spawn(async move { while !cancellation_token.is_cancelled() { // This would only fail if the sender closed while trying to insert data // It would wait until space is made let Ok(mut rx) = client .register_callback_channel(command_definition.clone()) .await else { continue; }; loop { // select used so that this loop gets broken if the token is cancelled select!( rx_value = rx.recv() => { if let Some((cmd, responder)) = rx_value { let header = cmd.header.clone(); let response = match C::parse(cmd) { Ok(cmd) => match callback(header, cmd) { Ok(response) => CommandResponse { success: true, response, }, Err(err) => CommandResponse { success: false, response: err.to_string(), }, }, Err(err) => CommandResponse { success: false, response: err.to_string(), }, }; // This should only err if we had an error elsewhere let _ = responder.send(response); } else { break; } }, _ = cancellation_token.cancelled() => { break; }, ); } } }); result } } pub struct CommandHandle { cancellation_token: CancellationToken, } impl Drop for CommandHandle { fn drop(&mut self) { self.cancellation_token.cancel(); } } #[cfg(test)] mod tests { use crate::client::command::CommandRegistry; use crate::client::tests::create_test_client; use crate::client::Callback; use crate::messages::callback::GenericCallbackError; use crate::messages::command::CommandResponse; use crate::messages::payload::RequestMessagePayload; use crate::messages::telemetry_definition::TelemetryDefinitionResponse; use crate::messages::ResponseMessage; use api_core::command::{ Command, CommandDefinition, CommandHeader, CommandParameterDefinition, IntoCommandDefinition, IntoCommandDefinitionError, }; use api_core::data_type::DataType; use std::collections::HashMap; use std::convert::Infallible; use std::sync::Arc; use std::time::Duration; use tokio::sync::oneshot; use tokio::time::timeout; use uuid::Uuid; struct CmdType { #[allow(unused)] param1: f32, } impl IntoCommandDefinition for CmdType { fn create(name: String) -> CommandDefinition { CommandDefinition { name, parameters: vec![CommandParameterDefinition { name: "param1".to_string(), data_type: DataType::Float32, }], } } fn parse(command: Command) -> Result { Ok(Self { param1: (*command.parameters.get("param1").ok_or_else(|| { IntoCommandDefinitionError::ParameterMissing("param1".to_string()) })?) .try_into() .map_err(|_| IntoCommandDefinitionError::MismatchedType { parameter: "param1".to_string(), expected: DataType::Float32, })?, }) } } #[tokio::test] async fn simple_handler() { // if _c drops then we are disconnected let (mut rx, _c, client) = create_test_client(); let cmd_reg = CommandRegistry::new(Arc::new(client)); let _cmd_handle = cmd_reg.register_handler("cmd", |_, _: CmdType| { Ok("success".to_string()) as Result<_, Infallible> }); let msg = timeout(Duration::from_secs(1), rx.recv()) .await .unwrap() .unwrap(); let Callback::Registered(callback) = msg.callback else { panic!("Incorrect Callback Type"); }; let mut params = HashMap::new(); params.insert("param1".to_string(), 0.0f32.into()); let (response_tx, response_rx) = oneshot::channel(); timeout( Duration::from_secs(1), callback.send(( ResponseMessage { uuid: Uuid::new_v4(), response: Some(msg.msg.uuid), payload: Command { header: CommandHeader { timestamp: Default::default(), }, parameters: params, } .into(), }, response_tx, )), ) .await .unwrap() .unwrap(); let response = timeout(Duration::from_secs(1), response_rx) .await .unwrap() .unwrap(); let RequestMessagePayload::CommandResponse(CommandResponse { success, response }) = response else { panic!("Unexpected Response Type"); }; assert!(success); assert_eq!(response, "success"); } #[tokio::test] async fn handler_failed() { // if _c drops then we are disconnected let (mut rx, _c, client) = create_test_client(); let cmd_reg = CommandRegistry::new(Arc::new(client)); let _cmd_handle = cmd_reg.register_handler("cmd", |_, _: CmdType| { Err("failure".into()) as Result<_, Box> }); let msg = timeout(Duration::from_secs(1), rx.recv()) .await .unwrap() .unwrap(); let Callback::Registered(callback) = msg.callback else { panic!("Incorrect Callback Type"); }; let mut params = HashMap::new(); params.insert("param1".to_string(), 1.0f32.into()); let (response_tx, response_rx) = oneshot::channel(); timeout( Duration::from_secs(1), callback.send(( ResponseMessage { uuid: Uuid::new_v4(), response: Some(msg.msg.uuid), payload: Command { header: CommandHeader { timestamp: Default::default(), }, parameters: params, } .into(), }, response_tx, )), ) .await .unwrap() .unwrap(); let response = timeout(Duration::from_secs(1), response_rx) .await .unwrap() .unwrap(); let RequestMessagePayload::CommandResponse(CommandResponse { success, response }) = response else { panic!("Unexpected Response Type"); }; assert!(!success); assert_eq!(response, "failure"); } #[tokio::test] async fn parse_failed() { // if _c drops then we are disconnected let (mut rx, _c, client) = create_test_client(); let cmd_reg = CommandRegistry::new(Arc::new(client)); let _cmd_handle = cmd_reg.register_handler("cmd", |_, _: CmdType| { Err("failure".into()) as Result<_, Box> }); let msg = timeout(Duration::from_secs(1), rx.recv()) .await .unwrap() .unwrap(); let Callback::Registered(callback) = msg.callback else { panic!("Incorrect Callback Type"); }; let mut params = HashMap::new(); params.insert("param1".to_string(), 1.0f64.into()); let (response_tx, response_rx) = oneshot::channel(); timeout( Duration::from_secs(1), callback.send(( ResponseMessage { uuid: Uuid::new_v4(), response: Some(msg.msg.uuid), payload: Command { header: CommandHeader { timestamp: Default::default(), }, parameters: params, } .into(), }, response_tx, )), ) .await .unwrap() .unwrap(); let response = timeout(Duration::from_secs(1), response_rx) .await .unwrap() .unwrap(); let RequestMessagePayload::CommandResponse(CommandResponse { success, response: _, }) = response else { panic!("Unexpected Response Type"); }; assert!(!success); } #[tokio::test] async fn wrong_message() { // if _c drops then we are disconnected let (mut rx, _c, client) = create_test_client(); let cmd_reg = CommandRegistry::new(Arc::new(client)); let _cmd_handle = cmd_reg.register_handler("cmd", |_, _: CmdType| -> Result<_, Infallible> { panic!("This should not happen"); }); let msg = timeout(Duration::from_secs(1), rx.recv()) .await .unwrap() .unwrap(); let Callback::Registered(callback) = msg.callback else { panic!("Incorrect Callback Type"); }; let (response_tx, response_rx) = oneshot::channel(); timeout( Duration::from_secs(1), callback.send(( ResponseMessage { uuid: Uuid::new_v4(), response: Some(msg.msg.uuid), payload: TelemetryDefinitionResponse { uuid: Uuid::new_v4(), } .into(), }, response_tx, )), ) .await .unwrap() .unwrap(); let response = timeout(Duration::from_secs(1), response_rx) .await .unwrap() .unwrap(); let RequestMessagePayload::GenericCallbackError(err) = response else { panic!("Unexpected Response Type"); }; assert_eq!(err, GenericCallbackError::MismatchedType); } #[tokio::test] async fn callback_closed() { // if _c drops then we are disconnected let (mut rx, _c, client) = create_test_client(); let cmd_reg = CommandRegistry::new(Arc::new(client)); let cmd_handle = cmd_reg.register_handler("cmd", |_, _: CmdType| -> Result<_, Infallible> { panic!("This should not happen"); }); let msg = timeout(Duration::from_secs(1), rx.recv()) .await .unwrap() .unwrap(); let Callback::Registered(callback) = msg.callback else { panic!("Incorrect Callback Type"); }; // This should shut down the command handler drop(cmd_handle); // Send a command let mut params = HashMap::new(); params.insert("param1".to_string(), 0.0f32.into()); let (response_tx, response_rx) = oneshot::channel(); timeout( Duration::from_secs(1), callback.send(( ResponseMessage { uuid: Uuid::new_v4(), response: Some(msg.msg.uuid), payload: Command { header: CommandHeader { timestamp: Default::default(), }, parameters: params, } .into(), }, response_tx, )), ) .await .unwrap() .unwrap(); let response = timeout(Duration::from_secs(1), response_rx) .await .unwrap() .unwrap(); let RequestMessagePayload::GenericCallbackError(err) = response else { panic!("Unexpected Response Type"); }; assert_eq!(err, GenericCallbackError::CallbackClosed); } #[tokio::test] async fn reconnect() { // if _c drops then we are disconnected let (mut rx, _c, client) = create_test_client(); let cmd_reg = CommandRegistry::new(Arc::new(client)); let _cmd_handle = cmd_reg.register_handler("cmd", |_, _: CmdType| -> Result<_, Infallible> { panic!("This should not happen"); }); let msg = timeout(Duration::from_secs(1), rx.recv()) .await .unwrap() .unwrap(); let Callback::Registered(callback) = msg.callback else { panic!("Incorrect Callback Type"); }; println!("Dropping"); drop(callback); println!("Dropped"); // The command re-registers itself let msg = timeout(Duration::from_secs(1), rx.recv()) .await .unwrap() .unwrap(); let Callback::Registered(_) = msg.callback else { panic!("Incorrect Callback Type"); }; } }