diff --git a/crates/transcribe-whisper-local/src/service/streaming.rs b/crates/transcribe-whisper-local/src/service/streaming.rs index 0c8294169f..21dc01a183 100644 --- a/crates/transcribe-whisper-local/src/service/streaming.rs +++ b/crates/transcribe-whisper-local/src/service/streaming.rs @@ -113,7 +113,7 @@ where } }; - let guard = connection_manager.acquire_connection(); + let guard = connection_manager.acquire_connection().await; Ok(ws_upgrade .on_upgrade(move |socket| async move { diff --git a/crates/ws-utils/src/lib.rs b/crates/ws-utils/src/lib.rs index 362cf3f2d9..2d2f8aa1a2 100644 --- a/crates/ws-utils/src/lib.rs +++ b/crates/ws-utils/src/lib.rs @@ -1,4 +1,10 @@ +//! Utilities for WebSocket-based audio streaming. +//! +//! This crate provides connection management and audio source abstractions +//! for WebSocket-based audio clients. + mod manager; + pub use manager::*; use std::pin::Pin; diff --git a/crates/ws-utils/src/manager.rs b/crates/ws-utils/src/manager.rs index 1b66b72039..dd1a67d4dc 100644 --- a/crates/ws-utils/src/manager.rs +++ b/crates/ws-utils/src/manager.rs @@ -1,22 +1,23 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; +use tokio::sync::RwLock; use tokio_util::sync::CancellationToken; #[derive(Clone)] pub struct ConnectionManager { - inner: Arc>>, + token: Arc>>, } impl Default for ConnectionManager { fn default() -> Self { Self { - inner: Arc::new(Mutex::new(None)), + token: Arc::new(RwLock::new(None)), } } } impl ConnectionManager { - pub fn acquire_connection(&self) -> ConnectionGuard { - let mut slot = self.inner.lock().unwrap(); + pub async fn acquire_connection(&self) -> ConnectionGuard { + let mut slot = self.token.write().await; if let Some(old) = slot.take() { old.cancel(); @@ -27,6 +28,13 @@ impl ConnectionManager { ConnectionGuard { token } } + + pub async fn cancel_all(&self) { + let mut slot = self.token.write().await; + if let Some(token) = slot.take() { + token.cancel(); + } + } } pub struct ConnectionGuard { @@ -34,7 +42,15 @@ pub struct ConnectionGuard { } impl ConnectionGuard { + pub fn is_cancelled(&self) -> bool { + self.token.is_cancelled() + } + pub async fn cancelled(&self) { self.token.cancelled().await } + + pub fn child_token(&self) -> CancellationToken { + self.token.child_token() + } } diff --git a/crates/ws/Cargo.toml b/crates/ws/Cargo.toml index 4922b0de70..6cc559af7e 100644 --- a/crates/ws/Cargo.toml +++ b/crates/ws/Cargo.toml @@ -11,6 +11,7 @@ server = [] [dependencies] bytes = { workspace = true } serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } thiserror = { workspace = true } async-stream = { workspace = true } @@ -19,6 +20,3 @@ futures-util = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "time", "sync", "macros"] } tokio-tungstenite = { workspace = true, features = ["native-tls-vendored"] } tracing = { workspace = true } - -[dev-dependencies] -serde_json.workspace = true diff --git a/crates/ws/src/client.rs b/crates/ws/src/client.rs index 1a35f8b4cb..31b5ce08a8 100644 --- a/crates/ws/src/client.rs +++ b/crates/ws/src/client.rs @@ -7,6 +7,7 @@ use futures_util::{ }; use tokio_tungstenite::{connect_async, tungstenite::client::IntoClientRequest}; +pub use crate::config::{ConnectionConfig, KeepAliveConfig, RetryConfig}; pub use tokio_tungstenite::tungstenite::{protocol::Message, ClientRequestBuilder, Utf8Bytes}; #[derive(Debug)] @@ -14,12 +15,6 @@ enum ControlCommand { Finalize(Option), } -#[derive(Clone)] -struct KeepAliveConfig { - interval: std::time::Duration, - message: Message, -} - #[derive(Clone)] pub struct WebSocketHandle { control_tx: tokio::sync::mpsc::UnboundedSender, @@ -27,12 +22,44 @@ pub struct WebSocketHandle { impl WebSocketHandle { pub async fn finalize_with_text(&self, text: Utf8Bytes) { - let _ = self + if self .control_tx - .send(ControlCommand::Finalize(Some(Message::Text(text)))); + .send(ControlCommand::Finalize(Some(Message::Text(text)))) + .is_err() + { + tracing::warn!("control channel closed, cannot send finalize command"); + } } } +pub struct SendTask { + handle: tokio::task::JoinHandle>, +} + +impl SendTask { + pub async fn wait(self) -> Result<(), crate::Error> { + match self.handle.await { + Ok(result) => result, + Err(join_err) if join_err.is_panic() => { + std::panic::resume_unwind(join_err.into_panic()); + } + Err(join_err) => { + tracing::error!("send task cancelled: {:?}", join_err); + Err(crate::Error::UnexpectedClose) + } + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum DecodeError { + #[error("unsupported message type")] + UnsupportedType, + + #[error("deserialization failed: {0}")] + DeserializationError(#[from] serde_json::Error), +} + pub trait WebSocketIO: Send + 'static { type Data: Send; type Input: Send; @@ -40,12 +67,13 @@ pub trait WebSocketIO: Send + 'static { fn to_input(data: Self::Data) -> Self::Input; fn to_message(input: Self::Input) -> Message; - fn from_message(msg: Message) -> Option; + fn decode(msg: Message) -> Result; } pub struct WebSocketClient { request: ClientRequestBuilder, keep_alive: Option, + config: ConnectionConfig, } impl WebSocketClient { @@ -53,9 +81,20 @@ impl WebSocketClient { Self { request, keep_alive: None, + config: ConnectionConfig::default(), } } + pub fn with_config(mut self, config: ConnectionConfig) -> Self { + self.config = config; + self + } + + pub fn with_keep_alive(mut self, config: KeepAliveConfig) -> Self { + self.keep_alive = Some(config); + self + } + pub fn with_keep_alive_message( mut self, interval: std::time::Duration, @@ -73,15 +112,18 @@ impl WebSocketClient { ( impl Stream>, WebSocketHandle, + SendTask, ), crate::Error, > { let keep_alive_config = self.keep_alive.clone(); + let close_grace_period = self.config.close_grace_period; + let retry_config = self.config.retry_config.clone(); let ws_stream = (|| self.try_connect(self.request.clone())) .retry( ConstantBuilder::default() - .with_max_times(5) - .with_delay(std::time::Duration::from_millis(500)), + .with_max_times(retry_config.max_attempts) + .with_delay(retry_config.delay), ) .when(|e| { tracing::error!("ws_connect_failed: {:?}", e); @@ -96,12 +138,16 @@ impl WebSocketClient { let (error_tx, mut error_rx) = tokio::sync::mpsc::unbounded_channel::(); let handle = WebSocketHandle { control_tx }; - let _send_task = tokio::spawn(async move { + let send_task = tokio::spawn(async move { if let Some(msg) = initial_message { if let Err(e) = ws_sender.send(msg).await { tracing::error!("ws_initial_message_failed: {:?}", e); - let _ = error_tx.send(e.into()); - return; + if error_tx.send(e.into()).is_err() { + tracing::warn!("output stream already closed, cannot propagate error"); + } + return Err(crate::Error::DataSend { + context: "initial message".to_string(), + }); } } @@ -120,7 +166,9 @@ impl WebSocketClient { if let Some(cfg) = keep_alive_config.as_ref() { if let Err(e) = ws_sender.send(cfg.message.clone()).await { tracing::error!("ws_keepalive_failed: {:?}", e); - let _ = error_tx.send(e.into()); + if error_tx.send(e.into()).is_err() { + tracing::warn!("output stream already closed, cannot propagate keepalive error"); + } break; } last_outbound_at = tokio::time::Instant::now(); @@ -132,7 +180,9 @@ impl WebSocketClient { if let Err(e) = ws_sender.send(msg).await { tracing::error!("ws_send_failed: {:?}", e); - let _ = error_tx.send(e.into()); + if error_tx.send(e.into()).is_err() { + tracing::warn!("output stream already closed, cannot propagate send error"); + } break; } last_outbound_at = tokio::time::Instant::now(); @@ -141,7 +191,9 @@ impl WebSocketClient { if let Some(msg) = maybe_msg { if let Err(e) = ws_sender.send(msg).await { tracing::error!("ws_finalize_failed: {:?}", e); - let _ = error_tx.send(e.into()); + if error_tx.send(e.into()).is_err() { + tracing::warn!("output stream already closed, cannot propagate finalize error"); + } break; } last_outbound_at = tokio::time::Instant::now(); @@ -151,36 +203,32 @@ impl WebSocketClient { } } - // Wait 5 seconds before closing the connection - // TODO: This might not be enough to ensure receiving remaining transcripts from the server. - tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; - let _ = ws_sender.close().await; + tracing::debug!("draining remaining messages before close"); + tokio::time::sleep(close_grace_period).await; + if let Err(e) = ws_sender.close().await { + tracing::debug!("ws_close_failed: {:?}", e); + } + Ok(()) }); + let send_task_handle = SendTask { handle: send_task }; + let output_stream = async_stream::stream! { loop { tokio::select! { Some(msg_result) = ws_receiver.next() => { match msg_result { Ok(msg) => { - let is_text = matches!(msg, Message::Text(_)); - let is_binary = matches!(msg, Message::Binary(_)); - let text_preview = if let Message::Text(ref t) = msg { - Some(t.to_string()) - } else { - None - }; - match msg { Message::Text(_) | Message::Binary(_) => { - if let Some(output) = T::from_message(msg) { - yield Ok(output); - } else if is_text { - if let Some(text) = text_preview { - tracing::warn!("ws_message_parse_failed: {}", text); + match T::decode(msg) { + Ok(output) => yield Ok(output), + Err(DecodeError::UnsupportedType) => { + tracing::debug!("ws_message_unsupported_type"); + } + Err(DecodeError::DeserializationError(e)) => { + tracing::warn!("ws_message_parse_failed: {}", e); } - } else if is_binary { - tracing::warn!("ws_binary_message_parse_failed"); } }, Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => continue, @@ -207,7 +255,7 @@ impl WebSocketClient { } }; - Ok((output_stream, handle)) + Ok((output_stream, handle, send_task_handle)) } async fn try_connect( @@ -219,12 +267,17 @@ impl WebSocketClient { >, crate::Error, > { - let req = req.into_client_request().unwrap(); + let req = req + .into_client_request() + .map_err(|e| crate::Error::InvalidRequest(e.to_string()))?; tracing::info!("connect_async: {:?}", req.uri()); - let (ws_stream, _) = - tokio::time::timeout(std::time::Duration::from_secs(8), connect_async(req)).await??; + let timeout_duration = self.config.connect_timeout; + let (ws_stream, _) = tokio::time::timeout(timeout_duration, connect_async(req)) + .await + .map_err(|e| crate::Error::timeout(e, timeout_duration))? + .map_err(crate::Error::Connection)?; Ok(ws_stream) } diff --git a/crates/ws/src/config.rs b/crates/ws/src/config.rs new file mode 100644 index 0000000000..772509e039 --- /dev/null +++ b/crates/ws/src/config.rs @@ -0,0 +1,40 @@ +use std::time::Duration; +use tokio_tungstenite::tungstenite::protocol::Message; + +#[derive(Clone, Debug)] +pub struct ConnectionConfig { + pub connect_timeout: Duration, + pub retry_config: RetryConfig, + pub close_grace_period: Duration, +} + +impl Default for ConnectionConfig { + fn default() -> Self { + Self { + connect_timeout: Duration::from_secs(8), + retry_config: RetryConfig::default(), + close_grace_period: Duration::from_secs(5), + } + } +} + +#[derive(Clone, Debug)] +pub struct RetryConfig { + pub max_attempts: usize, + pub delay: Duration, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_attempts: 5, + delay: Duration::from_millis(500), + } + } +} + +#[derive(Clone, Debug)] +pub struct KeepAliveConfig { + pub interval: Duration, + pub message: Message, +} diff --git a/crates/ws/src/error.rs b/crates/ws/src/error.rs index f6de2e31b0..25e023a26c 100644 --- a/crates/ws/src/error.rs +++ b/crates/ws/src/error.rs @@ -1,11 +1,50 @@ +use std::time::Duration; + #[derive(thiserror::Error, Debug)] pub enum Error { - #[error("unknown error")] - Unknown, - #[error("connection error")] + #[error("connection failed: {0}")] Connection(#[from] tokio_tungstenite::tungstenite::Error), - #[error("timeout error")] - Timeout(#[from] tokio::time::error::Elapsed), - #[error("send error")] - SendError(#[from] tokio::sync::mpsc::error::SendError<()>), + + #[error("connection timeout after {timeout:?}")] + Timeout { + #[source] + source: tokio::time::error::Elapsed, + timeout: Duration, + }, + + #[error("failed to send control command")] + ControlSend, + + #[error("failed to send data: {context}")] + DataSend { context: String }, + + #[error("connection closed unexpectedly")] + UnexpectedClose, + + #[error("invalid client request: {0}")] + InvalidRequest(String), + + #[error("message parsing failed: {message}")] + ParseError { message: String }, +} + +impl Error { + pub fn timeout(elapsed: tokio::time::error::Elapsed, duration: Duration) -> Self { + Self::Timeout { + source: elapsed, + timeout: duration, + } + } + + pub fn data_send(context: impl Into) -> Self { + Self::DataSend { + context: context.into(), + } + } + + pub fn parse_error(message: impl Into) -> Self { + Self::ParseError { + message: message.into(), + } + } } diff --git a/crates/ws/src/lib.rs b/crates/ws/src/lib.rs index f46f9b5d5c..a5ae249c85 100644 --- a/crates/ws/src/lib.rs +++ b/crates/ws/src/lib.rs @@ -1,6 +1,14 @@ +//! WebSocket client and server utilities for real-time audio streaming. +//! +//! This crate provides a high-level WebSocket client for streaming audio data +//! to speech-to-text services with automatic retry, keep-alive, and graceful shutdown. + #[cfg(feature = "client")] pub mod client; +#[cfg(feature = "client")] +pub mod config; + #[cfg(feature = "server")] pub mod server; diff --git a/crates/ws/tests/client_tests.rs b/crates/ws/tests/client_tests.rs index 180d66b712..bbd93df940 100644 --- a/crates/ws/tests/client_tests.rs +++ b/crates/ws/tests/client_tests.rs @@ -6,7 +6,7 @@ use tokio_tungstenite::{ accept_async, tungstenite::{protocol::Message, ClientRequestBuilder}, }; -use ws::client::{WebSocketClient, WebSocketIO}; +use ws::client::{DecodeError, WebSocketClient, WebSocketIO}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] struct TestMessage { @@ -29,10 +29,12 @@ impl WebSocketIO for TestIO { Message::Text(serde_json::to_string(&input).unwrap().into()) } - fn from_message(msg: Message) -> Option { + fn decode(msg: Message) -> Result { match msg { - Message::Text(text) => serde_json::from_str(&text).ok(), - _ => None, + Message::Text(text) => { + serde_json::from_str(&text).map_err(DecodeError::DeserializationError) + } + _ => Err(DecodeError::UnsupportedType), } } } @@ -98,7 +100,7 @@ async fn test_basic_echo() { ]; let stream = futures_util::stream::iter(messages.clone()); - let (output, _handle) = client.from_audio::(None, stream).await.unwrap(); + let (output, _handle, _send_task) = client.from_audio::(None, stream).await.unwrap(); let received = collect_messages::(output, 2).await; assert_eq!(received, messages); @@ -115,7 +117,7 @@ async fn test_finalize() { text: "initial".to_string(), count: 1, }]); - let (output, handle) = client.from_audio::(None, stream).await.unwrap(); + let (output, handle, _send_task) = client.from_audio::(None, stream).await.unwrap(); let final_msg = TestMessage { text: "final".to_string(), @@ -169,7 +171,7 @@ async fn test_keep_alive() { ); let stream = futures_util::stream::pending::(); - let (output, _handle) = client.from_audio::(None, stream).await.unwrap(); + let (output, _handle, _send_task) = client.from_audio::(None, stream).await.unwrap(); let received = collect_messages::(output, 1).await; assert_eq!(received[0].text, "done"); @@ -216,7 +218,7 @@ async fn test_retry() { text: "retry_test".to_string(), count: 1, }]); - let (output, _handle) = client.from_audio::(None, stream).await.unwrap(); + let (output, _handle, _send_task) = client.from_audio::(None, stream).await.unwrap(); let received = collect_messages::(output, 1).await; assert_eq!(received[0].text, "retry_test"); diff --git a/owhisper/owhisper-client/src/live.rs b/owhisper/owhisper-client/src/live.rs index 9d4ec447e0..36e58efd9b 100644 --- a/owhisper/owhisper-client/src/live.rs +++ b/owhisper/owhisper-client/src/live.rs @@ -133,10 +133,10 @@ impl WebSocketIO for ListenClientIO { } } - fn from_message(msg: Message) -> Option { + fn decode(msg: Message) -> Result { match msg { - Message::Text(text) => Some(text.to_string()), - _ => None, + Message::Text(text) => Ok(text.to_string()), + _ => Err(hypr_ws::client::DecodeError::UnsupportedType), } } } @@ -167,10 +167,10 @@ impl WebSocketIO for ListenClientDualIO { } } - fn from_message(msg: Message) -> Option { + fn decode(msg: Message) -> Result { match msg { - Message::Text(text) => Some(text.to_string()), - _ => None, + Message::Text(text) => Ok(text.to_string()), + _ => Err(hypr_ws::client::DecodeError::UnsupportedType), } } } @@ -194,7 +194,7 @@ impl ListenClient { > { let finalize_text = extract_finalize_text(&self.adapter); let ws = websocket_client_with_keep_alive(&self.request, &self.adapter); - let (raw_stream, inner) = ws + let (raw_stream, inner, _send_task) = ws .from_audio::(self.initial_message, audio_stream) .await?; @@ -236,7 +236,7 @@ impl ListenClientDual { ) -> Result<(DualOutputStream, DualHandle), hypr_ws::Error> { let finalize_text = extract_finalize_text(&self.adapter); let ws = websocket_client_with_keep_alive(&self.request, &self.adapter); - let (raw_stream, inner) = ws + let (raw_stream, inner, _send_task) = ws .from_audio::(self.initial_message, stream) .await?; @@ -275,7 +275,7 @@ impl ListenClientDual { mic_ws.from_audio::(self.initial_message.clone(), mic_outbound); let spk_connect = spk_ws.from_audio::(self.initial_message, spk_outbound); - let ((mic_raw, mic_handle), (spk_raw, spk_handle)) = + let ((mic_raw, mic_handle, _mic_send_task), (spk_raw, spk_handle, _spk_send_task)) = tokio::try_join!(mic_connect, spk_connect)?; tokio::spawn(forward_dual_to_single(stream, mic_tx, spk_tx));