From 4aa86da14a62b80b424df7deea4046cbebbaee73 Mon Sep 17 00:00:00 2001 From: Sergey Savelyev Date: Wed, 31 Dec 2025 18:45:46 -0500 Subject: [PATCH] add tests to api --- Cargo.lock | 1 + Cargo.toml | 2 + api-core/src/command.rs | 8 +- api-core/src/data_value.rs | 2 +- api/Cargo.toml | 3 + api/src/client/context.rs | 376 ++++++++++++++++++++++- api/src/client/mod.rs | 332 ++++++++++++++++++++ api/src/client/telemetry.rs | 3 + api/src/lib.rs | 3 + api/src/messages/callback.rs | 2 +- api/src/messages/command.rs | 2 +- api/src/messages/mod.rs | 2 +- api/src/messages/payload.rs | 4 +- api/src/messages/telemetry_definition.rs | 4 +- api/src/messages/telemetry_entry.rs | 2 +- api/src/test/mock_stream_sink.rs | 82 +++++ api/src/test/mod.rs | 1 + 17 files changed, 803 insertions(+), 26 deletions(-) create mode 100644 api/src/test/mock_stream_sink.rs create mode 100644 api/src/test/mod.rs diff --git a/Cargo.lock b/Cargo.lock index c40ed3d..cb48a77 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -309,6 +309,7 @@ dependencies = [ "api-proc-macro", "chrono", "derive_more", + "env_logger", "futures-util", "log", "serde", diff --git a/Cargo.toml b/Cargo.toml index 7516f1c..8dfc5ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,8 @@ sqlx = "0.8.6" syn = "2.0.112" thiserror = "2.0.17" tokio = { version = "1.48.0" } +tokio-test = "0.4.4" +tokio-stream = "0.1.17" tokio-tungstenite = { version = "0.28.0" } tokio-util = "0.7.17" trybuild = "1.0.114" diff --git a/api-core/src/command.rs b/api-core/src/command.rs index b46a1f9..e8a210e 100644 --- a/api-core/src/command.rs +++ b/api-core/src/command.rs @@ -11,25 +11,25 @@ pub struct CommandParameterDefinition { pub data_type: DataType, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct CommandDefinition { pub name: String, pub parameters: Vec, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct CommandHeader { pub timestamp: DateTime, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct Command { #[serde(flatten)] pub header: CommandHeader, pub parameters: HashMap, } -#[derive(Debug, Error)] +#[derive(Debug, PartialEq, Eq, Error)] pub enum IntoCommandDefinitionError { #[error("Parameter Missing: {0}")] ParameterMissing(String), diff --git a/api-core/src/data_value.rs b/api-core/src/data_value.rs index 3f257eb..622d7d6 100644 --- a/api-core/src/data_value.rs +++ b/api-core/src/data_value.rs @@ -1,7 +1,7 @@ use derive_more::{From, TryInto}; use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Copy, Serialize, Deserialize, From, TryInto)] +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, From, TryInto)] pub enum DataValue { Float32(f32), Float64(f64), diff --git a/api/Cargo.toml b/api/Cargo.toml index 1548081..4ca8c50 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -19,3 +19,6 @@ tokio = { workspace = true, features = ["rt", "macros", "time"] } tokio-tungstenite = { workspace = true, features = ["rustls-tls-native-roots"] } tokio-util = { workspace = true } uuid = { workspace = true, features = ["serde"] } + +[dev-dependencies] +env_logger = { workspace = true } diff --git a/api/src/client/context.rs b/api/src/client/context.rs index 362b92f..2299509 100644 --- a/api/src/client/context.rs +++ b/api/src/client/context.rs @@ -4,19 +4,19 @@ use crate::client::{Callback, ClientChannel, OutgoingMessage, RegisteredCallback use crate::messages::callback::GenericCallbackError; use crate::messages::payload::RequestMessagePayload; use crate::messages::{RequestMessage, ResponseMessage}; -use futures_util::{SinkExt, StreamExt}; +use futures_util::{Sink, SinkExt, Stream, StreamExt}; use log::{debug, error, info, trace, warn}; use std::collections::HashMap; +use std::fmt::Display; use std::sync::mpsc::sync_channel; use std::thread; use std::time::Duration; -use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot, watch, RwLockWriteGuard}; use tokio::time::sleep; use tokio::{select, spawn}; -use tokio_tungstenite::tungstenite::handshake::client::Request; -use tokio_tungstenite::tungstenite::Message; -use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; +use tokio_tungstenite::connect_async; +use tokio_tungstenite::tungstenite::handshake::client::{Request, Response as TungResponse}; +use tokio_tungstenite::tungstenite::{Error as TungError, Message}; use tokio_util::sync::CancellationToken; use uuid::Uuid; @@ -45,7 +45,9 @@ impl ClientContext { let _ = tx.send(()); while !self.cancel.is_cancelled() { - write_lock = self.run_connection(write_lock, &channel).await; + write_lock = self + .run_connection(write_lock, &channel, connect_async) + .await; } drop(write_lock); }); @@ -57,13 +59,19 @@ impl ClientContext { Ok(()) } - async fn run_connection<'a>( + async fn run_connection<'a, F, W, E>( &mut self, mut write_lock: RwLockWriteGuard<'a, mpsc::Sender>, channel: &'a ClientChannel, - ) -> RwLockWriteGuard<'a, mpsc::Sender> { + mut connection_fn: F, + ) -> RwLockWriteGuard<'a, mpsc::Sender> + where + F: AsyncFnMut(Request) -> Result<(W, TungResponse), TungError>, + W: Stream> + Sink + Unpin, + E: Display, + { debug!("Attempting to Connect to {}", self.request.uri()); - let mut ws = match connect_async(self.request.clone()).await { + let mut ws = match connection_fn(self.request.clone()).await { Ok((ws, _)) => ws, Err(e) => { info!("Failed to Connect: {e}"); @@ -87,19 +95,24 @@ impl ClientContext { // the lock to use that as a signal that we have reconnected let _ = self.connected_state_tx.send_replace(false); if close_connection { - if let Err(e) = ws.close(None).await { + // Manually close to allow the impl trait to be used + if let Err(e) = ws.send(Message::Close(None)).await { error!("Failed to Close the Connection: {e}"); } } write_lock } - async fn handle_connection( + async fn handle_connection( &mut self, - ws: &mut WebSocketStream>, + ws: &mut W, mut rx: mpsc::Receiver, channel: &ClientChannel, - ) -> bool { + ) -> bool + where + W: Stream> + Sink + Unpin, + >::Error: Display, + { let mut callbacks = HashMap::::new(); loop { select! { @@ -242,3 +255,340 @@ impl ClientContext { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::messages::telemetry_definition::{ + TelemetryDefinitionRequest, TelemetryDefinitionResponse, + }; + use crate::test::mock_stream_sink::{create_mock_stream_sink, MockStreamSinkControl}; + use api_core::data_type::DataType; + use log::LevelFilter; + use std::future::Future; + use std::ops::Deref; + use tokio::sync::mpsc::Sender; + use tokio::sync::RwLock; + use tokio::time::timeout; + use tokio::try_join; + use tokio_tungstenite::tungstenite::client::IntoClientRequest; + use tokio_util::bytes::Bytes; + + async fn assert_client_interaction(future: F) + where + F: Send + + FnOnce( + Sender, + MockStreamSinkControl, Message>, + CancellationToken, + ) -> R + + 'static, + R: Future + Send, + { + let (control, stream_sink) = + create_mock_stream_sink::, Message>(); + + let cancel_token = CancellationToken::new(); + let inner_cancel_token = cancel_token.clone(); + let (connected_state_tx, _connected_state_rx) = watch::channel(false); + + let mut context = ClientContext { + cancel: cancel_token, + request: "mock".into_client_request().unwrap(), + connected_state_tx, + client_configuration: Default::default(), + }; + + let (tx, _rx) = mpsc::channel(1); + let channel = ClientChannel::new(RwLock::new(tx)); + let used_channel = channel.clone(); + + let write_lock = used_channel.write().await; + + let handle = spawn(async move { + let channel = channel; + let read = channel.read().await; + let sender = read.deref().clone(); + drop(read); + future(sender, control, inner_cancel_token).await; + }); + + let mut stream_sink = Some(stream_sink); + + let connection_fn = async |_: Request| { + let stream_sink = stream_sink.take().ok_or(TungError::ConnectionClosed)?; + + Ok((stream_sink, TungResponse::default())) as Result<(_, _), TungError> + }; + + let context_result = async { + drop( + context + .run_connection(write_lock, &used_channel, connection_fn) + .await, + ); + Ok(()) + }; + + try_join!(context_result, timeout(Duration::from_secs(1), handle),) + .unwrap() + .1 + .unwrap(); + } + + #[tokio::test] + async fn connection_closes_when_websocket_closes() { + let _ = env_logger::builder() + .is_test(true) + .filter_level(LevelFilter::Trace) + .try_init(); + + assert_client_interaction(|sender, mut control, _| async move { + let msg = Uuid::new_v4(); + sender + .send(OutgoingMessage { + msg: RequestMessage { + uuid: msg, + response: None, + payload: TelemetryDefinitionRequest { + name: "".to_string(), + data_type: DataType::Float32, + } + .into(), + }, + callback: Callback::None, + }) + .await + .unwrap(); + // We expect an outgoing message + assert!(matches!( + control.outgoing.recv().await.unwrap(), + Message::Text(_) + )); + // We receive an incoming close message + control + .incoming + .send(Ok(Message::Close(None))) + .await + .unwrap(); + // Then we expect the outgoing to close with no message + assert!(control.outgoing.recv().await.is_none()); + assert!(control.incoming.is_closed()); + }) + .await; + } + + #[tokio::test] + async fn connection_closes_when_cancelled() { + let _ = env_logger::builder() + .is_test(true) + .filter_level(LevelFilter::Trace) + .try_init(); + + assert_client_interaction(|_, mut control, cancel| async move { + cancel.cancel(); + // We expect an outgoing cancel message + assert!(matches!( + control.outgoing.recv().await.unwrap(), + Message::Close(_) + )); + // Then we expect to close with no message + assert!(control.outgoing.recv().await.is_none()); + assert!(control.incoming.is_closed()); + }) + .await; + } + + #[tokio::test] + async fn callback_request() { + let _ = env_logger::builder() + .is_test(true) + .filter_level(LevelFilter::Trace) + .try_init(); + + assert_client_interaction(|sender, mut control, _| async move { + let (callback_tx, callback_rx) = oneshot::channel(); + let msg = Uuid::new_v4(); + sender + .send(OutgoingMessage { + msg: RequestMessage { + uuid: msg, + response: None, + payload: TelemetryDefinitionRequest { + name: "".to_string(), + data_type: DataType::Float32, + } + .into(), + }, + callback: Callback::Once(callback_tx), + }) + .await + .unwrap(); + + // We expect an outgoing message + assert!(matches!( + control.outgoing.recv().await.unwrap(), + Message::Text(_) + )); + + // Then we get an incoming message for this callback + let response_message = ResponseMessage { + uuid: Uuid::new_v4(), + response: Some(msg), + payload: TelemetryDefinitionResponse { + uuid: Uuid::new_v4(), + } + .into(), + }; + control + .incoming + .send(Ok(Message::Text( + serde_json::to_string(&response_message).unwrap().into(), + ))) + .await + .unwrap(); + + // We expect the callback to run + let message = callback_rx.await.unwrap(); + // And give us the message we provided it + assert_eq!(message, response_message); + + // We receive an incoming close message + control + .incoming + .send(Ok(Message::Close(None))) + .await + .unwrap(); + // Then we expect the outgoing to close with no message + assert!(control.outgoing.recv().await.is_none()); + assert!(control.incoming.is_closed()); + }) + .await; + } + + #[tokio::test] + async fn callback_registered() { + let _ = env_logger::builder() + .is_test(true) + .filter_level(LevelFilter::Trace) + .try_init(); + + assert_client_interaction(|sender, mut control, _| async move { + let (callback_tx, mut callback_rx) = mpsc::channel(1); + let msg = Uuid::new_v4(); + sender + .send(OutgoingMessage { + msg: RequestMessage { + uuid: msg, + response: None, + payload: TelemetryDefinitionRequest { + name: "".to_string(), + data_type: DataType::Float32, + } + .into(), + }, + callback: Callback::Registered(callback_tx), + }) + .await + .unwrap(); + + // We expect an outgoing message + assert!(matches!( + control.outgoing.recv().await.unwrap(), + Message::Text(_) + )); + + // We handle the callback a few times + for _ in 0..5 { + // Then we get an incoming message for this callback + let response_message = ResponseMessage { + uuid: Uuid::new_v4(), + response: Some(msg), + payload: TelemetryDefinitionResponse { + uuid: Uuid::new_v4(), + } + .into(), + }; + control + .incoming + .send(Ok(Message::Text( + serde_json::to_string(&response_message).unwrap().into(), + ))) + .await + .unwrap(); + + // We expect the response + let (rx, responder) = callback_rx.recv().await.unwrap(); + // And give us the message we provided it + assert_eq!(rx, response_message); + // Then the response gets sent out + responder + .send( + TelemetryDefinitionRequest { + name: "".to_string(), + data_type: DataType::Float32, + } + .into(), + ) + .unwrap(); + + // We expect an outgoing message + assert!(matches!( + control.outgoing.recv().await.unwrap(), + Message::Text(_) + )); + } + + // We receive an incoming close message + control + .incoming + .send(Ok(Message::Close(None))) + .await + .unwrap(); + // Then we expect the outgoing to close with no message + assert!(control.outgoing.recv().await.is_none()); + assert!(control.incoming.is_closed()); + }) + .await; + } + + #[tokio::test] + async fn ping_pong() { + let _ = env_logger::builder() + .is_test(true) + .filter_level(LevelFilter::Trace) + .try_init(); + + assert_client_interaction(|_, mut control, _| async move { + // Expect a pong in response to a ping + let bytes = Bytes::from_owner(Uuid::new_v4().into_bytes()); + control + .incoming + .send(Ok(Message::Ping(bytes.clone()))) + .await + .unwrap(); + let Some(Message::Pong(pong_bytes)) = control.outgoing.recv().await else { + panic!("Expected Pong Response"); + }; + assert_eq!(bytes, pong_bytes); + + // Nothing should happen + control + .incoming + .send(Ok(Message::Pong(bytes.clone()))) + .await + .unwrap(); + + // We receive an incoming close message + control + .incoming + .send(Ok(Message::Close(None))) + .await + .unwrap(); + // Then we expect the outgoing to close with no message + assert!(control.outgoing.recv().await.is_none()); + assert!(control.incoming.is_closed()); + }) + .await; + } +} diff --git a/api/src/client/mod.rs b/api/src/client/mod.rs index d333c51..95c8f9b 100644 --- a/api/src/client/mod.rs +++ b/api/src/client/mod.rs @@ -24,6 +24,7 @@ use uuid::Uuid; type RegisteredCallback = mpsc::Sender<(ResponseMessage, oneshot::Sender)>; type ClientChannel = Arc>>; +#[derive(Debug)] enum Callback { None, Once(oneshot::Sender), @@ -264,3 +265,334 @@ impl Drop for Client { self.cancel.cancel(); } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::messages::command::CommandResponse; + use crate::messages::telemetry_definition::{ + TelemetryDefinitionRequest, TelemetryDefinitionResponse, + }; + use crate::messages::telemetry_entry::TelemetryEntry; + use api_core::command::{Command, CommandDefinition, CommandHeader}; + use api_core::data_type::DataType; + use chrono::Utc; + use futures_util::future::{select, Either}; + use futures_util::FutureExt; + use std::future::Future; + use std::pin::{pin, Pin}; + use std::time::Duration; + use tokio::join; + use tokio::time::{sleep, timeout}; + + fn create_test_client() -> (mpsc::Receiver, watch::Sender, Client) { + let cancel = CancellationToken::new(); + let (tx, rx) = mpsc::channel(1); + let channel = Arc::new(RwLock::new(tx)); + let (connected_state_tx, connected_state_rx) = watch::channel(true); + let client = Client { + cancel, + channel, + connected_state_rx, + }; + (rx, connected_state_tx, client) + } + + #[tokio::test] + async fn send_message() { + let (mut rx, _, client) = create_test_client(); + + let msg_to_send = TelemetryEntry { + uuid: Uuid::new_v4(), + value: 0.0f32.into(), + timestamp: Utc::now(), + }; + let msg_send = timeout( + Duration::from_secs(1), + client.send_message(msg_to_send.clone()), + ); + let msg_recv = timeout(Duration::from_secs(1), rx.recv()); + + let (send, recv) = join!(msg_send, msg_recv); + send.unwrap().unwrap(); + let recv = recv.unwrap().unwrap(); + + assert!(matches!(recv.callback, Callback::None)); + assert!(recv.msg.response.is_none()); + // uuid should be random + + let RequestMessagePayload::TelemetryEntry(recv) = recv.msg.payload else { + panic!("Wrong Message Received") + }; + + assert_eq!(recv, msg_to_send); + } + + #[tokio::test] + async fn send_message_if_connected() { + let (mut rx, _, client) = create_test_client(); + + let msg_to_send = TelemetryEntry { + uuid: Uuid::new_v4(), + value: 0.0f32.into(), + timestamp: Utc::now(), + }; + let msg_send = timeout( + Duration::from_secs(1), + client.send_message_if_connected(msg_to_send.clone()), + ); + let msg_recv = timeout(Duration::from_secs(1), rx.recv()); + + let (send, recv) = join!(msg_send, msg_recv); + send.unwrap().unwrap(); + let recv = recv.unwrap().unwrap(); + + assert!(matches!(recv.callback, Callback::None)); + assert!(recv.msg.response.is_none()); + // uuid should be random + + let RequestMessagePayload::TelemetryEntry(recv) = recv.msg.payload else { + panic!("Wrong Message Received") + }; + + assert_eq!(recv, msg_to_send); + } + + #[tokio::test] + async fn send_message_if_connected_not_connected() { + let (_, connected_state_tx, client) = create_test_client(); + + let lock = client.channel.write().await; + connected_state_tx.send_replace(false); + + let msg_to_send = TelemetryEntry { + uuid: Uuid::new_v4(), + value: 0.0f32.into(), + timestamp: Utc::now(), + }; + let msg_send = timeout( + Duration::from_secs(1), + client.send_message_if_connected(msg_to_send.clone()), + ); + + let Err(MessageError::TokioLockError(_)) = msg_send.await.unwrap() else { + panic!("Expected to Err due to lock being unavailable") + }; + } + + #[tokio::test] + async fn try_send_message() { + let (_tx, _, client) = create_test_client(); + + let msg_to_send = TelemetryEntry { + uuid: Uuid::new_v4(), + value: 0.0f32.into(), + timestamp: Utc::now(), + }; + client.try_send_message(msg_to_send.clone()).unwrap(); + let Err(MessageError::TokioTrySendError(_)) = client.try_send_message(msg_to_send.clone()) + else { + panic!("Expected the buffer to be full"); + }; + } + + #[tokio::test] + async fn send_request() { + let (mut tx, _, client) = create_test_client(); + + let msg_to_send = TelemetryDefinitionRequest { + name: "".to_string(), + data_type: DataType::Float32, + }; + let response = timeout( + Duration::from_secs(1), + client.send_request(msg_to_send.clone()), + ); + + let response_uuid = Uuid::new_v4(); + let outgoing_rx = timeout(Duration::from_secs(1), async { + let msg = tx.recv().await.unwrap(); + let Callback::Once(cb) = msg.callback else { + panic!("Wrong Callback Type") + }; + cb.send(ResponseMessage { + uuid: Uuid::new_v4(), + response: Some(msg.msg.uuid), + payload: TelemetryDefinitionResponse { + uuid: response_uuid, + } + .into(), + }) + .unwrap(); + }); + + let (response, outgoing_rx) = join!(response, outgoing_rx); + let response = response.unwrap().unwrap(); + outgoing_rx.unwrap(); + + assert_eq!(response.uuid, response_uuid); + } + + #[tokio::test] + async fn register_callback_channel() { + let (mut tx, _, client) = create_test_client(); + + let msg_to_send = CommandDefinition { + name: "".to_string(), + parameters: vec![], + }; + let mut response = timeout( + Duration::from_secs(1), + client.register_callback_channel(msg_to_send), + ) + .await + .unwrap() + .unwrap(); + + let response_uuid = Uuid::new_v4(); + let outgoing_rx = timeout(Duration::from_secs(1), async { + let msg = tx.recv().await.unwrap(); + let Callback::Registered(cb) = msg.callback else { + panic!("Wrong Callback Type") + }; + + // Check that we get responses to the callback the expected number of times + for i in 0..5 { + let (tx, rx) = oneshot::channel(); + cb.send(( + ResponseMessage { + uuid: Uuid::new_v4(), + response: Some(msg.msg.uuid), + payload: Command { + header: CommandHeader { + timestamp: Utc::now(), + }, + parameters: Default::default(), + } + .into(), + }, + tx, + )) + .await + .unwrap(); + let RequestMessagePayload::CommandResponse(response) = rx.await.unwrap() else { + panic!("Unexpected Response Type"); + }; + assert_eq!(response.response, format!("{i}")); + } + }); + + let responder = timeout(Duration::from_secs(1), async { + for i in 0..5 { + let (cmd, responder) = response.recv().await.unwrap(); + responder + .send(CommandResponse { + success: false, + response: format!("{i}"), + }) + .unwrap(); + } + }); + + let (response, outgoing_rx) = join!(responder, outgoing_rx); + response.unwrap(); + outgoing_rx.unwrap(); + } + + #[tokio::test] + async fn register_callback_fn() { + let (mut tx, _, client) = create_test_client(); + + let msg_to_send = CommandDefinition { + name: "".to_string(), + parameters: vec![], + }; + let mut index = 0usize; + timeout( + Duration::from_secs(1), + client.register_callback_fn(msg_to_send, move |cmd| { + index += 1; + CommandResponse { + success: false, + response: format!("{}", index - 1), + } + }), + ) + .await + .unwrap() + .unwrap(); + + timeout(Duration::from_secs(1), async { + let msg = tx.recv().await.unwrap(); + let Callback::Registered(cb) = msg.callback else { + panic!("Wrong Callback Type") + }; + + // Check that we get responses to the callback the expected number of times + for i in 0..3 { + let (tx, rx) = oneshot::channel(); + cb.send(( + ResponseMessage { + uuid: Uuid::new_v4(), + response: Some(msg.msg.uuid), + payload: Command { + header: CommandHeader { + timestamp: Utc::now(), + }, + parameters: Default::default(), + } + .into(), + }, + tx, + )) + .await + .unwrap(); + let RequestMessagePayload::CommandResponse(response) = rx.await.unwrap() else { + panic!("Unexpected Response Type"); + }; + assert_eq!(response.response, format!("{i}")); + } + }) + .await + .unwrap(); + } + + #[tokio::test] + async fn connected_disconnected() { + let (_, connected, client) = create_test_client(); + + // When we're connected we should return immediately + connected.send_replace(true); + client.wait_connected().now_or_never().unwrap(); + + // When we're disconnected we should return immediately + connected.send_replace(false); + client.wait_disconnected().now_or_never().unwrap(); + + let c2 = connected.clone(); + // When we're disconnected, we should not return immediately + let f1 = pin!(client.wait_connected()); + let f2 = pin!(async move { + sleep(Duration::from_millis(1)).await; + c2.send_replace(true); + }); + let r = select(f1, f2).await; + match r { + Either::Left(_) => panic!("Wait Connected Finished Before Connection Changed"), + Either::Right((_, other)) => timeout(Duration::from_secs(1), other).await.unwrap(), + } + + let c2 = connected.clone(); + // When we're disconnected, we should not return immediately + let f1 = pin!(client.wait_disconnected()); + let f2 = pin!(async move { + sleep(Duration::from_millis(1)).await; + c2.send_replace(false); + }); + let r = select(f1, f2).await; + match r { + Either::Left(_) => panic!("Wait Disconnected Finished Before Connection Changed"), + Either::Right((_, other)) => timeout(Duration::from_secs(1), other).await.unwrap(), + } + } +} diff --git a/api/src/client/telemetry.rs b/api/src/client/telemetry.rs index 7cf0b70..1ec66dd 100644 --- a/api/src/client/telemetry.rs +++ b/api/src/client/telemetry.rs @@ -163,3 +163,6 @@ impl> TelemetryHandle { self.publish(value, Utc::now()).await } } + +#[cfg(test)] +mod tests {} diff --git a/api/src/lib.rs b/api/src/lib.rs index e7a1279..d98d8ff 100644 --- a/api/src/lib.rs +++ b/api/src/lib.rs @@ -10,3 +10,6 @@ pub mod messages; pub mod macros { pub use api_proc_macro::IntoCommandDefinition; } + +#[cfg(test)] +pub mod test; diff --git a/api/src/messages/callback.rs b/api/src/messages/callback.rs index 950287d..6b63f4f 100644 --- a/api/src/messages/callback.rs +++ b/api/src/messages/callback.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum GenericCallbackError { CallbackClosed, MismatchedType, diff --git a/api/src/messages/command.rs b/api/src/messages/command.rs index 75b8a1e..fe56329 100644 --- a/api/src/messages/command.rs +++ b/api/src/messages/command.rs @@ -8,7 +8,7 @@ impl RegisterCallback for CommandDefinition { type Response = CommandResponse; } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct CommandResponse { pub success: bool, pub response: String, diff --git a/api/src/messages/mod.rs b/api/src/messages/mod.rs index 3112e83..c4abb6b 100644 --- a/api/src/messages/mod.rs +++ b/api/src/messages/mod.rs @@ -18,7 +18,7 @@ pub struct RequestMessage { pub payload: RequestMessagePayload, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct ResponseMessage { pub uuid: Uuid, #[serde(default)] diff --git a/api/src/messages/payload.rs b/api/src/messages/payload.rs index 0dbe173..f525f24 100644 --- a/api/src/messages/payload.rs +++ b/api/src/messages/payload.rs @@ -7,7 +7,7 @@ use crate::messages::telemetry_entry::TelemetryEntry; use derive_more::{From, TryInto}; use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Serialize, Deserialize, From)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, From)] pub enum RequestMessagePayload { TelemetryDefinitionRequest(TelemetryDefinitionRequest), TelemetryEntry(TelemetryEntry), @@ -16,7 +16,7 @@ pub enum RequestMessagePayload { CommandResponse(CommandResponse), } -#[derive(Debug, Clone, Serialize, Deserialize, From, TryInto)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, From, TryInto)] pub enum ResponseMessagePayload { TelemetryDefinitionResponse(TelemetryDefinitionResponse), Command(Command), diff --git a/api/src/messages/telemetry_definition.rs b/api/src/messages/telemetry_definition.rs index f3f1f2d..d6369fb 100644 --- a/api/src/messages/telemetry_definition.rs +++ b/api/src/messages/telemetry_definition.rs @@ -3,13 +3,13 @@ use crate::messages::RequestResponse; use serde::{Deserialize, Serialize}; use uuid::Uuid; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TelemetryDefinitionRequest { pub name: String, pub data_type: DataType, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TelemetryDefinitionResponse { pub uuid: Uuid, } diff --git a/api/src/messages/telemetry_entry.rs b/api/src/messages/telemetry_entry.rs index 3376e30..cb82ce9 100644 --- a/api/src/messages/telemetry_entry.rs +++ b/api/src/messages/telemetry_entry.rs @@ -4,7 +4,7 @@ use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use uuid::Uuid; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct TelemetryEntry { pub uuid: Uuid, pub value: DataValue, diff --git a/api/src/test/mock_stream_sink.rs b/api/src/test/mock_stream_sink.rs new file mode 100644 index 0000000..aa60dee --- /dev/null +++ b/api/src/test/mock_stream_sink.rs @@ -0,0 +1,82 @@ +use futures_util::sink::{unfold, Unfold}; +use futures_util::{Sink, SinkExt, Stream}; +use std::fmt::Display; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::mpsc; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::mpsc::{Receiver, Sender}; + +pub struct MockStreamSinkControl { + pub incoming: Sender, + pub outgoing: Receiver, +} + +pub struct MockStreamSink { + stream_rx: Receiver, + sink_tx: Pin>>, +} + +impl Stream for MockStreamSink +where + Self: Unpin, +{ + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.stream_rx.poll_recv(cx) + } +} + +impl Sink for MockStreamSink +where + U1: FnMut(u32, R) -> U2, + U2: Future>, +{ + type Error = E; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.sink_tx.poll_ready_unpin(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: R) -> Result<(), Self::Error> { + self.sink_tx.start_send_unpin(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.sink_tx.poll_flush_unpin(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.sink_tx.poll_close_unpin(cx) + } +} + +pub fn create_mock_stream_sink() -> ( + MockStreamSinkControl, + impl Stream + Sink, +) { + let (stream_tx, stream_rx) = mpsc::channel::(1); + let (sink_tx, sink_rx) = mpsc::channel::(1); + + let sink_tx = Arc::new(sink_tx); + + ( + MockStreamSinkControl { + incoming: stream_tx, + outgoing: sink_rx, + }, + MockStreamSink:: { + stream_rx, + sink_tx: Box::pin(unfold(0u32, move |_, item| { + let sink_tx = sink_tx.clone(); + async move { + sink_tx.send(item).await?; + Ok(0u32) as Result<_, SendError> + } + })), + }, + ) +} diff --git a/api/src/test/mod.rs b/api/src/test/mod.rs new file mode 100644 index 0000000..86ba89b --- /dev/null +++ b/api/src/test/mod.rs @@ -0,0 +1 @@ +pub mod mock_stream_sink;