pub mod command; mod config; mod context; pub mod error; pub mod telemetry; use crate::client::config::ClientConfiguration; use crate::client::error::{MessageError, RequestError}; use crate::messages::callback::GenericCallbackError; use crate::messages::payload::RequestMessagePayload; use crate::messages::payload::ResponseMessagePayload; use crate::messages::{ ClientMessage, RegisterCallback, RequestMessage, RequestResponse, ResponseMessage, }; use context::ClientContext; use error::ConnectError; use std::sync::Arc; use tokio::spawn; use tokio::sync::{mpsc, oneshot, watch, RwLock}; use tokio_tungstenite::tungstenite::client::IntoClientRequest; use tokio_util::sync::CancellationToken; use uuid::Uuid; type RegisteredCallback = mpsc::Sender<(ResponseMessage, oneshot::Sender)>; type ClientChannel = Arc>>; #[derive(Debug)] enum Callback { None, Once(oneshot::Sender), Registered(RegisteredCallback), } #[derive(Debug)] struct OutgoingMessage { msg: RequestMessage, callback: Callback, } pub struct Client { cancel: CancellationToken, channel: ClientChannel, connected_state_rx: watch::Receiver, } impl Client { pub fn connect(request: R) -> Result where R: IntoClientRequest, { Self::connect_with_config(request, ClientConfiguration::default()) } pub fn connect_with_config( request: R, config: ClientConfiguration, ) -> Result where R: IntoClientRequest, { let (tx, _rx) = mpsc::channel(1); let cancel = CancellationToken::new(); let channel = Arc::new(RwLock::new(tx)); let (connected_state_tx, connected_state_rx) = watch::channel(false); let context = ClientContext { cancel: cancel.clone(), request: request.into_client_request()?, connected_state_tx, client_configuration: config, }; context.start(channel.clone())?; Ok(Self { cancel, channel, connected_state_rx, }) } pub async fn send_message(&self, msg: M) -> Result<(), MessageError> { let sender = self.channel.read().await; let data = sender.reserve().await?; data.send(OutgoingMessage { msg: RequestMessage { uuid: Uuid::new_v4(), response: None, payload: msg.into(), }, callback: Callback::None, }); Ok(()) } pub async fn send_message_if_connected( &self, msg: M, ) -> Result<(), MessageError> { let sender = self.channel.try_read()?; let data = sender.reserve().await?; data.send(OutgoingMessage { msg: RequestMessage { uuid: Uuid::new_v4(), response: None, payload: msg.into(), }, callback: Callback::None, }); Ok(()) } pub fn try_send_message(&self, msg: M) -> Result<(), MessageError> { let sender = self.channel.try_read()?; let data = sender.try_reserve()?; data.send(OutgoingMessage { msg: RequestMessage { uuid: Uuid::new_v4(), response: None, payload: msg.into(), }, callback: Callback::None, }); Ok(()) } pub async fn send_request( &self, msg: M, ) -> Result>::Error>> { let sender = self.channel.read().await; let data = sender.reserve().await?; let (tx, rx) = oneshot::channel(); data.send(OutgoingMessage { msg: RequestMessage { uuid: Uuid::new_v4(), response: None, payload: msg.into(), }, callback: Callback::Once(tx), }); let response = rx.await?; let response = M::Response::try_from(response.payload).map_err(RequestError::Inner)?; Ok(response) } pub async fn register_callback_channel( &self, msg: M, ) -> Result)>, MessageError> where ::Callback: Send + 'static, ::Response: Send + 'static, <::Callback as TryFrom>::Error: Send, { let sender = self.channel.read().await; let data = sender.reserve().await?; let (inner_tx, mut inner_rx) = mpsc::channel(16); let (outer_tx, outer_rx) = mpsc::channel(1); data.send(OutgoingMessage { msg: RequestMessage { uuid: Uuid::new_v4(), response: None, payload: msg.into(), }, callback: Callback::Registered(inner_tx), }); spawn(async move { // If the handler was unregistered we can stop while let Some((msg, responder)) = inner_rx.recv().await { let response: RequestMessagePayload = match M::Callback::try_from(msg.payload) { Err(_) => GenericCallbackError::MismatchedType.into(), Ok(o) => { let (response_tx, response_rx) = oneshot::channel::(); match outer_tx.send((o, response_tx)).await { Err(_) => GenericCallbackError::CallbackClosed.into(), Ok(()) => response_rx .await .map(M::Response::into) .unwrap_or_else(|_| GenericCallbackError::CallbackClosed.into()), } } }; if responder.send(response).is_err() { // If the callback was unregistered we can stop break; } } }); Ok(outer_rx) } pub async fn register_callback_fn( &self, msg: M, mut f: F, ) -> Result<(), MessageError> where F: FnMut(M::Callback) -> M::Response + Send + 'static, { let sender = self.channel.read().await; let data = sender.reserve().await?; let (inner_tx, mut inner_rx) = mpsc::channel(16); data.send(OutgoingMessage { msg: RequestMessage { uuid: Uuid::new_v4(), response: None, payload: msg.into(), }, callback: Callback::Registered(inner_tx), }); spawn(async move { // If the handler was unregistered we can stop while let Some((msg, responder)) = inner_rx.recv().await { let response: RequestMessagePayload = match M::Callback::try_from(msg.payload) { Err(_) => GenericCallbackError::MismatchedType.into(), Ok(o) => f(o).into(), }; if responder.send(response).is_err() { // If the callback was unregistered we can stop break; } } }); Ok(()) } pub async fn wait_connected(&self) { let mut connected_rx = self.connected_state_rx.clone(); // If we aren't currently connected if !*connected_rx.borrow_and_update() { // Wait for a change notification // If the channel is closed there is nothing we can do let _ = connected_rx.changed().await; } } pub async fn wait_disconnected(&self) { let mut connected_rx = self.connected_state_rx.clone(); // If we are currently connected if *connected_rx.borrow_and_update() { // Wait for a change notification // If the channel is closed there is nothing we can do let _ = connected_rx.changed().await; } } } impl Drop for Client { fn drop(&mut self) { self.cancel.cancel(); } } #[cfg(test)] mod tests { use super::*; use crate::messages::command::CommandResponse; use crate::messages::telemetry_definition::{ TelemetryDefinitionRequest, TelemetryDefinitionResponse, }; use crate::messages::telemetry_entry::TelemetryEntry; use api_core::command::{Command, CommandDefinition, CommandHeader}; use api_core::data_type::DataType; use chrono::Utc; use futures_util::future::{select, Either}; use futures_util::FutureExt; use std::pin::pin; use std::time::Duration; use tokio::join; use tokio::time::{sleep, timeout}; pub fn create_test_client() -> (mpsc::Receiver, watch::Sender, Client) { let cancel = CancellationToken::new(); let (tx, rx) = mpsc::channel(1); let channel = Arc::new(RwLock::new(tx)); let (connected_state_tx, connected_state_rx) = watch::channel(true); let client = Client { cancel, channel, connected_state_rx, }; (rx, connected_state_tx, client) } #[tokio::test] async fn send_message() { let (mut rx, _, client) = create_test_client(); let msg_to_send = TelemetryEntry { uuid: Uuid::new_v4(), value: 0.0f32.into(), timestamp: Utc::now(), }; let msg_send = timeout( Duration::from_secs(1), client.send_message(msg_to_send.clone()), ); let msg_recv = timeout(Duration::from_secs(1), rx.recv()); let (send, recv) = join!(msg_send, msg_recv); send.unwrap().unwrap(); let recv = recv.unwrap().unwrap(); assert!(matches!(recv.callback, Callback::None)); assert!(recv.msg.response.is_none()); // uuid should be random let RequestMessagePayload::TelemetryEntry(recv) = recv.msg.payload else { panic!("Wrong Message Received") }; assert_eq!(recv, msg_to_send); } #[tokio::test] async fn send_message_if_connected() { let (mut rx, _, client) = create_test_client(); let msg_to_send = TelemetryEntry { uuid: Uuid::new_v4(), value: 0.0f32.into(), timestamp: Utc::now(), }; let msg_send = timeout( Duration::from_secs(1), client.send_message_if_connected(msg_to_send.clone()), ); let msg_recv = timeout(Duration::from_secs(1), rx.recv()); let (send, recv) = join!(msg_send, msg_recv); send.unwrap().unwrap(); let recv = recv.unwrap().unwrap(); assert!(matches!(recv.callback, Callback::None)); assert!(recv.msg.response.is_none()); // uuid should be random let RequestMessagePayload::TelemetryEntry(recv) = recv.msg.payload else { panic!("Wrong Message Received") }; assert_eq!(recv, msg_to_send); } #[tokio::test] async fn send_message_if_connected_not_connected() { let (_, connected_state_tx, client) = create_test_client(); let _lock = client.channel.write().await; connected_state_tx.send_replace(false); let msg_to_send = TelemetryEntry { uuid: Uuid::new_v4(), value: 0.0f32.into(), timestamp: Utc::now(), }; let msg_send = timeout( Duration::from_secs(1), client.send_message_if_connected(msg_to_send.clone()), ); let Err(MessageError::TokioLockError(_)) = msg_send.await.unwrap() else { panic!("Expected to Err due to lock being unavailable") }; } #[tokio::test] async fn try_send_message() { let (_tx, _, client) = create_test_client(); let msg_to_send = TelemetryEntry { uuid: Uuid::new_v4(), value: 0.0f32.into(), timestamp: Utc::now(), }; client.try_send_message(msg_to_send.clone()).unwrap(); let Err(MessageError::TokioTrySendError(_)) = client.try_send_message(msg_to_send.clone()) else { panic!("Expected the buffer to be full"); }; } #[tokio::test] async fn send_request() { let (mut tx, _, client) = create_test_client(); let msg_to_send = TelemetryDefinitionRequest { name: "".to_string(), data_type: DataType::Float32, }; let response = timeout( Duration::from_secs(1), client.send_request(msg_to_send.clone()), ); let response_uuid = Uuid::new_v4(); let outgoing_rx = timeout(Duration::from_secs(1), async { let msg = tx.recv().await.unwrap(); let Callback::Once(cb) = msg.callback else { panic!("Wrong Callback Type") }; cb.send(ResponseMessage { uuid: Uuid::new_v4(), response: Some(msg.msg.uuid), payload: TelemetryDefinitionResponse { uuid: response_uuid, } .into(), }) .unwrap(); }); let (response, outgoing_rx) = join!(response, outgoing_rx); let response = response.unwrap().unwrap(); outgoing_rx.unwrap(); assert_eq!(response.uuid, response_uuid); } #[tokio::test] async fn register_callback_channel() { let (mut tx, _, client) = create_test_client(); let msg_to_send = CommandDefinition { name: "".to_string(), parameters: vec![], }; let mut response = timeout( Duration::from_secs(1), client.register_callback_channel(msg_to_send), ) .await .unwrap() .unwrap(); let outgoing_rx = timeout(Duration::from_secs(1), async { let msg = tx.recv().await.unwrap(); let Callback::Registered(cb) = msg.callback else { panic!("Wrong Callback Type") }; // Check that we get responses to the callback the expected number of times for i in 0..5 { let (tx, rx) = oneshot::channel(); cb.send(( ResponseMessage { uuid: Uuid::new_v4(), response: Some(msg.msg.uuid), payload: Command { header: CommandHeader { timestamp: Utc::now(), }, parameters: Default::default(), } .into(), }, tx, )) .await .unwrap(); let RequestMessagePayload::CommandResponse(response) = rx.await.unwrap() else { panic!("Unexpected Response Type"); }; assert_eq!(response.response, format!("{i}")); } }); let responder = timeout(Duration::from_secs(1), async { for i in 0..5 { let (_cmd, responder) = response.recv().await.unwrap(); responder .send(CommandResponse { success: false, response: format!("{i}"), }) .unwrap(); } }); let (response, outgoing_rx) = join!(responder, outgoing_rx); response.unwrap(); outgoing_rx.unwrap(); } #[tokio::test] async fn register_callback_fn() { let (mut tx, _, client) = create_test_client(); let msg_to_send = CommandDefinition { name: "".to_string(), parameters: vec![], }; let mut index = 0usize; timeout( Duration::from_secs(1), client.register_callback_fn(msg_to_send, move |_| { index += 1; CommandResponse { success: false, response: format!("{}", index - 1), } }), ) .await .unwrap() .unwrap(); timeout(Duration::from_secs(1), async { let msg = tx.recv().await.unwrap(); let Callback::Registered(cb) = msg.callback else { panic!("Wrong Callback Type") }; // Check that we get responses to the callback the expected number of times for i in 0..3 { let (tx, rx) = oneshot::channel(); cb.send(( ResponseMessage { uuid: Uuid::new_v4(), response: Some(msg.msg.uuid), payload: Command { header: CommandHeader { timestamp: Utc::now(), }, parameters: Default::default(), } .into(), }, tx, )) .await .unwrap(); let RequestMessagePayload::CommandResponse(response) = rx.await.unwrap() else { panic!("Unexpected Response Type"); }; assert_eq!(response.response, format!("{i}")); } }) .await .unwrap(); } #[tokio::test] async fn connected_disconnected() { let (_, connected, client) = create_test_client(); // When we're connected we should return immediately connected.send_replace(true); client.wait_connected().now_or_never().unwrap(); // When we're disconnected we should return immediately connected.send_replace(false); client.wait_disconnected().now_or_never().unwrap(); let c2 = connected.clone(); // When we're disconnected, we should not return immediately let f1 = pin!(client.wait_connected()); let f2 = pin!(async move { sleep(Duration::from_millis(1)).await; c2.send_replace(true); }); let r = select(f1, f2).await; match r { Either::Left(_) => panic!("Wait Connected Finished Before Connection Changed"), Either::Right((_, other)) => timeout(Duration::from_secs(1), other).await.unwrap(), } let c2 = connected.clone(); // When we're disconnected, we should not return immediately let f1 = pin!(client.wait_disconnected()); let f2 = pin!(async move { sleep(Duration::from_millis(1)).await; c2.send_replace(false); }); let r = select(f1, f2).await; match r { Either::Left(_) => panic!("Wait Disconnected Finished Before Connection Changed"), Either::Right((_, other)) => timeout(Duration::from_secs(1), other).await.unwrap(), } } }