pub mod error; use crate::client::error::{MessageError, RequestError}; use crate::messages::callback::GenericCallbackError; use crate::messages::payload::RequestMessagePayload; use crate::messages::payload::ResponseMessagePayload; use crate::messages::{ ClientMessage, RegisterCallback, RequestMessage, RequestResponse, ResponseMessage, }; use error::ConnectError; use futures_util::stream::StreamExt; use futures_util::SinkExt; use log::{debug, error, warn}; use std::collections::HashMap; use std::sync::mpsc::sync_channel; use std::sync::Arc; use std::thread; use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot, RwLock, RwLockWriteGuard}; use tokio::{select, spawn}; use tokio_tungstenite::tungstenite::client::IntoClientRequest; use tokio_tungstenite::tungstenite::handshake::client::Request; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; use tokio_util::sync::CancellationToken; use uuid::Uuid; type RegisteredCallback = mpsc::Sender<(ResponseMessage, oneshot::Sender)>; type ClientChannel = Arc>>; enum Callback { None, Once(oneshot::Sender), Registered(RegisteredCallback), } struct OutgoingMessage { msg: RequestMessage, callback: Callback, } pub struct Client { cancel: CancellationToken, channel: ClientChannel, } struct ClientContext { cancel: CancellationToken, request: Request, } impl Client { pub fn connect(request: R) -> Result where R: IntoClientRequest, { let (tx, _rx) = mpsc::channel(1); let cancel = CancellationToken::new(); let channel = Arc::new(RwLock::new(tx)); let context = ClientContext { cancel: cancel.clone(), request: request.into_client_request()?, }; context.start(channel.clone())?; Ok(Self { cancel, channel }) } pub async fn send_message(&self, msg: M) -> Result<(), MessageError> { let sender = self.channel.read().await; let data = sender.reserve().await?; data.send(OutgoingMessage { msg: RequestMessage { uuid: Uuid::new_v4(), response: None, payload: msg.into(), }, callback: Callback::None, }); Ok(()) } pub async fn send_message_if_connected( &self, msg: M, ) -> Result<(), MessageError> { let sender = self.channel.try_read()?; let data = sender.reserve().await?; data.send(OutgoingMessage { msg: RequestMessage { uuid: Uuid::new_v4(), response: None, payload: msg.into(), }, callback: Callback::None, }); Ok(()) } pub fn try_send_message(&self, msg: M) -> Result<(), MessageError> { let sender = self.channel.try_read()?; let data = sender.try_reserve()?; data.send(OutgoingMessage { msg: RequestMessage { uuid: Uuid::new_v4(), response: None, payload: msg.into(), }, callback: Callback::None, }); Ok(()) } pub async fn send_request( &self, msg: M, ) -> Result>::Error>> { let sender = self.channel.read().await; let data = sender.reserve().await?; let (tx, rx) = oneshot::channel(); data.send(OutgoingMessage { msg: RequestMessage { uuid: Uuid::new_v4(), response: None, payload: msg.into(), }, callback: Callback::Once(tx), }); let response = rx.await?; let response = M::Response::try_from(response.payload).map_err(RequestError::Inner)?; Ok(response) } pub async fn register_callback_channel( &self, msg: M, ) -> Result)>, MessageError> where ::Callback: Send + 'static, ::Response: Send + 'static, <::Callback as TryFrom>::Error: Send, { let sender = self.channel.read().await; let data = sender.reserve().await?; let (inner_tx, mut inner_rx) = mpsc::channel(16); let (outer_tx, outer_rx) = mpsc::channel(1); data.send(OutgoingMessage { msg: RequestMessage { uuid: Uuid::new_v4(), response: None, payload: msg.into(), }, callback: Callback::Registered(inner_tx), }); spawn(async move { // If the handler was unregistered we can stop while let Some((msg, responder)) = inner_rx.recv().await { let response: RequestMessagePayload = match M::Callback::try_from(msg.payload) { Err(_) => GenericCallbackError::MismatchedType.into(), Ok(o) => { let (response_tx, response_rx) = oneshot::channel::(); match outer_tx.send((o, response_tx)).await { Err(_) => GenericCallbackError::CallbackClosed.into(), Ok(()) => response_rx .await .map(M::Response::into) .unwrap_or_else(|_| GenericCallbackError::CallbackClosed.into()), } } }; if responder.send(response).is_err() { // If the callback was unregistered we can stop break; } } }); Ok(outer_rx) } pub async fn register_callback_fn( &self, msg: M, mut f: F, ) -> Result<(), MessageError> where F: FnMut(M::Callback) -> M::Response + Send + 'static, { let sender = self.channel.read().await; let data = sender.reserve().await?; let (inner_tx, mut inner_rx) = mpsc::channel(16); data.send(OutgoingMessage { msg: RequestMessage { uuid: Uuid::new_v4(), response: None, payload: msg.into(), }, callback: Callback::Registered(inner_tx), }); spawn(async move { // If the handler was unregistered we can stop while let Some((msg, responder)) = inner_rx.recv().await { let response: RequestMessagePayload = match M::Callback::try_from(msg.payload) { Err(_) => GenericCallbackError::MismatchedType.into(), Ok(o) => f(o).into(), }; if responder.send(response).is_err() { // If the callback was unregistered we can stop break; } } }); Ok(()) } } impl ClientContext { fn start(mut self, channel: ClientChannel) -> Result<(), ConnectError> { let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build()?; let (tx, rx) = sync_channel::<()>(1); let _detached = thread::Builder::new() .name("tlm-client".to_string()) .spawn(move || { runtime.block_on(async { let mut write_lock = channel.write().await; // This cannot fail let _ = tx.send(()); while !self.cancel.is_cancelled() { write_lock = self.run_connection(write_lock, &channel).await; } drop(write_lock); }); })?; // This cannot fail let _ = rx.recv(); Ok(()) } async fn run_connection<'a>( &mut self, mut write_lock: RwLockWriteGuard<'a, mpsc::Sender>, channel: &'a ClientChannel, ) -> RwLockWriteGuard<'a, mpsc::Sender> { let mut ws = match connect_async(self.request.clone()).await { Ok((ws, _)) => ws, Err(e) => { error!("Connect Error: {e}"); return write_lock; } }; let (tx, rx) = mpsc::channel(128); *write_lock = tx; drop(write_lock); let close_connection = self.handle_connection(&mut ws, rx, channel).await; let write_lock = channel.write().await; if close_connection { if let Err(e) = ws.close(None).await { println!("Close Error {e}"); } } write_lock } async fn handle_connection( &mut self, ws: &mut WebSocketStream>, mut rx: mpsc::Receiver, channel: &ClientChannel, ) -> bool { let mut callbacks = HashMap::::new(); loop { select! { _ = self.cancel.cancelled() => { break; }, Some(msg) = ws.next() => { match msg { Ok(msg) => { match msg { Message::Text(msg) => { let msg: ResponseMessage = match serde_json::from_str(&msg) { Ok(m) => m, Err(e) => { error!("Failed to deserialize {e}"); break; } }; self.handle_incoming(msg, &mut callbacks, channel).await; } Message::Binary(_) => unimplemented!("Binary Data Not Implemented"), Message::Ping(data) => { if let Err(e) = ws.send(Message::Pong(data)).await { error!("Failed to send Pong {e}"); break; } } Message::Pong(_) => { // Intentionally Left Empty } Message::Close(_) => { debug!("Websocket Closed"); return false; } Message::Frame(_) => unreachable!("Not Possible"), } } Err(e) => { error!("Receive Error {e}"); break; } } } Some(msg) = rx.recv() => { // Insert a callback if it isn't a None callback if !matches!(msg.callback, Callback::None) { callbacks.insert(msg.msg.uuid, msg.callback); } let msg = match serde_json::to_string(&msg.msg) { Ok(m) => m, Err(e) => { error!("Encode Error {e}"); break; } }; if let Err(e) = ws.send(Message::Text(msg.into())).await { error!("Send Error {e}"); break; } } else => { break; }, } } true } async fn handle_incoming( &mut self, msg: ResponseMessage, callbacks: &mut HashMap, channel: &ClientChannel, ) { if let Some(response_uuid) = msg.response { match callbacks.get(&response_uuid) { Some(Callback::None) => { callbacks.remove(&response_uuid); unreachable!("We skip registering callbacks of None type"); } Some(Callback::Once(_)) => { let Some(Callback::Once(callback)) = callbacks.remove(&response_uuid) else { return; }; let _ = callback.send(msg); } Some(Callback::Registered(callback)) => { let callback = callback.clone(); spawn(Self::handle_registered_callback( callback, msg, channel.clone(), )); } None => { warn!("No Callback Registered for {response_uuid}"); } } } } async fn handle_registered_callback( callback: RegisteredCallback, msg: ResponseMessage, channel: ClientChannel, ) { let (tx, rx) = oneshot::channel(); let uuid = msg.uuid; let response = match callback.send((msg, tx)).await { Err(_) => GenericCallbackError::CallbackClosed.into(), Ok(()) => rx .await .unwrap_or_else(|_| GenericCallbackError::CallbackClosed.into()), }; if let Err(e) = Self::send_response(channel, response, uuid).await { error!("Failed to send response {e}"); } } async fn send_response( channel: ClientChannel, payload: RequestMessagePayload, response_uuid: Uuid, ) -> Result<(), MessageError> { // If this failed that means we're in the middle of reconnecting, so our callbacks // are all being cleaned up as-is. No response needed. let sender = channel.try_read()?; let data = sender.reserve().await?; data.send(OutgoingMessage { msg: RequestMessage { uuid: Uuid::new_v4(), response: Some(response_uuid), payload, }, callback: Callback::None, }); Ok(()) } } impl Drop for Client { fn drop(&mut self) { self.cancel.cancel(); } }