Skip to content

Commit a865d5f

Browse files
committed
add websocket stream
1 parent e0ee47c commit a865d5f

File tree

4 files changed

+766
-11
lines changed

4 files changed

+766
-11
lines changed

examples/bargein_server.py

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from __future__ import annotations
22

33
import base64
4+
import json
45
import logging
56
import os
7+
import time
68
from collections.abc import AsyncIterator
79
from contextlib import asynccontextmanager
810
from typing import Any
911

1012
import numpy as np
1113
import onnxruntime as ort
12-
from fastapi import FastAPI, HTTPException
14+
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
1315
from pydantic import BaseModel, Field
1416

1517
logging.basicConfig(level=logging.INFO)
@@ -153,7 +155,6 @@ async def detect_bargein(request: BargeinRequest) -> BargeinResponse:
153155
try:
154156
# Decode the waveform
155157
waveform = decode_waveform(request.waveform)
156-
logger.info(f"Decoded waveform shape: {waveform.shape}")
157158

158159
# Run inference
159160
is_bargein = run_inference(waveform, request.threshold, request.min_frames)
@@ -174,6 +175,120 @@ async def detect_bargein(request: BargeinRequest) -> BargeinResponse:
174175
) from None
175176

176177

178+
@app.websocket("/bargein")
179+
async def websocket_bargein(websocket: WebSocket) -> None:
180+
"""
181+
WebSocket endpoint for bargein detection.
182+
183+
Protocol:
184+
- Client sends: {"type": "session.create", "settings": {"sample_rate": "16000"}}
185+
- Server sends: {"type": "session.created"}
186+
- Client sends: {"type": "input_audio", "audio": "<base64>", "sample_rate": 16000, "num_channels": 1, "threshold": 0.95, "min_frames": 2}
187+
- Server sends: {"type": "bargein_detected"} (when detected)
188+
- Client sends: {"type": "session.finalize"}
189+
- Server sends: {"type": "session.finalized"}
190+
"""
191+
await websocket.accept()
192+
logger.info("WebSocket connection established")
193+
194+
if onnx_session is None:
195+
await websocket.send_json({"type": "error", "message": "ONNX model not loaded"})
196+
await websocket.close()
197+
return
198+
199+
try:
200+
# Wait for session.create message
201+
while True:
202+
try:
203+
data = await websocket.receive_text()
204+
msg = json.loads(data)
205+
msg_type = msg.get("type")
206+
207+
if msg_type == "session.create":
208+
logger.info("Session created")
209+
await websocket.send_json({"type": "session.created"})
210+
break
211+
else:
212+
await websocket.send_json(
213+
{"type": "error", "message": f"Expected session.create, got {msg_type}"}
214+
)
215+
await websocket.close()
216+
return
217+
except json.JSONDecodeError as e:
218+
await websocket.send_json({"type": "error", "message": f"Invalid JSON: {str(e)}"})
219+
await websocket.close()
220+
return
221+
222+
# Process audio frames
223+
while True:
224+
try:
225+
data = await websocket.receive_text()
226+
msg = json.loads(data)
227+
msg_type = msg.get("type")
228+
229+
if msg_type == "input_audio":
230+
# Decode and process audio
231+
232+
audio_b64 = msg.get("audio")
233+
threshold = msg.get("threshold", 0.95)
234+
min_frames = msg.get("min_frames", 2)
235+
created_at = msg.get("created_at", time.time())
236+
237+
if not audio_b64:
238+
await websocket.send_json(
239+
{"type": "error", "message": "Missing audio data"}
240+
)
241+
continue
242+
243+
try:
244+
waveform = decode_waveform(audio_b64)
245+
is_bargein = run_inference(waveform, threshold, min_frames)
246+
delta = time.time() - created_at
247+
248+
await websocket.send_json(
249+
{"type": "inference_done", "delta": delta, "is_bargein": is_bargein}
250+
)
251+
252+
if is_bargein:
253+
logger.info("Bargein detected via WebSocket")
254+
await websocket.send_json({"type": "bargein_detected"})
255+
256+
except Exception as e:
257+
logger.error(f"Error processing audio: {e}", exc_info=True)
258+
await websocket.send_json(
259+
{"type": "error", "message": f"Error processing audio: {str(e)}"}
260+
)
261+
262+
elif msg_type == "session.finalize":
263+
logger.info("Session finalized")
264+
await websocket.send_json({"type": "session.finalized"})
265+
break
266+
267+
else:
268+
logger.warning(f"Unknown message type: {msg_type}")
269+
await websocket.send_json(
270+
{"type": "error", "message": f"Unknown message type: {msg_type}"}
271+
)
272+
273+
except json.JSONDecodeError as e:
274+
await websocket.send_json({"type": "error", "message": f"Invalid JSON: {str(e)}"})
275+
continue
276+
277+
except WebSocketDisconnect:
278+
logger.info("WebSocket disconnected")
279+
except Exception as e:
280+
logger.error(f"WebSocket error: {e}", exc_info=True)
281+
try:
282+
await websocket.send_json({"type": "error", "message": f"Internal error: {str(e)}"})
283+
except Exception:
284+
pass
285+
finally:
286+
try:
287+
await websocket.close()
288+
except Exception:
289+
pass
290+
291+
177292
@app.get("/health")
178293
async def health() -> dict[str, Any]:
179294
"""Detailed health check endpoint."""

examples/voice_agents/basic_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
metrics,
1515
room_io,
1616
)
17+
from livekit.agents.inference.bargein import BargeinDetector
1718
from livekit.agents.llm import function_tool
1819
from livekit.plugins import silero
19-
from livekit.plugins.bargein_detector import BargeinDetector
2020
from livekit.plugins.turn_detector.multilingual import MultilingualModel
2121

2222
# uncomment to enable Krisp background voice/noise cancellation
@@ -94,7 +94,7 @@ async def entrypoint(ctx: JobContext):
9494
# See more at https://docs.livekit.io/agents/build/turns
9595
turn_detection=MultilingualModel(),
9696
vad=ctx.proc.userdata["vad"],
97-
bargein_detector=BargeinDetector(),
97+
bargein_detector=BargeinDetector(use_proxy=True),
9898
# allow the LLM to generate a response while waiting for the end of turn
9999
# See more at https://docs.livekit.io/agents/build/audio/#preemptive-generation
100100
preemptive_generation=True,

livekit-agents/livekit/agents/bargein.py

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,16 @@
66
from collections.abc import AsyncIterator
77
from dataclasses import dataclass
88
from enum import Enum, unique
9-
from typing import Literal, Union
9+
from typing import Generic, Literal, TypeVar, Union
10+
11+
from pydantic import BaseModel, ConfigDict, Field
1012

1113
from livekit import rtc
1214

13-
from .utils import aio
15+
from ._exceptions import APIConnectionError, APIError
16+
from .log import logger
17+
from .types import APIConnectOptions
18+
from .utils import aio, log_exceptions
1419

1520

1621
@unique
@@ -38,10 +43,25 @@ class BargeinEvent:
3843
"""Time taken to perform the inference, in seconds."""
3944

4045

41-
class BargeinDetector(ABC, rtc.EventEmitter[Literal["bargein_detected"]]):
42-
def __init__(self) -> None:
46+
class BargeinError(BaseModel):
47+
model_config = ConfigDict(arbitrary_types_allowed=True)
48+
type: Literal["bargein_error"] = "bargein_error"
49+
timestamp: float
50+
label: str
51+
error: Exception = Field(..., exclude=True)
52+
recoverable: bool
53+
54+
55+
TEvent = TypeVar("TEvent")
56+
57+
58+
class BargeinDetector(
59+
ABC, rtc.EventEmitter[Union[Literal["bargein_detected", "error"], TEvent]], Generic[TEvent]
60+
):
61+
def __init__(self, *, sample_rate: int) -> None:
4362
super().__init__()
4463
self._label = f"{type(self).__module__}.{type(self).__name__}"
64+
self._sample_rate = sample_rate
4565

4666
@property
4767
def model(self) -> str:
@@ -55,6 +75,17 @@ def provider(self) -> str:
5575
def label(self) -> str:
5676
return self._label
5777

78+
def _emit_error(self, api_error: Exception, recoverable: bool) -> None:
79+
self.emit(
80+
"error",
81+
BargeinError(
82+
timestamp=time.time(),
83+
label=self._label,
84+
error=api_error,
85+
recoverable=recoverable,
86+
),
87+
)
88+
5889
@abstractmethod
5990
def stream(self) -> BargeinStream: ...
6091

@@ -75,7 +106,7 @@ class _OverlapSpeechEndedSentinel:
75106
class _FlushSentinel:
76107
pass
77108

78-
def __init__(self, bargein_detector: BargeinDetector) -> None:
109+
def __init__(self, bargein_detector: BargeinDetector, conn_options: APIConnectOptions) -> None:
79110
self._bargein_detector = bargein_detector
80111
self._last_activity_time = time.perf_counter()
81112
self._input_ch = aio.Chan[
@@ -91,9 +122,60 @@ def __init__(self, bargein_detector: BargeinDetector) -> None:
91122
self._event_ch = aio.Chan[BargeinEvent]()
92123
self._task = asyncio.create_task(self._main_task())
93124
self._task.add_done_callback(lambda _: self._event_ch.close())
125+
self._num_retries = 0
126+
self._conn_options = conn_options
127+
self._sample_rate = bargein_detector._sample_rate
128+
self._resampler: rtc.AudioResampler | None = None
94129

95130
@abstractmethod
96-
async def _main_task(self) -> None: ...
131+
async def _run(self) -> None: ...
132+
133+
@log_exceptions(logger=logger)
134+
async def _main_task(self) -> None:
135+
max_retries = self._conn_options.max_retry
136+
137+
while self._num_retries <= max_retries:
138+
try:
139+
return await self._run()
140+
except APIError as e:
141+
if max_retries == 0:
142+
self._emit_error(e, recoverable=False)
143+
raise
144+
elif self._num_retries == max_retries:
145+
self._emit_error(e, recoverable=False)
146+
raise APIConnectionError(
147+
f"failed to detect bargein after {self._num_retries} attempts",
148+
) from e
149+
else:
150+
self._emit_error(e, recoverable=True)
151+
152+
retry_interval = self._conn_options._interval_for_retry(self._num_retries)
153+
logger.warning(
154+
f"failed to detect bargein, retrying in {retry_interval}s",
155+
exc_info=e,
156+
extra={
157+
"bargein_detector": self._bargein_detector._label,
158+
"attempt": self._num_retries,
159+
},
160+
)
161+
await asyncio.sleep(retry_interval)
162+
163+
self._num_retries += 1
164+
165+
except Exception as e:
166+
self._emit_error(e, recoverable=False)
167+
raise
168+
169+
def _emit_error(self, api_error: Exception, recoverable: bool) -> None:
170+
self._bargein_detector.emit(
171+
"error",
172+
BargeinError(
173+
timestamp=time.time(),
174+
label=self._bargein_detector._label,
175+
error=api_error,
176+
recoverable=recoverable,
177+
),
178+
)
97179

98180
def start_agent_speech(self) -> None:
99181
"""Mark the start of the agent's speech"""
@@ -130,7 +212,28 @@ def push_frame(
130212
"""Push some audio frame to be analyzed"""
131213
self._check_input_not_ended()
132214
self._check_not_closed()
133-
self._input_ch.send_nowait(frame)
215+
216+
if not isinstance(frame, rtc.AudioFrame):
217+
self._input_ch.send_nowait(frame)
218+
return
219+
220+
if self._sample_rate != frame.sample_rate:
221+
if not self._resampler:
222+
self._resampler = rtc.AudioResampler(
223+
input_rate=frame.sample_rate,
224+
output_rate=self._sample_rate,
225+
num_channels=1,
226+
quality=rtc.AudioResamplerQuality.LOW,
227+
)
228+
elif self._resampler._input_rate != frame.sample_rate:
229+
raise ValueError("the sample rate of the input frames must be consistent")
230+
231+
if self._resampler:
232+
frames = self._resampler.push(frame)
233+
for frame in frames:
234+
self._input_ch.send_nowait(frame)
235+
else:
236+
self._input_ch.send_nowait(frame)
134237

135238
def flush(self) -> None:
136239
"""Mark the end of the current segment"""

0 commit comments

Comments
 (0)