use crate::command::service::CommandManagementService; use crate::core::client_side_command::Inner; use crate::core::command_service_server::CommandService; use crate::core::{ClientSideCommand, Command, CommandResponse, Uuid}; use log::{error, trace}; use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; use tokio::select; use tokio::sync::mpsc; use tokio::sync::oneshot; use tokio_util::sync::CancellationToken; use tonic::codegen::tokio_stream::wrappers::ReceiverStream; use tonic::codegen::tokio_stream::{Stream, StreamExt}; use tonic::{Request, Response, Status, Streaming}; pub struct CoreCommandService { pub command_service: Arc, pub cancellation_token: CancellationToken, } #[tonic::async_trait] impl CommandService for CoreCommandService { type NewCommandStream = Pin> + Send>>; async fn new_command( &self, request: Request>, ) -> Result, Status> { trace!("CoreCommandService::new_command"); let cancel_token = self.cancellation_token.clone(); let mut in_stream = request.into_inner(); let cmd_request = select! { _ = cancel_token.cancelled() => return Err(Status::internal("Shutting Down")), Some(message) = in_stream.next() => { match message { Ok(ClientSideCommand { inner: Some(Inner::Request(cmd_request)) }) => cmd_request, Err(err) => { error!("Error in Stream: {err}"); return Err(Status::cancelled("Error in Stream")); }, _ => { return Err(Status::invalid_argument("First message must be request")); }, } }, else => return Err(Status::internal("Shutting Down")), }; let mut cmd_rx = match self.command_service.register_command(cmd_request).await { Ok(rx) => rx, Err(e) => { error!("Failed to register command: {e}"); return Err(Status::internal("Failed to register command")); } }; let (tx, rx) = mpsc::channel(128); tokio::spawn(async move { let mut result = Status::resource_exhausted("End of Command Stream"); let mut in_progress = HashMap::>::new(); loop { select! { _ = cancel_token.cancelled() => break, _ = tx.closed() => break, Some(message) = cmd_rx.recv() => { match message { None => break, Some(message) => { let key = message.0.uuid.clone().unwrap().value; in_progress.insert(key.clone(), message.1); match tx.send(Ok(message.0)).await { Ok(()) => {}, Err(e) => { error!("Failed to send command data: {e}"); if in_progress.remove(&key).unwrap().send(CommandResponse { uuid: Some(Uuid::from(key)), success: false, response: "Failed to send command data.".to_string(), }).is_err() { error!("Failed to send command response on failure to send command data"); } break; } } } } }, Some(message) = in_stream.next() => { match message { Ok(message) => { match message.inner { Some(Inner::Response(response)) => { if let Some(uuid) = &response.uuid { match in_progress.remove(&uuid.value) { Some(sender) => { if sender.send(response).is_err() { error!("Failed to send command response on success") } } None => { result = Status::invalid_argument("Invalid Command UUID"); break; } } } } _ => { result = Status::invalid_argument("Subsequent Message Must Be Command Responses"); break; } } } Err(e) => { error!("Received error from command handler {e}"); break }, } } else => break, } } cmd_rx.close(); if !tx.is_closed() { match tx.send(Err(result)).await { Ok(()) => {} Err(e) => { error!("Failed to close old command sender {e}"); } } } for (key, sender) in in_progress.drain() { if sender .send(CommandResponse { uuid: Some(Uuid::from(key)), success: false, response: "Command Handler Shut Down".to_string(), }) .is_err() { error!("Failed to send command response on shutdown"); } } }); Ok(Response::new(Box::pin(ReceiverStream::new(rx)))) } }