Replace gRPC Backend #10
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -309,6 +309,7 @@ dependencies = [
|
|||||||
"api-proc-macro",
|
"api-proc-macro",
|
||||||
"chrono",
|
"chrono",
|
||||||
"derive_more",
|
"derive_more",
|
||||||
|
"env_logger",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"log",
|
"log",
|
||||||
"serde",
|
"serde",
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ sqlx = "0.8.6"
|
|||||||
syn = "2.0.112"
|
syn = "2.0.112"
|
||||||
thiserror = "2.0.17"
|
thiserror = "2.0.17"
|
||||||
tokio = { version = "1.48.0" }
|
tokio = { version = "1.48.0" }
|
||||||
|
tokio-test = "0.4.4"
|
||||||
|
tokio-stream = "0.1.17"
|
||||||
tokio-tungstenite = { version = "0.28.0" }
|
tokio-tungstenite = { version = "0.28.0" }
|
||||||
tokio-util = "0.7.17"
|
tokio-util = "0.7.17"
|
||||||
trybuild = "1.0.114"
|
trybuild = "1.0.114"
|
||||||
|
|||||||
@@ -11,25 +11,25 @@ pub struct CommandParameterDefinition {
|
|||||||
pub data_type: DataType,
|
pub data_type: DataType,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
pub struct CommandDefinition {
|
pub struct CommandDefinition {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub parameters: Vec<CommandParameterDefinition>,
|
pub parameters: Vec<CommandParameterDefinition>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
pub struct CommandHeader {
|
pub struct CommandHeader {
|
||||||
pub timestamp: DateTime<Utc>,
|
pub timestamp: DateTime<Utc>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct Command {
|
pub struct Command {
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
pub header: CommandHeader,
|
pub header: CommandHeader,
|
||||||
pub parameters: HashMap<String, DataValue>,
|
pub parameters: HashMap<String, DataValue>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, PartialEq, Eq, Error)]
|
||||||
pub enum IntoCommandDefinitionError {
|
pub enum IntoCommandDefinitionError {
|
||||||
#[error("Parameter Missing: {0}")]
|
#[error("Parameter Missing: {0}")]
|
||||||
ParameterMissing(String),
|
ParameterMissing(String),
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use derive_more::{From, TryInto};
|
use derive_more::{From, TryInto};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, From, TryInto)]
|
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, From, TryInto)]
|
||||||
pub enum DataValue {
|
pub enum DataValue {
|
||||||
Float32(f32),
|
Float32(f32),
|
||||||
Float64(f64),
|
Float64(f64),
|
||||||
|
|||||||
@@ -19,3 +19,6 @@ tokio = { workspace = true, features = ["rt", "macros", "time"] }
|
|||||||
tokio-tungstenite = { workspace = true, features = ["rustls-tls-native-roots"] }
|
tokio-tungstenite = { workspace = true, features = ["rustls-tls-native-roots"] }
|
||||||
tokio-util = { workspace = true }
|
tokio-util = { workspace = true }
|
||||||
uuid = { workspace = true, features = ["serde"] }
|
uuid = { workspace = true, features = ["serde"] }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
env_logger = { workspace = true }
|
||||||
|
|||||||
@@ -4,19 +4,19 @@ use crate::client::{Callback, ClientChannel, OutgoingMessage, RegisteredCallback
|
|||||||
use crate::messages::callback::GenericCallbackError;
|
use crate::messages::callback::GenericCallbackError;
|
||||||
use crate::messages::payload::RequestMessagePayload;
|
use crate::messages::payload::RequestMessagePayload;
|
||||||
use crate::messages::{RequestMessage, ResponseMessage};
|
use crate::messages::{RequestMessage, ResponseMessage};
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{Sink, SinkExt, Stream, StreamExt};
|
||||||
use log::{debug, error, info, trace, warn};
|
use log::{debug, error, info, trace, warn};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::fmt::Display;
|
||||||
use std::sync::mpsc::sync_channel;
|
use std::sync::mpsc::sync_channel;
|
||||||
use std::thread;
|
use std::thread;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::net::TcpStream;
|
|
||||||
use tokio::sync::{mpsc, oneshot, watch, RwLockWriteGuard};
|
use tokio::sync::{mpsc, oneshot, watch, RwLockWriteGuard};
|
||||||
use tokio::time::sleep;
|
use tokio::time::sleep;
|
||||||
use tokio::{select, spawn};
|
use tokio::{select, spawn};
|
||||||
use tokio_tungstenite::tungstenite::handshake::client::Request;
|
use tokio_tungstenite::connect_async;
|
||||||
use tokio_tungstenite::tungstenite::Message;
|
use tokio_tungstenite::tungstenite::handshake::client::{Request, Response as TungResponse};
|
||||||
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
|
use tokio_tungstenite::tungstenite::{Error as TungError, Message};
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
@@ -45,7 +45,9 @@ impl ClientContext {
|
|||||||
let _ = tx.send(());
|
let _ = tx.send(());
|
||||||
|
|
||||||
while !self.cancel.is_cancelled() {
|
while !self.cancel.is_cancelled() {
|
||||||
write_lock = self.run_connection(write_lock, &channel).await;
|
write_lock = self
|
||||||
|
.run_connection(write_lock, &channel, connect_async)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
drop(write_lock);
|
drop(write_lock);
|
||||||
});
|
});
|
||||||
@@ -57,13 +59,19 @@ impl ClientContext {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run_connection<'a>(
|
async fn run_connection<'a, F, W, E>(
|
||||||
&mut self,
|
&mut self,
|
||||||
mut write_lock: RwLockWriteGuard<'a, mpsc::Sender<OutgoingMessage>>,
|
mut write_lock: RwLockWriteGuard<'a, mpsc::Sender<OutgoingMessage>>,
|
||||||
channel: &'a ClientChannel,
|
channel: &'a ClientChannel,
|
||||||
) -> RwLockWriteGuard<'a, mpsc::Sender<OutgoingMessage>> {
|
mut connection_fn: F,
|
||||||
|
) -> RwLockWriteGuard<'a, mpsc::Sender<OutgoingMessage>>
|
||||||
|
where
|
||||||
|
F: AsyncFnMut(Request) -> Result<(W, TungResponse), TungError>,
|
||||||
|
W: Stream<Item = Result<Message, TungError>> + Sink<Message, Error = E> + Unpin,
|
||||||
|
E: Display,
|
||||||
|
{
|
||||||
debug!("Attempting to Connect to {}", self.request.uri());
|
debug!("Attempting to Connect to {}", self.request.uri());
|
||||||
let mut ws = match connect_async(self.request.clone()).await {
|
let mut ws = match connection_fn(self.request.clone()).await {
|
||||||
Ok((ws, _)) => ws,
|
Ok((ws, _)) => ws,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
info!("Failed to Connect: {e}");
|
info!("Failed to Connect: {e}");
|
||||||
@@ -87,19 +95,24 @@ impl ClientContext {
|
|||||||
// the lock to use that as a signal that we have reconnected
|
// the lock to use that as a signal that we have reconnected
|
||||||
let _ = self.connected_state_tx.send_replace(false);
|
let _ = self.connected_state_tx.send_replace(false);
|
||||||
if close_connection {
|
if close_connection {
|
||||||
if let Err(e) = ws.close(None).await {
|
// Manually close to allow the impl trait to be used
|
||||||
|
if let Err(e) = ws.send(Message::Close(None)).await {
|
||||||
error!("Failed to Close the Connection: {e}");
|
error!("Failed to Close the Connection: {e}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
write_lock
|
write_lock
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_connection(
|
async fn handle_connection<W>(
|
||||||
&mut self,
|
&mut self,
|
||||||
ws: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
|
ws: &mut W,
|
||||||
mut rx: mpsc::Receiver<OutgoingMessage>,
|
mut rx: mpsc::Receiver<OutgoingMessage>,
|
||||||
channel: &ClientChannel,
|
channel: &ClientChannel,
|
||||||
) -> bool {
|
) -> bool
|
||||||
|
where
|
||||||
|
W: Stream<Item = Result<Message, TungError>> + Sink<Message> + Unpin,
|
||||||
|
<W as Sink<Message>>::Error: Display,
|
||||||
|
{
|
||||||
let mut callbacks = HashMap::<Uuid, Callback>::new();
|
let mut callbacks = HashMap::<Uuid, Callback>::new();
|
||||||
loop {
|
loop {
|
||||||
select! {
|
select! {
|
||||||
@@ -242,3 +255,340 @@ impl ClientContext {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::messages::telemetry_definition::{
|
||||||
|
TelemetryDefinitionRequest, TelemetryDefinitionResponse,
|
||||||
|
};
|
||||||
|
use crate::test::mock_stream_sink::{create_mock_stream_sink, MockStreamSinkControl};
|
||||||
|
use api_core::data_type::DataType;
|
||||||
|
use log::LevelFilter;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::ops::Deref;
|
||||||
|
use tokio::sync::mpsc::Sender;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tokio::time::timeout;
|
||||||
|
use tokio::try_join;
|
||||||
|
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||||
|
use tokio_util::bytes::Bytes;
|
||||||
|
|
||||||
|
async fn assert_client_interaction<F, R>(future: F)
|
||||||
|
where
|
||||||
|
F: Send
|
||||||
|
+ FnOnce(
|
||||||
|
Sender<OutgoingMessage>,
|
||||||
|
MockStreamSinkControl<Result<Message, TungError>, Message>,
|
||||||
|
CancellationToken,
|
||||||
|
) -> R
|
||||||
|
+ 'static,
|
||||||
|
R: Future<Output = ()> + Send,
|
||||||
|
{
|
||||||
|
let (control, stream_sink) =
|
||||||
|
create_mock_stream_sink::<Result<Message, TungError>, Message>();
|
||||||
|
|
||||||
|
let cancel_token = CancellationToken::new();
|
||||||
|
let inner_cancel_token = cancel_token.clone();
|
||||||
|
let (connected_state_tx, _connected_state_rx) = watch::channel(false);
|
||||||
|
|
||||||
|
let mut context = ClientContext {
|
||||||
|
cancel: cancel_token,
|
||||||
|
request: "mock".into_client_request().unwrap(),
|
||||||
|
connected_state_tx,
|
||||||
|
client_configuration: Default::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let (tx, _rx) = mpsc::channel(1);
|
||||||
|
let channel = ClientChannel::new(RwLock::new(tx));
|
||||||
|
let used_channel = channel.clone();
|
||||||
|
|
||||||
|
let write_lock = used_channel.write().await;
|
||||||
|
|
||||||
|
let handle = spawn(async move {
|
||||||
|
let channel = channel;
|
||||||
|
let read = channel.read().await;
|
||||||
|
let sender = read.deref().clone();
|
||||||
|
drop(read);
|
||||||
|
future(sender, control, inner_cancel_token).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut stream_sink = Some(stream_sink);
|
||||||
|
|
||||||
|
let connection_fn = async |_: Request| {
|
||||||
|
let stream_sink = stream_sink.take().ok_or(TungError::ConnectionClosed)?;
|
||||||
|
|
||||||
|
Ok((stream_sink, TungResponse::default())) as Result<(_, _), TungError>
|
||||||
|
};
|
||||||
|
|
||||||
|
let context_result = async {
|
||||||
|
drop(
|
||||||
|
context
|
||||||
|
.run_connection(write_lock, &used_channel, connection_fn)
|
||||||
|
.await,
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
};
|
||||||
|
|
||||||
|
try_join!(context_result, timeout(Duration::from_secs(1), handle),)
|
||||||
|
.unwrap()
|
||||||
|
.1
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn connection_closes_when_websocket_closes() {
|
||||||
|
let _ = env_logger::builder()
|
||||||
|
.is_test(true)
|
||||||
|
.filter_level(LevelFilter::Trace)
|
||||||
|
.try_init();
|
||||||
|
|
||||||
|
assert_client_interaction(|sender, mut control, _| async move {
|
||||||
|
let msg = Uuid::new_v4();
|
||||||
|
sender
|
||||||
|
.send(OutgoingMessage {
|
||||||
|
msg: RequestMessage {
|
||||||
|
uuid: msg,
|
||||||
|
response: None,
|
||||||
|
payload: TelemetryDefinitionRequest {
|
||||||
|
name: "".to_string(),
|
||||||
|
data_type: DataType::Float32,
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
},
|
||||||
|
callback: Callback::None,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
// We expect an outgoing message
|
||||||
|
assert!(matches!(
|
||||||
|
control.outgoing.recv().await.unwrap(),
|
||||||
|
Message::Text(_)
|
||||||
|
));
|
||||||
|
// We receive an incoming close message
|
||||||
|
control
|
||||||
|
.incoming
|
||||||
|
.send(Ok(Message::Close(None)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
// Then we expect the outgoing to close with no message
|
||||||
|
assert!(control.outgoing.recv().await.is_none());
|
||||||
|
assert!(control.incoming.is_closed());
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn connection_closes_when_cancelled() {
|
||||||
|
let _ = env_logger::builder()
|
||||||
|
.is_test(true)
|
||||||
|
.filter_level(LevelFilter::Trace)
|
||||||
|
.try_init();
|
||||||
|
|
||||||
|
assert_client_interaction(|_, mut control, cancel| async move {
|
||||||
|
cancel.cancel();
|
||||||
|
// We expect an outgoing cancel message
|
||||||
|
assert!(matches!(
|
||||||
|
control.outgoing.recv().await.unwrap(),
|
||||||
|
Message::Close(_)
|
||||||
|
));
|
||||||
|
// Then we expect to close with no message
|
||||||
|
assert!(control.outgoing.recv().await.is_none());
|
||||||
|
assert!(control.incoming.is_closed());
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn callback_request() {
|
||||||
|
let _ = env_logger::builder()
|
||||||
|
.is_test(true)
|
||||||
|
.filter_level(LevelFilter::Trace)
|
||||||
|
.try_init();
|
||||||
|
|
||||||
|
assert_client_interaction(|sender, mut control, _| async move {
|
||||||
|
let (callback_tx, callback_rx) = oneshot::channel();
|
||||||
|
let msg = Uuid::new_v4();
|
||||||
|
sender
|
||||||
|
.send(OutgoingMessage {
|
||||||
|
msg: RequestMessage {
|
||||||
|
uuid: msg,
|
||||||
|
response: None,
|
||||||
|
payload: TelemetryDefinitionRequest {
|
||||||
|
name: "".to_string(),
|
||||||
|
data_type: DataType::Float32,
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
},
|
||||||
|
callback: Callback::Once(callback_tx),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// We expect an outgoing message
|
||||||
|
assert!(matches!(
|
||||||
|
control.outgoing.recv().await.unwrap(),
|
||||||
|
Message::Text(_)
|
||||||
|
));
|
||||||
|
|
||||||
|
// Then we get an incoming message for this callback
|
||||||
|
let response_message = ResponseMessage {
|
||||||
|
uuid: Uuid::new_v4(),
|
||||||
|
response: Some(msg),
|
||||||
|
payload: TelemetryDefinitionResponse {
|
||||||
|
uuid: Uuid::new_v4(),
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
};
|
||||||
|
control
|
||||||
|
.incoming
|
||||||
|
.send(Ok(Message::Text(
|
||||||
|
serde_json::to_string(&response_message).unwrap().into(),
|
||||||
|
)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// We expect the callback to run
|
||||||
|
let message = callback_rx.await.unwrap();
|
||||||
|
// And give us the message we provided it
|
||||||
|
assert_eq!(message, response_message);
|
||||||
|
|
||||||
|
// We receive an incoming close message
|
||||||
|
control
|
||||||
|
.incoming
|
||||||
|
.send(Ok(Message::Close(None)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
// Then we expect the outgoing to close with no message
|
||||||
|
assert!(control.outgoing.recv().await.is_none());
|
||||||
|
assert!(control.incoming.is_closed());
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn callback_registered() {
|
||||||
|
let _ = env_logger::builder()
|
||||||
|
.is_test(true)
|
||||||
|
.filter_level(LevelFilter::Trace)
|
||||||
|
.try_init();
|
||||||
|
|
||||||
|
assert_client_interaction(|sender, mut control, _| async move {
|
||||||
|
let (callback_tx, mut callback_rx) = mpsc::channel(1);
|
||||||
|
let msg = Uuid::new_v4();
|
||||||
|
sender
|
||||||
|
.send(OutgoingMessage {
|
||||||
|
msg: RequestMessage {
|
||||||
|
uuid: msg,
|
||||||
|
response: None,
|
||||||
|
payload: TelemetryDefinitionRequest {
|
||||||
|
name: "".to_string(),
|
||||||
|
data_type: DataType::Float32,
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
},
|
||||||
|
callback: Callback::Registered(callback_tx),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// We expect an outgoing message
|
||||||
|
assert!(matches!(
|
||||||
|
control.outgoing.recv().await.unwrap(),
|
||||||
|
Message::Text(_)
|
||||||
|
));
|
||||||
|
|
||||||
|
// We handle the callback a few times
|
||||||
|
for _ in 0..5 {
|
||||||
|
// Then we get an incoming message for this callback
|
||||||
|
let response_message = ResponseMessage {
|
||||||
|
uuid: Uuid::new_v4(),
|
||||||
|
response: Some(msg),
|
||||||
|
payload: TelemetryDefinitionResponse {
|
||||||
|
uuid: Uuid::new_v4(),
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
};
|
||||||
|
control
|
||||||
|
.incoming
|
||||||
|
.send(Ok(Message::Text(
|
||||||
|
serde_json::to_string(&response_message).unwrap().into(),
|
||||||
|
)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// We expect the response
|
||||||
|
let (rx, responder) = callback_rx.recv().await.unwrap();
|
||||||
|
// And give us the message we provided it
|
||||||
|
assert_eq!(rx, response_message);
|
||||||
|
// Then the response gets sent out
|
||||||
|
responder
|
||||||
|
.send(
|
||||||
|
TelemetryDefinitionRequest {
|
||||||
|
name: "".to_string(),
|
||||||
|
data_type: DataType::Float32,
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// We expect an outgoing message
|
||||||
|
assert!(matches!(
|
||||||
|
control.outgoing.recv().await.unwrap(),
|
||||||
|
Message::Text(_)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// We receive an incoming close message
|
||||||
|
control
|
||||||
|
.incoming
|
||||||
|
.send(Ok(Message::Close(None)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
// Then we expect the outgoing to close with no message
|
||||||
|
assert!(control.outgoing.recv().await.is_none());
|
||||||
|
assert!(control.incoming.is_closed());
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ping_pong() {
|
||||||
|
let _ = env_logger::builder()
|
||||||
|
.is_test(true)
|
||||||
|
.filter_level(LevelFilter::Trace)
|
||||||
|
.try_init();
|
||||||
|
|
||||||
|
assert_client_interaction(|_, mut control, _| async move {
|
||||||
|
// Expect a pong in response to a ping
|
||||||
|
let bytes = Bytes::from_owner(Uuid::new_v4().into_bytes());
|
||||||
|
control
|
||||||
|
.incoming
|
||||||
|
.send(Ok(Message::Ping(bytes.clone())))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let Some(Message::Pong(pong_bytes)) = control.outgoing.recv().await else {
|
||||||
|
panic!("Expected Pong Response");
|
||||||
|
};
|
||||||
|
assert_eq!(bytes, pong_bytes);
|
||||||
|
|
||||||
|
// Nothing should happen
|
||||||
|
control
|
||||||
|
.incoming
|
||||||
|
.send(Ok(Message::Pong(bytes.clone())))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// We receive an incoming close message
|
||||||
|
control
|
||||||
|
.incoming
|
||||||
|
.send(Ok(Message::Close(None)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
// Then we expect the outgoing to close with no message
|
||||||
|
assert!(control.outgoing.recv().await.is_none());
|
||||||
|
assert!(control.incoming.is_closed());
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ use uuid::Uuid;
|
|||||||
type RegisteredCallback = mpsc::Sender<(ResponseMessage, oneshot::Sender<RequestMessagePayload>)>;
|
type RegisteredCallback = mpsc::Sender<(ResponseMessage, oneshot::Sender<RequestMessagePayload>)>;
|
||||||
type ClientChannel = Arc<RwLock<mpsc::Sender<OutgoingMessage>>>;
|
type ClientChannel = Arc<RwLock<mpsc::Sender<OutgoingMessage>>>;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
enum Callback {
|
enum Callback {
|
||||||
None,
|
None,
|
||||||
Once(oneshot::Sender<ResponseMessage>),
|
Once(oneshot::Sender<ResponseMessage>),
|
||||||
@@ -264,3 +265,334 @@ impl Drop for Client {
|
|||||||
self.cancel.cancel();
|
self.cancel.cancel();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::messages::command::CommandResponse;
|
||||||
|
use crate::messages::telemetry_definition::{
|
||||||
|
TelemetryDefinitionRequest, TelemetryDefinitionResponse,
|
||||||
|
};
|
||||||
|
use crate::messages::telemetry_entry::TelemetryEntry;
|
||||||
|
use api_core::command::{Command, CommandDefinition, CommandHeader};
|
||||||
|
use api_core::data_type::DataType;
|
||||||
|
use chrono::Utc;
|
||||||
|
use futures_util::future::{select, Either};
|
||||||
|
use futures_util::FutureExt;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::pin::{pin, Pin};
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::join;
|
||||||
|
use tokio::time::{sleep, timeout};
|
||||||
|
|
||||||
|
fn create_test_client() -> (mpsc::Receiver<OutgoingMessage>, watch::Sender<bool>, Client) {
|
||||||
|
let cancel = CancellationToken::new();
|
||||||
|
let (tx, rx) = mpsc::channel(1);
|
||||||
|
let channel = Arc::new(RwLock::new(tx));
|
||||||
|
let (connected_state_tx, connected_state_rx) = watch::channel(true);
|
||||||
|
let client = Client {
|
||||||
|
cancel,
|
||||||
|
channel,
|
||||||
|
connected_state_rx,
|
||||||
|
};
|
||||||
|
(rx, connected_state_tx, client)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn send_message() {
|
||||||
|
let (mut rx, _, client) = create_test_client();
|
||||||
|
|
||||||
|
let msg_to_send = TelemetryEntry {
|
||||||
|
uuid: Uuid::new_v4(),
|
||||||
|
value: 0.0f32.into(),
|
||||||
|
timestamp: Utc::now(),
|
||||||
|
};
|
||||||
|
let msg_send = timeout(
|
||||||
|
Duration::from_secs(1),
|
||||||
|
client.send_message(msg_to_send.clone()),
|
||||||
|
);
|
||||||
|
let msg_recv = timeout(Duration::from_secs(1), rx.recv());
|
||||||
|
|
||||||
|
let (send, recv) = join!(msg_send, msg_recv);
|
||||||
|
send.unwrap().unwrap();
|
||||||
|
let recv = recv.unwrap().unwrap();
|
||||||
|
|
||||||
|
assert!(matches!(recv.callback, Callback::None));
|
||||||
|
assert!(recv.msg.response.is_none());
|
||||||
|
// uuid should be random
|
||||||
|
|
||||||
|
let RequestMessagePayload::TelemetryEntry(recv) = recv.msg.payload else {
|
||||||
|
panic!("Wrong Message Received")
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(recv, msg_to_send);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn send_message_if_connected() {
|
||||||
|
let (mut rx, _, client) = create_test_client();
|
||||||
|
|
||||||
|
let msg_to_send = TelemetryEntry {
|
||||||
|
uuid: Uuid::new_v4(),
|
||||||
|
value: 0.0f32.into(),
|
||||||
|
timestamp: Utc::now(),
|
||||||
|
};
|
||||||
|
let msg_send = timeout(
|
||||||
|
Duration::from_secs(1),
|
||||||
|
client.send_message_if_connected(msg_to_send.clone()),
|
||||||
|
);
|
||||||
|
let msg_recv = timeout(Duration::from_secs(1), rx.recv());
|
||||||
|
|
||||||
|
let (send, recv) = join!(msg_send, msg_recv);
|
||||||
|
send.unwrap().unwrap();
|
||||||
|
let recv = recv.unwrap().unwrap();
|
||||||
|
|
||||||
|
assert!(matches!(recv.callback, Callback::None));
|
||||||
|
assert!(recv.msg.response.is_none());
|
||||||
|
// uuid should be random
|
||||||
|
|
||||||
|
let RequestMessagePayload::TelemetryEntry(recv) = recv.msg.payload else {
|
||||||
|
panic!("Wrong Message Received")
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(recv, msg_to_send);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn send_message_if_connected_not_connected() {
|
||||||
|
let (_, connected_state_tx, client) = create_test_client();
|
||||||
|
|
||||||
|
let lock = client.channel.write().await;
|
||||||
|
connected_state_tx.send_replace(false);
|
||||||
|
|
||||||
|
let msg_to_send = TelemetryEntry {
|
||||||
|
uuid: Uuid::new_v4(),
|
||||||
|
value: 0.0f32.into(),
|
||||||
|
timestamp: Utc::now(),
|
||||||
|
};
|
||||||
|
let msg_send = timeout(
|
||||||
|
Duration::from_secs(1),
|
||||||
|
client.send_message_if_connected(msg_to_send.clone()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let Err(MessageError::TokioLockError(_)) = msg_send.await.unwrap() else {
|
||||||
|
panic!("Expected to Err due to lock being unavailable")
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn try_send_message() {
|
||||||
|
let (_tx, _, client) = create_test_client();
|
||||||
|
|
||||||
|
let msg_to_send = TelemetryEntry {
|
||||||
|
uuid: Uuid::new_v4(),
|
||||||
|
value: 0.0f32.into(),
|
||||||
|
timestamp: Utc::now(),
|
||||||
|
};
|
||||||
|
client.try_send_message(msg_to_send.clone()).unwrap();
|
||||||
|
let Err(MessageError::TokioTrySendError(_)) = client.try_send_message(msg_to_send.clone())
|
||||||
|
else {
|
||||||
|
panic!("Expected the buffer to be full");
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn send_request() {
|
||||||
|
let (mut tx, _, client) = create_test_client();
|
||||||
|
|
||||||
|
let msg_to_send = TelemetryDefinitionRequest {
|
||||||
|
name: "".to_string(),
|
||||||
|
data_type: DataType::Float32,
|
||||||
|
};
|
||||||
|
let response = timeout(
|
||||||
|
Duration::from_secs(1),
|
||||||
|
client.send_request(msg_to_send.clone()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let response_uuid = Uuid::new_v4();
|
||||||
|
let outgoing_rx = timeout(Duration::from_secs(1), async {
|
||||||
|
let msg = tx.recv().await.unwrap();
|
||||||
|
let Callback::Once(cb) = msg.callback else {
|
||||||
|
panic!("Wrong Callback Type")
|
||||||
|
};
|
||||||
|
cb.send(ResponseMessage {
|
||||||
|
uuid: Uuid::new_v4(),
|
||||||
|
response: Some(msg.msg.uuid),
|
||||||
|
payload: TelemetryDefinitionResponse {
|
||||||
|
uuid: response_uuid,
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let (response, outgoing_rx) = join!(response, outgoing_rx);
|
||||||
|
let response = response.unwrap().unwrap();
|
||||||
|
outgoing_rx.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(response.uuid, response_uuid);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn register_callback_channel() {
|
||||||
|
let (mut tx, _, client) = create_test_client();
|
||||||
|
|
||||||
|
let msg_to_send = CommandDefinition {
|
||||||
|
name: "".to_string(),
|
||||||
|
parameters: vec![],
|
||||||
|
};
|
||||||
|
let mut response = timeout(
|
||||||
|
Duration::from_secs(1),
|
||||||
|
client.register_callback_channel(msg_to_send),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let response_uuid = Uuid::new_v4();
|
||||||
|
let outgoing_rx = timeout(Duration::from_secs(1), async {
|
||||||
|
let msg = tx.recv().await.unwrap();
|
||||||
|
let Callback::Registered(cb) = msg.callback else {
|
||||||
|
panic!("Wrong Callback Type")
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check that we get responses to the callback the expected number of times
|
||||||
|
for i in 0..5 {
|
||||||
|
let (tx, rx) = oneshot::channel();
|
||||||
|
cb.send((
|
||||||
|
ResponseMessage {
|
||||||
|
uuid: Uuid::new_v4(),
|
||||||
|
response: Some(msg.msg.uuid),
|
||||||
|
payload: Command {
|
||||||
|
header: CommandHeader {
|
||||||
|
timestamp: Utc::now(),
|
||||||
|
},
|
||||||
|
parameters: Default::default(),
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
},
|
||||||
|
tx,
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let RequestMessagePayload::CommandResponse(response) = rx.await.unwrap() else {
|
||||||
|
panic!("Unexpected Response Type");
|
||||||
|
};
|
||||||
|
assert_eq!(response.response, format!("{i}"));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let responder = timeout(Duration::from_secs(1), async {
|
||||||
|
for i in 0..5 {
|
||||||
|
let (cmd, responder) = response.recv().await.unwrap();
|
||||||
|
responder
|
||||||
|
.send(CommandResponse {
|
||||||
|
success: false,
|
||||||
|
response: format!("{i}"),
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let (response, outgoing_rx) = join!(responder, outgoing_rx);
|
||||||
|
response.unwrap();
|
||||||
|
outgoing_rx.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn register_callback_fn() {
|
||||||
|
let (mut tx, _, client) = create_test_client();
|
||||||
|
|
||||||
|
let msg_to_send = CommandDefinition {
|
||||||
|
name: "".to_string(),
|
||||||
|
parameters: vec![],
|
||||||
|
};
|
||||||
|
let mut index = 0usize;
|
||||||
|
timeout(
|
||||||
|
Duration::from_secs(1),
|
||||||
|
client.register_callback_fn(msg_to_send, move |cmd| {
|
||||||
|
index += 1;
|
||||||
|
CommandResponse {
|
||||||
|
success: false,
|
||||||
|
response: format!("{}", index - 1),
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
timeout(Duration::from_secs(1), async {
|
||||||
|
let msg = tx.recv().await.unwrap();
|
||||||
|
let Callback::Registered(cb) = msg.callback else {
|
||||||
|
panic!("Wrong Callback Type")
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check that we get responses to the callback the expected number of times
|
||||||
|
for i in 0..3 {
|
||||||
|
let (tx, rx) = oneshot::channel();
|
||||||
|
cb.send((
|
||||||
|
ResponseMessage {
|
||||||
|
uuid: Uuid::new_v4(),
|
||||||
|
response: Some(msg.msg.uuid),
|
||||||
|
payload: Command {
|
||||||
|
header: CommandHeader {
|
||||||
|
timestamp: Utc::now(),
|
||||||
|
},
|
||||||
|
parameters: Default::default(),
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
},
|
||||||
|
tx,
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let RequestMessagePayload::CommandResponse(response) = rx.await.unwrap() else {
|
||||||
|
panic!("Unexpected Response Type");
|
||||||
|
};
|
||||||
|
assert_eq!(response.response, format!("{i}"));
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn connected_disconnected() {
|
||||||
|
let (_, connected, client) = create_test_client();
|
||||||
|
|
||||||
|
// When we're connected we should return immediately
|
||||||
|
connected.send_replace(true);
|
||||||
|
client.wait_connected().now_or_never().unwrap();
|
||||||
|
|
||||||
|
// When we're disconnected we should return immediately
|
||||||
|
connected.send_replace(false);
|
||||||
|
client.wait_disconnected().now_or_never().unwrap();
|
||||||
|
|
||||||
|
let c2 = connected.clone();
|
||||||
|
// When we're disconnected, we should not return immediately
|
||||||
|
let f1 = pin!(client.wait_connected());
|
||||||
|
let f2 = pin!(async move {
|
||||||
|
sleep(Duration::from_millis(1)).await;
|
||||||
|
c2.send_replace(true);
|
||||||
|
});
|
||||||
|
let r = select(f1, f2).await;
|
||||||
|
match r {
|
||||||
|
Either::Left(_) => panic!("Wait Connected Finished Before Connection Changed"),
|
||||||
|
Either::Right((_, other)) => timeout(Duration::from_secs(1), other).await.unwrap(),
|
||||||
|
}
|
||||||
|
|
||||||
|
let c2 = connected.clone();
|
||||||
|
// When we're disconnected, we should not return immediately
|
||||||
|
let f1 = pin!(client.wait_disconnected());
|
||||||
|
let f2 = pin!(async move {
|
||||||
|
sleep(Duration::from_millis(1)).await;
|
||||||
|
c2.send_replace(false);
|
||||||
|
});
|
||||||
|
let r = select(f1, f2).await;
|
||||||
|
match r {
|
||||||
|
Either::Left(_) => panic!("Wait Disconnected Finished Before Connection Changed"),
|
||||||
|
Either::Right((_, other)) => timeout(Duration::from_secs(1), other).await.unwrap(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -163,3 +163,6 @@ impl<T: Into<DataValue>> TelemetryHandle<T> {
|
|||||||
self.publish(value, Utc::now()).await
|
self.publish(value, Utc::now()).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {}
|
||||||
|
|||||||
@@ -10,3 +10,6 @@ pub mod messages;
|
|||||||
pub mod macros {
|
pub mod macros {
|
||||||
pub use api_proc_macro::IntoCommandDefinition;
|
pub use api_proc_macro::IntoCommandDefinition;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub mod test;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
pub enum GenericCallbackError {
|
pub enum GenericCallbackError {
|
||||||
CallbackClosed,
|
CallbackClosed,
|
||||||
MismatchedType,
|
MismatchedType,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ impl RegisterCallback for CommandDefinition {
|
|||||||
type Response = CommandResponse;
|
type Response = CommandResponse;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
pub struct CommandResponse {
|
pub struct CommandResponse {
|
||||||
pub success: bool,
|
pub success: bool,
|
||||||
pub response: String,
|
pub response: String,
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ pub struct RequestMessage {
|
|||||||
pub payload: RequestMessagePayload,
|
pub payload: RequestMessagePayload,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct ResponseMessage {
|
pub struct ResponseMessage {
|
||||||
pub uuid: Uuid,
|
pub uuid: Uuid,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ use crate::messages::telemetry_entry::TelemetryEntry;
|
|||||||
use derive_more::{From, TryInto};
|
use derive_more::{From, TryInto};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, From)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, From)]
|
||||||
pub enum RequestMessagePayload {
|
pub enum RequestMessagePayload {
|
||||||
TelemetryDefinitionRequest(TelemetryDefinitionRequest),
|
TelemetryDefinitionRequest(TelemetryDefinitionRequest),
|
||||||
TelemetryEntry(TelemetryEntry),
|
TelemetryEntry(TelemetryEntry),
|
||||||
@@ -16,7 +16,7 @@ pub enum RequestMessagePayload {
|
|||||||
CommandResponse(CommandResponse),
|
CommandResponse(CommandResponse),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, From, TryInto)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, From, TryInto)]
|
||||||
pub enum ResponseMessagePayload {
|
pub enum ResponseMessagePayload {
|
||||||
TelemetryDefinitionResponse(TelemetryDefinitionResponse),
|
TelemetryDefinitionResponse(TelemetryDefinitionResponse),
|
||||||
Command(Command),
|
Command(Command),
|
||||||
|
|||||||
@@ -3,13 +3,13 @@ use crate::messages::RequestResponse;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
pub struct TelemetryDefinitionRequest {
|
pub struct TelemetryDefinitionRequest {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub data_type: DataType,
|
pub data_type: DataType,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
pub struct TelemetryDefinitionResponse {
|
pub struct TelemetryDefinitionResponse {
|
||||||
pub uuid: Uuid,
|
pub uuid: Uuid,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use chrono::{DateTime, Utc};
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct TelemetryEntry {
|
pub struct TelemetryEntry {
|
||||||
pub uuid: Uuid,
|
pub uuid: Uuid,
|
||||||
pub value: DataValue,
|
pub value: DataValue,
|
||||||
|
|||||||
82
api/src/test/mock_stream_sink.rs
Normal file
82
api/src/test/mock_stream_sink.rs
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
use futures_util::sink::{unfold, Unfold};
|
||||||
|
use futures_util::{Sink, SinkExt, Stream};
|
||||||
|
use std::fmt::Display;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio::sync::mpsc::error::SendError;
|
||||||
|
use tokio::sync::mpsc::{Receiver, Sender};
|
||||||
|
|
||||||
|
pub struct MockStreamSinkControl<T, R> {
|
||||||
|
pub incoming: Sender<T>,
|
||||||
|
pub outgoing: Receiver<R>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MockStreamSink<T, U1, U2> {
|
||||||
|
stream_rx: Receiver<T>,
|
||||||
|
sink_tx: Pin<Box<Unfold<u32, U1, U2>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, U1, U2> Stream for MockStreamSink<T, U1, U2>
|
||||||
|
where
|
||||||
|
Self: Unpin,
|
||||||
|
{
|
||||||
|
type Item = T;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
self.stream_rx.poll_recv(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, R, U1, U2, E> Sink<R> for MockStreamSink<T, U1, U2>
|
||||||
|
where
|
||||||
|
U1: FnMut(u32, R) -> U2,
|
||||||
|
U2: Future<Output = Result<u32, E>>,
|
||||||
|
{
|
||||||
|
type Error = E;
|
||||||
|
|
||||||
|
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.sink_tx.poll_ready_unpin(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_send(mut self: Pin<&mut Self>, item: R) -> Result<(), Self::Error> {
|
||||||
|
self.sink_tx.start_send_unpin(item)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.sink_tx.poll_flush_unpin(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.sink_tx.poll_close_unpin(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_mock_stream_sink<T: Send, R: Send + 'static>() -> (
|
||||||
|
MockStreamSinkControl<T, R>,
|
||||||
|
impl Stream<Item = T> + Sink<R, Error = impl Display>,
|
||||||
|
) {
|
||||||
|
let (stream_tx, stream_rx) = mpsc::channel::<T>(1);
|
||||||
|
let (sink_tx, sink_rx) = mpsc::channel::<R>(1);
|
||||||
|
|
||||||
|
let sink_tx = Arc::new(sink_tx);
|
||||||
|
|
||||||
|
(
|
||||||
|
MockStreamSinkControl {
|
||||||
|
incoming: stream_tx,
|
||||||
|
outgoing: sink_rx,
|
||||||
|
},
|
||||||
|
MockStreamSink::<T, _, _> {
|
||||||
|
stream_rx,
|
||||||
|
sink_tx: Box::pin(unfold(0u32, move |_, item| {
|
||||||
|
let sink_tx = sink_tx.clone();
|
||||||
|
async move {
|
||||||
|
sink_tx.send(item).await?;
|
||||||
|
Ok(0u32) as Result<_, SendError<R>>
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
1
api/src/test/mod.rs
Normal file
1
api/src/test/mod.rs
Normal file
@@ -0,0 +1 @@
|
|||||||
|
pub mod mock_stream_sink;
|
||||||
Reference in New Issue
Block a user