add tests for the api

This commit is contained in:
2026-01-01 12:13:05 -05:00
parent 4aa86da14a
commit d3b882f56d
12 changed files with 739 additions and 42 deletions

View File

@@ -3,6 +3,7 @@ use crate::messages::command::CommandResponse;
use api_core::command::{CommandHeader, IntoCommandDefinition};
use std::fmt::Display;
use std::sync::Arc;
use tokio::select;
use tokio_util::sync::CancellationToken;
pub struct CommandRegistry {
@@ -41,26 +42,36 @@ impl CommandRegistry {
continue;
};
while let Some((cmd, responder)) = rx.recv().await {
let header = cmd.header.clone();
let response = match C::parse(cmd) {
Ok(cmd) => match callback(header, cmd) {
Ok(response) => CommandResponse {
success: true,
response,
},
Err(err) => CommandResponse {
success: false,
response: err.to_string(),
},
loop {
// select used so that this loop gets broken if the token is cancelled
select!(
rx_value = rx.recv() => {
if let Some((cmd, responder)) = rx_value {
let header = cmd.header.clone();
let response = match C::parse(cmd) {
Ok(cmd) => match callback(header, cmd) {
Ok(response) => CommandResponse {
success: true,
response,
},
Err(err) => CommandResponse {
success: false,
response: err.to_string(),
},
},
Err(err) => CommandResponse {
success: false,
response: err.to_string(),
},
};
// This should only err if we had an error elsewhere
let _ = responder.send(response);
} else {
break;
}
},
Err(err) => CommandResponse {
success: false,
response: err.to_string(),
},
};
// This should only err if we had an error elsewhere
let _ = responder.send(response);
_ = cancellation_token.cancelled() => { break; },
);
}
}
});
@@ -78,3 +89,366 @@ impl Drop for CommandHandle {
self.cancellation_token.cancel();
}
}
#[cfg(test)]
mod tests {
use crate::client::command::CommandRegistry;
use crate::client::tests::create_test_client;
use crate::client::Callback;
use crate::messages::callback::GenericCallbackError;
use crate::messages::command::CommandResponse;
use crate::messages::payload::RequestMessagePayload;
use crate::messages::telemetry_definition::TelemetryDefinitionResponse;
use crate::messages::ResponseMessage;
use api_core::command::{
Command, CommandDefinition, CommandHeader, CommandParameterDefinition,
IntoCommandDefinition, IntoCommandDefinitionError,
};
use api_core::data_type::DataType;
use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::oneshot;
use tokio::time::timeout;
use uuid::Uuid;
struct CmdType {
#[allow(unused)]
param1: f32,
}
impl IntoCommandDefinition for CmdType {
fn create(name: String) -> CommandDefinition {
CommandDefinition {
name,
parameters: vec![CommandParameterDefinition {
name: "param1".to_string(),
data_type: DataType::Float32,
}],
}
}
fn parse(command: Command) -> Result<Self, IntoCommandDefinitionError> {
Ok(Self {
param1: (*command.parameters.get("param1").ok_or_else(|| {
IntoCommandDefinitionError::ParameterMissing("param1".to_string())
})?)
.try_into()
.map_err(|_| IntoCommandDefinitionError::MismatchedType {
parameter: "param1".to_string(),
expected: DataType::Float32,
})?,
})
}
}
#[tokio::test]
async fn simple_handler() {
// if _c drops then we are disconnected
let (mut rx, _c, client) = create_test_client();
let cmd_reg = CommandRegistry::new(Arc::new(client));
let _cmd_handle = cmd_reg.register_handler("cmd", |_, _: CmdType| {
Ok("success".to_string()) as Result<_, Infallible>
});
let msg = timeout(Duration::from_secs(1), rx.recv())
.await
.unwrap()
.unwrap();
let Callback::Registered(callback) = msg.callback else {
panic!("Incorrect Callback Type");
};
let mut params = HashMap::new();
params.insert("param1".to_string(), 0.0f32.into());
let (response_tx, response_rx) = oneshot::channel();
timeout(
Duration::from_secs(1),
callback.send((
ResponseMessage {
uuid: Uuid::new_v4(),
response: Some(msg.msg.uuid),
payload: Command {
header: CommandHeader {
timestamp: Default::default(),
},
parameters: params,
}
.into(),
},
response_tx,
)),
)
.await
.unwrap()
.unwrap();
let response = timeout(Duration::from_secs(1), response_rx)
.await
.unwrap()
.unwrap();
let RequestMessagePayload::CommandResponse(CommandResponse { success, response }) =
response
else {
panic!("Unexpected Response Type");
};
assert!(success);
assert_eq!(response, "success");
}
#[tokio::test]
async fn handler_failed() {
// if _c drops then we are disconnected
let (mut rx, _c, client) = create_test_client();
let cmd_reg = CommandRegistry::new(Arc::new(client));
let _cmd_handle = cmd_reg.register_handler("cmd", |_, _: CmdType| {
Err("failure".into()) as Result<_, Box<dyn std::error::Error>>
});
let msg = timeout(Duration::from_secs(1), rx.recv())
.await
.unwrap()
.unwrap();
let Callback::Registered(callback) = msg.callback else {
panic!("Incorrect Callback Type");
};
let mut params = HashMap::new();
params.insert("param1".to_string(), 1.0f32.into());
let (response_tx, response_rx) = oneshot::channel();
timeout(
Duration::from_secs(1),
callback.send((
ResponseMessage {
uuid: Uuid::new_v4(),
response: Some(msg.msg.uuid),
payload: Command {
header: CommandHeader {
timestamp: Default::default(),
},
parameters: params,
}
.into(),
},
response_tx,
)),
)
.await
.unwrap()
.unwrap();
let response = timeout(Duration::from_secs(1), response_rx)
.await
.unwrap()
.unwrap();
let RequestMessagePayload::CommandResponse(CommandResponse { success, response }) =
response
else {
panic!("Unexpected Response Type");
};
assert!(!success);
assert_eq!(response, "failure");
}
#[tokio::test]
async fn parse_failed() {
// if _c drops then we are disconnected
let (mut rx, _c, client) = create_test_client();
let cmd_reg = CommandRegistry::new(Arc::new(client));
let _cmd_handle = cmd_reg.register_handler("cmd", |_, _: CmdType| {
Err("failure".into()) as Result<_, Box<dyn std::error::Error>>
});
let msg = timeout(Duration::from_secs(1), rx.recv())
.await
.unwrap()
.unwrap();
let Callback::Registered(callback) = msg.callback else {
panic!("Incorrect Callback Type");
};
let mut params = HashMap::new();
params.insert("param1".to_string(), 1.0f64.into());
let (response_tx, response_rx) = oneshot::channel();
timeout(
Duration::from_secs(1),
callback.send((
ResponseMessage {
uuid: Uuid::new_v4(),
response: Some(msg.msg.uuid),
payload: Command {
header: CommandHeader {
timestamp: Default::default(),
},
parameters: params,
}
.into(),
},
response_tx,
)),
)
.await
.unwrap()
.unwrap();
let response = timeout(Duration::from_secs(1), response_rx)
.await
.unwrap()
.unwrap();
let RequestMessagePayload::CommandResponse(CommandResponse {
success,
response: _,
}) = response
else {
panic!("Unexpected Response Type");
};
assert!(!success);
}
#[tokio::test]
async fn wrong_message() {
// if _c drops then we are disconnected
let (mut rx, _c, client) = create_test_client();
let cmd_reg = CommandRegistry::new(Arc::new(client));
let _cmd_handle =
cmd_reg.register_handler("cmd", |_, _: CmdType| -> Result<_, Infallible> {
panic!("This should not happen");
});
let msg = timeout(Duration::from_secs(1), rx.recv())
.await
.unwrap()
.unwrap();
let Callback::Registered(callback) = msg.callback else {
panic!("Incorrect Callback Type");
};
let (response_tx, response_rx) = oneshot::channel();
timeout(
Duration::from_secs(1),
callback.send((
ResponseMessage {
uuid: Uuid::new_v4(),
response: Some(msg.msg.uuid),
payload: TelemetryDefinitionResponse {
uuid: Uuid::new_v4(),
}
.into(),
},
response_tx,
)),
)
.await
.unwrap()
.unwrap();
let response = timeout(Duration::from_secs(1), response_rx)
.await
.unwrap()
.unwrap();
let RequestMessagePayload::GenericCallbackError(err) = response else {
panic!("Unexpected Response Type");
};
assert_eq!(err, GenericCallbackError::MismatchedType);
}
#[tokio::test]
async fn callback_closed() {
// if _c drops then we are disconnected
let (mut rx, _c, client) = create_test_client();
let cmd_reg = CommandRegistry::new(Arc::new(client));
let cmd_handle =
cmd_reg.register_handler("cmd", |_, _: CmdType| -> Result<_, Infallible> {
panic!("This should not happen");
});
let msg = timeout(Duration::from_secs(1), rx.recv())
.await
.unwrap()
.unwrap();
let Callback::Registered(callback) = msg.callback else {
panic!("Incorrect Callback Type");
};
// This should shut down the command handler
drop(cmd_handle);
// Send a command
let mut params = HashMap::new();
params.insert("param1".to_string(), 0.0f32.into());
let (response_tx, response_rx) = oneshot::channel();
timeout(
Duration::from_secs(1),
callback.send((
ResponseMessage {
uuid: Uuid::new_v4(),
response: Some(msg.msg.uuid),
payload: Command {
header: CommandHeader {
timestamp: Default::default(),
},
parameters: params,
}
.into(),
},
response_tx,
)),
)
.await
.unwrap()
.unwrap();
let response = timeout(Duration::from_secs(1), response_rx)
.await
.unwrap()
.unwrap();
let RequestMessagePayload::GenericCallbackError(err) = response else {
panic!("Unexpected Response Type");
};
assert_eq!(err, GenericCallbackError::CallbackClosed);
}
#[tokio::test]
async fn reconnect() {
// if _c drops then we are disconnected
let (mut rx, _c, client) = create_test_client();
let cmd_reg = CommandRegistry::new(Arc::new(client));
let _cmd_handle =
cmd_reg.register_handler("cmd", |_, _: CmdType| -> Result<_, Infallible> {
panic!("This should not happen");
});
let msg = timeout(Duration::from_secs(1), rx.recv())
.await
.unwrap()
.unwrap();
let Callback::Registered(callback) = msg.callback else {
panic!("Incorrect Callback Type");
};
println!("Dropping");
drop(callback);
println!("Dropped");
// The command re-registers itself
let msg = timeout(Duration::from_secs(1), rx.recv())
.await
.unwrap()
.unwrap();
let Callback::Registered(_) = msg.callback else {
panic!("Incorrect Callback Type");
};
}
}

View File

@@ -1,3 +1,4 @@
use api_core::data_type::DataType;
use thiserror::Error;
#[derive(Error, Debug)]
@@ -16,6 +17,11 @@ pub enum MessageError {
TokioTrySendError(#[from] tokio::sync::mpsc::error::TrySendError<()>),
#[error(transparent)]
TokioLockError(#[from] tokio::sync::TryLockError),
#[error("Incorrect Data Type. {expected} expected. {actual} actual.")]
IncorrectDataType {
expected: DataType,
actual: DataType,
},
}
#[derive(Error, Debug)]

View File

@@ -31,6 +31,7 @@ enum Callback {
Registered(RegisteredCallback),
}
#[derive(Debug)]
struct OutgoingMessage {
msg: RequestMessage,
callback: Callback,
@@ -192,6 +193,7 @@ impl Client {
break;
}
}
println!("Exited Loop");
});
Ok(outer_rx)
@@ -279,13 +281,12 @@ mod tests {
use chrono::Utc;
use futures_util::future::{select, Either};
use futures_util::FutureExt;
use std::future::Future;
use std::pin::{pin, Pin};
use std::pin::pin;
use std::time::Duration;
use tokio::join;
use tokio::time::{sleep, timeout};
fn create_test_client() -> (mpsc::Receiver<OutgoingMessage>, watch::Sender<bool>, Client) {
pub fn create_test_client() -> (mpsc::Receiver<OutgoingMessage>, watch::Sender<bool>, Client) {
let cancel = CancellationToken::new();
let (tx, rx) = mpsc::channel(1);
let channel = Arc::new(RwLock::new(tx));
@@ -362,7 +363,7 @@ mod tests {
async fn send_message_if_connected_not_connected() {
let (_, connected_state_tx, client) = create_test_client();
let lock = client.channel.write().await;
let _lock = client.channel.write().await;
connected_state_tx.send_replace(false);
let msg_to_send = TelemetryEntry {
@@ -449,7 +450,6 @@ mod tests {
.unwrap()
.unwrap();
let response_uuid = Uuid::new_v4();
let outgoing_rx = timeout(Duration::from_secs(1), async {
let msg = tx.recv().await.unwrap();
let Callback::Registered(cb) = msg.callback else {
@@ -484,7 +484,7 @@ mod tests {
let responder = timeout(Duration::from_secs(1), async {
for i in 0..5 {
let (cmd, responder) = response.recv().await.unwrap();
let (_cmd, responder) = response.recv().await.unwrap();
responder
.send(CommandResponse {
success: false,
@@ -510,7 +510,7 @@ mod tests {
let mut index = 0usize;
timeout(
Duration::from_secs(1),
client.register_callback_fn(msg_to_send, move |cmd| {
client.register_callback_fn(msg_to_send, move |_| {
index += 1;
CommandResponse {
success: false,

View File

@@ -75,6 +75,7 @@ impl TelemetryRegistry {
cancellation_token,
uuid: response_uuid,
client: stored_client,
data_type,
}
}
inner(self.client.clone(), name.into(), data_type).await
@@ -96,6 +97,7 @@ pub struct GenericTelemetryHandle {
cancellation_token: CancellationToken,
uuid: Arc<RwLock<Uuid>>,
client: Arc<Client>,
data_type: DataType,
}
impl GenericTelemetryHandle {
@@ -104,6 +106,12 @@ impl GenericTelemetryHandle {
value: DataValue,
timestamp: DateTime<Utc>,
) -> Result<(), MessageError> {
if value.to_data_type() != self.data_type {
return Err(MessageError::IncorrectDataType {
expected: self.data_type,
actual: value.to_data_type(),
});
}
let Ok(lock) = self.uuid.try_read() else {
return Ok(());
};
@@ -165,4 +173,292 @@ impl<T: Into<DataValue>> TelemetryHandle<T> {
}
#[cfg(test)]
mod tests {}
mod tests {
use crate::client::error::MessageError;
use crate::client::telemetry::TelemetryRegistry;
use crate::client::tests::create_test_client;
use crate::client::Callback;
use crate::messages::payload::RequestMessagePayload;
use crate::messages::telemetry_definition::{
TelemetryDefinitionRequest, TelemetryDefinitionResponse,
};
use crate::messages::telemetry_entry::TelemetryEntry;
use crate::messages::ResponseMessage;
use api_core::data_type::DataType;
use api_core::data_value::DataValue;
use futures_util::FutureExt;
use std::sync::Arc;
use std::time::Duration;
use tokio::task::yield_now;
use tokio::time::timeout;
use tokio::try_join;
use uuid::Uuid;
#[tokio::test]
async fn generic() {
// if _c drops then we are disconnected
let (mut rx, _c, client) = create_test_client();
let tlm = TelemetryRegistry::new(Arc::new(client));
let tlm_handle = tlm.register_generic("generic", DataType::Float32);
let tlm_uuid = Uuid::new_v4();
let expected_rx = async {
let msg = rx.recv().await.unwrap();
let Callback::Once(responder) = msg.callback else {
panic!("Expected Once Callback");
};
assert!(msg.msg.response.is_none());
let RequestMessagePayload::TelemetryDefinitionRequest(TelemetryDefinitionRequest {
name,
data_type,
}) = msg.msg.payload
else {
panic!("Expected Telemetry Definition Request")
};
assert_eq!(name, "generic".to_string());
assert_eq!(data_type, DataType::Float32);
responder
.send(ResponseMessage {
uuid: Uuid::new_v4(),
response: Some(msg.msg.uuid),
payload: TelemetryDefinitionResponse { uuid: tlm_uuid }.into(),
})
.unwrap();
};
let (tlm_handle, _) = try_join!(
timeout(Duration::from_secs(1), tlm_handle),
timeout(Duration::from_secs(1), expected_rx),
)
.unwrap();
assert_eq!(*tlm_handle.uuid.try_read().unwrap(), tlm_uuid);
// This should NOT block if there is space in the queue
tlm_handle
.publish_now(0.0f32.into())
.now_or_never()
.unwrap()
.unwrap();
let tlm_msg = timeout(Duration::from_secs(1), rx.recv())
.await
.unwrap()
.unwrap();
assert!(matches!(tlm_msg.callback, Callback::None));
match tlm_msg.msg.payload {
RequestMessagePayload::TelemetryEntry(TelemetryEntry { uuid, value, .. }) => {
assert_eq!(uuid, tlm_uuid);
assert_eq!(value, DataValue::Float32(0.0f32));
}
_ => panic!("Expected Telemetry Entry"),
}
}
#[tokio::test]
async fn mismatched_type() {
let (mut rx, _, client) = create_test_client();
let tlm = TelemetryRegistry::new(Arc::new(client));
let tlm_handle = tlm.register_generic("generic", DataType::Float32);
let tlm_uuid = Uuid::new_v4();
let expected_rx = async {
let msg = rx.recv().await.unwrap();
let Callback::Once(responder) = msg.callback else {
panic!("Expected Once Callback");
};
assert!(msg.msg.response.is_none());
let RequestMessagePayload::TelemetryDefinitionRequest(TelemetryDefinitionRequest {
name,
data_type,
}) = msg.msg.payload
else {
panic!("Expected Telemetry Definition Request")
};
assert_eq!(name, "generic".to_string());
assert_eq!(data_type, DataType::Float32);
responder
.send(ResponseMessage {
uuid: Uuid::new_v4(),
response: Some(msg.msg.uuid),
payload: TelemetryDefinitionResponse { uuid: tlm_uuid }.into(),
})
.unwrap();
};
let (tlm_handle, _) = try_join!(
timeout(Duration::from_secs(1), tlm_handle),
timeout(Duration::from_secs(1), expected_rx),
)
.unwrap();
assert_eq!(*tlm_handle.uuid.try_read().unwrap(), tlm_uuid);
match timeout(
Duration::from_secs(1),
tlm_handle.publish_now(0.0f64.into()),
)
.await
.unwrap()
{
Err(MessageError::IncorrectDataType { expected, actual }) => {
assert_eq!(expected, DataType::Float32);
assert_eq!(actual, DataType::Float64);
}
_ => panic!("Error Expected"),
}
}
#[tokio::test]
async fn typed() {
// if _c drops then we are disconnected
let (mut rx, _c, client) = create_test_client();
let tlm = TelemetryRegistry::new(Arc::new(client));
let tlm_handle = tlm.register::<f32>("typed");
let tlm_uuid = Uuid::new_v4();
let expected_rx = async {
let msg = rx.recv().await.unwrap();
let Callback::Once(responder) = msg.callback else {
panic!("Expected Once Callback");
};
assert!(msg.msg.response.is_none());
let RequestMessagePayload::TelemetryDefinitionRequest(TelemetryDefinitionRequest {
name,
data_type,
}) = msg.msg.payload
else {
panic!("Expected Telemetry Definition Request")
};
assert_eq!(name, "typed".to_string());
assert_eq!(data_type, DataType::Float32);
responder
.send(ResponseMessage {
uuid: Uuid::new_v4(),
response: Some(msg.msg.uuid),
payload: TelemetryDefinitionResponse { uuid: tlm_uuid }.into(),
})
.unwrap();
};
let (tlm_handle, _) = try_join!(
timeout(Duration::from_secs(1), tlm_handle),
timeout(Duration::from_secs(1), expected_rx),
)
.unwrap();
assert_eq!(*tlm_handle.as_generic().uuid.try_read().unwrap(), tlm_uuid);
// This should NOT block if there is space in the queue
tlm_handle
.publish_now(1.0f32.into())
.now_or_never()
.unwrap()
.unwrap();
// This should block as there should not be space in the queue
assert!(tlm_handle
.publish_now(2.0f32.into())
.now_or_never()
.is_none());
let tlm_msg = timeout(Duration::from_secs(1), rx.recv())
.await
.unwrap()
.unwrap();
assert!(matches!(tlm_msg.callback, Callback::None));
match tlm_msg.msg.payload {
RequestMessagePayload::TelemetryEntry(TelemetryEntry { uuid, value, .. }) => {
assert_eq!(uuid, tlm_uuid);
assert_eq!(value, DataValue::Float32(1.0f32));
}
_ => panic!("Expected Telemetry Entry"),
}
let _make_generic_again = tlm_handle.to_generic();
}
#[tokio::test]
async fn reconnect() {
// if _c drops then we are disconnected
let (mut rx, connected, client) = create_test_client();
let tlm = TelemetryRegistry::new(Arc::new(client));
let tlm_handle = tlm.register_generic("generic", DataType::Float32);
let tlm_uuid = Uuid::new_v4();
let expected_rx = async {
let msg = rx.recv().await.unwrap();
let Callback::Once(responder) = msg.callback else {
panic!("Expected Once Callback");
};
assert!(msg.msg.response.is_none());
let RequestMessagePayload::TelemetryDefinitionRequest(TelemetryDefinitionRequest {
name,
data_type,
}) = msg.msg.payload
else {
panic!("Expected Telemetry Definition Request")
};
assert_eq!(name, "generic".to_string());
assert_eq!(data_type, DataType::Float32);
responder
.send(ResponseMessage {
uuid: Uuid::new_v4(),
response: Some(msg.msg.uuid),
payload: TelemetryDefinitionResponse { uuid: tlm_uuid }.into(),
})
.unwrap();
};
let (tlm_handle, _) = try_join!(
timeout(Duration::from_secs(1), tlm_handle),
timeout(Duration::from_secs(1), expected_rx),
)
.unwrap();
assert_eq!(*tlm_handle.uuid.try_read().unwrap(), tlm_uuid);
// Notify Disconnect
connected.send_replace(false);
// Notify Reconnect
connected.send_replace(true);
{
let new_tlm_uuid = Uuid::new_v4();
let msg = rx.recv().await.unwrap();
let Callback::Once(responder) = msg.callback else {
panic!("Expected Once Callback");
};
assert!(msg.msg.response.is_none());
let RequestMessagePayload::TelemetryDefinitionRequest(TelemetryDefinitionRequest {
name,
data_type,
}) = msg.msg.payload
else {
panic!("Expected Telemetry Definition Request")
};
assert_eq!(name, "generic".to_string());
assert_eq!(data_type, DataType::Float32);
responder
.send(ResponseMessage {
uuid: Uuid::new_v4(),
response: Some(msg.msg.uuid),
payload: TelemetryDefinitionResponse { uuid: new_tlm_uuid }.into(),
})
.unwrap();
// Yield to the executor so that the UUIDs can be updated
yield_now().await;
assert_eq!(*tlm_handle.uuid.try_read().unwrap(), new_tlm_uuid);
}
}
}