move cmd off of grpc
This commit is contained in:
@@ -1,20 +1,23 @@
|
||||
pub mod error;
|
||||
|
||||
use crate::client::error::{MessageError, RequestError};
|
||||
use crate::messages::callback::GenericCallbackError;
|
||||
use crate::messages::payload::RequestMessagePayload;
|
||||
use crate::messages::payload::ResponseMessagePayload;
|
||||
use crate::messages::{ClientMessage, RequestMessage, RequestResponse, ResponseMessage};
|
||||
use crate::messages::{
|
||||
ClientMessage, RegisterCallback, RequestMessage, RequestResponse, ResponseMessage,
|
||||
};
|
||||
use error::ConnectError;
|
||||
use futures_util::stream::StreamExt;
|
||||
use futures_util::SinkExt;
|
||||
use log::{debug, error};
|
||||
use log::{debug, error, warn};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::mpsc::sync_channel;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::select;
|
||||
use tokio::sync::mpsc::{Receiver, Sender};
|
||||
use tokio::sync::{mpsc, oneshot, RwLock, RwLockWriteGuard};
|
||||
use tokio::{select, spawn};
|
||||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||
use tokio_tungstenite::tungstenite::handshake::client::Request;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
@@ -22,9 +25,13 @@ use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
type RegisteredCallback = mpsc::Sender<(ResponseMessage, oneshot::Sender<RequestMessagePayload>)>;
|
||||
type ClientChannel = Arc<RwLock<mpsc::Sender<OutgoingMessage>>>;
|
||||
|
||||
enum Callback {
|
||||
None,
|
||||
Once(oneshot::Sender<ResponseMessage>),
|
||||
Registered(RegisteredCallback),
|
||||
}
|
||||
|
||||
struct OutgoingMessage {
|
||||
@@ -34,7 +41,7 @@ struct OutgoingMessage {
|
||||
|
||||
pub struct Client {
|
||||
cancel: CancellationToken,
|
||||
channel: Arc<RwLock<Sender<OutgoingMessage>>>,
|
||||
channel: ClientChannel,
|
||||
}
|
||||
|
||||
struct ClientContext {
|
||||
@@ -128,10 +135,101 @@ impl Client {
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn register_callback_channel<M: RegisterCallback>(
|
||||
&self,
|
||||
msg: M,
|
||||
) -> Result<mpsc::Receiver<(M::Callback, oneshot::Sender<M::Response>)>, MessageError>
|
||||
where
|
||||
<M as RegisterCallback>::Callback: Send + 'static,
|
||||
<M as RegisterCallback>::Response: Send + 'static,
|
||||
<<M as RegisterCallback>::Callback as TryFrom<ResponseMessagePayload>>::Error: Send,
|
||||
{
|
||||
let sender = self.channel.read().await;
|
||||
let data = sender.reserve().await?;
|
||||
|
||||
let (inner_tx, mut inner_rx) = mpsc::channel(16);
|
||||
let (outer_tx, outer_rx) = mpsc::channel(1);
|
||||
|
||||
data.send(OutgoingMessage {
|
||||
msg: RequestMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: None,
|
||||
payload: msg.into(),
|
||||
},
|
||||
callback: Callback::Registered(inner_tx),
|
||||
});
|
||||
|
||||
spawn(async move {
|
||||
// If the handler was unregistered we can stop
|
||||
while let Some((msg, responder)) = inner_rx.recv().await {
|
||||
let response: RequestMessagePayload = match M::Callback::try_from(msg.payload) {
|
||||
Err(_) => GenericCallbackError::MismatchedType.into(),
|
||||
Ok(o) => {
|
||||
let (response_tx, response_rx) = oneshot::channel::<M::Response>();
|
||||
match outer_tx.send((o, response_tx)).await {
|
||||
Err(_) => GenericCallbackError::CallbackClosed.into(),
|
||||
Ok(()) => response_rx
|
||||
.await
|
||||
.map(M::Response::into)
|
||||
.unwrap_or_else(|_| GenericCallbackError::CallbackClosed.into()),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if responder.send(response).is_err() {
|
||||
// If the callback was unregistered we can stop
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(outer_rx)
|
||||
}
|
||||
|
||||
pub async fn register_callback_fn<M: RegisterCallback, F>(
|
||||
&self,
|
||||
msg: M,
|
||||
mut f: F,
|
||||
) -> Result<(), MessageError>
|
||||
where
|
||||
F: FnMut(M::Callback) -> M::Response + Send + 'static,
|
||||
{
|
||||
let sender = self.channel.read().await;
|
||||
let data = sender.reserve().await?;
|
||||
|
||||
let (inner_tx, mut inner_rx) = mpsc::channel(16);
|
||||
|
||||
data.send(OutgoingMessage {
|
||||
msg: RequestMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: None,
|
||||
payload: msg.into(),
|
||||
},
|
||||
callback: Callback::Registered(inner_tx),
|
||||
});
|
||||
|
||||
spawn(async move {
|
||||
// If the handler was unregistered we can stop
|
||||
while let Some((msg, responder)) = inner_rx.recv().await {
|
||||
let response: RequestMessagePayload = match M::Callback::try_from(msg.payload) {
|
||||
Err(_) => GenericCallbackError::MismatchedType.into(),
|
||||
Ok(o) => f(o).into(),
|
||||
};
|
||||
|
||||
if responder.send(response).is_err() {
|
||||
// If the callback was unregistered we can stop
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientContext {
|
||||
fn start(mut self, channel: Arc<RwLock<Sender<OutgoingMessage>>>) -> Result<(), ConnectError> {
|
||||
fn start(mut self, channel: ClientChannel) -> Result<(), ConnectError> {
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
@@ -162,9 +260,9 @@ impl ClientContext {
|
||||
|
||||
async fn run_connection<'a>(
|
||||
&mut self,
|
||||
mut write_lock: RwLockWriteGuard<'a, Sender<OutgoingMessage>>,
|
||||
channel: &'a Arc<RwLock<Sender<OutgoingMessage>>>,
|
||||
) -> RwLockWriteGuard<'a, Sender<OutgoingMessage>> {
|
||||
mut write_lock: RwLockWriteGuard<'a, mpsc::Sender<OutgoingMessage>>,
|
||||
channel: &'a ClientChannel,
|
||||
) -> RwLockWriteGuard<'a, mpsc::Sender<OutgoingMessage>> {
|
||||
let mut ws = match connect_async(self.request.clone()).await {
|
||||
Ok((ws, _)) => ws,
|
||||
Err(e) => {
|
||||
@@ -177,7 +275,7 @@ impl ClientContext {
|
||||
*write_lock = tx;
|
||||
drop(write_lock);
|
||||
|
||||
let close_connection = self.handle_connection(&mut ws, rx).await;
|
||||
let close_connection = self.handle_connection(&mut ws, rx, channel).await;
|
||||
|
||||
let write_lock = channel.write().await;
|
||||
if close_connection {
|
||||
@@ -191,7 +289,8 @@ impl ClientContext {
|
||||
async fn handle_connection(
|
||||
&mut self,
|
||||
ws: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
mut rx: Receiver<OutgoingMessage>,
|
||||
mut rx: mpsc::Receiver<OutgoingMessage>,
|
||||
channel: &ClientChannel,
|
||||
) -> bool {
|
||||
let mut callbacks = HashMap::<Uuid, Callback>::new();
|
||||
loop {
|
||||
@@ -209,7 +308,7 @@ impl ClientContext {
|
||||
break;
|
||||
}
|
||||
};
|
||||
self.handle_incoming(msg, &mut callbacks).await;
|
||||
self.handle_incoming(msg, &mut callbacks, channel).await;
|
||||
}
|
||||
Message::Binary(_) => unimplemented!("Binary Data Not Implemented"),
|
||||
Message::Ping(data) => {
|
||||
@@ -235,7 +334,10 @@ impl ClientContext {
|
||||
}
|
||||
}
|
||||
Some(msg) = rx.recv() => {
|
||||
callbacks.insert(msg.msg.uuid, msg.callback);
|
||||
// Insert a callback if it isn't a None callback
|
||||
if !matches!(msg.callback, Callback::None) {
|
||||
callbacks.insert(msg.msg.uuid, msg.callback);
|
||||
}
|
||||
let msg = match serde_json::to_string(&msg.msg) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
@@ -258,22 +360,77 @@ impl ClientContext {
|
||||
&mut self,
|
||||
msg: ResponseMessage,
|
||||
callbacks: &mut HashMap<Uuid, Callback>,
|
||||
channel: &ClientChannel,
|
||||
) {
|
||||
if let Some(response_uuid) = msg.response {
|
||||
match callbacks.get(&response_uuid) {
|
||||
Some(Callback::None) => {
|
||||
callbacks.remove(&response_uuid);
|
||||
unreachable!("We skip registering callbacks of None type");
|
||||
}
|
||||
Some(Callback::Once(_)) => {
|
||||
let Some(Callback::Once(callback)) = callbacks.remove(&response_uuid) else {
|
||||
return;
|
||||
};
|
||||
let _ = callback.send(msg);
|
||||
}
|
||||
Some(Callback::None) => {
|
||||
callbacks.remove(&response_uuid);
|
||||
Some(Callback::Registered(callback)) => {
|
||||
let callback = callback.clone();
|
||||
spawn(Self::handle_registered_callback(
|
||||
callback,
|
||||
msg,
|
||||
channel.clone(),
|
||||
));
|
||||
}
|
||||
None => {
|
||||
warn!("No Callback Registered for {response_uuid}");
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_registered_callback(
|
||||
callback: RegisteredCallback,
|
||||
msg: ResponseMessage,
|
||||
channel: ClientChannel,
|
||||
) {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
let uuid = msg.uuid;
|
||||
|
||||
let response = match callback.send((msg, tx)).await {
|
||||
Err(_) => GenericCallbackError::CallbackClosed.into(),
|
||||
Ok(()) => rx
|
||||
.await
|
||||
.unwrap_or_else(|_| GenericCallbackError::CallbackClosed.into()),
|
||||
};
|
||||
|
||||
if let Err(e) = Self::send_response(channel, response, uuid).await {
|
||||
error!("Failed to send response {e}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_response(
|
||||
channel: ClientChannel,
|
||||
payload: RequestMessagePayload,
|
||||
response_uuid: Uuid,
|
||||
) -> Result<(), MessageError> {
|
||||
// If this failed that means we're in the middle of reconnecting, so our callbacks
|
||||
// are all being cleaned up as-is. No response needed.
|
||||
let sender = channel.try_read()?;
|
||||
let data = sender.reserve().await?;
|
||||
|
||||
data.send(OutgoingMessage {
|
||||
msg: RequestMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: Some(response_uuid),
|
||||
payload,
|
||||
},
|
||||
callback: Callback::None,
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Client {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use derive_more::TryInto;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, TryInto)]
|
||||
pub enum DataValue {
|
||||
Float32(f32),
|
||||
Float64(f64),
|
||||
|
||||
7
api/src/messages/callback.rs
Normal file
7
api/src/messages/callback.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum GenericCallbackError {
|
||||
CallbackClosed,
|
||||
MismatchedType,
|
||||
}
|
||||
35
api/src/messages/command.rs
Normal file
35
api/src/messages/command.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use crate::data_type::DataType;
|
||||
use crate::data_value::DataValue;
|
||||
use crate::messages::RegisterCallback;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CommandParameterDefinition {
|
||||
pub name: String,
|
||||
pub data_type: DataType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CommandDefinition {
|
||||
pub name: String,
|
||||
pub parameters: Vec<CommandParameterDefinition>,
|
||||
}
|
||||
|
||||
impl RegisterCallback for CommandDefinition {
|
||||
type Callback = Command;
|
||||
type Response = CommandResponse;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Command {
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub parameters: HashMap<String, DataValue>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CommandResponse {
|
||||
pub success: bool,
|
||||
pub response: String,
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
pub mod callback;
|
||||
pub mod command;
|
||||
pub mod payload;
|
||||
pub mod telemetry_definition;
|
||||
pub mod telemetry_entry;
|
||||
@@ -28,7 +30,7 @@ pub trait RequestResponse: Into<RequestMessagePayload> {
|
||||
type Response: TryFrom<ResponseMessagePayload>;
|
||||
}
|
||||
|
||||
// pub trait RegisterCallback {
|
||||
// type Callback : TryFrom<ResponseMessagePayload>;
|
||||
// type Response : Into<RequestMessagePayload>;
|
||||
// }
|
||||
pub trait RegisterCallback: Into<RequestMessagePayload> {
|
||||
type Callback: TryFrom<ResponseMessagePayload>;
|
||||
type Response: Into<RequestMessagePayload>;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use crate::messages::callback::GenericCallbackError;
|
||||
use crate::messages::command::{Command, CommandDefinition, CommandResponse};
|
||||
use crate::messages::telemetry_definition::{
|
||||
TelemetryDefinitionRequest, TelemetryDefinitionResponse,
|
||||
};
|
||||
@@ -9,9 +11,13 @@ use serde::{Deserialize, Serialize};
|
||||
pub enum RequestMessagePayload {
|
||||
TelemetryDefinitionRequest(TelemetryDefinitionRequest),
|
||||
TelemetryEntry(TelemetryEntry),
|
||||
GenericCallbackError(GenericCallbackError),
|
||||
CommandDefinition(CommandDefinition),
|
||||
CommandResponse(CommandResponse),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, From, TryInto)]
|
||||
pub enum ResponseMessagePayload {
|
||||
TelemetryDefinitionResponse(TelemetryDefinitionResponse),
|
||||
Command(Command),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user