From d3b882f56d7df7550e80be1a0ba4853e44d894a4 Mon Sep 17 00:00:00 2001 From: Sergey Savelyev Date: Thu, 1 Jan 2026 12:13:05 -0500 Subject: [PATCH] add tests for the api --- Cargo.lock | 2 - api-core/Cargo.toml | 2 +- api-core/src/data_type.rs | 3 +- api-core/src/data_value.rs | 11 + api-proc-macro/Cargo.toml | 2 +- api-proc-macro/src/into_command_definition.rs | 10 +- ...st_derive_macro_into_command_definition.rs | 19 + api/src/client/command.rs | 412 +++++++++++++++++- api/src/client/error.rs | 6 + api/src/client/mod.rs | 14 +- api/src/client/telemetry.rs | 298 ++++++++++++- examples/simple_producer/Cargo.toml | 2 - 12 files changed, 739 insertions(+), 42 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cb48a77..70a8d29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2020,11 +2020,9 @@ dependencies = [ "chrono", "env_logger", "futures-util", - "log", "num-traits", "tokio", "tokio-util", - "uuid", ] [[package]] diff --git a/api-core/Cargo.toml b/api-core/Cargo.toml index 19e0b66..bedaa3a 100644 --- a/api-core/Cargo.toml +++ b/api-core/Cargo.toml @@ -7,6 +7,6 @@ authors = ["Sergey "] [dependencies] chrono = { workspace = true, features = ["serde"] } -derive_more = { workspace = true, features = ["from", "try_into"] } +derive_more = { workspace = true, features = ["display", "from", "try_into"] } serde = { workspace = true, features = ["derive"] } thiserror = { workspace = true } diff --git a/api-core/src/data_type.rs b/api-core/src/data_type.rs index ad4c584..807e084 100644 --- a/api-core/src/data_type.rs +++ b/api-core/src/data_type.rs @@ -1,7 +1,8 @@ use crate::data_value::DataValue; +use derive_more::Display; use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Display)] pub enum DataType { Float32, Float64, diff --git a/api-core/src/data_value.rs b/api-core/src/data_value.rs index 622d7d6..9e14bd7 100644 --- a/api-core/src/data_value.rs +++ b/api-core/src/data_value.rs @@ -1,3 +1,4 @@ +use crate::data_type::DataType; use derive_more::{From, TryInto}; use serde::{Deserialize, Serialize}; @@ -7,3 +8,13 @@ pub enum DataValue { Float64(f64), Boolean(bool), } + +impl DataValue { + pub fn to_data_type(self) -> DataType { + match self { + DataValue::Float32(_) => DataType::Float32, + DataValue::Float64(_) => DataType::Float64, + DataValue::Boolean(_) => DataType::Boolean, + } + } +} diff --git a/api-proc-macro/Cargo.toml b/api-proc-macro/Cargo.toml index c1d1649..da77f88 100644 --- a/api-proc-macro/Cargo.toml +++ b/api-proc-macro/Cargo.toml @@ -15,5 +15,5 @@ quote = { workspace = true } syn = { workspace = true } [dev-dependencies] -trybuild = { workspace = true } api = { path = "../api" } +trybuild = { workspace = true } diff --git a/api-proc-macro/src/into_command_definition.rs b/api-proc-macro/src/into_command_definition.rs index 359ed4c..18ecf27 100644 --- a/api-proc-macro/src/into_command_definition.rs +++ b/api-proc-macro/src/into_command_definition.rs @@ -70,10 +70,7 @@ pub fn derive_into_command_definition_impl( }); quote! { #(#field_entries)* } } - Fields::Unnamed(fields) => abort!( - fields, - "IntoCommandDefinition not supported for unnamed structs" - ), + Fields::Unnamed(_) => unreachable!("Already checked this"), Fields::Unit => quote! {}, }; let param_name_stream = match &data.fields { @@ -84,10 +81,7 @@ pub fn derive_into_command_definition_impl( }); quote! { #(#field_entries)* } } - Fields::Unnamed(fields) => abort!( - fields, - "IntoCommandDefinition not supported for unnamed structs" - ), + Fields::Unnamed(_) => unreachable!("Already checked this"), Fields::Unit => quote! {}, }; diff --git a/api-proc-macro/tests/test_derive_macro_into_command_definition.rs b/api-proc-macro/tests/test_derive_macro_into_command_definition.rs index 4be746f..dda6e9e 100644 --- a/api-proc-macro/tests/test_derive_macro_into_command_definition.rs +++ b/api-proc-macro/tests/test_derive_macro_into_command_definition.rs @@ -149,3 +149,22 @@ fn test_generic_command() { .unwrap(); assert_eq!(result.a, true); } + +#[test] +fn test_unit_command() { + #[derive(IntoCommandDefinition)] + struct TestStruct; + + let command_definition = TestStruct::create("Test".to_string()); + + assert_eq!(command_definition.name, "Test"); + assert_eq!(command_definition.parameters.capacity(), 0); + + TestStruct::parse(Command { + header: CommandHeader { + timestamp: Default::default(), + }, + parameters: HashMap::new(), + }) + .unwrap(); +} diff --git a/api/src/client/command.rs b/api/src/client/command.rs index 84b3dc3..63d834e 100644 --- a/api/src/client/command.rs +++ b/api/src/client/command.rs @@ -3,6 +3,7 @@ use crate::messages::command::CommandResponse; use api_core::command::{CommandHeader, IntoCommandDefinition}; use std::fmt::Display; use std::sync::Arc; +use tokio::select; use tokio_util::sync::CancellationToken; pub struct CommandRegistry { @@ -41,26 +42,36 @@ impl CommandRegistry { continue; }; - while let Some((cmd, responder)) = rx.recv().await { - let header = cmd.header.clone(); - let response = match C::parse(cmd) { - Ok(cmd) => match callback(header, cmd) { - Ok(response) => CommandResponse { - success: true, - response, - }, - Err(err) => CommandResponse { - success: false, - response: err.to_string(), - }, + loop { + // select used so that this loop gets broken if the token is cancelled + select!( + rx_value = rx.recv() => { + if let Some((cmd, responder)) = rx_value { + let header = cmd.header.clone(); + let response = match C::parse(cmd) { + Ok(cmd) => match callback(header, cmd) { + Ok(response) => CommandResponse { + success: true, + response, + }, + Err(err) => CommandResponse { + success: false, + response: err.to_string(), + }, + }, + Err(err) => CommandResponse { + success: false, + response: err.to_string(), + }, + }; + // This should only err if we had an error elsewhere + let _ = responder.send(response); + } else { + break; + } }, - Err(err) => CommandResponse { - success: false, - response: err.to_string(), - }, - }; - // This should only err if we had an error elsewhere - let _ = responder.send(response); + _ = cancellation_token.cancelled() => { break; }, + ); } } }); @@ -78,3 +89,366 @@ impl Drop for CommandHandle { self.cancellation_token.cancel(); } } + +#[cfg(test)] +mod tests { + use crate::client::command::CommandRegistry; + use crate::client::tests::create_test_client; + use crate::client::Callback; + use crate::messages::callback::GenericCallbackError; + use crate::messages::command::CommandResponse; + use crate::messages::payload::RequestMessagePayload; + use crate::messages::telemetry_definition::TelemetryDefinitionResponse; + use crate::messages::ResponseMessage; + use api_core::command::{ + Command, CommandDefinition, CommandHeader, CommandParameterDefinition, + IntoCommandDefinition, IntoCommandDefinitionError, + }; + use api_core::data_type::DataType; + use std::collections::HashMap; + use std::convert::Infallible; + use std::sync::Arc; + use std::time::Duration; + use tokio::sync::oneshot; + use tokio::time::timeout; + use uuid::Uuid; + + struct CmdType { + #[allow(unused)] + param1: f32, + } + + impl IntoCommandDefinition for CmdType { + fn create(name: String) -> CommandDefinition { + CommandDefinition { + name, + parameters: vec![CommandParameterDefinition { + name: "param1".to_string(), + data_type: DataType::Float32, + }], + } + } + + fn parse(command: Command) -> Result { + Ok(Self { + param1: (*command.parameters.get("param1").ok_or_else(|| { + IntoCommandDefinitionError::ParameterMissing("param1".to_string()) + })?) + .try_into() + .map_err(|_| IntoCommandDefinitionError::MismatchedType { + parameter: "param1".to_string(), + expected: DataType::Float32, + })?, + }) + } + } + + #[tokio::test] + async fn simple_handler() { + // if _c drops then we are disconnected + let (mut rx, _c, client) = create_test_client(); + + let cmd_reg = CommandRegistry::new(Arc::new(client)); + + let _cmd_handle = cmd_reg.register_handler("cmd", |_, _: CmdType| { + Ok("success".to_string()) as Result<_, Infallible> + }); + + let msg = timeout(Duration::from_secs(1), rx.recv()) + .await + .unwrap() + .unwrap(); + let Callback::Registered(callback) = msg.callback else { + panic!("Incorrect Callback Type"); + }; + + let mut params = HashMap::new(); + params.insert("param1".to_string(), 0.0f32.into()); + + let (response_tx, response_rx) = oneshot::channel(); + timeout( + Duration::from_secs(1), + callback.send(( + ResponseMessage { + uuid: Uuid::new_v4(), + response: Some(msg.msg.uuid), + payload: Command { + header: CommandHeader { + timestamp: Default::default(), + }, + parameters: params, + } + .into(), + }, + response_tx, + )), + ) + .await + .unwrap() + .unwrap(); + let response = timeout(Duration::from_secs(1), response_rx) + .await + .unwrap() + .unwrap(); + let RequestMessagePayload::CommandResponse(CommandResponse { success, response }) = + response + else { + panic!("Unexpected Response Type"); + }; + assert!(success); + assert_eq!(response, "success"); + } + + #[tokio::test] + async fn handler_failed() { + // if _c drops then we are disconnected + let (mut rx, _c, client) = create_test_client(); + + let cmd_reg = CommandRegistry::new(Arc::new(client)); + + let _cmd_handle = cmd_reg.register_handler("cmd", |_, _: CmdType| { + Err("failure".into()) as Result<_, Box> + }); + + let msg = timeout(Duration::from_secs(1), rx.recv()) + .await + .unwrap() + .unwrap(); + let Callback::Registered(callback) = msg.callback else { + panic!("Incorrect Callback Type"); + }; + + let mut params = HashMap::new(); + params.insert("param1".to_string(), 1.0f32.into()); + + let (response_tx, response_rx) = oneshot::channel(); + timeout( + Duration::from_secs(1), + callback.send(( + ResponseMessage { + uuid: Uuid::new_v4(), + response: Some(msg.msg.uuid), + payload: Command { + header: CommandHeader { + timestamp: Default::default(), + }, + parameters: params, + } + .into(), + }, + response_tx, + )), + ) + .await + .unwrap() + .unwrap(); + let response = timeout(Duration::from_secs(1), response_rx) + .await + .unwrap() + .unwrap(); + let RequestMessagePayload::CommandResponse(CommandResponse { success, response }) = + response + else { + panic!("Unexpected Response Type"); + }; + assert!(!success); + assert_eq!(response, "failure"); + } + + #[tokio::test] + async fn parse_failed() { + // if _c drops then we are disconnected + let (mut rx, _c, client) = create_test_client(); + + let cmd_reg = CommandRegistry::new(Arc::new(client)); + + let _cmd_handle = cmd_reg.register_handler("cmd", |_, _: CmdType| { + Err("failure".into()) as Result<_, Box> + }); + + let msg = timeout(Duration::from_secs(1), rx.recv()) + .await + .unwrap() + .unwrap(); + let Callback::Registered(callback) = msg.callback else { + panic!("Incorrect Callback Type"); + }; + + let mut params = HashMap::new(); + params.insert("param1".to_string(), 1.0f64.into()); + + let (response_tx, response_rx) = oneshot::channel(); + timeout( + Duration::from_secs(1), + callback.send(( + ResponseMessage { + uuid: Uuid::new_v4(), + response: Some(msg.msg.uuid), + payload: Command { + header: CommandHeader { + timestamp: Default::default(), + }, + parameters: params, + } + .into(), + }, + response_tx, + )), + ) + .await + .unwrap() + .unwrap(); + let response = timeout(Duration::from_secs(1), response_rx) + .await + .unwrap() + .unwrap(); + let RequestMessagePayload::CommandResponse(CommandResponse { + success, + response: _, + }) = response + else { + panic!("Unexpected Response Type"); + }; + assert!(!success); + } + + #[tokio::test] + async fn wrong_message() { + // if _c drops then we are disconnected + let (mut rx, _c, client) = create_test_client(); + + let cmd_reg = CommandRegistry::new(Arc::new(client)); + + let _cmd_handle = + cmd_reg.register_handler("cmd", |_, _: CmdType| -> Result<_, Infallible> { + panic!("This should not happen"); + }); + + let msg = timeout(Duration::from_secs(1), rx.recv()) + .await + .unwrap() + .unwrap(); + let Callback::Registered(callback) = msg.callback else { + panic!("Incorrect Callback Type"); + }; + + let (response_tx, response_rx) = oneshot::channel(); + timeout( + Duration::from_secs(1), + callback.send(( + ResponseMessage { + uuid: Uuid::new_v4(), + response: Some(msg.msg.uuid), + payload: TelemetryDefinitionResponse { + uuid: Uuid::new_v4(), + } + .into(), + }, + response_tx, + )), + ) + .await + .unwrap() + .unwrap(); + let response = timeout(Duration::from_secs(1), response_rx) + .await + .unwrap() + .unwrap(); + let RequestMessagePayload::GenericCallbackError(err) = response else { + panic!("Unexpected Response Type"); + }; + assert_eq!(err, GenericCallbackError::MismatchedType); + } + + #[tokio::test] + async fn callback_closed() { + // if _c drops then we are disconnected + let (mut rx, _c, client) = create_test_client(); + + let cmd_reg = CommandRegistry::new(Arc::new(client)); + + let cmd_handle = + cmd_reg.register_handler("cmd", |_, _: CmdType| -> Result<_, Infallible> { + panic!("This should not happen"); + }); + + let msg = timeout(Duration::from_secs(1), rx.recv()) + .await + .unwrap() + .unwrap(); + let Callback::Registered(callback) = msg.callback else { + panic!("Incorrect Callback Type"); + }; + + // This should shut down the command handler + drop(cmd_handle); + + // Send a command + let mut params = HashMap::new(); + params.insert("param1".to_string(), 0.0f32.into()); + let (response_tx, response_rx) = oneshot::channel(); + timeout( + Duration::from_secs(1), + callback.send(( + ResponseMessage { + uuid: Uuid::new_v4(), + response: Some(msg.msg.uuid), + payload: Command { + header: CommandHeader { + timestamp: Default::default(), + }, + parameters: params, + } + .into(), + }, + response_tx, + )), + ) + .await + .unwrap() + .unwrap(); + + let response = timeout(Duration::from_secs(1), response_rx) + .await + .unwrap() + .unwrap(); + let RequestMessagePayload::GenericCallbackError(err) = response else { + panic!("Unexpected Response Type"); + }; + assert_eq!(err, GenericCallbackError::CallbackClosed); + } + + #[tokio::test] + async fn reconnect() { + // if _c drops then we are disconnected + let (mut rx, _c, client) = create_test_client(); + + let cmd_reg = CommandRegistry::new(Arc::new(client)); + + let _cmd_handle = + cmd_reg.register_handler("cmd", |_, _: CmdType| -> Result<_, Infallible> { + panic!("This should not happen"); + }); + + let msg = timeout(Duration::from_secs(1), rx.recv()) + .await + .unwrap() + .unwrap(); + let Callback::Registered(callback) = msg.callback else { + panic!("Incorrect Callback Type"); + }; + + println!("Dropping"); + drop(callback); + println!("Dropped"); + + // The command re-registers itself + let msg = timeout(Duration::from_secs(1), rx.recv()) + .await + .unwrap() + .unwrap(); + let Callback::Registered(_) = msg.callback else { + panic!("Incorrect Callback Type"); + }; + } +} diff --git a/api/src/client/error.rs b/api/src/client/error.rs index 67e675e..e5d5738 100644 --- a/api/src/client/error.rs +++ b/api/src/client/error.rs @@ -1,3 +1,4 @@ +use api_core::data_type::DataType; use thiserror::Error; #[derive(Error, Debug)] @@ -16,6 +17,11 @@ pub enum MessageError { TokioTrySendError(#[from] tokio::sync::mpsc::error::TrySendError<()>), #[error(transparent)] TokioLockError(#[from] tokio::sync::TryLockError), + #[error("Incorrect Data Type. {expected} expected. {actual} actual.")] + IncorrectDataType { + expected: DataType, + actual: DataType, + }, } #[derive(Error, Debug)] diff --git a/api/src/client/mod.rs b/api/src/client/mod.rs index 95c8f9b..e4ac66b 100644 --- a/api/src/client/mod.rs +++ b/api/src/client/mod.rs @@ -31,6 +31,7 @@ enum Callback { Registered(RegisteredCallback), } +#[derive(Debug)] struct OutgoingMessage { msg: RequestMessage, callback: Callback, @@ -192,6 +193,7 @@ impl Client { break; } } + println!("Exited Loop"); }); Ok(outer_rx) @@ -279,13 +281,12 @@ mod tests { use chrono::Utc; use futures_util::future::{select, Either}; use futures_util::FutureExt; - use std::future::Future; - use std::pin::{pin, Pin}; + use std::pin::pin; use std::time::Duration; use tokio::join; use tokio::time::{sleep, timeout}; - fn create_test_client() -> (mpsc::Receiver, watch::Sender, Client) { + pub fn create_test_client() -> (mpsc::Receiver, watch::Sender, Client) { let cancel = CancellationToken::new(); let (tx, rx) = mpsc::channel(1); let channel = Arc::new(RwLock::new(tx)); @@ -362,7 +363,7 @@ mod tests { async fn send_message_if_connected_not_connected() { let (_, connected_state_tx, client) = create_test_client(); - let lock = client.channel.write().await; + let _lock = client.channel.write().await; connected_state_tx.send_replace(false); let msg_to_send = TelemetryEntry { @@ -449,7 +450,6 @@ mod tests { .unwrap() .unwrap(); - let response_uuid = Uuid::new_v4(); let outgoing_rx = timeout(Duration::from_secs(1), async { let msg = tx.recv().await.unwrap(); let Callback::Registered(cb) = msg.callback else { @@ -484,7 +484,7 @@ mod tests { let responder = timeout(Duration::from_secs(1), async { for i in 0..5 { - let (cmd, responder) = response.recv().await.unwrap(); + let (_cmd, responder) = response.recv().await.unwrap(); responder .send(CommandResponse { success: false, @@ -510,7 +510,7 @@ mod tests { let mut index = 0usize; timeout( Duration::from_secs(1), - client.register_callback_fn(msg_to_send, move |cmd| { + client.register_callback_fn(msg_to_send, move |_| { index += 1; CommandResponse { success: false, diff --git a/api/src/client/telemetry.rs b/api/src/client/telemetry.rs index 1ec66dd..641cf28 100644 --- a/api/src/client/telemetry.rs +++ b/api/src/client/telemetry.rs @@ -75,6 +75,7 @@ impl TelemetryRegistry { cancellation_token, uuid: response_uuid, client: stored_client, + data_type, } } inner(self.client.clone(), name.into(), data_type).await @@ -96,6 +97,7 @@ pub struct GenericTelemetryHandle { cancellation_token: CancellationToken, uuid: Arc>, client: Arc, + data_type: DataType, } impl GenericTelemetryHandle { @@ -104,6 +106,12 @@ impl GenericTelemetryHandle { value: DataValue, timestamp: DateTime, ) -> Result<(), MessageError> { + if value.to_data_type() != self.data_type { + return Err(MessageError::IncorrectDataType { + expected: self.data_type, + actual: value.to_data_type(), + }); + } let Ok(lock) = self.uuid.try_read() else { return Ok(()); }; @@ -165,4 +173,292 @@ impl> TelemetryHandle { } #[cfg(test)] -mod tests {} +mod tests { + use crate::client::error::MessageError; + use crate::client::telemetry::TelemetryRegistry; + use crate::client::tests::create_test_client; + use crate::client::Callback; + use crate::messages::payload::RequestMessagePayload; + use crate::messages::telemetry_definition::{ + TelemetryDefinitionRequest, TelemetryDefinitionResponse, + }; + use crate::messages::telemetry_entry::TelemetryEntry; + use crate::messages::ResponseMessage; + use api_core::data_type::DataType; + use api_core::data_value::DataValue; + use futures_util::FutureExt; + use std::sync::Arc; + use std::time::Duration; + use tokio::task::yield_now; + use tokio::time::timeout; + use tokio::try_join; + use uuid::Uuid; + + #[tokio::test] + async fn generic() { + // if _c drops then we are disconnected + let (mut rx, _c, client) = create_test_client(); + + let tlm = TelemetryRegistry::new(Arc::new(client)); + let tlm_handle = tlm.register_generic("generic", DataType::Float32); + + let tlm_uuid = Uuid::new_v4(); + + let expected_rx = async { + let msg = rx.recv().await.unwrap(); + let Callback::Once(responder) = msg.callback else { + panic!("Expected Once Callback"); + }; + assert!(msg.msg.response.is_none()); + let RequestMessagePayload::TelemetryDefinitionRequest(TelemetryDefinitionRequest { + name, + data_type, + }) = msg.msg.payload + else { + panic!("Expected Telemetry Definition Request") + }; + assert_eq!(name, "generic".to_string()); + assert_eq!(data_type, DataType::Float32); + responder + .send(ResponseMessage { + uuid: Uuid::new_v4(), + response: Some(msg.msg.uuid), + payload: TelemetryDefinitionResponse { uuid: tlm_uuid }.into(), + }) + .unwrap(); + }; + + let (tlm_handle, _) = try_join!( + timeout(Duration::from_secs(1), tlm_handle), + timeout(Duration::from_secs(1), expected_rx), + ) + .unwrap(); + + assert_eq!(*tlm_handle.uuid.try_read().unwrap(), tlm_uuid); + + // This should NOT block if there is space in the queue + tlm_handle + .publish_now(0.0f32.into()) + .now_or_never() + .unwrap() + .unwrap(); + + let tlm_msg = timeout(Duration::from_secs(1), rx.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!(tlm_msg.callback, Callback::None)); + match tlm_msg.msg.payload { + RequestMessagePayload::TelemetryEntry(TelemetryEntry { uuid, value, .. }) => { + assert_eq!(uuid, tlm_uuid); + assert_eq!(value, DataValue::Float32(0.0f32)); + } + _ => panic!("Expected Telemetry Entry"), + } + } + + #[tokio::test] + async fn mismatched_type() { + let (mut rx, _, client) = create_test_client(); + + let tlm = TelemetryRegistry::new(Arc::new(client)); + let tlm_handle = tlm.register_generic("generic", DataType::Float32); + + let tlm_uuid = Uuid::new_v4(); + + let expected_rx = async { + let msg = rx.recv().await.unwrap(); + let Callback::Once(responder) = msg.callback else { + panic!("Expected Once Callback"); + }; + assert!(msg.msg.response.is_none()); + let RequestMessagePayload::TelemetryDefinitionRequest(TelemetryDefinitionRequest { + name, + data_type, + }) = msg.msg.payload + else { + panic!("Expected Telemetry Definition Request") + }; + assert_eq!(name, "generic".to_string()); + assert_eq!(data_type, DataType::Float32); + responder + .send(ResponseMessage { + uuid: Uuid::new_v4(), + response: Some(msg.msg.uuid), + payload: TelemetryDefinitionResponse { uuid: tlm_uuid }.into(), + }) + .unwrap(); + }; + + let (tlm_handle, _) = try_join!( + timeout(Duration::from_secs(1), tlm_handle), + timeout(Duration::from_secs(1), expected_rx), + ) + .unwrap(); + + assert_eq!(*tlm_handle.uuid.try_read().unwrap(), tlm_uuid); + + match timeout( + Duration::from_secs(1), + tlm_handle.publish_now(0.0f64.into()), + ) + .await + .unwrap() + { + Err(MessageError::IncorrectDataType { expected, actual }) => { + assert_eq!(expected, DataType::Float32); + assert_eq!(actual, DataType::Float64); + } + _ => panic!("Error Expected"), + } + } + + #[tokio::test] + async fn typed() { + // if _c drops then we are disconnected + let (mut rx, _c, client) = create_test_client(); + + let tlm = TelemetryRegistry::new(Arc::new(client)); + let tlm_handle = tlm.register::("typed"); + + let tlm_uuid = Uuid::new_v4(); + + let expected_rx = async { + let msg = rx.recv().await.unwrap(); + let Callback::Once(responder) = msg.callback else { + panic!("Expected Once Callback"); + }; + assert!(msg.msg.response.is_none()); + let RequestMessagePayload::TelemetryDefinitionRequest(TelemetryDefinitionRequest { + name, + data_type, + }) = msg.msg.payload + else { + panic!("Expected Telemetry Definition Request") + }; + assert_eq!(name, "typed".to_string()); + assert_eq!(data_type, DataType::Float32); + responder + .send(ResponseMessage { + uuid: Uuid::new_v4(), + response: Some(msg.msg.uuid), + payload: TelemetryDefinitionResponse { uuid: tlm_uuid }.into(), + }) + .unwrap(); + }; + + let (tlm_handle, _) = try_join!( + timeout(Duration::from_secs(1), tlm_handle), + timeout(Duration::from_secs(1), expected_rx), + ) + .unwrap(); + + assert_eq!(*tlm_handle.as_generic().uuid.try_read().unwrap(), tlm_uuid); + + // This should NOT block if there is space in the queue + tlm_handle + .publish_now(1.0f32.into()) + .now_or_never() + .unwrap() + .unwrap(); + // This should block as there should not be space in the queue + assert!(tlm_handle + .publish_now(2.0f32.into()) + .now_or_never() + .is_none()); + + let tlm_msg = timeout(Duration::from_secs(1), rx.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!(tlm_msg.callback, Callback::None)); + match tlm_msg.msg.payload { + RequestMessagePayload::TelemetryEntry(TelemetryEntry { uuid, value, .. }) => { + assert_eq!(uuid, tlm_uuid); + assert_eq!(value, DataValue::Float32(1.0f32)); + } + _ => panic!("Expected Telemetry Entry"), + } + + let _make_generic_again = tlm_handle.to_generic(); + } + + #[tokio::test] + async fn reconnect() { + // if _c drops then we are disconnected + let (mut rx, connected, client) = create_test_client(); + + let tlm = TelemetryRegistry::new(Arc::new(client)); + let tlm_handle = tlm.register_generic("generic", DataType::Float32); + + let tlm_uuid = Uuid::new_v4(); + + let expected_rx = async { + let msg = rx.recv().await.unwrap(); + let Callback::Once(responder) = msg.callback else { + panic!("Expected Once Callback"); + }; + assert!(msg.msg.response.is_none()); + let RequestMessagePayload::TelemetryDefinitionRequest(TelemetryDefinitionRequest { + name, + data_type, + }) = msg.msg.payload + else { + panic!("Expected Telemetry Definition Request") + }; + assert_eq!(name, "generic".to_string()); + assert_eq!(data_type, DataType::Float32); + responder + .send(ResponseMessage { + uuid: Uuid::new_v4(), + response: Some(msg.msg.uuid), + payload: TelemetryDefinitionResponse { uuid: tlm_uuid }.into(), + }) + .unwrap(); + }; + + let (tlm_handle, _) = try_join!( + timeout(Duration::from_secs(1), tlm_handle), + timeout(Duration::from_secs(1), expected_rx), + ) + .unwrap(); + + assert_eq!(*tlm_handle.uuid.try_read().unwrap(), tlm_uuid); + + // Notify Disconnect + connected.send_replace(false); + // Notify Reconnect + connected.send_replace(true); + + { + let new_tlm_uuid = Uuid::new_v4(); + + let msg = rx.recv().await.unwrap(); + let Callback::Once(responder) = msg.callback else { + panic!("Expected Once Callback"); + }; + assert!(msg.msg.response.is_none()); + let RequestMessagePayload::TelemetryDefinitionRequest(TelemetryDefinitionRequest { + name, + data_type, + }) = msg.msg.payload + else { + panic!("Expected Telemetry Definition Request") + }; + assert_eq!(name, "generic".to_string()); + assert_eq!(data_type, DataType::Float32); + responder + .send(ResponseMessage { + uuid: Uuid::new_v4(), + response: Some(msg.msg.uuid), + payload: TelemetryDefinitionResponse { uuid: new_tlm_uuid }.into(), + }) + .unwrap(); + + // Yield to the executor so that the UUIDs can be updated + yield_now().await; + + assert_eq!(*tlm_handle.uuid.try_read().unwrap(), new_tlm_uuid); + } + } +} diff --git a/examples/simple_producer/Cargo.toml b/examples/simple_producer/Cargo.toml index 9746d13..c105fc6 100644 --- a/examples/simple_producer/Cargo.toml +++ b/examples/simple_producer/Cargo.toml @@ -12,5 +12,3 @@ futures-util = { workspace = true } num-traits = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "signal", "time", "macros"] } tokio-util = { workspace = true } -uuid = { workspace = true } -log = "0.4.29"