use crate::client::config::ClientConfiguration; use crate::client::error::{ConnectError, MessageError}; use crate::client::{Callback, ClientChannel, OutgoingMessage, RegisteredCallback}; use crate::messages::callback::GenericCallbackError; use crate::messages::payload::RequestMessagePayload; use crate::messages::{RequestMessage, ResponseMessage}; use futures_util::{Sink, SinkExt, Stream, StreamExt}; use log::{debug, error, info, trace, warn}; use std::collections::HashMap; use std::fmt::Display; use std::sync::mpsc::sync_channel; use std::thread; use std::time::Duration; use tokio::sync::{mpsc, oneshot, watch, RwLockWriteGuard}; use tokio::time::sleep; use tokio::{select, spawn}; use tokio_tungstenite::connect_async; use tokio_tungstenite::tungstenite::handshake::client::{Request, Response as TungResponse}; use tokio_tungstenite::tungstenite::{Error as TungError, Message}; use tokio_util::sync::CancellationToken; use uuid::Uuid; pub struct ClientContext { pub cancel: CancellationToken, pub request: Request, pub connected_state_tx: watch::Sender, pub client_configuration: ClientConfiguration, } impl ClientContext { pub fn start(mut self, channel: ClientChannel) -> Result<(), ConnectError> { let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build()?; let (tx, rx) = sync_channel::<()>(1); let _detached = thread::Builder::new() .name("tlm-client".to_string()) .spawn(move || { runtime.block_on(async { let mut write_lock = channel.write().await; // This cannot fail let _ = tx.send(()); while !self.cancel.is_cancelled() { write_lock = self .run_connection(write_lock, &channel, connect_async) .await; } drop(write_lock); }); })?; // This cannot fail let _ = rx.recv(); Ok(()) } async fn run_connection<'a, F, W, E>( &mut self, mut write_lock: RwLockWriteGuard<'a, mpsc::Sender>, channel: &'a ClientChannel, mut connection_fn: F, ) -> RwLockWriteGuard<'a, mpsc::Sender> where F: AsyncFnMut(Request) -> Result<(W, TungResponse), TungError>, W: Stream> + Sink + Unpin, E: Display, { debug!("Attempting to Connect to {}", self.request.uri()); let mut ws = match connection_fn(self.request.clone()).await { Ok((ws, _)) => ws, Err(e) => { info!("Failed to Connect: {e}"); sleep(Duration::from_secs(1)).await; return write_lock; } }; info!("Connected to {}", self.request.uri()); let (tx, rx) = mpsc::channel(self.client_configuration.send_buffer_size); *write_lock = tx; drop(write_lock); // Don't care about the previous value let _ = self.connected_state_tx.send_replace(true); let close_connection = self.handle_connection(&mut ws, rx, channel).await; let write_lock = channel.write().await; // Send this after grabbing the lock - to prevent extra contention when others try to grab // the lock to use that as a signal that we have reconnected let _ = self.connected_state_tx.send_replace(false); if close_connection { // Manually close to allow the impl trait to be used if let Err(e) = ws.send(Message::Close(None)).await { error!("Failed to Close the Connection: {e}"); } } write_lock } async fn handle_connection( &mut self, ws: &mut W, mut rx: mpsc::Receiver, channel: &ClientChannel, ) -> bool where W: Stream> + Sink + Unpin, >::Error: Display, { let mut callbacks = HashMap::::new(); loop { select! { _ = self.cancel.cancelled() => { break; }, Some(msg) = ws.next() => { match msg { Ok(msg) => { match msg { Message::Text(msg) => { trace!("Incoming: {msg}"); let msg: ResponseMessage = match serde_json::from_str(&msg) { Ok(m) => m, Err(e) => { error!("Failed to deserialize {e}"); break; } }; self.handle_incoming(msg, &mut callbacks, channel).await; } Message::Binary(_) => unimplemented!("Binary Data Not Implemented"), Message::Ping(data) => { if let Err(e) = ws.send(Message::Pong(data)).await { error!("Failed to send Pong {e}"); break; } } Message::Pong(_) => { // Intentionally Left Empty } Message::Close(_) => { debug!("Websocket Closed"); return false; } Message::Frame(_) => unreachable!("Not Possible"), } } Err(e) => { error!("Receive Error {e}"); break; } } } Some(msg) = rx.recv() => { // Insert a callback if it isn't a None callback if !matches!(msg.callback, Callback::None) { callbacks.insert(msg.msg.uuid, msg.callback); } let msg = match serde_json::to_string(&msg.msg) { Ok(m) => m, Err(e) => { error!("Encode Error {e}"); break; } }; trace!("Outgoing: {msg}"); if let Err(e) = ws.send(Message::Text(msg.into())).await { error!("Send Error {e}"); break; } } else => { break; }, } } true } async fn handle_incoming( &mut self, msg: ResponseMessage, callbacks: &mut HashMap, channel: &ClientChannel, ) { if let Some(response_uuid) = msg.response { match callbacks.get(&response_uuid) { Some(Callback::None) => { callbacks.remove(&response_uuid); unreachable!("We skip registering callbacks of None type"); } Some(Callback::Once(_)) => { let Some(Callback::Once(callback)) = callbacks.remove(&response_uuid) else { return; }; let _ = callback.send(msg); } Some(Callback::Registered(callback)) => { let callback = callback.clone(); spawn(Self::handle_registered_callback( callback, msg, channel.clone(), )); } None => { warn!("No Callback Registered for {response_uuid}"); } } } } async fn handle_registered_callback( callback: RegisteredCallback, msg: ResponseMessage, channel: ClientChannel, ) { let (tx, rx) = oneshot::channel(); let uuid = msg.uuid; let response = match callback.send((msg, tx)).await { Err(_) => GenericCallbackError::CallbackClosed.into(), Ok(()) => rx .await .unwrap_or_else(|_| GenericCallbackError::CallbackClosed.into()), }; if let Err(e) = Self::send_response(channel, response, uuid).await { error!("Failed to send response {e}"); } } async fn send_response( channel: ClientChannel, payload: RequestMessagePayload, response_uuid: Uuid, ) -> Result<(), MessageError> { // If this failed that means we're in the middle of reconnecting, so our callbacks // are all being cleaned up as-is. No response needed. let sender = channel.try_read()?; let data = sender.reserve().await?; data.send(OutgoingMessage { msg: RequestMessage { uuid: Uuid::new_v4(), response: Some(response_uuid), payload, }, callback: Callback::None, }); Ok(()) } } #[cfg(test)] mod tests { use super::*; use crate::messages::telemetry_definition::{ TelemetryDefinitionRequest, TelemetryDefinitionResponse, }; use crate::test::mock_stream_sink::{create_mock_stream_sink, MockStreamSinkControl}; use api_core::data_type::DataType; use log::LevelFilter; use std::future::Future; use std::ops::Deref; use tokio::sync::mpsc::Sender; use tokio::sync::RwLock; use tokio::time::timeout; use tokio::try_join; use tokio_tungstenite::tungstenite::client::IntoClientRequest; use tokio_util::bytes::Bytes; async fn assert_client_interaction(future: F) where F: Send + FnOnce( Sender, MockStreamSinkControl, Message>, CancellationToken, ) -> R + 'static, R: Future + Send, { let (control, stream_sink) = create_mock_stream_sink::, Message>(); let cancel_token = CancellationToken::new(); let inner_cancel_token = cancel_token.clone(); let (connected_state_tx, _connected_state_rx) = watch::channel(false); let mut context = ClientContext { cancel: cancel_token, request: "mock".into_client_request().unwrap(), connected_state_tx, client_configuration: Default::default(), }; let (tx, _rx) = mpsc::channel(1); let channel = ClientChannel::new(RwLock::new(tx)); let used_channel = channel.clone(); let write_lock = used_channel.write().await; let handle = spawn(async move { let channel = channel; let read = channel.read().await; let sender = read.deref().clone(); drop(read); future(sender, control, inner_cancel_token).await; }); let mut stream_sink = Some(stream_sink); let connection_fn = async |_: Request| { let stream_sink = stream_sink.take().ok_or(TungError::ConnectionClosed)?; Ok((stream_sink, TungResponse::default())) as Result<(_, _), TungError> }; let context_result = async { drop( context .run_connection(write_lock, &used_channel, connection_fn) .await, ); Ok(()) }; try_join!(context_result, timeout(Duration::from_secs(1), handle),) .unwrap() .1 .unwrap(); } #[tokio::test] async fn connection_closes_when_websocket_closes() { let _ = env_logger::builder() .is_test(true) .filter_level(LevelFilter::Trace) .try_init(); assert_client_interaction(|sender, mut control, _| async move { let msg = Uuid::new_v4(); sender .send(OutgoingMessage { msg: RequestMessage { uuid: msg, response: None, payload: TelemetryDefinitionRequest { name: "".to_string(), data_type: DataType::Float32, } .into(), }, callback: Callback::None, }) .await .unwrap(); // We expect an outgoing message assert!(matches!( control.outgoing.recv().await.unwrap(), Message::Text(_) )); // We receive an incoming close message control .incoming .send(Ok(Message::Close(None))) .await .unwrap(); // Then we expect the outgoing to close with no message assert!(control.outgoing.recv().await.is_none()); assert!(control.incoming.is_closed()); }) .await; } #[tokio::test] async fn connection_closes_when_cancelled() { let _ = env_logger::builder() .is_test(true) .filter_level(LevelFilter::Trace) .try_init(); assert_client_interaction(|_, mut control, cancel| async move { cancel.cancel(); // We expect an outgoing cancel message assert!(matches!( control.outgoing.recv().await.unwrap(), Message::Close(_) )); // Then we expect to close with no message assert!(control.outgoing.recv().await.is_none()); assert!(control.incoming.is_closed()); }) .await; } #[tokio::test] async fn callback_request() { let _ = env_logger::builder() .is_test(true) .filter_level(LevelFilter::Trace) .try_init(); assert_client_interaction(|sender, mut control, _| async move { let (callback_tx, callback_rx) = oneshot::channel(); let msg = Uuid::new_v4(); sender .send(OutgoingMessage { msg: RequestMessage { uuid: msg, response: None, payload: TelemetryDefinitionRequest { name: "".to_string(), data_type: DataType::Float32, } .into(), }, callback: Callback::Once(callback_tx), }) .await .unwrap(); // We expect an outgoing message assert!(matches!( control.outgoing.recv().await.unwrap(), Message::Text(_) )); // Then we get an incoming message for this callback let response_message = ResponseMessage { uuid: Uuid::new_v4(), response: Some(msg), payload: TelemetryDefinitionResponse { uuid: Uuid::new_v4(), } .into(), }; control .incoming .send(Ok(Message::Text( serde_json::to_string(&response_message).unwrap().into(), ))) .await .unwrap(); // We expect the callback to run let message = callback_rx.await.unwrap(); // And give us the message we provided it assert_eq!(message, response_message); // We receive an incoming close message control .incoming .send(Ok(Message::Close(None))) .await .unwrap(); // Then we expect the outgoing to close with no message assert!(control.outgoing.recv().await.is_none()); assert!(control.incoming.is_closed()); }) .await; } #[tokio::test] async fn callback_registered() { let _ = env_logger::builder() .is_test(true) .filter_level(LevelFilter::Trace) .try_init(); assert_client_interaction(|sender, mut control, _| async move { let (callback_tx, mut callback_rx) = mpsc::channel(1); let msg = Uuid::new_v4(); sender .send(OutgoingMessage { msg: RequestMessage { uuid: msg, response: None, payload: TelemetryDefinitionRequest { name: "".to_string(), data_type: DataType::Float32, } .into(), }, callback: Callback::Registered(callback_tx), }) .await .unwrap(); // We expect an outgoing message assert!(matches!( control.outgoing.recv().await.unwrap(), Message::Text(_) )); // We handle the callback a few times for _ in 0..5 { // Then we get an incoming message for this callback let response_message = ResponseMessage { uuid: Uuid::new_v4(), response: Some(msg), payload: TelemetryDefinitionResponse { uuid: Uuid::new_v4(), } .into(), }; control .incoming .send(Ok(Message::Text( serde_json::to_string(&response_message).unwrap().into(), ))) .await .unwrap(); // We expect the response let (rx, responder) = callback_rx.recv().await.unwrap(); // And give us the message we provided it assert_eq!(rx, response_message); // Then the response gets sent out responder .send( TelemetryDefinitionRequest { name: "".to_string(), data_type: DataType::Float32, } .into(), ) .unwrap(); // We expect an outgoing message assert!(matches!( control.outgoing.recv().await.unwrap(), Message::Text(_) )); } // We receive an incoming close message control .incoming .send(Ok(Message::Close(None))) .await .unwrap(); // Then we expect the outgoing to close with no message assert!(control.outgoing.recv().await.is_none()); assert!(control.incoming.is_closed()); }) .await; } #[tokio::test] async fn ping_pong() { let _ = env_logger::builder() .is_test(true) .filter_level(LevelFilter::Trace) .try_init(); assert_client_interaction(|_, mut control, _| async move { // Expect a pong in response to a ping let bytes = Bytes::from_owner(Uuid::new_v4().into_bytes()); control .incoming .send(Ok(Message::Ping(bytes.clone()))) .await .unwrap(); let Some(Message::Pong(pong_bytes)) = control.outgoing.recv().await else { panic!("Expected Pong Response"); }; assert_eq!(bytes, pong_bytes); // Nothing should happen control .incoming .send(Ok(Message::Pong(bytes.clone()))) .await .unwrap(); // We receive an incoming close message control .incoming .send(Ok(Message::Close(None))) .await .unwrap(); // Then we expect the outgoing to close with no message assert!(control.outgoing.recv().await.is_none()); assert!(control.incoming.is_closed()); }) .await; } }