Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/transcribe-whisper-local/src/service/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ where
}
};

let guard = connection_manager.acquire_connection();
let guard = connection_manager.acquire_connection().await;

Ok(ws_upgrade
.on_upgrade(move |socket| async move {
Expand Down
6 changes: 6 additions & 0 deletions crates/ws-utils/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
//! Utilities for WebSocket-based audio streaming.
//!
//! This crate provides connection management and audio source abstractions
//! for WebSocket-based audio clients.

mod manager;

pub use manager::*;

use std::pin::Pin;
Expand Down
26 changes: 21 additions & 5 deletions crates/ws-utils/src/manager.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;

#[derive(Clone)]
pub struct ConnectionManager {
inner: Arc<Mutex<Option<CancellationToken>>>,
token: Arc<RwLock<Option<CancellationToken>>>,
}

impl Default for ConnectionManager {
fn default() -> Self {
Self {
inner: Arc::new(Mutex::new(None)),
token: Arc::new(RwLock::new(None)),
}
}
}

impl ConnectionManager {
pub fn acquire_connection(&self) -> ConnectionGuard {
let mut slot = self.inner.lock().unwrap();
pub async fn acquire_connection(&self) -> ConnectionGuard {
let mut slot = self.token.write().await;

if let Some(old) = slot.take() {
old.cancel();
Expand All @@ -27,14 +28,29 @@ impl ConnectionManager {

ConnectionGuard { token }
}

pub async fn cancel_all(&self) {
let mut slot = self.token.write().await;
if let Some(token) = slot.take() {
token.cancel();
}
}
}

pub struct ConnectionGuard {
token: CancellationToken,
}

impl ConnectionGuard {
pub fn is_cancelled(&self) -> bool {
self.token.is_cancelled()
}

pub async fn cancelled(&self) {
self.token.cancelled().await
}

pub fn child_token(&self) -> CancellationToken {
self.token.child_token()
}
}
4 changes: 1 addition & 3 deletions crates/ws/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ server = []
[dependencies]
bytes = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }

async-stream = { workspace = true }
Expand All @@ -19,6 +20,3 @@ futures-util = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread", "time", "sync", "macros"] }
tokio-tungstenite = { workspace = true, features = ["native-tls-vendored"] }
tracing = { workspace = true }

[dev-dependencies]
serde_json.workspace = true
133 changes: 93 additions & 40 deletions crates/ws/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,55 +7,94 @@ use futures_util::{
};
use tokio_tungstenite::{connect_async, tungstenite::client::IntoClientRequest};

pub use crate::config::{ConnectionConfig, KeepAliveConfig, RetryConfig};
pub use tokio_tungstenite::tungstenite::{protocol::Message, ClientRequestBuilder, Utf8Bytes};

#[derive(Debug)]
enum ControlCommand {
Finalize(Option<Message>),
}

#[derive(Clone)]
struct KeepAliveConfig {
interval: std::time::Duration,
message: Message,
}

#[derive(Clone)]
pub struct WebSocketHandle {
control_tx: tokio::sync::mpsc::UnboundedSender<ControlCommand>,
}

impl WebSocketHandle {
pub async fn finalize_with_text(&self, text: Utf8Bytes) {
let _ = self
if self
.control_tx
.send(ControlCommand::Finalize(Some(Message::Text(text))));
.send(ControlCommand::Finalize(Some(Message::Text(text))))
.is_err()
{
tracing::warn!("control channel closed, cannot send finalize command");
}
}
}

pub struct SendTask {
handle: tokio::task::JoinHandle<Result<(), crate::Error>>,
}

impl SendTask {
pub async fn wait(self) -> Result<(), crate::Error> {
match self.handle.await {
Ok(result) => result,
Err(join_err) if join_err.is_panic() => {
std::panic::resume_unwind(join_err.into_panic());
}
Err(join_err) => {
tracing::error!("send task cancelled: {:?}", join_err);
Err(crate::Error::UnexpectedClose)
}
}
}
}

#[derive(Debug, thiserror::Error)]
pub enum DecodeError {
#[error("unsupported message type")]
UnsupportedType,

#[error("deserialization failed: {0}")]
DeserializationError(#[from] serde_json::Error),
}

pub trait WebSocketIO: Send + 'static {
type Data: Send;
type Input: Send;
type Output: DeserializeOwned;

fn to_input(data: Self::Data) -> Self::Input;
fn to_message(input: Self::Input) -> Message;
fn from_message(msg: Message) -> Option<Self::Output>;
fn decode(msg: Message) -> Result<Self::Output, DecodeError>;
}

pub struct WebSocketClient {
request: ClientRequestBuilder,
keep_alive: Option<KeepAliveConfig>,
config: ConnectionConfig,
}

impl WebSocketClient {
pub fn new(request: ClientRequestBuilder) -> Self {
Self {
request,
keep_alive: None,
config: ConnectionConfig::default(),
}
}

pub fn with_config(mut self, config: ConnectionConfig) -> Self {
self.config = config;
self
}

pub fn with_keep_alive(mut self, config: KeepAliveConfig) -> Self {
self.keep_alive = Some(config);
self
}

pub fn with_keep_alive_message(
mut self,
interval: std::time::Duration,
Expand All @@ -73,15 +112,18 @@ impl WebSocketClient {
(
impl Stream<Item = Result<T::Output, crate::Error>>,
WebSocketHandle,
SendTask,
),
crate::Error,
> {
let keep_alive_config = self.keep_alive.clone();
let close_grace_period = self.config.close_grace_period;
let retry_config = self.config.retry_config.clone();
let ws_stream = (|| self.try_connect(self.request.clone()))
.retry(
ConstantBuilder::default()
.with_max_times(5)
.with_delay(std::time::Duration::from_millis(500)),
.with_max_times(retry_config.max_attempts)
.with_delay(retry_config.delay),
)
.when(|e| {
tracing::error!("ws_connect_failed: {:?}", e);
Expand All @@ -96,12 +138,16 @@ impl WebSocketClient {
let (error_tx, mut error_rx) = tokio::sync::mpsc::unbounded_channel::<crate::Error>();
let handle = WebSocketHandle { control_tx };

let _send_task = tokio::spawn(async move {
let send_task = tokio::spawn(async move {
if let Some(msg) = initial_message {
if let Err(e) = ws_sender.send(msg).await {
tracing::error!("ws_initial_message_failed: {:?}", e);
let _ = error_tx.send(e.into());
return;
if error_tx.send(e.into()).is_err() {
tracing::warn!("output stream already closed, cannot propagate error");
}
return Err(crate::Error::DataSend {
context: "initial message".to_string(),
});
}
}

Expand All @@ -120,7 +166,9 @@ impl WebSocketClient {
if let Some(cfg) = keep_alive_config.as_ref() {
if let Err(e) = ws_sender.send(cfg.message.clone()).await {
tracing::error!("ws_keepalive_failed: {:?}", e);
let _ = error_tx.send(e.into());
if error_tx.send(e.into()).is_err() {
tracing::warn!("output stream already closed, cannot propagate keepalive error");
}
break;
}
last_outbound_at = tokio::time::Instant::now();
Expand All @@ -132,7 +180,9 @@ impl WebSocketClient {

if let Err(e) = ws_sender.send(msg).await {
tracing::error!("ws_send_failed: {:?}", e);
let _ = error_tx.send(e.into());
if error_tx.send(e.into()).is_err() {
tracing::warn!("output stream already closed, cannot propagate send error");
}
break;
}
last_outbound_at = tokio::time::Instant::now();
Expand All @@ -141,7 +191,9 @@ impl WebSocketClient {
if let Some(msg) = maybe_msg {
if let Err(e) = ws_sender.send(msg).await {
tracing::error!("ws_finalize_failed: {:?}", e);
let _ = error_tx.send(e.into());
if error_tx.send(e.into()).is_err() {
tracing::warn!("output stream already closed, cannot propagate finalize error");
}
break;
}
last_outbound_at = tokio::time::Instant::now();
Expand All @@ -151,36 +203,32 @@ impl WebSocketClient {
}
}

// Wait 5 seconds before closing the connection
// TODO: This might not be enough to ensure receiving remaining transcripts from the server.
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
let _ = ws_sender.close().await;
tracing::debug!("draining remaining messages before close");
tokio::time::sleep(close_grace_period).await;
if let Err(e) = ws_sender.close().await {
tracing::debug!("ws_close_failed: {:?}", e);
}
Ok(())
});

let send_task_handle = SendTask { handle: send_task };

let output_stream = async_stream::stream! {
loop {
tokio::select! {
Some(msg_result) = ws_receiver.next() => {
match msg_result {
Ok(msg) => {
let is_text = matches!(msg, Message::Text(_));
let is_binary = matches!(msg, Message::Binary(_));
let text_preview = if let Message::Text(ref t) = msg {
Some(t.to_string())
} else {
None
};

match msg {
Message::Text(_) | Message::Binary(_) => {
if let Some(output) = T::from_message(msg) {
yield Ok(output);
} else if is_text {
if let Some(text) = text_preview {
tracing::warn!("ws_message_parse_failed: {}", text);
match T::decode(msg) {
Ok(output) => yield Ok(output),
Err(DecodeError::UnsupportedType) => {
tracing::debug!("ws_message_unsupported_type");
}
Err(DecodeError::DeserializationError(e)) => {
tracing::warn!("ws_message_parse_failed: {}", e);
}
} else if is_binary {
tracing::warn!("ws_binary_message_parse_failed");
}
},
Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => continue,
Expand All @@ -207,7 +255,7 @@ impl WebSocketClient {
}
};

Ok((output_stream, handle))
Ok((output_stream, handle, send_task_handle))
}

async fn try_connect(
Expand All @@ -219,12 +267,17 @@ impl WebSocketClient {
>,
crate::Error,
> {
let req = req.into_client_request().unwrap();
let req = req
.into_client_request()
.map_err(|e| crate::Error::InvalidRequest(e.to_string()))?;

tracing::info!("connect_async: {:?}", req.uri());

let (ws_stream, _) =
tokio::time::timeout(std::time::Duration::from_secs(8), connect_async(req)).await??;
let timeout_duration = self.config.connect_timeout;
let (ws_stream, _) = tokio::time::timeout(timeout_duration, connect_async(req))
.await
.map_err(|e| crate::Error::timeout(e, timeout_duration))?
.map_err(crate::Error::Connection)?;

Ok(ws_stream)
}
Expand Down
40 changes: 40 additions & 0 deletions crates/ws/src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use std::time::Duration;
use tokio_tungstenite::tungstenite::protocol::Message;

#[derive(Clone, Debug)]
pub struct ConnectionConfig {
pub connect_timeout: Duration,
pub retry_config: RetryConfig,
pub close_grace_period: Duration,
}

impl Default for ConnectionConfig {
fn default() -> Self {
Self {
connect_timeout: Duration::from_secs(8),
retry_config: RetryConfig::default(),
close_grace_period: Duration::from_secs(5),
}
}
}

#[derive(Clone, Debug)]
pub struct RetryConfig {
pub max_attempts: usize,
pub delay: Duration,
}

impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 5,
delay: Duration::from_millis(500),
}
}
}

#[derive(Clone, Debug)]
pub struct KeepAliveConfig {
pub interval: Duration,
pub message: Message,
}
Loading