move cmd off of grpc

This commit is contained in:
2025-12-30 14:19:41 -05:00
parent 29f7f6d83b
commit 6980b7f6aa
26 changed files with 452 additions and 389 deletions

View File

@@ -1,20 +1,23 @@
pub mod error;
use crate::client::error::{MessageError, RequestError};
use crate::messages::callback::GenericCallbackError;
use crate::messages::payload::RequestMessagePayload;
use crate::messages::payload::ResponseMessagePayload;
use crate::messages::{ClientMessage, RequestMessage, RequestResponse, ResponseMessage};
use crate::messages::{
ClientMessage, RegisterCallback, RequestMessage, RequestResponse, ResponseMessage,
};
use error::ConnectError;
use futures_util::stream::StreamExt;
use futures_util::SinkExt;
use log::{debug, error};
use log::{debug, error, warn};
use std::collections::HashMap;
use std::sync::mpsc::sync_channel;
use std::sync::Arc;
use std::thread;
use tokio::net::TcpStream;
use tokio::select;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::{mpsc, oneshot, RwLock, RwLockWriteGuard};
use tokio::{select, spawn};
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::handshake::client::Request;
use tokio_tungstenite::tungstenite::Message;
@@ -22,9 +25,13 @@ use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
type RegisteredCallback = mpsc::Sender<(ResponseMessage, oneshot::Sender<RequestMessagePayload>)>;
type ClientChannel = Arc<RwLock<mpsc::Sender<OutgoingMessage>>>;
enum Callback {
None,
Once(oneshot::Sender<ResponseMessage>),
Registered(RegisteredCallback),
}
struct OutgoingMessage {
@@ -34,7 +41,7 @@ struct OutgoingMessage {
pub struct Client {
cancel: CancellationToken,
channel: Arc<RwLock<Sender<OutgoingMessage>>>,
channel: ClientChannel,
}
struct ClientContext {
@@ -128,10 +135,101 @@ impl Client {
Ok(response)
}
pub async fn register_callback_channel<M: RegisterCallback>(
&self,
msg: M,
) -> Result<mpsc::Receiver<(M::Callback, oneshot::Sender<M::Response>)>, MessageError>
where
<M as RegisterCallback>::Callback: Send + 'static,
<M as RegisterCallback>::Response: Send + 'static,
<<M as RegisterCallback>::Callback as TryFrom<ResponseMessagePayload>>::Error: Send,
{
let sender = self.channel.read().await;
let data = sender.reserve().await?;
let (inner_tx, mut inner_rx) = mpsc::channel(16);
let (outer_tx, outer_rx) = mpsc::channel(1);
data.send(OutgoingMessage {
msg: RequestMessage {
uuid: Uuid::new_v4(),
response: None,
payload: msg.into(),
},
callback: Callback::Registered(inner_tx),
});
spawn(async move {
// If the handler was unregistered we can stop
while let Some((msg, responder)) = inner_rx.recv().await {
let response: RequestMessagePayload = match M::Callback::try_from(msg.payload) {
Err(_) => GenericCallbackError::MismatchedType.into(),
Ok(o) => {
let (response_tx, response_rx) = oneshot::channel::<M::Response>();
match outer_tx.send((o, response_tx)).await {
Err(_) => GenericCallbackError::CallbackClosed.into(),
Ok(()) => response_rx
.await
.map(M::Response::into)
.unwrap_or_else(|_| GenericCallbackError::CallbackClosed.into()),
}
}
};
if responder.send(response).is_err() {
// If the callback was unregistered we can stop
break;
}
}
});
Ok(outer_rx)
}
pub async fn register_callback_fn<M: RegisterCallback, F>(
&self,
msg: M,
mut f: F,
) -> Result<(), MessageError>
where
F: FnMut(M::Callback) -> M::Response + Send + 'static,
{
let sender = self.channel.read().await;
let data = sender.reserve().await?;
let (inner_tx, mut inner_rx) = mpsc::channel(16);
data.send(OutgoingMessage {
msg: RequestMessage {
uuid: Uuid::new_v4(),
response: None,
payload: msg.into(),
},
callback: Callback::Registered(inner_tx),
});
spawn(async move {
// If the handler was unregistered we can stop
while let Some((msg, responder)) = inner_rx.recv().await {
let response: RequestMessagePayload = match M::Callback::try_from(msg.payload) {
Err(_) => GenericCallbackError::MismatchedType.into(),
Ok(o) => f(o).into(),
};
if responder.send(response).is_err() {
// If the callback was unregistered we can stop
break;
}
}
});
Ok(())
}
}
impl ClientContext {
fn start(mut self, channel: Arc<RwLock<Sender<OutgoingMessage>>>) -> Result<(), ConnectError> {
fn start(mut self, channel: ClientChannel) -> Result<(), ConnectError> {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
@@ -162,9 +260,9 @@ impl ClientContext {
async fn run_connection<'a>(
&mut self,
mut write_lock: RwLockWriteGuard<'a, Sender<OutgoingMessage>>,
channel: &'a Arc<RwLock<Sender<OutgoingMessage>>>,
) -> RwLockWriteGuard<'a, Sender<OutgoingMessage>> {
mut write_lock: RwLockWriteGuard<'a, mpsc::Sender<OutgoingMessage>>,
channel: &'a ClientChannel,
) -> RwLockWriteGuard<'a, mpsc::Sender<OutgoingMessage>> {
let mut ws = match connect_async(self.request.clone()).await {
Ok((ws, _)) => ws,
Err(e) => {
@@ -177,7 +275,7 @@ impl ClientContext {
*write_lock = tx;
drop(write_lock);
let close_connection = self.handle_connection(&mut ws, rx).await;
let close_connection = self.handle_connection(&mut ws, rx, channel).await;
let write_lock = channel.write().await;
if close_connection {
@@ -191,7 +289,8 @@ impl ClientContext {
async fn handle_connection(
&mut self,
ws: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
mut rx: Receiver<OutgoingMessage>,
mut rx: mpsc::Receiver<OutgoingMessage>,
channel: &ClientChannel,
) -> bool {
let mut callbacks = HashMap::<Uuid, Callback>::new();
loop {
@@ -209,7 +308,7 @@ impl ClientContext {
break;
}
};
self.handle_incoming(msg, &mut callbacks).await;
self.handle_incoming(msg, &mut callbacks, channel).await;
}
Message::Binary(_) => unimplemented!("Binary Data Not Implemented"),
Message::Ping(data) => {
@@ -235,7 +334,10 @@ impl ClientContext {
}
}
Some(msg) = rx.recv() => {
callbacks.insert(msg.msg.uuid, msg.callback);
// Insert a callback if it isn't a None callback
if !matches!(msg.callback, Callback::None) {
callbacks.insert(msg.msg.uuid, msg.callback);
}
let msg = match serde_json::to_string(&msg.msg) {
Ok(m) => m,
Err(e) => {
@@ -258,22 +360,77 @@ impl ClientContext {
&mut self,
msg: ResponseMessage,
callbacks: &mut HashMap<Uuid, Callback>,
channel: &ClientChannel,
) {
if let Some(response_uuid) = msg.response {
match callbacks.get(&response_uuid) {
Some(Callback::None) => {
callbacks.remove(&response_uuid);
unreachable!("We skip registering callbacks of None type");
}
Some(Callback::Once(_)) => {
let Some(Callback::Once(callback)) = callbacks.remove(&response_uuid) else {
return;
};
let _ = callback.send(msg);
}
Some(Callback::None) => {
callbacks.remove(&response_uuid);
Some(Callback::Registered(callback)) => {
let callback = callback.clone();
spawn(Self::handle_registered_callback(
callback,
msg,
channel.clone(),
));
}
None => {
warn!("No Callback Registered for {response_uuid}");
}
None => {}
}
}
}
async fn handle_registered_callback(
callback: RegisteredCallback,
msg: ResponseMessage,
channel: ClientChannel,
) {
let (tx, rx) = oneshot::channel();
let uuid = msg.uuid;
let response = match callback.send((msg, tx)).await {
Err(_) => GenericCallbackError::CallbackClosed.into(),
Ok(()) => rx
.await
.unwrap_or_else(|_| GenericCallbackError::CallbackClosed.into()),
};
if let Err(e) = Self::send_response(channel, response, uuid).await {
error!("Failed to send response {e}");
}
}
async fn send_response(
channel: ClientChannel,
payload: RequestMessagePayload,
response_uuid: Uuid,
) -> Result<(), MessageError> {
// If this failed that means we're in the middle of reconnecting, so our callbacks
// are all being cleaned up as-is. No response needed.
let sender = channel.try_read()?;
let data = sender.reserve().await?;
data.send(OutgoingMessage {
msg: RequestMessage {
uuid: Uuid::new_v4(),
response: Some(response_uuid),
payload,
},
callback: Callback::None,
});
Ok(())
}
}
impl Drop for Client {