add tests to api
This commit is contained in:
@@ -19,3 +19,6 @@ tokio = { workspace = true, features = ["rt", "macros", "time"] }
|
||||
tokio-tungstenite = { workspace = true, features = ["rustls-tls-native-roots"] }
|
||||
tokio-util = { workspace = true }
|
||||
uuid = { workspace = true, features = ["serde"] }
|
||||
|
||||
[dev-dependencies]
|
||||
env_logger = { workspace = true }
|
||||
|
||||
@@ -4,19 +4,19 @@ use crate::client::{Callback, ClientChannel, OutgoingMessage, RegisteredCallback
|
||||
use crate::messages::callback::GenericCallbackError;
|
||||
use crate::messages::payload::RequestMessagePayload;
|
||||
use crate::messages::{RequestMessage, ResponseMessage};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use futures_util::{Sink, SinkExt, Stream, StreamExt};
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
use std::sync::mpsc::sync_channel;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{mpsc, oneshot, watch, RwLockWriteGuard};
|
||||
use tokio::time::sleep;
|
||||
use tokio::{select, spawn};
|
||||
use tokio_tungstenite::tungstenite::handshake::client::Request;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
|
||||
use tokio_tungstenite::connect_async;
|
||||
use tokio_tungstenite::tungstenite::handshake::client::{Request, Response as TungResponse};
|
||||
use tokio_tungstenite::tungstenite::{Error as TungError, Message};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -45,7 +45,9 @@ impl ClientContext {
|
||||
let _ = tx.send(());
|
||||
|
||||
while !self.cancel.is_cancelled() {
|
||||
write_lock = self.run_connection(write_lock, &channel).await;
|
||||
write_lock = self
|
||||
.run_connection(write_lock, &channel, connect_async)
|
||||
.await;
|
||||
}
|
||||
drop(write_lock);
|
||||
});
|
||||
@@ -57,13 +59,19 @@ impl ClientContext {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_connection<'a>(
|
||||
async fn run_connection<'a, F, W, E>(
|
||||
&mut self,
|
||||
mut write_lock: RwLockWriteGuard<'a, mpsc::Sender<OutgoingMessage>>,
|
||||
channel: &'a ClientChannel,
|
||||
) -> RwLockWriteGuard<'a, mpsc::Sender<OutgoingMessage>> {
|
||||
mut connection_fn: F,
|
||||
) -> RwLockWriteGuard<'a, mpsc::Sender<OutgoingMessage>>
|
||||
where
|
||||
F: AsyncFnMut(Request) -> Result<(W, TungResponse), TungError>,
|
||||
W: Stream<Item = Result<Message, TungError>> + Sink<Message, Error = E> + Unpin,
|
||||
E: Display,
|
||||
{
|
||||
debug!("Attempting to Connect to {}", self.request.uri());
|
||||
let mut ws = match connect_async(self.request.clone()).await {
|
||||
let mut ws = match connection_fn(self.request.clone()).await {
|
||||
Ok((ws, _)) => ws,
|
||||
Err(e) => {
|
||||
info!("Failed to Connect: {e}");
|
||||
@@ -87,19 +95,24 @@ impl ClientContext {
|
||||
// the lock to use that as a signal that we have reconnected
|
||||
let _ = self.connected_state_tx.send_replace(false);
|
||||
if close_connection {
|
||||
if let Err(e) = ws.close(None).await {
|
||||
// Manually close to allow the impl trait to be used
|
||||
if let Err(e) = ws.send(Message::Close(None)).await {
|
||||
error!("Failed to Close the Connection: {e}");
|
||||
}
|
||||
}
|
||||
write_lock
|
||||
}
|
||||
|
||||
async fn handle_connection(
|
||||
async fn handle_connection<W>(
|
||||
&mut self,
|
||||
ws: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
ws: &mut W,
|
||||
mut rx: mpsc::Receiver<OutgoingMessage>,
|
||||
channel: &ClientChannel,
|
||||
) -> bool {
|
||||
) -> bool
|
||||
where
|
||||
W: Stream<Item = Result<Message, TungError>> + Sink<Message> + Unpin,
|
||||
<W as Sink<Message>>::Error: Display,
|
||||
{
|
||||
let mut callbacks = HashMap::<Uuid, Callback>::new();
|
||||
loop {
|
||||
select! {
|
||||
@@ -242,3 +255,340 @@ impl ClientContext {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::messages::telemetry_definition::{
|
||||
TelemetryDefinitionRequest, TelemetryDefinitionResponse,
|
||||
};
|
||||
use crate::test::mock_stream_sink::{create_mock_stream_sink, MockStreamSinkControl};
|
||||
use api_core::data_type::DataType;
|
||||
use log::LevelFilter;
|
||||
use std::future::Future;
|
||||
use std::ops::Deref;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time::timeout;
|
||||
use tokio::try_join;
|
||||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||
use tokio_util::bytes::Bytes;
|
||||
|
||||
async fn assert_client_interaction<F, R>(future: F)
|
||||
where
|
||||
F: Send
|
||||
+ FnOnce(
|
||||
Sender<OutgoingMessage>,
|
||||
MockStreamSinkControl<Result<Message, TungError>, Message>,
|
||||
CancellationToken,
|
||||
) -> R
|
||||
+ 'static,
|
||||
R: Future<Output = ()> + Send,
|
||||
{
|
||||
let (control, stream_sink) =
|
||||
create_mock_stream_sink::<Result<Message, TungError>, Message>();
|
||||
|
||||
let cancel_token = CancellationToken::new();
|
||||
let inner_cancel_token = cancel_token.clone();
|
||||
let (connected_state_tx, _connected_state_rx) = watch::channel(false);
|
||||
|
||||
let mut context = ClientContext {
|
||||
cancel: cancel_token,
|
||||
request: "mock".into_client_request().unwrap(),
|
||||
connected_state_tx,
|
||||
client_configuration: Default::default(),
|
||||
};
|
||||
|
||||
let (tx, _rx) = mpsc::channel(1);
|
||||
let channel = ClientChannel::new(RwLock::new(tx));
|
||||
let used_channel = channel.clone();
|
||||
|
||||
let write_lock = used_channel.write().await;
|
||||
|
||||
let handle = spawn(async move {
|
||||
let channel = channel;
|
||||
let read = channel.read().await;
|
||||
let sender = read.deref().clone();
|
||||
drop(read);
|
||||
future(sender, control, inner_cancel_token).await;
|
||||
});
|
||||
|
||||
let mut stream_sink = Some(stream_sink);
|
||||
|
||||
let connection_fn = async |_: Request| {
|
||||
let stream_sink = stream_sink.take().ok_or(TungError::ConnectionClosed)?;
|
||||
|
||||
Ok((stream_sink, TungResponse::default())) as Result<(_, _), TungError>
|
||||
};
|
||||
|
||||
let context_result = async {
|
||||
drop(
|
||||
context
|
||||
.run_connection(write_lock, &used_channel, connection_fn)
|
||||
.await,
|
||||
);
|
||||
Ok(())
|
||||
};
|
||||
|
||||
try_join!(context_result, timeout(Duration::from_secs(1), handle),)
|
||||
.unwrap()
|
||||
.1
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connection_closes_when_websocket_closes() {
|
||||
let _ = env_logger::builder()
|
||||
.is_test(true)
|
||||
.filter_level(LevelFilter::Trace)
|
||||
.try_init();
|
||||
|
||||
assert_client_interaction(|sender, mut control, _| async move {
|
||||
let msg = Uuid::new_v4();
|
||||
sender
|
||||
.send(OutgoingMessage {
|
||||
msg: RequestMessage {
|
||||
uuid: msg,
|
||||
response: None,
|
||||
payload: TelemetryDefinitionRequest {
|
||||
name: "".to_string(),
|
||||
data_type: DataType::Float32,
|
||||
}
|
||||
.into(),
|
||||
},
|
||||
callback: Callback::None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
// We expect an outgoing message
|
||||
assert!(matches!(
|
||||
control.outgoing.recv().await.unwrap(),
|
||||
Message::Text(_)
|
||||
));
|
||||
// We receive an incoming close message
|
||||
control
|
||||
.incoming
|
||||
.send(Ok(Message::Close(None)))
|
||||
.await
|
||||
.unwrap();
|
||||
// Then we expect the outgoing to close with no message
|
||||
assert!(control.outgoing.recv().await.is_none());
|
||||
assert!(control.incoming.is_closed());
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connection_closes_when_cancelled() {
|
||||
let _ = env_logger::builder()
|
||||
.is_test(true)
|
||||
.filter_level(LevelFilter::Trace)
|
||||
.try_init();
|
||||
|
||||
assert_client_interaction(|_, mut control, cancel| async move {
|
||||
cancel.cancel();
|
||||
// We expect an outgoing cancel message
|
||||
assert!(matches!(
|
||||
control.outgoing.recv().await.unwrap(),
|
||||
Message::Close(_)
|
||||
));
|
||||
// Then we expect to close with no message
|
||||
assert!(control.outgoing.recv().await.is_none());
|
||||
assert!(control.incoming.is_closed());
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn callback_request() {
|
||||
let _ = env_logger::builder()
|
||||
.is_test(true)
|
||||
.filter_level(LevelFilter::Trace)
|
||||
.try_init();
|
||||
|
||||
assert_client_interaction(|sender, mut control, _| async move {
|
||||
let (callback_tx, callback_rx) = oneshot::channel();
|
||||
let msg = Uuid::new_v4();
|
||||
sender
|
||||
.send(OutgoingMessage {
|
||||
msg: RequestMessage {
|
||||
uuid: msg,
|
||||
response: None,
|
||||
payload: TelemetryDefinitionRequest {
|
||||
name: "".to_string(),
|
||||
data_type: DataType::Float32,
|
||||
}
|
||||
.into(),
|
||||
},
|
||||
callback: Callback::Once(callback_tx),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// We expect an outgoing message
|
||||
assert!(matches!(
|
||||
control.outgoing.recv().await.unwrap(),
|
||||
Message::Text(_)
|
||||
));
|
||||
|
||||
// Then we get an incoming message for this callback
|
||||
let response_message = ResponseMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: Some(msg),
|
||||
payload: TelemetryDefinitionResponse {
|
||||
uuid: Uuid::new_v4(),
|
||||
}
|
||||
.into(),
|
||||
};
|
||||
control
|
||||
.incoming
|
||||
.send(Ok(Message::Text(
|
||||
serde_json::to_string(&response_message).unwrap().into(),
|
||||
)))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// We expect the callback to run
|
||||
let message = callback_rx.await.unwrap();
|
||||
// And give us the message we provided it
|
||||
assert_eq!(message, response_message);
|
||||
|
||||
// We receive an incoming close message
|
||||
control
|
||||
.incoming
|
||||
.send(Ok(Message::Close(None)))
|
||||
.await
|
||||
.unwrap();
|
||||
// Then we expect the outgoing to close with no message
|
||||
assert!(control.outgoing.recv().await.is_none());
|
||||
assert!(control.incoming.is_closed());
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn callback_registered() {
|
||||
let _ = env_logger::builder()
|
||||
.is_test(true)
|
||||
.filter_level(LevelFilter::Trace)
|
||||
.try_init();
|
||||
|
||||
assert_client_interaction(|sender, mut control, _| async move {
|
||||
let (callback_tx, mut callback_rx) = mpsc::channel(1);
|
||||
let msg = Uuid::new_v4();
|
||||
sender
|
||||
.send(OutgoingMessage {
|
||||
msg: RequestMessage {
|
||||
uuid: msg,
|
||||
response: None,
|
||||
payload: TelemetryDefinitionRequest {
|
||||
name: "".to_string(),
|
||||
data_type: DataType::Float32,
|
||||
}
|
||||
.into(),
|
||||
},
|
||||
callback: Callback::Registered(callback_tx),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// We expect an outgoing message
|
||||
assert!(matches!(
|
||||
control.outgoing.recv().await.unwrap(),
|
||||
Message::Text(_)
|
||||
));
|
||||
|
||||
// We handle the callback a few times
|
||||
for _ in 0..5 {
|
||||
// Then we get an incoming message for this callback
|
||||
let response_message = ResponseMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: Some(msg),
|
||||
payload: TelemetryDefinitionResponse {
|
||||
uuid: Uuid::new_v4(),
|
||||
}
|
||||
.into(),
|
||||
};
|
||||
control
|
||||
.incoming
|
||||
.send(Ok(Message::Text(
|
||||
serde_json::to_string(&response_message).unwrap().into(),
|
||||
)))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// We expect the response
|
||||
let (rx, responder) = callback_rx.recv().await.unwrap();
|
||||
// And give us the message we provided it
|
||||
assert_eq!(rx, response_message);
|
||||
// Then the response gets sent out
|
||||
responder
|
||||
.send(
|
||||
TelemetryDefinitionRequest {
|
||||
name: "".to_string(),
|
||||
data_type: DataType::Float32,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// We expect an outgoing message
|
||||
assert!(matches!(
|
||||
control.outgoing.recv().await.unwrap(),
|
||||
Message::Text(_)
|
||||
));
|
||||
}
|
||||
|
||||
// We receive an incoming close message
|
||||
control
|
||||
.incoming
|
||||
.send(Ok(Message::Close(None)))
|
||||
.await
|
||||
.unwrap();
|
||||
// Then we expect the outgoing to close with no message
|
||||
assert!(control.outgoing.recv().await.is_none());
|
||||
assert!(control.incoming.is_closed());
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ping_pong() {
|
||||
let _ = env_logger::builder()
|
||||
.is_test(true)
|
||||
.filter_level(LevelFilter::Trace)
|
||||
.try_init();
|
||||
|
||||
assert_client_interaction(|_, mut control, _| async move {
|
||||
// Expect a pong in response to a ping
|
||||
let bytes = Bytes::from_owner(Uuid::new_v4().into_bytes());
|
||||
control
|
||||
.incoming
|
||||
.send(Ok(Message::Ping(bytes.clone())))
|
||||
.await
|
||||
.unwrap();
|
||||
let Some(Message::Pong(pong_bytes)) = control.outgoing.recv().await else {
|
||||
panic!("Expected Pong Response");
|
||||
};
|
||||
assert_eq!(bytes, pong_bytes);
|
||||
|
||||
// Nothing should happen
|
||||
control
|
||||
.incoming
|
||||
.send(Ok(Message::Pong(bytes.clone())))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// We receive an incoming close message
|
||||
control
|
||||
.incoming
|
||||
.send(Ok(Message::Close(None)))
|
||||
.await
|
||||
.unwrap();
|
||||
// Then we expect the outgoing to close with no message
|
||||
assert!(control.outgoing.recv().await.is_none());
|
||||
assert!(control.incoming.is_closed());
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ use uuid::Uuid;
|
||||
type RegisteredCallback = mpsc::Sender<(ResponseMessage, oneshot::Sender<RequestMessagePayload>)>;
|
||||
type ClientChannel = Arc<RwLock<mpsc::Sender<OutgoingMessage>>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Callback {
|
||||
None,
|
||||
Once(oneshot::Sender<ResponseMessage>),
|
||||
@@ -264,3 +265,334 @@ impl Drop for Client {
|
||||
self.cancel.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::messages::command::CommandResponse;
|
||||
use crate::messages::telemetry_definition::{
|
||||
TelemetryDefinitionRequest, TelemetryDefinitionResponse,
|
||||
};
|
||||
use crate::messages::telemetry_entry::TelemetryEntry;
|
||||
use api_core::command::{Command, CommandDefinition, CommandHeader};
|
||||
use api_core::data_type::DataType;
|
||||
use chrono::Utc;
|
||||
use futures_util::future::{select, Either};
|
||||
use futures_util::FutureExt;
|
||||
use std::future::Future;
|
||||
use std::pin::{pin, Pin};
|
||||
use std::time::Duration;
|
||||
use tokio::join;
|
||||
use tokio::time::{sleep, timeout};
|
||||
|
||||
fn create_test_client() -> (mpsc::Receiver<OutgoingMessage>, watch::Sender<bool>, Client) {
|
||||
let cancel = CancellationToken::new();
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
let channel = Arc::new(RwLock::new(tx));
|
||||
let (connected_state_tx, connected_state_rx) = watch::channel(true);
|
||||
let client = Client {
|
||||
cancel,
|
||||
channel,
|
||||
connected_state_rx,
|
||||
};
|
||||
(rx, connected_state_tx, client)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_message() {
|
||||
let (mut rx, _, client) = create_test_client();
|
||||
|
||||
let msg_to_send = TelemetryEntry {
|
||||
uuid: Uuid::new_v4(),
|
||||
value: 0.0f32.into(),
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
let msg_send = timeout(
|
||||
Duration::from_secs(1),
|
||||
client.send_message(msg_to_send.clone()),
|
||||
);
|
||||
let msg_recv = timeout(Duration::from_secs(1), rx.recv());
|
||||
|
||||
let (send, recv) = join!(msg_send, msg_recv);
|
||||
send.unwrap().unwrap();
|
||||
let recv = recv.unwrap().unwrap();
|
||||
|
||||
assert!(matches!(recv.callback, Callback::None));
|
||||
assert!(recv.msg.response.is_none());
|
||||
// uuid should be random
|
||||
|
||||
let RequestMessagePayload::TelemetryEntry(recv) = recv.msg.payload else {
|
||||
panic!("Wrong Message Received")
|
||||
};
|
||||
|
||||
assert_eq!(recv, msg_to_send);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_message_if_connected() {
|
||||
let (mut rx, _, client) = create_test_client();
|
||||
|
||||
let msg_to_send = TelemetryEntry {
|
||||
uuid: Uuid::new_v4(),
|
||||
value: 0.0f32.into(),
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
let msg_send = timeout(
|
||||
Duration::from_secs(1),
|
||||
client.send_message_if_connected(msg_to_send.clone()),
|
||||
);
|
||||
let msg_recv = timeout(Duration::from_secs(1), rx.recv());
|
||||
|
||||
let (send, recv) = join!(msg_send, msg_recv);
|
||||
send.unwrap().unwrap();
|
||||
let recv = recv.unwrap().unwrap();
|
||||
|
||||
assert!(matches!(recv.callback, Callback::None));
|
||||
assert!(recv.msg.response.is_none());
|
||||
// uuid should be random
|
||||
|
||||
let RequestMessagePayload::TelemetryEntry(recv) = recv.msg.payload else {
|
||||
panic!("Wrong Message Received")
|
||||
};
|
||||
|
||||
assert_eq!(recv, msg_to_send);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_message_if_connected_not_connected() {
|
||||
let (_, connected_state_tx, client) = create_test_client();
|
||||
|
||||
let lock = client.channel.write().await;
|
||||
connected_state_tx.send_replace(false);
|
||||
|
||||
let msg_to_send = TelemetryEntry {
|
||||
uuid: Uuid::new_v4(),
|
||||
value: 0.0f32.into(),
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
let msg_send = timeout(
|
||||
Duration::from_secs(1),
|
||||
client.send_message_if_connected(msg_to_send.clone()),
|
||||
);
|
||||
|
||||
let Err(MessageError::TokioLockError(_)) = msg_send.await.unwrap() else {
|
||||
panic!("Expected to Err due to lock being unavailable")
|
||||
};
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn try_send_message() {
|
||||
let (_tx, _, client) = create_test_client();
|
||||
|
||||
let msg_to_send = TelemetryEntry {
|
||||
uuid: Uuid::new_v4(),
|
||||
value: 0.0f32.into(),
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
client.try_send_message(msg_to_send.clone()).unwrap();
|
||||
let Err(MessageError::TokioTrySendError(_)) = client.try_send_message(msg_to_send.clone())
|
||||
else {
|
||||
panic!("Expected the buffer to be full");
|
||||
};
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_request() {
|
||||
let (mut tx, _, client) = create_test_client();
|
||||
|
||||
let msg_to_send = TelemetryDefinitionRequest {
|
||||
name: "".to_string(),
|
||||
data_type: DataType::Float32,
|
||||
};
|
||||
let response = timeout(
|
||||
Duration::from_secs(1),
|
||||
client.send_request(msg_to_send.clone()),
|
||||
);
|
||||
|
||||
let response_uuid = Uuid::new_v4();
|
||||
let outgoing_rx = timeout(Duration::from_secs(1), async {
|
||||
let msg = tx.recv().await.unwrap();
|
||||
let Callback::Once(cb) = msg.callback else {
|
||||
panic!("Wrong Callback Type")
|
||||
};
|
||||
cb.send(ResponseMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: Some(msg.msg.uuid),
|
||||
payload: TelemetryDefinitionResponse {
|
||||
uuid: response_uuid,
|
||||
}
|
||||
.into(),
|
||||
})
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let (response, outgoing_rx) = join!(response, outgoing_rx);
|
||||
let response = response.unwrap().unwrap();
|
||||
outgoing_rx.unwrap();
|
||||
|
||||
assert_eq!(response.uuid, response_uuid);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn register_callback_channel() {
|
||||
let (mut tx, _, client) = create_test_client();
|
||||
|
||||
let msg_to_send = CommandDefinition {
|
||||
name: "".to_string(),
|
||||
parameters: vec![],
|
||||
};
|
||||
let mut response = timeout(
|
||||
Duration::from_secs(1),
|
||||
client.register_callback_channel(msg_to_send),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let response_uuid = Uuid::new_v4();
|
||||
let outgoing_rx = timeout(Duration::from_secs(1), async {
|
||||
let msg = tx.recv().await.unwrap();
|
||||
let Callback::Registered(cb) = msg.callback else {
|
||||
panic!("Wrong Callback Type")
|
||||
};
|
||||
|
||||
// Check that we get responses to the callback the expected number of times
|
||||
for i in 0..5 {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
cb.send((
|
||||
ResponseMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: Some(msg.msg.uuid),
|
||||
payload: Command {
|
||||
header: CommandHeader {
|
||||
timestamp: Utc::now(),
|
||||
},
|
||||
parameters: Default::default(),
|
||||
}
|
||||
.into(),
|
||||
},
|
||||
tx,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
let RequestMessagePayload::CommandResponse(response) = rx.await.unwrap() else {
|
||||
panic!("Unexpected Response Type");
|
||||
};
|
||||
assert_eq!(response.response, format!("{i}"));
|
||||
}
|
||||
});
|
||||
|
||||
let responder = timeout(Duration::from_secs(1), async {
|
||||
for i in 0..5 {
|
||||
let (cmd, responder) = response.recv().await.unwrap();
|
||||
responder
|
||||
.send(CommandResponse {
|
||||
success: false,
|
||||
response: format!("{i}"),
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let (response, outgoing_rx) = join!(responder, outgoing_rx);
|
||||
response.unwrap();
|
||||
outgoing_rx.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn register_callback_fn() {
|
||||
let (mut tx, _, client) = create_test_client();
|
||||
|
||||
let msg_to_send = CommandDefinition {
|
||||
name: "".to_string(),
|
||||
parameters: vec![],
|
||||
};
|
||||
let mut index = 0usize;
|
||||
timeout(
|
||||
Duration::from_secs(1),
|
||||
client.register_callback_fn(msg_to_send, move |cmd| {
|
||||
index += 1;
|
||||
CommandResponse {
|
||||
success: false,
|
||||
response: format!("{}", index - 1),
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
timeout(Duration::from_secs(1), async {
|
||||
let msg = tx.recv().await.unwrap();
|
||||
let Callback::Registered(cb) = msg.callback else {
|
||||
panic!("Wrong Callback Type")
|
||||
};
|
||||
|
||||
// Check that we get responses to the callback the expected number of times
|
||||
for i in 0..3 {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
cb.send((
|
||||
ResponseMessage {
|
||||
uuid: Uuid::new_v4(),
|
||||
response: Some(msg.msg.uuid),
|
||||
payload: Command {
|
||||
header: CommandHeader {
|
||||
timestamp: Utc::now(),
|
||||
},
|
||||
parameters: Default::default(),
|
||||
}
|
||||
.into(),
|
||||
},
|
||||
tx,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
let RequestMessagePayload::CommandResponse(response) = rx.await.unwrap() else {
|
||||
panic!("Unexpected Response Type");
|
||||
};
|
||||
assert_eq!(response.response, format!("{i}"));
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connected_disconnected() {
|
||||
let (_, connected, client) = create_test_client();
|
||||
|
||||
// When we're connected we should return immediately
|
||||
connected.send_replace(true);
|
||||
client.wait_connected().now_or_never().unwrap();
|
||||
|
||||
// When we're disconnected we should return immediately
|
||||
connected.send_replace(false);
|
||||
client.wait_disconnected().now_or_never().unwrap();
|
||||
|
||||
let c2 = connected.clone();
|
||||
// When we're disconnected, we should not return immediately
|
||||
let f1 = pin!(client.wait_connected());
|
||||
let f2 = pin!(async move {
|
||||
sleep(Duration::from_millis(1)).await;
|
||||
c2.send_replace(true);
|
||||
});
|
||||
let r = select(f1, f2).await;
|
||||
match r {
|
||||
Either::Left(_) => panic!("Wait Connected Finished Before Connection Changed"),
|
||||
Either::Right((_, other)) => timeout(Duration::from_secs(1), other).await.unwrap(),
|
||||
}
|
||||
|
||||
let c2 = connected.clone();
|
||||
// When we're disconnected, we should not return immediately
|
||||
let f1 = pin!(client.wait_disconnected());
|
||||
let f2 = pin!(async move {
|
||||
sleep(Duration::from_millis(1)).await;
|
||||
c2.send_replace(false);
|
||||
});
|
||||
let r = select(f1, f2).await;
|
||||
match r {
|
||||
Either::Left(_) => panic!("Wait Disconnected Finished Before Connection Changed"),
|
||||
Either::Right((_, other)) => timeout(Duration::from_secs(1), other).await.unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,3 +163,6 @@ impl<T: Into<DataValue>> TelemetryHandle<T> {
|
||||
self.publish(value, Utc::now()).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {}
|
||||
|
||||
@@ -10,3 +10,6 @@ pub mod messages;
|
||||
pub mod macros {
|
||||
pub use api_proc_macro::IntoCommandDefinition;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum GenericCallbackError {
|
||||
CallbackClosed,
|
||||
MismatchedType,
|
||||
|
||||
@@ -8,7 +8,7 @@ impl RegisterCallback for CommandDefinition {
|
||||
type Response = CommandResponse;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct CommandResponse {
|
||||
pub success: bool,
|
||||
pub response: String,
|
||||
|
||||
@@ -18,7 +18,7 @@ pub struct RequestMessage {
|
||||
pub payload: RequestMessagePayload,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ResponseMessage {
|
||||
pub uuid: Uuid,
|
||||
#[serde(default)]
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::messages::telemetry_entry::TelemetryEntry;
|
||||
use derive_more::{From, TryInto};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, From)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, From)]
|
||||
pub enum RequestMessagePayload {
|
||||
TelemetryDefinitionRequest(TelemetryDefinitionRequest),
|
||||
TelemetryEntry(TelemetryEntry),
|
||||
@@ -16,7 +16,7 @@ pub enum RequestMessagePayload {
|
||||
CommandResponse(CommandResponse),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, From, TryInto)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, From, TryInto)]
|
||||
pub enum ResponseMessagePayload {
|
||||
TelemetryDefinitionResponse(TelemetryDefinitionResponse),
|
||||
Command(Command),
|
||||
|
||||
@@ -3,13 +3,13 @@ use crate::messages::RequestResponse;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct TelemetryDefinitionRequest {
|
||||
pub name: String,
|
||||
pub data_type: DataType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct TelemetryDefinitionResponse {
|
||||
pub uuid: Uuid,
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct TelemetryEntry {
|
||||
pub uuid: Uuid,
|
||||
pub value: DataValue,
|
||||
|
||||
82
api/src/test/mock_stream_sink.rs
Normal file
82
api/src/test/mock_stream_sink.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use futures_util::sink::{unfold, Unfold};
|
||||
use futures_util::{Sink, SinkExt, Stream};
|
||||
use std::fmt::Display;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::mpsc::{Receiver, Sender};
|
||||
|
||||
pub struct MockStreamSinkControl<T, R> {
|
||||
pub incoming: Sender<T>,
|
||||
pub outgoing: Receiver<R>,
|
||||
}
|
||||
|
||||
pub struct MockStreamSink<T, U1, U2> {
|
||||
stream_rx: Receiver<T>,
|
||||
sink_tx: Pin<Box<Unfold<u32, U1, U2>>>,
|
||||
}
|
||||
|
||||
impl<T, U1, U2> Stream for MockStreamSink<T, U1, U2>
|
||||
where
|
||||
Self: Unpin,
|
||||
{
|
||||
type Item = T;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.stream_rx.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, R, U1, U2, E> Sink<R> for MockStreamSink<T, U1, U2>
|
||||
where
|
||||
U1: FnMut(u32, R) -> U2,
|
||||
U2: Future<Output = Result<u32, E>>,
|
||||
{
|
||||
type Error = E;
|
||||
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.sink_tx.poll_ready_unpin(cx)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, item: R) -> Result<(), Self::Error> {
|
||||
self.sink_tx.start_send_unpin(item)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.sink_tx.poll_flush_unpin(cx)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.sink_tx.poll_close_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_mock_stream_sink<T: Send, R: Send + 'static>() -> (
|
||||
MockStreamSinkControl<T, R>,
|
||||
impl Stream<Item = T> + Sink<R, Error = impl Display>,
|
||||
) {
|
||||
let (stream_tx, stream_rx) = mpsc::channel::<T>(1);
|
||||
let (sink_tx, sink_rx) = mpsc::channel::<R>(1);
|
||||
|
||||
let sink_tx = Arc::new(sink_tx);
|
||||
|
||||
(
|
||||
MockStreamSinkControl {
|
||||
incoming: stream_tx,
|
||||
outgoing: sink_rx,
|
||||
},
|
||||
MockStreamSink::<T, _, _> {
|
||||
stream_rx,
|
||||
sink_tx: Box::pin(unfold(0u32, move |_, item| {
|
||||
let sink_tx = sink_tx.clone();
|
||||
async move {
|
||||
sink_tx.send(item).await?;
|
||||
Ok(0u32) as Result<_, SendError<R>>
|
||||
}
|
||||
})),
|
||||
},
|
||||
)
|
||||
}
|
||||
1
api/src/test/mod.rs
Normal file
1
api/src/test/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod mock_stream_sink;
|
||||
Reference in New Issue
Block a user