Replace gRPC Backend (#10)
**Rationale:** Having two separate servers and communication methods resulted in additional maintenance & the need to convert often between backend & frontend data types. By moving the backend communication off of gRPC and to just use websockets it both gives more control & allows for simplification of the implementation. #8 **Changes:** - Replaces gRPC backend. - New implementation automatically handles reconnect logic - Implements an api layer - Migrates examples to the api layer - Implements a proc macro to make command handling easier - Implements unit tests for the api layer (90+% coverage) - Implements integration tests for the proc macro (90+% coverage) Reviewed-on: #10 Co-authored-by: Sergey Savelyev <sergeysav.nn@gmail.com> Co-committed-by: Sergey Savelyev <sergeysav.nn@gmail.com>
This commit was merged in pull request #10.
This commit is contained in:
598
api/src/client/mod.rs
Normal file
598
api/src/client/mod.rs
Normal file
@@ -0,0 +1,598 @@
|
||||
pub mod command;
|
||||
mod config;
|
||||
mod context;
|
||||
pub mod error;
|
||||
pub mod telemetry;
|
||||
|
||||
use crate::client::config::ClientConfiguration;
|
||||
use crate::client::error::{MessageError, RequestError};
|
||||
use crate::messages::callback::GenericCallbackError;
|
||||
use crate::messages::payload::RequestMessagePayload;
|
||||
use crate::messages::payload::ResponseMessagePayload;
|
||||
use crate::messages::{
|
||||
ClientMessage, RegisterCallback, RequestMessage, RequestResponse, ResponseMessage,
|
||||
};
|
||||
use context::ClientContext;
|
||||
use error::ConnectError;
|
||||
use std::sync::Arc;
|
||||
use tokio::spawn;
|
||||
use tokio::sync::{mpsc, oneshot, watch, RwLock};
|
||||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
type RegisteredCallback = mpsc::Sender<(ResponseMessage, oneshot::Sender<RequestMessagePayload>)>;
|
||||
type ClientChannel = Arc<RwLock<mpsc::Sender<OutgoingMessage>>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Callback {
|
||||
None,
|
||||
Once(oneshot::Sender<ResponseMessage>),
|
||||
Registered(RegisteredCallback),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct OutgoingMessage {
|
||||
msg: RequestMessage,
|
||||
callback: Callback,
|
||||
}
|
||||
|
||||
pub struct Client {
|
||||
cancel: CancellationToken,
|
||||
channel: ClientChannel,
|
||||
connected_state_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub fn connect<R>(request: R) -> Result<Self, ConnectError>
|
||||
where
|
||||
R: IntoClientRequest,
|
||||
{
|
||||
Self::connect_with_config(request, ClientConfiguration::default())
|
||||
}
|
||||
|
||||
pub fn connect_with_config<R>(
|
||||
request: R,
|
||||
config: ClientConfiguration,
|
||||
) -> Result<Self, ConnectError>
|
||||
where
|
||||
R: IntoClientRequest,
|
||||
{
|
||||
let (tx, _rx) = mpsc::channel(1);
|
||||
let cancel = CancellationToken::new();
|
||||
let channel = Arc::new(RwLock::new(tx));
|
||||
let (connected_state_tx, connected_state_rx) = watch::channel(false);
|
||||
let context = ClientContext {
|
||||
cancel: cancel.clone(),
|
||||
request: request.into_client_request()?,
|
||||
connected_state_tx,
|
||||
client_configuration: config,
|
||||
};
|
||||
|
||||
context.start(channel.clone())?;
|
||||
|
||||
Ok(Self {
|
||||
cancel,
|
||||
channel,
|
||||
connected_state_rx,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn send_message<M: ClientMessage>(&self, msg: M) -> Result<(), MessageError> {
|
||||
let sender = self.channel.read().await;
|
||||
let data = sender.reserve().await?;
|
||||
data.send(OutgoingMessage {
|
||||
msg: RequestMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: None,
|
||||
payload: msg.into(),
|
||||
},
|
||||
callback: Callback::None,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn send_message_if_connected<M: ClientMessage>(
|
||||
&self,
|
||||
msg: M,
|
||||
) -> Result<(), MessageError> {
|
||||
let sender = self.channel.try_read()?;
|
||||
let data = sender.reserve().await?;
|
||||
data.send(OutgoingMessage {
|
||||
msg: RequestMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: None,
|
||||
payload: msg.into(),
|
||||
},
|
||||
callback: Callback::None,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn try_send_message<M: ClientMessage>(&self, msg: M) -> Result<(), MessageError> {
|
||||
let sender = self.channel.try_read()?;
|
||||
let data = sender.try_reserve()?;
|
||||
data.send(OutgoingMessage {
|
||||
msg: RequestMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: None,
|
||||
payload: msg.into(),
|
||||
},
|
||||
callback: Callback::None,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn send_request<M: RequestResponse>(
|
||||
&self,
|
||||
msg: M,
|
||||
) -> Result<M::Response, RequestError<<M::Response as TryFrom<ResponseMessagePayload>>::Error>>
|
||||
{
|
||||
let sender = self.channel.read().await;
|
||||
let data = sender.reserve().await?;
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
data.send(OutgoingMessage {
|
||||
msg: RequestMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: None,
|
||||
payload: msg.into(),
|
||||
},
|
||||
callback: Callback::Once(tx),
|
||||
});
|
||||
|
||||
let response = rx.await?;
|
||||
let response = M::Response::try_from(response.payload).map_err(RequestError::Inner)?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn register_callback_channel<M: RegisterCallback>(
|
||||
&self,
|
||||
msg: M,
|
||||
) -> Result<mpsc::Receiver<(M::Callback, oneshot::Sender<M::Response>)>, MessageError>
|
||||
where
|
||||
<M as RegisterCallback>::Callback: Send + 'static,
|
||||
<M as RegisterCallback>::Response: Send + 'static,
|
||||
<<M as RegisterCallback>::Callback as TryFrom<ResponseMessagePayload>>::Error: Send,
|
||||
{
|
||||
let sender = self.channel.read().await;
|
||||
let data = sender.reserve().await?;
|
||||
|
||||
let (inner_tx, mut inner_rx) = mpsc::channel(16);
|
||||
let (outer_tx, outer_rx) = mpsc::channel(1);
|
||||
|
||||
data.send(OutgoingMessage {
|
||||
msg: RequestMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: None,
|
||||
payload: msg.into(),
|
||||
},
|
||||
callback: Callback::Registered(inner_tx),
|
||||
});
|
||||
|
||||
spawn(async move {
|
||||
// If the handler was unregistered we can stop
|
||||
while let Some((msg, responder)) = inner_rx.recv().await {
|
||||
let response: RequestMessagePayload = match M::Callback::try_from(msg.payload) {
|
||||
Err(_) => GenericCallbackError::MismatchedType.into(),
|
||||
Ok(o) => {
|
||||
let (response_tx, response_rx) = oneshot::channel::<M::Response>();
|
||||
match outer_tx.send((o, response_tx)).await {
|
||||
Err(_) => GenericCallbackError::CallbackClosed.into(),
|
||||
Ok(()) => response_rx
|
||||
.await
|
||||
.map(M::Response::into)
|
||||
.unwrap_or_else(|_| GenericCallbackError::CallbackClosed.into()),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if responder.send(response).is_err() {
|
||||
// If the callback was unregistered we can stop
|
||||
break;
|
||||
}
|
||||
}
|
||||
println!("Exited Loop");
|
||||
});
|
||||
|
||||
Ok(outer_rx)
|
||||
}
|
||||
|
||||
pub async fn register_callback_fn<M: RegisterCallback, F>(
|
||||
&self,
|
||||
msg: M,
|
||||
mut f: F,
|
||||
) -> Result<(), MessageError>
|
||||
where
|
||||
F: FnMut(M::Callback) -> M::Response + Send + 'static,
|
||||
{
|
||||
let sender = self.channel.read().await;
|
||||
let data = sender.reserve().await?;
|
||||
|
||||
let (inner_tx, mut inner_rx) = mpsc::channel(16);
|
||||
|
||||
data.send(OutgoingMessage {
|
||||
msg: RequestMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: None,
|
||||
payload: msg.into(),
|
||||
},
|
||||
callback: Callback::Registered(inner_tx),
|
||||
});
|
||||
|
||||
spawn(async move {
|
||||
// If the handler was unregistered we can stop
|
||||
while let Some((msg, responder)) = inner_rx.recv().await {
|
||||
let response: RequestMessagePayload = match M::Callback::try_from(msg.payload) {
|
||||
Err(_) => GenericCallbackError::MismatchedType.into(),
|
||||
Ok(o) => f(o).into(),
|
||||
};
|
||||
|
||||
if responder.send(response).is_err() {
|
||||
// If the callback was unregistered we can stop
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn wait_connected(&self) {
|
||||
let mut connected_rx = self.connected_state_rx.clone();
|
||||
|
||||
// If we aren't currently connected
|
||||
if !*connected_rx.borrow_and_update() {
|
||||
// Wait for a change notification
|
||||
// If the channel is closed there is nothing we can do
|
||||
let _ = connected_rx.changed().await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn wait_disconnected(&self) {
|
||||
let mut connected_rx = self.connected_state_rx.clone();
|
||||
|
||||
// If we are currently connected
|
||||
if *connected_rx.borrow_and_update() {
|
||||
// Wait for a change notification
|
||||
// If the channel is closed there is nothing we can do
|
||||
let _ = connected_rx.changed().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Client {
|
||||
fn drop(&mut self) {
|
||||
self.cancel.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::messages::command::CommandResponse;
|
||||
use crate::messages::telemetry_definition::{
|
||||
TelemetryDefinitionRequest, TelemetryDefinitionResponse,
|
||||
};
|
||||
use crate::messages::telemetry_entry::TelemetryEntry;
|
||||
use api_core::command::{Command, CommandDefinition, CommandHeader};
|
||||
use api_core::data_type::DataType;
|
||||
use chrono::Utc;
|
||||
use futures_util::future::{select, Either};
|
||||
use futures_util::FutureExt;
|
||||
use std::pin::pin;
|
||||
use std::time::Duration;
|
||||
use tokio::join;
|
||||
use tokio::time::{sleep, timeout};
|
||||
|
||||
pub fn create_test_client() -> (mpsc::Receiver<OutgoingMessage>, watch::Sender<bool>, Client) {
|
||||
let cancel = CancellationToken::new();
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
let channel = Arc::new(RwLock::new(tx));
|
||||
let (connected_state_tx, connected_state_rx) = watch::channel(true);
|
||||
let client = Client {
|
||||
cancel,
|
||||
channel,
|
||||
connected_state_rx,
|
||||
};
|
||||
(rx, connected_state_tx, client)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_message() {
|
||||
let (mut rx, _, client) = create_test_client();
|
||||
|
||||
let msg_to_send = TelemetryEntry {
|
||||
uuid: Uuid::new_v4(),
|
||||
value: 0.0f32.into(),
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
let msg_send = timeout(
|
||||
Duration::from_secs(1),
|
||||
client.send_message(msg_to_send.clone()),
|
||||
);
|
||||
let msg_recv = timeout(Duration::from_secs(1), rx.recv());
|
||||
|
||||
let (send, recv) = join!(msg_send, msg_recv);
|
||||
send.unwrap().unwrap();
|
||||
let recv = recv.unwrap().unwrap();
|
||||
|
||||
assert!(matches!(recv.callback, Callback::None));
|
||||
assert!(recv.msg.response.is_none());
|
||||
// uuid should be random
|
||||
|
||||
let RequestMessagePayload::TelemetryEntry(recv) = recv.msg.payload else {
|
||||
panic!("Wrong Message Received")
|
||||
};
|
||||
|
||||
assert_eq!(recv, msg_to_send);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_message_if_connected() {
|
||||
let (mut rx, _, client) = create_test_client();
|
||||
|
||||
let msg_to_send = TelemetryEntry {
|
||||
uuid: Uuid::new_v4(),
|
||||
value: 0.0f32.into(),
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
let msg_send = timeout(
|
||||
Duration::from_secs(1),
|
||||
client.send_message_if_connected(msg_to_send.clone()),
|
||||
);
|
||||
let msg_recv = timeout(Duration::from_secs(1), rx.recv());
|
||||
|
||||
let (send, recv) = join!(msg_send, msg_recv);
|
||||
send.unwrap().unwrap();
|
||||
let recv = recv.unwrap().unwrap();
|
||||
|
||||
assert!(matches!(recv.callback, Callback::None));
|
||||
assert!(recv.msg.response.is_none());
|
||||
// uuid should be random
|
||||
|
||||
let RequestMessagePayload::TelemetryEntry(recv) = recv.msg.payload else {
|
||||
panic!("Wrong Message Received")
|
||||
};
|
||||
|
||||
assert_eq!(recv, msg_to_send);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_message_if_connected_not_connected() {
|
||||
let (_, connected_state_tx, client) = create_test_client();
|
||||
|
||||
let _lock = client.channel.write().await;
|
||||
connected_state_tx.send_replace(false);
|
||||
|
||||
let msg_to_send = TelemetryEntry {
|
||||
uuid: Uuid::new_v4(),
|
||||
value: 0.0f32.into(),
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
let msg_send = timeout(
|
||||
Duration::from_secs(1),
|
||||
client.send_message_if_connected(msg_to_send.clone()),
|
||||
);
|
||||
|
||||
let Err(MessageError::TokioLockError(_)) = msg_send.await.unwrap() else {
|
||||
panic!("Expected to Err due to lock being unavailable")
|
||||
};
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn try_send_message() {
|
||||
let (_tx, _, client) = create_test_client();
|
||||
|
||||
let msg_to_send = TelemetryEntry {
|
||||
uuid: Uuid::new_v4(),
|
||||
value: 0.0f32.into(),
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
client.try_send_message(msg_to_send.clone()).unwrap();
|
||||
let Err(MessageError::TokioTrySendError(_)) = client.try_send_message(msg_to_send.clone())
|
||||
else {
|
||||
panic!("Expected the buffer to be full");
|
||||
};
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_request() {
|
||||
let (mut tx, _, client) = create_test_client();
|
||||
|
||||
let msg_to_send = TelemetryDefinitionRequest {
|
||||
name: "".to_string(),
|
||||
data_type: DataType::Float32,
|
||||
};
|
||||
let response = timeout(
|
||||
Duration::from_secs(1),
|
||||
client.send_request(msg_to_send.clone()),
|
||||
);
|
||||
|
||||
let response_uuid = Uuid::new_v4();
|
||||
let outgoing_rx = timeout(Duration::from_secs(1), async {
|
||||
let msg = tx.recv().await.unwrap();
|
||||
let Callback::Once(cb) = msg.callback else {
|
||||
panic!("Wrong Callback Type")
|
||||
};
|
||||
cb.send(ResponseMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: Some(msg.msg.uuid),
|
||||
payload: TelemetryDefinitionResponse {
|
||||
uuid: response_uuid,
|
||||
}
|
||||
.into(),
|
||||
})
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let (response, outgoing_rx) = join!(response, outgoing_rx);
|
||||
let response = response.unwrap().unwrap();
|
||||
outgoing_rx.unwrap();
|
||||
|
||||
assert_eq!(response.uuid, response_uuid);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn register_callback_channel() {
|
||||
let (mut tx, _, client) = create_test_client();
|
||||
|
||||
let msg_to_send = CommandDefinition {
|
||||
name: "".to_string(),
|
||||
parameters: vec![],
|
||||
};
|
||||
let mut response = timeout(
|
||||
Duration::from_secs(1),
|
||||
client.register_callback_channel(msg_to_send),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let outgoing_rx = timeout(Duration::from_secs(1), async {
|
||||
let msg = tx.recv().await.unwrap();
|
||||
let Callback::Registered(cb) = msg.callback else {
|
||||
panic!("Wrong Callback Type")
|
||||
};
|
||||
|
||||
// Check that we get responses to the callback the expected number of times
|
||||
for i in 0..5 {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
cb.send((
|
||||
ResponseMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: Some(msg.msg.uuid),
|
||||
payload: Command {
|
||||
header: CommandHeader {
|
||||
timestamp: Utc::now(),
|
||||
},
|
||||
parameters: Default::default(),
|
||||
}
|
||||
.into(),
|
||||
},
|
||||
tx,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
let RequestMessagePayload::CommandResponse(response) = rx.await.unwrap() else {
|
||||
panic!("Unexpected Response Type");
|
||||
};
|
||||
assert_eq!(response.response, format!("{i}"));
|
||||
}
|
||||
});
|
||||
|
||||
let responder = timeout(Duration::from_secs(1), async {
|
||||
for i in 0..5 {
|
||||
let (_cmd, responder) = response.recv().await.unwrap();
|
||||
responder
|
||||
.send(CommandResponse {
|
||||
success: false,
|
||||
response: format!("{i}"),
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let (response, outgoing_rx) = join!(responder, outgoing_rx);
|
||||
response.unwrap();
|
||||
outgoing_rx.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn register_callback_fn() {
|
||||
let (mut tx, _, client) = create_test_client();
|
||||
|
||||
let msg_to_send = CommandDefinition {
|
||||
name: "".to_string(),
|
||||
parameters: vec![],
|
||||
};
|
||||
let mut index = 0usize;
|
||||
timeout(
|
||||
Duration::from_secs(1),
|
||||
client.register_callback_fn(msg_to_send, move |_| {
|
||||
index += 1;
|
||||
CommandResponse {
|
||||
success: false,
|
||||
response: format!("{}", index - 1),
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
timeout(Duration::from_secs(1), async {
|
||||
let msg = tx.recv().await.unwrap();
|
||||
let Callback::Registered(cb) = msg.callback else {
|
||||
panic!("Wrong Callback Type")
|
||||
};
|
||||
|
||||
// Check that we get responses to the callback the expected number of times
|
||||
for i in 0..3 {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
cb.send((
|
||||
ResponseMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: Some(msg.msg.uuid),
|
||||
payload: Command {
|
||||
header: CommandHeader {
|
||||
timestamp: Utc::now(),
|
||||
},
|
||||
parameters: Default::default(),
|
||||
}
|
||||
.into(),
|
||||
},
|
||||
tx,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
let RequestMessagePayload::CommandResponse(response) = rx.await.unwrap() else {
|
||||
panic!("Unexpected Response Type");
|
||||
};
|
||||
assert_eq!(response.response, format!("{i}"));
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connected_disconnected() {
|
||||
let (_, connected, client) = create_test_client();
|
||||
|
||||
// When we're connected we should return immediately
|
||||
connected.send_replace(true);
|
||||
client.wait_connected().now_or_never().unwrap();
|
||||
|
||||
// When we're disconnected we should return immediately
|
||||
connected.send_replace(false);
|
||||
client.wait_disconnected().now_or_never().unwrap();
|
||||
|
||||
let c2 = connected.clone();
|
||||
// When we're disconnected, we should not return immediately
|
||||
let f1 = pin!(client.wait_connected());
|
||||
let f2 = pin!(async move {
|
||||
sleep(Duration::from_millis(1)).await;
|
||||
c2.send_replace(true);
|
||||
});
|
||||
let r = select(f1, f2).await;
|
||||
match r {
|
||||
Either::Left(_) => panic!("Wait Connected Finished Before Connection Changed"),
|
||||
Either::Right((_, other)) => timeout(Duration::from_secs(1), other).await.unwrap(),
|
||||
}
|
||||
|
||||
let c2 = connected.clone();
|
||||
// When we're disconnected, we should not return immediately
|
||||
let f1 = pin!(client.wait_disconnected());
|
||||
let f2 = pin!(async move {
|
||||
sleep(Duration::from_millis(1)).await;
|
||||
c2.send_replace(false);
|
||||
});
|
||||
let r = select(f1, f2).await;
|
||||
match r {
|
||||
Either::Left(_) => panic!("Wait Disconnected Finished Before Connection Changed"),
|
||||
Either::Right((_, other)) => timeout(Duration::from_secs(1), other).await.unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user