improve reconnect logic

This commit is contained in:
2025-12-30 18:33:42 -05:00
parent 6a5e3e2b24
commit a3aeff1d6f
9 changed files with 342 additions and 50 deletions

View File

@@ -12,7 +12,7 @@ serde = { workspace = true, features = ["derive"] }
derive_more = { workspace = true, features = ["from", "try_into"] }
uuid = { workspace = true, features = ["serde"] }
chrono = { workspace = true, features = ["serde"] }
tokio = { workspace = true, features = ["rt", "macros"] }
tokio = { workspace = true, features = ["rt", "macros", "time"] }
tokio-tungstenite = { workspace = true, features = ["rustls-tls-native-roots"] }
tokio-util = { workspace = true }
futures-util = { workspace = true }

View File

@@ -10,13 +10,15 @@ use crate::messages::{
use error::ConnectError;
use futures_util::stream::StreamExt;
use futures_util::SinkExt;
use log::{debug, error, warn};
use log::{debug, error, info, trace, warn};
use std::collections::HashMap;
use std::sync::mpsc::sync_channel;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot, RwLock, RwLockWriteGuard};
use tokio::sync::{mpsc, oneshot, watch, RwLock, RwLockWriteGuard};
use tokio::time::sleep;
use tokio::{select, spawn};
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::handshake::client::Request;
@@ -42,11 +44,13 @@ struct OutgoingMessage {
pub struct Client {
cancel: CancellationToken,
channel: ClientChannel,
connected_state_rx: watch::Receiver<bool>,
}
struct ClientContext {
cancel: CancellationToken,
request: Request,
connected_state_tx: watch::Sender<bool>,
}
impl Client {
@@ -57,14 +61,20 @@ impl Client {
let (tx, _rx) = mpsc::channel(1);
let cancel = CancellationToken::new();
let channel = Arc::new(RwLock::new(tx));
let (connected_state_tx, connected_state_rx) = watch::channel(false);
let context = ClientContext {
cancel: cancel.clone(),
request: request.into_client_request()?,
connected_state_tx,
};
context.start(channel.clone())?;
Ok(Self { cancel, channel })
Ok(Self {
cancel,
channel,
connected_state_rx,
})
}
pub async fn send_message<M: ClientMessage>(&self, msg: M) -> Result<(), MessageError> {
@@ -226,6 +236,28 @@ impl Client {
Ok(())
}
pub async fn wait_connected(&self) {
let mut connected_rx = self.connected_state_rx.clone();
// If we aren't currently connected
if !*connected_rx.borrow_and_update() {
// Wait for a change notification
// If the channel is closed there is nothing we can do
let _ = connected_rx.changed().await;
}
}
pub async fn wait_disconnected(&self) {
let mut connected_rx = self.connected_state_rx.clone();
// If we are currently connected
if *connected_rx.borrow_and_update() {
// Wait for a change notification
// If the channel is closed there is nothing we can do
let _ = connected_rx.changed().await;
}
}
}
impl ClientContext {
@@ -263,24 +295,33 @@ impl ClientContext {
mut write_lock: RwLockWriteGuard<'a, mpsc::Sender<OutgoingMessage>>,
channel: &'a ClientChannel,
) -> RwLockWriteGuard<'a, mpsc::Sender<OutgoingMessage>> {
debug!("Attempting to Connect to {}", self.request.uri());
let mut ws = match connect_async(self.request.clone()).await {
Ok((ws, _)) => ws,
Err(e) => {
error!("Connect Error: {e}");
info!("Failed to Connect: {e}");
sleep(Duration::from_secs(1)).await;
return write_lock;
}
};
info!("Connected to {}", self.request.uri());
let (tx, rx) = mpsc::channel(128);
*write_lock = tx;
drop(write_lock);
// Don't care about the previous value
let _ = self.connected_state_tx.send_replace(true);
let close_connection = self.handle_connection(&mut ws, rx, channel).await;
let write_lock = channel.write().await;
// Send this after grabbing the lock - to prevent extra contention when others try to grab
// the lock to use that as a signal that we have reconnected
let _ = self.connected_state_tx.send_replace(false);
if close_connection {
if let Err(e) = ws.close(None).await {
println!("Close Error {e}");
error!("Failed to Close the Connection: {e}");
}
}
write_lock
@@ -301,6 +342,7 @@ impl ClientContext {
Ok(msg) => {
match msg {
Message::Text(msg) => {
trace!("Incoming: {msg}");
let msg: ResponseMessage = match serde_json::from_str(&msg) {
Ok(m) => m,
Err(e) => {
@@ -345,6 +387,7 @@ impl ClientContext {
break;
}
};
trace!("Outgoing: {msg}");
if let Err(e) = ws.send(Message::Text(msg.into())).await {
error!("Send Error {e}");
break;

View File

@@ -11,6 +11,8 @@ use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestMessage {
pub uuid: Uuid,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<Uuid>,
#[serde(flatten)]
pub payload: RequestMessagePayload,
@@ -19,6 +21,8 @@ pub struct RequestMessage {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseMessage {
pub uuid: Uuid,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<Uuid>,
#[serde(flatten)]
pub payload: ResponseMessagePayload,