From f4f1f3e89c39ad99eaa5418eb06b7893a41c654d Mon Sep 17 00:00:00 2001 From: Long Chen Date: Fri, 21 Nov 2025 11:24:43 +0800 Subject: [PATCH 1/4] add use_realtime to 11labs STT and support scribe v2 realtime model --- ...me_scribe_v2.py => elevenlab_scribe_v2.py} | 14 +- .../livekit/plugins/elevenlabs/__init__.py | 11 +- .../livekit/plugins/elevenlabs/models.py | 16 +- .../livekit/plugins/elevenlabs/stt.py | 351 ++++++++++++++- .../livekit/plugins/elevenlabs/stt_v2.py | 417 ------------------ 5 files changed, 360 insertions(+), 449 deletions(-) rename examples/other/{realtime_scribe_v2.py => elevenlab_scribe_v2.py} (78%) delete mode 100644 livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt_v2.py diff --git a/examples/other/realtime_scribe_v2.py b/examples/other/elevenlab_scribe_v2.py similarity index 78% rename from examples/other/realtime_scribe_v2.py rename to examples/other/elevenlab_scribe_v2.py index 53392be166..4a7b8e0370 100644 --- a/examples/other/realtime_scribe_v2.py +++ b/examples/other/elevenlab_scribe_v2.py @@ -12,12 +12,14 @@ async def entrypoint(ctx: JobContext): - stt = elevenlabs.STTv2( - model_id="scribe_v2_realtime", - vad_silence_threshold_secs=0.5, - vad_threshold=0.5, - min_speech_duration_ms=100, - min_silence_duration_ms=300, + stt = elevenlabs.STT( + use_realtime=True, + server_vad={ + "vad_silence_threshold_secs": 0.5, + "vad_threshold": 0.5, + "min_speech_duration_ms": 100, + "min_silence_duration_ms": 300, + }, ) session = AgentSession( diff --git a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/__init__.py b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/__init__.py index 7c7c97de6c..6a04e40bc6 100644 --- a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/__init__.py +++ b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/__init__.py @@ -17,23 +17,20 @@ See https://docs.livekit.io/agents/integrations/tts/elevenlabs/ for more information. """ -from .models import STTAudioFormat, STTModels, TTSEncoding, TTSModels -from .stt import STT -from .stt_v2 import SpeechStreamv2, STTv2 +from .models import STTRealtimeSampleRates, TTSEncoding, TTSModels +from .stt import STT, SpeechStream from .tts import DEFAULT_VOICE_ID, TTS, Voice, VoiceSettings from .version import __version__ __all__ = [ "STT", - "STTv2", - "SpeechStreamv2", + "SpeechStream", "TTS", "Voice", "VoiceSettings", "TTSEncoding", "TTSModels", - "STTModels", - "STTAudioFormat", + "STTRealtimeSampleRates", "DEFAULT_VOICE_ID", "__version__", ] diff --git a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/models.py b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/models.py index 37c81ed656..ecf546a74c 100644 --- a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/models.py +++ b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/models.py @@ -21,13 +21,11 @@ "mp3_44100_192", ] -STTModels = Literal["scribe_v2_realtime",] - -STTAudioFormat = Literal[ - "pcm_8000", - "pcm_16000", - "pcm_22050", - "pcm_24000", - "pcm_44100", - "pcm_48000", +STTRealtimeSampleRates = Literal[ + 8000, + 16000, + 22050, + 24000, + 44100, + 48000, ] diff --git a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py index 8f601784dd..6539eaeeca 100644 --- a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py +++ b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py @@ -15,8 +15,12 @@ from __future__ import annotations import asyncio +import base64 +import json import os +import weakref from dataclasses import dataclass +from typing import TypedDict import aiohttp @@ -28,31 +32,52 @@ APIStatusError, APITimeoutError, stt, + utils, ) from livekit.agents.stt import SpeechEventType, STTCapabilities from livekit.agents.types import NOT_GIVEN, NotGivenOr from livekit.agents.utils import AudioBuffer, http_context, is_given +from .log import logger +from .models import STTRealtimeSampleRates + API_BASE_URL_V1 = "https://api.elevenlabs.io/v1" AUTHORIZATION_HEADER = "xi-api-key" +class VADOptions(TypedDict, total=False): + vad_silence_threshold_secs: float | None + """Silence threshold in seconds for VAD. Default to 1.5""" + vad_threshold: float | None + """Threshold for voice activity detection. Default to 0.4""" + min_speech_duration_ms: int | None + """Minimum speech duration in milliseconds. Default to 250""" + min_silence_duration_ms: int | None + """Minimum silence duration in milliseconds. Default to 2500""" + + @dataclass -class _STTOptions: +class STTOptions: api_key: str base_url: str - language_code: str | None = None - tag_audio_events: bool = True + language_code: str | None + tag_audio_events: bool + sample_rate: STTRealtimeSampleRates + server_vad: NotGivenOr[VADOptions | None] class STT(stt.STT): def __init__( self, + *, api_key: NotGivenOr[str] = NOT_GIVEN, base_url: NotGivenOr[str] = NOT_GIVEN, - http_session: aiohttp.ClientSession | None = None, language_code: NotGivenOr[str] = NOT_GIVEN, tag_audio_events: bool = True, + use_realtime: bool = False, + sample_rate: STTRealtimeSampleRates = 16000, + server_vad: NotGivenOr[VADOptions] = NOT_GIVEN, + http_session: aiohttp.ClientSession | None = None, ) -> None: """ Create a new instance of ElevenLabs STT. @@ -60,11 +85,19 @@ def __init__( Args: api_key (NotGivenOr[str]): ElevenLabs API key. Can be set via argument or `ELEVEN_API_KEY` environment variable. base_url (NotGivenOr[str]): Custom base URL for the API. Optional. - http_session (aiohttp.ClientSession | None): Custom HTTP session for API requests. Optional. language_code (NotGivenOr[str]): Language code for the STT model. Optional. - tag_audio_events (bool): Whether to tag audio events like (laughter), (footsteps), etc. in the transcription. Default is True. + tag_audio_events (bool): Whether to tag audio events like (laughter), (footsteps), etc. in the transcription. + Only supported for Scribe v1 model. Default is True. + use_realtime (bool): Whether to use "scribe_v2_realtime" model for streaming mode. Default is False. + sample_rate (STTRealtimeSampleRates): Audio sample rate in Hz. Default is 16000. + server_vad (NotGivenOr[VADOptions]): Server-side VAD options, only supported for Scribe v2 realtime model. + http_session (aiohttp.ClientSession | None): Custom HTTP session for API requests. Optional. """ # noqa: E501 - super().__init__(capabilities=STTCapabilities(streaming=False, interim_results=True)) + + super().__init__(capabilities=STTCapabilities(streaming=use_realtime, interim_results=True)) + + if not use_realtime and is_given(server_vad): + logger.warning("Server-side VAD is only supported for Scribe v2 realtime model") elevenlabs_api_key = api_key if is_given(api_key) else os.environ.get("ELEVEN_API_KEY") if not elevenlabs_api_key: @@ -72,14 +105,16 @@ def __init__( "ElevenLabs API key is required, either as argument or " "set ELEVEN_API_KEY environmental variable" ) - self._opts = _STTOptions( + self._opts = STTOptions( api_key=elevenlabs_api_key, base_url=base_url if is_given(base_url) else API_BASE_URL_V1, + language_code=language_code or None, tag_audio_events=tag_audio_events, + sample_rate=sample_rate, + server_vad=server_vad, ) - if is_given(language_code): - self._opts.language_code = language_code self._session = http_session + self._streams = weakref.WeakSet[SpeechStream]() @property def model(self) -> str: @@ -170,3 +205,299 @@ def _transcription_to_speech_event( ) ], ) + + def update_options( + self, + *, + tag_audio_events: NotGivenOr[bool] = NOT_GIVEN, + server_vad: NotGivenOr[VADOptions] = NOT_GIVEN, + ) -> None: + if is_given(tag_audio_events): + self._opts.tag_audio_events = tag_audio_events + + if is_given(server_vad): + self._opts.server_vad = server_vad + + for stream in self._streams: + stream.update_options(server_vad=server_vad) + + def stream( + self, + *, + language: NotGivenOr[str] = NOT_GIVEN, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + ) -> SpeechStream: + stream = SpeechStream( + stt=self, + opts=self._opts, + conn_options=conn_options, + language=language if is_given(language) else self._opts.language_code, + http_session=self._ensure_session(), + ) + self._streams.add(stream) + return stream + + +class SpeechStream(stt.SpeechStream): + """Streaming speech recognition using ElevenLabs Scribe v2 realtime API""" + + def __init__( + self, + *, + stt: STT, + opts: STTOptions, + conn_options: APIConnectOptions, + language: str | None, + http_session: aiohttp.ClientSession, + ) -> None: + super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate) + self._opts = opts + self._language = language + self._session = http_session + self._reconnect_event = asyncio.Event() + self._speaking = False # Track if we're currently in a speech segment + + def update_options( + self, + *, + server_vad: NotGivenOr[VADOptions] = NOT_GIVEN, + ) -> None: + if is_given(server_vad): + self._opts.server_vad = server_vad + self._reconnect_event.set() + + async def _run(self) -> None: + """Run the streaming transcription session""" + closing_ws = False + + async def keepalive_task(ws: aiohttp.ClientWebSocketResponse) -> None: + try: + while True: + await ws.ping() + await asyncio.sleep(30) + except Exception: + return + + @utils.log_exceptions(logger=logger) + async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: + nonlocal closing_ws + + # Buffer audio into chunks (50ms chunks) + samples_50ms = self._opts.sample_rate // 20 + audio_bstream = utils.audio.AudioByteStream( + sample_rate=self._opts.sample_rate, + num_channels=1, + samples_per_channel=samples_50ms, + ) + + async for data in self._input_ch: + # Write audio bytes to buffer and get 50ms frames + frames: list[rtc.AudioFrame] = [] + if isinstance(data, rtc.AudioFrame): + frames.extend(audio_bstream.write(data.data.tobytes())) + elif isinstance(data, self._FlushSentinel): + frames.extend(audio_bstream.flush()) + + for frame in frames: + audio_b64 = base64.b64encode(frame.data.tobytes()).decode("utf-8") + await ws.send_str( + json.dumps( + { + "message_type": "input_audio_chunk", + "audio_base_64": audio_b64, + "commit": False, + "sample_rate": self._opts.sample_rate, + } + ) + ) + + closing_ws = True + + @utils.log_exceptions(logger=logger) + async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None: + nonlocal closing_ws + + while True: + msg = await ws.receive() + + if msg.type in ( + aiohttp.WSMsgType.CLOSED, + aiohttp.WSMsgType.CLOSE, + aiohttp.WSMsgType.CLOSING, + ): + if closing_ws or self._session.closed: + return + raise APIStatusError(message="ElevenLabs STT connection closed unexpectedly") + + if msg.type != aiohttp.WSMsgType.TEXT: + logger.warning("unexpected ElevenLabs STT message type %s", msg.type) + continue + + try: + parsed = json.loads(msg.data) + self._process_stream_event(parsed) + except Exception: + logger.exception("failed to process ElevenLabs STT message") + + ws: aiohttp.ClientWebSocketResponse | None = None + + while True: + try: + ws = await self._connect_ws() + tasks = [ + asyncio.create_task(send_task(ws)), + asyncio.create_task(recv_task(ws)), + asyncio.create_task(keepalive_task(ws)), + ] + tasks_group = asyncio.gather(*tasks) + wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait()) + + try: + done, _ = await asyncio.wait( + (tasks_group, wait_reconnect_task), + return_when=asyncio.FIRST_COMPLETED, + ) + + for task in done: + if task != wait_reconnect_task: + task.result() + + if wait_reconnect_task not in done: + break + + self._reconnect_event.clear() + finally: + await utils.aio.gracefully_cancel(*tasks, wait_reconnect_task) + tasks_group.cancel() + tasks_group.exception() # Retrieve exception to prevent it from being logged + finally: + if ws is not None: + await ws.close() + + async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: + """Establish WebSocket connection to ElevenLabs Scribe v2 API""" + commit_strategy = "manual" if self._opts.server_vad is None else "vad" + params = [ + "model_id=scribe_v2_realtime", + f"encoding=pcm_{self._opts.sample_rate}", + f"commit_strategy={commit_strategy}", + ] + + if server_vad := self._opts.server_vad: + if ( + vad_silence_threshold_secs := server_vad.get("vad_silence_threshold_secs") + ) is not None: + params.append(f"vad_silence_threshold_secs={vad_silence_threshold_secs}") + if (vad_threshold := server_vad.get("vad_threshold")) is not None: + params.append(f"vad_threshold={vad_threshold}") + if (min_speech_duration_ms := server_vad.get("min_speech_duration_ms")) is not None: + params.append(f"min_speech_duration_ms={min_speech_duration_ms}") + if (min_silence_duration_ms := server_vad.get("min_silence_duration_ms")) is not None: + params.append(f"min_silence_duration_ms={min_silence_duration_ms}") + + if self._language: + params.append(f"language_code={self._language}") + + query_string = "&".join(params) + + # Convert HTTPS URL to WSS + base_url = self._opts.base_url.replace("https://", "wss://").replace("http://", "ws://") + ws_url = f"{base_url}/speech-to-text/realtime?{query_string}" + + try: + ws = await asyncio.wait_for( + self._session.ws_connect( + ws_url, + headers={AUTHORIZATION_HEADER: self._opts.api_key}, + ), + self._conn_options.timeout, + ) + except (aiohttp.ClientConnectorError, asyncio.TimeoutError) as e: + raise APIConnectionError("Failed to connect to ElevenLabs") from e + + return ws + + def _process_stream_event(self, data: dict) -> None: + """Process incoming WebSocket messages from ElevenLabs""" + message_type = data.get("message_type") + text = data.get("text", "") + + speech_data = stt.SpeechData( + language=self._language or "en", + text=text, + ) + + if message_type == "partial_transcript": + logger.debug("Received message type partial_transcript: %s", data) + + if text: + # Send START_OF_SPEECH if we're not already speaking + if not self._speaking: + self._event_ch.send_nowait( + stt.SpeechEvent(type=SpeechEventType.START_OF_SPEECH) + ) + self._speaking = True + + # Send INTERIM_TRANSCRIPT + interim_event = stt.SpeechEvent( + type=SpeechEventType.INTERIM_TRANSCRIPT, + alternatives=[speech_data], + ) + self._event_ch.send_nowait(interim_event) + + elif message_type == "committed_transcript": + logger.debug("Received message type committed_transcript: %s", data) + + # Final committed transcripts - these are sent to the LLM/TTS layer in LiveKit agents + # and trigger agent responses (unlike partial transcripts which are UI-only) + + if text: + # Send START_OF_SPEECH if we're not already speaking + if not self._speaking: + self._event_ch.send_nowait( + stt.SpeechEvent(type=SpeechEventType.START_OF_SPEECH) + ) + self._speaking = True + + # Send FINAL_TRANSCRIPT but keep speaking=True + # Multiple commits can occur within the same speech segment + final_event = stt.SpeechEvent( + type=SpeechEventType.FINAL_TRANSCRIPT, + alternatives=[speech_data], + ) + self._event_ch.send_nowait(final_event) + else: + # Empty commit signals end of speech segment (similar to Cartesia's is_final flag) + # This groups multiple committed transcripts into one speech segment + if self._speaking: + self._event_ch.send_nowait(stt.SpeechEvent(type=SpeechEventType.END_OF_SPEECH)) + self._speaking = False + + elif message_type == "session_started": + # Session initialization message - informational only + session_id = data.get("session_id", "unknown") + logger.debug("Session started with ID: %s", session_id) + + elif message_type == "committed_transcript_with_timestamps": + logger.debug("Received message type committed_transcript_with_timestamps: %s", data) + + # Error handling for known ElevenLabs error types + elif message_type in ( + "auth_error", + "quota_exceeded", + "transcriber_error", + "input_error", + "error", + ): + error_msg = data.get("message", "Unknown error") + error_details = data.get("details", "") + details_suffix = " - " + error_details if error_details else "" + logger.error( + "ElevenLabs STT error [%s]: %s%s", + message_type, + error_msg, + details_suffix, + ) + raise APIConnectionError(f"{message_type}: {error_msg}{details_suffix}") + else: + logger.warning("ElevenLabs STT unknown message type: %s, data: %s", message_type, data) diff --git a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt_v2.py b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt_v2.py deleted file mode 100644 index 36236f3a83..0000000000 --- a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt_v2.py +++ /dev/null @@ -1,417 +0,0 @@ -# Copyright 2023 LiveKit, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import asyncio -import base64 -import json -import os -import typing -import weakref -from dataclasses import dataclass - -import aiohttp - -from livekit import rtc -from livekit.agents import ( - DEFAULT_API_CONNECT_OPTIONS, - APIConnectionError, - APIConnectOptions, - APIStatusError, - stt, - utils, -) -from livekit.agents.stt import SpeechEventType, STTCapabilities -from livekit.agents.types import NOT_GIVEN, NotGivenOr -from livekit.agents.utils import AudioBuffer, is_given - -from .log import logger -from .models import STTAudioFormat, STTModels - -API_BASE_URL_V1 = "https://api.elevenlabs.io/v1" -AUTHORIZATION_HEADER = "xi-api-key" - - -@dataclass -class STTOptions: - api_key: str - base_url: str - language_code: str | None = None - model_id: STTModels = "scribe_v2_realtime" - audio_format: STTAudioFormat = "pcm_16000" - sample_rate: int = 16000 - vad_silence_threshold_secs: float | None = None - vad_threshold: float | None = None - min_speech_duration_ms: int | None = None - min_silence_duration_ms: int | None = None - - -class STTv2(stt.STT): - def __init__( - self, - api_key: NotGivenOr[str] = NOT_GIVEN, - base_url: NotGivenOr[str] = NOT_GIVEN, - http_session: aiohttp.ClientSession | None = None, - language_code: NotGivenOr[str] = NOT_GIVEN, - model_id: STTModels = "scribe_v2_realtime", - sample_rate: int = 16000, - vad_silence_threshold_secs: NotGivenOr[float] = NOT_GIVEN, - vad_threshold: NotGivenOr[float] = NOT_GIVEN, - min_speech_duration_ms: NotGivenOr[int] = NOT_GIVEN, - min_silence_duration_ms: NotGivenOr[int] = NOT_GIVEN, - ) -> None: - """ - Create a new instance of ElevenLabs STT v2 with streaming support. - - Uses Voice Activity Detection (VAD) to automatically detect speech segments - and commit transcriptions when the user stops speaking. - - Args: - api_key (NotGivenOr[str]): ElevenLabs API key. Can be set via argument or `ELEVEN_API_KEY` environment variable. - base_url (NotGivenOr[str]): Custom base URL for the API. Optional. - http_session (aiohttp.ClientSession | None): Custom HTTP session for API requests. Optional. - language_code (NotGivenOr[str]): Language code for the STT model. Optional. - model_id (STTModels): Model ID for Scribe. Default is "scribe_v2_realtime". - sample_rate (int): Audio sample rate in Hz. Default is 16000. - vad_silence_threshold_secs (NotGivenOr[float]): Silence threshold in seconds for VAD (must be between 0.3 and 3.0). Optional. - vad_threshold (NotGivenOr[float]): Threshold for voice activity detection (must be between 0.1 and 0.9). Optional. - min_speech_duration_ms (NotGivenOr[int]): Minimum speech duration in milliseconds (must be between 50 and 2000). Optional. - min_silence_duration_ms (NotGivenOr[int]): Minimum silence duration in milliseconds (must be between 50 and 2000). Optional. - """ # noqa: E501 - super().__init__(capabilities=STTCapabilities(streaming=True, interim_results=True)) - - elevenlabs_api_key = api_key if is_given(api_key) else os.environ.get("ELEVEN_API_KEY") - if not elevenlabs_api_key: - raise ValueError( - "ElevenLabs API key is required, either as argument or " - "set ELEVEN_API_KEY environmental variable" - ) - - # Determine audio format based on sample rate - audio_format = typing.cast(STTAudioFormat, f"pcm_{sample_rate}") - - self._opts = STTOptions( - api_key=elevenlabs_api_key, - base_url=base_url if is_given(base_url) else API_BASE_URL_V1, - model_id=model_id, - audio_format=audio_format, - sample_rate=sample_rate, - vad_silence_threshold_secs=vad_silence_threshold_secs - if is_given(vad_silence_threshold_secs) - else None, - vad_threshold=vad_threshold if is_given(vad_threshold) else None, - min_speech_duration_ms=min_speech_duration_ms - if is_given(min_speech_duration_ms) - else None, - min_silence_duration_ms=min_silence_duration_ms - if is_given(min_silence_duration_ms) - else None, - ) - if is_given(language_code): - self._opts.language_code = language_code - self._session = http_session - self._streams = weakref.WeakSet[SpeechStreamv2]() - - @property - def model(self) -> str: - return self._opts.model_id - - @property - def provider(self) -> str: - return "ElevenLabs" - - def _ensure_session(self) -> aiohttp.ClientSession: - if not self._session: - self._session = utils.http_context.http_session() - - return self._session - - async def _recognize_impl( - self, - buffer: AudioBuffer, - *, - language: NotGivenOr[str] = NOT_GIVEN, - conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, - ) -> stt.SpeechEvent: - raise NotImplementedError( - "Scribe v2 API does not support non-streaming recognize. Use stream() instead or use the original STT class for Scribe v1" - ) - - def stream( - self, - *, - language: NotGivenOr[str] = NOT_GIVEN, - conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, - ) -> SpeechStreamv2: - return SpeechStreamv2( - stt=self, - opts=self._opts, - conn_options=conn_options, - language=language if is_given(language) else self._opts.language_code, - http_session=self._ensure_session(), - ) - - -class SpeechStreamv2(stt.SpeechStream): - """Streaming speech recognition using ElevenLabs Scribe v2 realtime API""" - - def __init__( - self, - *, - stt: STTv2, - opts: STTOptions, - conn_options: APIConnectOptions, - language: str | None, - http_session: aiohttp.ClientSession, - ) -> None: - super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate) - self._opts = opts - self._language = language - self._session = http_session - self._reconnect_event = asyncio.Event() - self._speaking = False # Track if we're currently in a speech segment - - async def _run(self) -> None: - """Run the streaming transcription session""" - closing_ws = False - - async def keepalive_task(ws: aiohttp.ClientWebSocketResponse) -> None: - try: - while True: - await ws.ping() - await asyncio.sleep(30) - except Exception: - return - - @utils.log_exceptions(logger=logger) - async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: - nonlocal closing_ws - - # Buffer audio into chunks (50ms chunks) - samples_50ms = self._opts.sample_rate // 20 - audio_bstream = utils.audio.AudioByteStream( - sample_rate=self._opts.sample_rate, - num_channels=1, - samples_per_channel=samples_50ms, - ) - - async for data in self._input_ch: - # Write audio bytes to buffer and get 50ms frames - frames: list[rtc.AudioFrame] = [] - if isinstance(data, rtc.AudioFrame): - frames.extend(audio_bstream.write(data.data.tobytes())) - elif isinstance(data, self._FlushSentinel): - frames.extend(audio_bstream.flush()) - - for frame in frames: - audio_b64 = base64.b64encode(frame.data.tobytes()).decode("utf-8") - await ws.send_str( - json.dumps( - { - "message_type": "input_audio_chunk", - "audio_base_64": audio_b64, - "commit": False, - "sample_rate": self._opts.sample_rate, - } - ) - ) - - closing_ws = True - - @utils.log_exceptions(logger=logger) - async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None: - nonlocal closing_ws - - while True: - msg = await ws.receive() - - if msg.type in ( - aiohttp.WSMsgType.CLOSED, - aiohttp.WSMsgType.CLOSE, - aiohttp.WSMsgType.CLOSING, - ): - if closing_ws or self._session.closed: - return - raise APIStatusError(message="ElevenLabs STT connection closed unexpectedly") - - if msg.type != aiohttp.WSMsgType.TEXT: - logger.warning("unexpected ElevenLabs STT message type %s", msg.type) - continue - - try: - parsed = json.loads(msg.data) - self._process_stream_event(parsed) - except Exception: - logger.exception("failed to process ElevenLabs STT message") - - ws: aiohttp.ClientWebSocketResponse | None = None - - while True: - try: - ws = await self._connect_ws() - tasks = [ - asyncio.create_task(send_task(ws)), - asyncio.create_task(recv_task(ws)), - asyncio.create_task(keepalive_task(ws)), - ] - tasks_group = asyncio.gather(*tasks) - wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait()) - - try: - done, _ = await asyncio.wait( - (tasks_group, wait_reconnect_task), - return_when=asyncio.FIRST_COMPLETED, - ) - - for task in done: - if task != wait_reconnect_task: - task.result() - - if wait_reconnect_task not in done: - break - - self._reconnect_event.clear() - finally: - await utils.aio.gracefully_cancel(*tasks, wait_reconnect_task) - tasks_group.cancel() - tasks_group.exception() # Retrieve exception to prevent it from being logged - finally: - if ws is not None: - await ws.close() - - async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: - """Establish WebSocket connection to ElevenLabs Scribe v2 API""" - # Build query parameters - params = [ - f"model_id={self._opts.model_id}", - f"encoding={self._opts.audio_format}", - f"sample_rate={self._opts.sample_rate}", - "commit_strategy=vad", # Always use VAD for automatic speech detection - ] - - if self._opts.vad_silence_threshold_secs is not None: - params.append(f"vad_silence_threshold_secs={self._opts.vad_silence_threshold_secs}") - if self._opts.vad_threshold is not None: - params.append(f"vad_threshold={self._opts.vad_threshold}") - if self._opts.min_speech_duration_ms is not None: - params.append(f"min_speech_duration_ms={self._opts.min_speech_duration_ms}") - if self._opts.min_silence_duration_ms is not None: - params.append(f"min_silence_duration_ms={self._opts.min_silence_duration_ms}") - if self._language: - params.append(f"language_code={self._language}") - - query_string = "&".join(params) - - # Convert HTTPS URL to WSS - base_url = self._opts.base_url.replace("https://", "wss://").replace("http://", "ws://") - ws_url = f"{base_url}/speech-to-text/realtime?{query_string}" - - try: - ws = await asyncio.wait_for( - self._session.ws_connect( - ws_url, - headers={AUTHORIZATION_HEADER: self._opts.api_key}, - ), - self._conn_options.timeout, - ) - except (aiohttp.ClientConnectorError, asyncio.TimeoutError) as e: - raise APIConnectionError("Failed to connect to ElevenLabs") from e - - return ws - - def _process_stream_event(self, data: dict) -> None: - """Process incoming WebSocket messages from ElevenLabs""" - message_type = data.get("message_type") - text = data.get("text", "") - - speech_data = stt.SpeechData( - language=self._language or "en", - text=text, - ) - - if message_type == "partial_transcript": - logger.debug("Received message type partial_transcript: %s", data) - - if text: - # Send START_OF_SPEECH if we're not already speaking - if not self._speaking: - self._event_ch.send_nowait( - stt.SpeechEvent(type=SpeechEventType.START_OF_SPEECH) - ) - self._speaking = True - - # Send INTERIM_TRANSCRIPT - interim_event = stt.SpeechEvent( - type=SpeechEventType.INTERIM_TRANSCRIPT, - alternatives=[speech_data], - ) - self._event_ch.send_nowait(interim_event) - - elif message_type == "committed_transcript": - logger.debug("Received message type committed_transcript: %s", data) - - # Final committed transcripts - these are sent to the LLM/TTS layer in LiveKit agents - # and trigger agent responses (unlike partial transcripts which are UI-only) - - if text: - # Send START_OF_SPEECH if we're not already speaking - if not self._speaking: - self._event_ch.send_nowait( - stt.SpeechEvent(type=SpeechEventType.START_OF_SPEECH) - ) - self._speaking = True - - # Send FINAL_TRANSCRIPT but keep speaking=True - # Multiple commits can occur within the same speech segment - final_event = stt.SpeechEvent( - type=SpeechEventType.FINAL_TRANSCRIPT, - alternatives=[speech_data], - ) - self._event_ch.send_nowait(final_event) - else: - # Empty commit signals end of speech segment (similar to Cartesia's is_final flag) - # This groups multiple committed transcripts into one speech segment - if self._speaking: - self._event_ch.send_nowait(stt.SpeechEvent(type=SpeechEventType.END_OF_SPEECH)) - self._speaking = False - - elif message_type == "session_started": - # Session initialization message - informational only - session_id = data.get("session_id", "unknown") - logger.info("STTv2: Session started with ID: %s", session_id) - - elif message_type == "committed_transcript_with_timestamps": - logger.debug("Received message type committed_transcript_with_timestamps: %s", data) - - # Error handling for known ElevenLabs error types - elif message_type in ( - "auth_error", - "quota_exceeded", - "transcriber_error", - "input_error", - "error", - ): - error_msg = data.get("message", "Unknown error") - error_details = data.get("details", "") - details_suffix = " - " + error_details if error_details else "" - logger.error( - "STTv2: ElevenLabs error [%s]: %s%s", - message_type, - error_msg, - details_suffix, - ) - raise APIConnectionError(f"{message_type}: {error_msg}{details_suffix}") - else: - logger.warning("STTv2: Unknown message type: %s, data: %s", message_type, data) From a18611ffbc74a911b575699529402a037d4a3392 Mon Sep 17 00:00:00 2001 From: Long Chen Date: Fri, 21 Nov 2025 11:33:45 +0800 Subject: [PATCH 2/4] raise error --- .../livekit/plugins/elevenlabs/stt.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py index 6539eaeeca..f1d48c6f49 100644 --- a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py +++ b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py @@ -155,6 +155,13 @@ async def _recognize_impl( headers={AUTHORIZATION_HEADER: self._opts.api_key}, ) as response: response_json = await response.json() + if response.status != 200: + raise APIStatusError( + message=response_json.get("detail", "Unknown ElevenLabs error"), + status_code=response.status, + request_id=None, + body=response_json, + ) extracted_text = response_json.get("text") language_code = response_json.get("language_code") speaker_id = None From 4576f381c71e6953c4a34a37ab0f728a8dab920b Mon Sep 17 00:00:00 2001 From: Long Chen Date: Fri, 21 Nov 2025 14:01:21 +0800 Subject: [PATCH 3/4] support using VAD with a streaming STT --- examples/other/elevenlab_scribe_v2.py | 37 +++-- examples/voice_agents/stream_stt_with_vad.py | 55 ++++++ .../livekit/agents/stt/stream_adapter.py | 157 +++++++++++++----- livekit-agents/livekit/agents/stt/stt.py | 1 + .../livekit/plugins/assemblyai/stt.py | 6 +- .../livekit/plugins/cartesia/stt.py | 7 +- .../livekit/plugins/deepgram/stt.py | 5 +- .../livekit/plugins/deepgram/stt_v2.py | 4 +- .../livekit/plugins/elevenlabs/stt.py | 28 +++- .../livekit/plugins/gladia/stt.py | 12 +- .../livekit/plugins/sarvam/stt.py | 4 +- .../livekit/plugins/soniox/stt.py | 40 ++--- .../livekit/plugins/speechmatics/stt.py | 5 +- 13 files changed, 263 insertions(+), 98 deletions(-) create mode 100644 examples/voice_agents/stream_stt_with_vad.py diff --git a/examples/other/elevenlab_scribe_v2.py b/examples/other/elevenlab_scribe_v2.py index 4a7b8e0370..55b73ac4d4 100644 --- a/examples/other/elevenlab_scribe_v2.py +++ b/examples/other/elevenlab_scribe_v2.py @@ -2,32 +2,32 @@ from dotenv import load_dotenv -from livekit.agents import Agent, AgentSession, JobContext, JobProcess, WorkerOptions, cli -from livekit.plugins import elevenlabs, openai, silero +from livekit.agents import Agent, AgentServer, AgentSession, JobContext, JobProcess, cli, stt +from livekit.plugins import elevenlabs, silero logger = logging.getLogger("realtime-scribe-v2") logger.setLevel(logging.INFO) load_dotenv() +server = AgentServer() -async def entrypoint(ctx: JobContext): - stt = elevenlabs.STT( - use_realtime=True, - server_vad={ - "vad_silence_threshold_secs": 0.5, - "vad_threshold": 0.5, - "min_speech_duration_ms": 100, - "min_silence_duration_ms": 300, - }, - ) +@server.rtc_session() +async def entrypoint(ctx: JobContext): session = AgentSession( - allow_interruptions=True, vad=ctx.proc.userdata["vad"], - stt=stt, - llm=openai.LLM(model="gpt-4.1-mini"), - tts=elevenlabs.TTS(model="eleven_turbo_v2_5"), + stt=stt.StreamAdapter( + stt=elevenlabs.STT( + use_realtime=True, + server_vad=None, # disable server-side VAD + language_code="en", + ), + vad=ctx.proc.userdata["vad"], + use_streaming=True, + ), + llm="openai/gpt-4.1-mini", + tts="elevenlabs", ) await session.start( agent=Agent(instructions="You are a somewhat helpful assistant."), room=ctx.room @@ -40,5 +40,8 @@ def prewarm(proc: JobProcess): proc.userdata["vad"] = silero.VAD.load() +server.setup_fnc = prewarm + + if __name__ == "__main__": - cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, prewarm_fnc=prewarm)) + cli.run_app(server) diff --git a/examples/voice_agents/stream_stt_with_vad.py b/examples/voice_agents/stream_stt_with_vad.py new file mode 100644 index 0000000000..a203f14a02 --- /dev/null +++ b/examples/voice_agents/stream_stt_with_vad.py @@ -0,0 +1,55 @@ +import logging + +from dotenv import load_dotenv + +from livekit.agents import ( + Agent, + AgentServer, + AgentSession, + JobContext, + JobProcess, + cli, + stt, +) +from livekit.plugins import deepgram, silero + +logger = logging.getLogger("stream-stt-with-vad") + +# This example shows how to use a streaming STT with a VAD. +# Only the audio frames which are detected as speech by the VAD will be sent to the STT. +# This requires the STT to support streaming and flush, e.g. deepgram, cartesia, etc., +# check the `STT.capabilities` for more details. + +load_dotenv() + +server = AgentServer() + + +@server.rtc_session() +async def entrypoint(ctx: JobContext): + session = AgentSession( + vad=ctx.proc.userdata["vad"], + stt=stt.StreamAdapter( + stt=deepgram.STT(), + vad=ctx.proc.userdata["vad"], + use_streaming=True, # use streaming mode of the wrapped STT with VAD + ), + llm="openai/gpt-4.1-mini", + tts="elevenlabs", + ) + await session.start( + agent=Agent(instructions="You are a somewhat helpful assistant."), room=ctx.room + ) + + await session.say("Hello, how can I help you?") + + +def prewarm(proc: JobProcess): + proc.userdata["vad"] = silero.VAD.load() + + +server.setup_fnc = prewarm + + +if __name__ == "__main__": + cli.run_app(server) diff --git a/livekit-agents/livekit/agents/stt/stream_adapter.py b/livekit-agents/livekit/agents/stt/stream_adapter.py index 0a84cd7a76..2a5f31324f 100644 --- a/livekit-agents/livekit/agents/stt/stream_adapter.py +++ b/livekit-agents/livekit/agents/stt/stream_adapter.py @@ -2,11 +2,13 @@ import asyncio from collections.abc import AsyncIterable +from dataclasses import dataclass from typing import Any from .. import utils +from ..log import logger from ..types import DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, APIConnectOptions, NotGivenOr -from ..vad import VAD, VADEventType +from ..vad import VAD, VADEventType, VADStream from .stt import STT, RecognizeStream, SpeechEvent, SpeechEventType, STTCapabilities # already a retry mechanism in STT.recognize, don't retry in stream adapter @@ -15,17 +17,46 @@ ) +@dataclass +class StreamAdapterOptions: + use_streaming: bool = False + + class StreamAdapter(STT): - def __init__(self, *, stt: STT, vad: VAD) -> None: + def __init__( + self, + *, + stt: STT, + vad: VAD, + use_streaming: bool = False, + ) -> None: + """ + Create a new instance of StreamAdapter. + + Args: + stt: The STT to wrap. + vad: The VAD to use. + use_streaming: Whether to use streaming mode of the wrapped STT. Default is False. + """ super().__init__( capabilities=STTCapabilities( streaming=True, - interim_results=False, - diarization=False, # diarization requires streaming STT + interim_results=use_streaming, + diarization=stt.capabilities.diarization and use_streaming, ) ) self._vad = vad self._stt = stt + self._opts = StreamAdapterOptions(use_streaming=use_streaming) + if use_streaming and not stt.capabilities.streaming: + raise ValueError( + f"STT {stt.label} does not support streaming while use_streaming is enabled" + ) + if use_streaming and not stt.capabilities.flush: + logger.warning( + f"STT {stt.label} does not support flush while use_streaming is enabled, " + "this may cause incomplete transcriptions." + ) # TODO(theomonnom): The segment_id needs to be populated! self._stt.on("metrics_collected", self._on_metrics_collected) @@ -65,6 +96,7 @@ def stream( wrapped_stt=self._stt, language=language, conn_options=conn_options, + opts=self._opts, ) def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None: @@ -83,12 +115,14 @@ def __init__( wrapped_stt: STT, language: NotGivenOr[str], conn_options: APIConnectOptions, + opts: StreamAdapterOptions, ) -> None: super().__init__(stt=stt, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS) self._vad = vad self._wrapped_stt = wrapped_stt self._wrapped_stt_conn_options = conn_options self._language = language + self._opts = opts async def _metrics_monitor_task(self, event_aiter: AsyncIterable[SpeechEvent]) -> None: pass # do nothing @@ -106,43 +140,88 @@ async def _forward_input() -> None: vad_stream.end_input() - async def _recognize() -> None: - """recognize speech from vad""" - async for event in vad_stream: - if event.type == VADEventType.START_OF_SPEECH: - self._event_ch.send_nowait(SpeechEvent(SpeechEventType.START_OF_SPEECH)) - elif event.type == VADEventType.END_OF_SPEECH: - self._event_ch.send_nowait( - SpeechEvent( - type=SpeechEventType.END_OF_SPEECH, - ) - ) + async def _forward_stream_output(stream: RecognizeStream) -> None: + async for event in stream: + self._event_ch.send_nowait(event) + + stt_stream: RecognizeStream | None = None + forward_input_task = asyncio.create_task(_forward_input(), name="forward_input") + tasks = [] + if not self._opts.use_streaming: + tasks.append( + asyncio.create_task( + self._recognize_non_streaming(vad_stream), name="recognize_non_streaming" + ), + ) + else: + stt_stream = self._wrapped_stt.stream( + language=self._language, conn_options=self._wrapped_stt_conn_options + ) + tasks += [ + asyncio.create_task( + _forward_stream_output(stt_stream), name="forward_stream_output" + ), + asyncio.create_task( + self._recognize_streaming(vad_stream, stt_stream), + name="recognize_streaming", + ), + ] - merged_frames = utils.merge_frames(event.frames) - t_event = await self._wrapped_stt.recognize( - buffer=merged_frames, - language=self._language, - conn_options=self._wrapped_stt_conn_options, + try: + await asyncio.gather(*tasks, forward_input_task) + finally: + await utils.aio.cancel_and_wait(forward_input_task) + await vad_stream.aclose() + if stt_stream is not None: + stt_stream.end_input() + await stt_stream.aclose() + await utils.aio.cancel_and_wait(*tasks) + + async def _recognize_streaming( + self, vad_stream: VADStream, stt_stream: RecognizeStream + ) -> None: + speaking = False + async for event in vad_stream: + frames = [] + if event.type == VADEventType.START_OF_SPEECH: + speaking = True + frames = event.frames + elif event.type == VADEventType.INFERENCE_DONE and speaking: + frames = event.frames + elif event.type == VADEventType.END_OF_SPEECH: + speaking = False + stt_stream.flush() + + for f in frames: + stt_stream.push_frame(f) + + async def _recognize_non_streaming(self, vad_stream: VADStream) -> None: + """recognize speech from vad""" + async for event in vad_stream: + if event.type == VADEventType.START_OF_SPEECH: + self._event_ch.send_nowait(SpeechEvent(SpeechEventType.START_OF_SPEECH)) + elif event.type == VADEventType.END_OF_SPEECH: + self._event_ch.send_nowait( + SpeechEvent( + type=SpeechEventType.END_OF_SPEECH, ) + ) - if len(t_event.alternatives) == 0: - continue - elif not t_event.alternatives[0].text: - continue + merged_frames = utils.merge_frames(event.frames) + t_event = await self._wrapped_stt.recognize( + buffer=merged_frames, + language=self._language, + conn_options=self._wrapped_stt_conn_options, + ) - self._event_ch.send_nowait( - SpeechEvent( - type=SpeechEventType.FINAL_TRANSCRIPT, - alternatives=[t_event.alternatives[0]], - ) - ) + if len(t_event.alternatives) == 0: + continue + elif not t_event.alternatives[0].text: + continue - tasks = [ - asyncio.create_task(_forward_input(), name="forward_input"), - asyncio.create_task(_recognize(), name="recognize"), - ] - try: - await asyncio.gather(*tasks) - finally: - await utils.aio.cancel_and_wait(*tasks) - await vad_stream.aclose() + self._event_ch.send_nowait( + SpeechEvent( + type=SpeechEventType.FINAL_TRANSCRIPT, + alternatives=[t_event.alternatives[0]], + ) + ) diff --git a/livekit-agents/livekit/agents/stt/stt.py b/livekit-agents/livekit/agents/stt/stt.py index ba23e9ca55..aa77b765a6 100644 --- a/livekit-agents/livekit/agents/stt/stt.py +++ b/livekit-agents/livekit/agents/stt/stt.py @@ -72,6 +72,7 @@ class STTCapabilities: streaming: bool interim_results: bool diarization: bool = False + flush: bool = False class STTError(BaseModel): diff --git a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py index 239389ec8e..fabefdbbe0 100644 --- a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py +++ b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py @@ -76,7 +76,7 @@ def __init__( buffer_size_seconds: float = 0.05, ): super().__init__( - capabilities=stt.STTCapabilities(streaming=True, interim_results=False), + capabilities=stt.STTCapabilities(streaming=True, interim_results=False, flush=True), ) assemblyai_api_key = api_key if is_given(api_key) else os.environ.get("ASSEMBLYAI_API_KEY") if assemblyai_api_key is None: @@ -171,6 +171,7 @@ def update_options( class SpeechStream(stt.SpeechStream): # Used to close websocket _CLOSE_MSG: str = json.dumps({"type": "Terminate"}) + _FLUSH_MSG: str = json.dumps({"type": "ForceEndpoint"}) def __init__( self, @@ -241,6 +242,9 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: self._speech_duration += frame.duration await ws.send_bytes(frame.data.tobytes()) + if isinstance(data, self._FlushSentinel): + await ws.send_str(SpeechStream._FLUSH_MSG) + closing_ws = True await ws.send_str(SpeechStream._CLOSE_MSG) diff --git a/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/stt.py b/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/stt.py index ca88ecd2d8..411ac1856f 100644 --- a/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/stt.py +++ b/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/stt.py @@ -89,7 +89,9 @@ def __init__( Raises: ValueError: If no API key is provided or found in environment variables. """ - super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=False)) + super().__init__( + capabilities=stt.STTCapabilities(streaming=True, interim_results=False, flush=True) + ) cartesia_api_key = api_key or os.environ.get("CARTESIA_API_KEY") if not cartesia_api_key: @@ -249,6 +251,9 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: self._speech_duration += frame.duration await ws.send_bytes(frame.data.tobytes()) + if isinstance(data, self._FlushSentinel): + await ws.send_str("finalize") + closing_ws = True await ws.send_str("finalize") diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 4e5918ab84..696b1f6aca 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -131,7 +131,10 @@ def __init__( super().__init__( capabilities=stt.STTCapabilities( - streaming=True, interim_results=interim_results, diarization=enable_diarization + streaming=True, + interim_results=interim_results, + diarization=enable_diarization, + flush=True, ) ) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt_v2.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt_v2.py index eade7e7eb5..381d4eae0b 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt_v2.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt_v2.py @@ -96,7 +96,9 @@ def __init__( the DEEPGRAM_API_KEY environmental variable. """ # noqa: E501 - super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True)) + super().__init__( + capabilities=stt.STTCapabilities(streaming=True, interim_results=True, flush=True) + ) deepgram_api_key = api_key if is_given(api_key) else os.environ.get("DEEPGRAM_API_KEY") if not deepgram_api_key: diff --git a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py index f1d48c6f49..d1b3b4d9bb 100644 --- a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py +++ b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py @@ -76,7 +76,7 @@ def __init__( tag_audio_events: bool = True, use_realtime: bool = False, sample_rate: STTRealtimeSampleRates = 16000, - server_vad: NotGivenOr[VADOptions] = NOT_GIVEN, + server_vad: NotGivenOr[VADOptions | None] = NOT_GIVEN, http_session: aiohttp.ClientSession | None = None, ) -> None: """ @@ -90,11 +90,16 @@ def __init__( Only supported for Scribe v1 model. Default is True. use_realtime (bool): Whether to use "scribe_v2_realtime" model for streaming mode. Default is False. sample_rate (STTRealtimeSampleRates): Audio sample rate in Hz. Default is 16000. - server_vad (NotGivenOr[VADOptions]): Server-side VAD options, only supported for Scribe v2 realtime model. + server_vad (NotGivenOr[VADOptions | None]): Server-side VAD options, only supported for Scribe v2 realtime model. + If None, use the "manual" commit strategy. http_session (aiohttp.ClientSession | None): Custom HTTP session for API requests. Optional. """ # noqa: E501 - super().__init__(capabilities=STTCapabilities(streaming=use_realtime, interim_results=True)) + super().__init__( + capabilities=STTCapabilities( + streaming=use_realtime, interim_results=True, flush=use_realtime + ) + ) if not use_realtime and is_given(server_vad): logger.warning("Server-side VAD is only supported for Scribe v2 realtime model") @@ -217,7 +222,7 @@ def update_options( self, *, tag_audio_events: NotGivenOr[bool] = NOT_GIVEN, - server_vad: NotGivenOr[VADOptions] = NOT_GIVEN, + server_vad: NotGivenOr[VADOptions | None] = NOT_GIVEN, ) -> None: if is_given(tag_audio_events): self._opts.tag_audio_events = tag_audio_events @@ -267,7 +272,7 @@ def __init__( def update_options( self, *, - server_vad: NotGivenOr[VADOptions] = NOT_GIVEN, + server_vad: NotGivenOr[VADOptions | None] = NOT_GIVEN, ) -> None: if is_given(server_vad): self._opts.server_vad = server_vad @@ -300,10 +305,12 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: async for data in self._input_ch: # Write audio bytes to buffer and get 50ms frames frames: list[rtc.AudioFrame] = [] + commit = False if isinstance(data, rtc.AudioFrame): frames.extend(audio_bstream.write(data.data.tobytes())) elif isinstance(data, self._FlushSentinel): frames.extend(audio_bstream.flush()) + commit = True for frame in frames: audio_b64 = base64.b64encode(frame.data.tobytes()).decode("utf-8") @@ -317,6 +324,17 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: } ) ) + if commit: + await ws.send_str( + json.dumps( + { + "message_type": "input_audio_chunk", + "audio_base_64": "", + "commit": True, + "sample_rate": self._opts.sample_rate, + } + ) + ) closing_ws = True diff --git a/livekit-plugins/livekit-plugins-gladia/livekit/plugins/gladia/stt.py b/livekit-plugins/livekit-plugins-gladia/livekit/plugins/gladia/stt.py index 182c2f505b..373e195e8e 100644 --- a/livekit-plugins/livekit-plugins-gladia/livekit/plugins/gladia/stt.py +++ b/livekit-plugins/livekit-plugins-gladia/livekit/plugins/gladia/stt.py @@ -273,7 +273,9 @@ def __init__( ValueError: If no API key is provided or found in environment variables. """ super().__init__( - capabilities=stt.STTCapabilities(streaming=True, interim_results=interim_results) + capabilities=stt.STTCapabilities( + streaming=True, interim_results=interim_results, flush=True + ) ) self._base_url = base_url @@ -963,10 +965,10 @@ async def _send_audio_task(self) -> None: message = json.dumps({"type": "audio_chunk", "data": {"chunk": chunk_b64}}) await self._ws.send_str(message) - if has_ended: - self._audio_duration_collector.flush() - await self._ws.send_str(json.dumps({"type": "stop_recording"})) - has_ended = False + if has_ended: + self._audio_duration_collector.flush() + await self._ws.send_str(json.dumps({"type": "stop_recording"})) + has_ended = False # Tell Gladia we're done sending audio when the stream ends if self._ws: diff --git a/livekit-plugins/livekit-plugins-sarvam/livekit/plugins/sarvam/stt.py b/livekit-plugins/livekit-plugins-sarvam/livekit/plugins/sarvam/stt.py index 3d8036138e..e20c71f76d 100644 --- a/livekit-plugins/livekit-plugins-sarvam/livekit/plugins/sarvam/stt.py +++ b/livekit-plugins/livekit-plugins-sarvam/livekit/plugins/sarvam/stt.py @@ -173,7 +173,9 @@ def __init__( http_session: aiohttp.ClientSession | None = None, prompt: str | None = None, ) -> None: - super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True)) + super().__init__( + capabilities=stt.STTCapabilities(streaming=True, interim_results=True, flush=True) + ) self._api_key = api_key or os.environ.get("SARVAM_API_KEY") if not self._api_key: diff --git a/livekit-plugins/livekit-plugins-soniox/livekit/plugins/soniox/stt.py b/livekit-plugins/livekit-plugins-soniox/livekit/plugins/soniox/stt.py index 9b00c0991a..9938ec7478 100644 --- a/livekit-plugins/livekit-plugins-soniox/livekit/plugins/soniox/stt.py +++ b/livekit-plugins/livekit-plugins-soniox/livekit/plugins/soniox/stt.py @@ -98,6 +98,7 @@ class STTOptions: enable_language_identification: bool = True client_reference_id: str | None = None + enable_endpoint_detection: bool = True class STT(stt.STT): @@ -116,8 +117,9 @@ def __init__( api_key: str | None = None, base_url: str = BASE_URL, http_session: aiohttp.ClientSession | None = None, - vad: vad.VAD | None = None, params: STTOptions | None = None, + # deprecated + vad: vad.VAD | None = None, ): """Initialize instance of Soniox Speech-to-Text API service. @@ -126,17 +128,21 @@ def __init__( base_url: Base URL for Soniox Speech-to-Text API, default to BASE_URL defined in this module. http_session: Optional aiohttp.ClientSession to use for requests. - vad: If passed, enable Voice Activity Detection (VAD) for audio frames. params: Additional configuration parameters, such as model, language hints, context and speaker diarization. """ - super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True)) + super().__init__( + capabilities=stt.STTCapabilities(streaming=True, interim_results=True, flush=True) + ) self._api_key = api_key or os.getenv("SONIOX_API_KEY") self._base_url = base_url self._http_session = http_session - self._vad_stream = vad.stream() if vad else None self._params = params or STTOptions() + if vad is not None: + logger.warning( + "`vad` is deprecated. Use `stt.StreamAdapter(..., use_streaming=True)` instead." + ) @property def model(self) -> str: @@ -198,8 +204,6 @@ def _ensure_session(self) -> aiohttp.ClientSession: async def _connect_ws(self): """Open a WebSocket connection to the Soniox Speech-to-Text API and send the initial configuration.""" - # If VAD was passed, disable endpoint detection, otherwise enable it. - enable_endpoint_detection = not self._stt._vad_stream context = self._stt._params.context if isinstance(context, ContextObject): @@ -211,7 +215,7 @@ async def _connect_ws(self): "model": self._stt._params.model, "audio_format": "pcm_s16le", "num_channels": self._stt._params.num_channels or 1, - "enable_endpoint_detection": enable_endpoint_detection, + "enable_endpoint_detection": self._stt._params.enable_endpoint_detection, "sample_rate": self._stt._params.sample_rate, "language_hints": self._stt._params.language_hints, "context": context, @@ -238,7 +242,6 @@ async def _run(self) -> None: # Create task for audio processing, voice turn detection and message handling. tasks = [ asyncio.create_task(self._prepare_audio_task()), - asyncio.create_task(self._handle_vad_task()), asyncio.create_task(self._send_audio_task()), asyncio.create_task(self._recv_messages_task()), asyncio.create_task(self._keepalive_task()), @@ -302,23 +305,18 @@ async def _keepalive_task(self): logger.error(f"Error while sending keep alive message: {e}") async def _prepare_audio_task(self): - """Read audio frames, process VAD, and enqueue PCM data for sending.""" + """Read audio frames and enqueue PCM data for sending.""" if not self._ws: logger.error("WebSocket connection to Soniox Speech-to-Text API is not established") return async for data in self._input_ch: - if self._stt._vad_stream: - # If VAD is enabled, push the audio frame to the VAD stream. - if isinstance(data, self._FlushSentinel): - self._stt._vad_stream.flush() - else: - self._stt._vad_stream.push_frame(data) - if isinstance(data, rtc.AudioFrame): # Get the raw bytes from the audio frame. pcm_data = data.data.tobytes() self.audio_queue.put_nowait(pcm_data) + else: + self.audio_queue.put_nowait(FINALIZE_MESSAGE) async def _send_audio_task(self): """Take queued audio data and transmit it over the WebSocket.""" @@ -340,16 +338,6 @@ async def _send_audio_task(self): logger.error(f"Error while sending audio data: {e}") break - async def _handle_vad_task(self): - """Listen for VAD events to trigger finalize or keepalive messages.""" - if not self._stt._vad_stream: - logger.debug("VAD stream is not enabled, skipping VAD task") - return - - async for event in self._stt._vad_stream: - if event.type == vad.VADEventType.END_OF_SPEECH: - self.audio_queue.put_nowait(FINALIZE_MESSAGE) - async def _recv_messages_task(self): """Receive transcription messages, handle tokens, errors, and dispatch events.""" diff --git a/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/stt.py b/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/stt.py index 55a6eb6b57..d01b1b8978 100644 --- a/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/stt.py +++ b/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/stt.py @@ -241,7 +241,7 @@ def __init__( super().__init__( capabilities=stt.STTCapabilities( - streaming=True, interim_results=True, diarization=enable_diarization + streaming=True, interim_results=True, diarization=enable_diarization, flush=True ), ) @@ -496,6 +496,9 @@ def _evt_on_end_of_utterance(message: dict[str, Any]) -> None: self._speech_duration += frame.duration await self._client.send_audio(frame.data.tobytes()) + if isinstance(data, self._FlushSentinel): + await self._client.force_end_of_utterance() + # TODO - handle the closing of the stream? def _handle_transcript(self, message: dict[str, Any], is_final: bool) -> None: From 72aae8305732265172e0bf36dcb986d14ecf41f8 Mon Sep 17 00:00:00 2001 From: Long Chen Date: Fri, 21 Nov 2025 18:51:25 +0800 Subject: [PATCH 4/4] add silence_mode --- .../livekit/agents/stt/stream_adapter.py | 42 ++++++++++++++++--- .../livekit/plugins/deepgram/stt.py | 8 ++-- .../livekit/plugins/deepgram/stt_v2.py | 6 +-- 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/livekit-agents/livekit/agents/stt/stream_adapter.py b/livekit-agents/livekit/agents/stt/stream_adapter.py index 2a5f31324f..4529569f5a 100644 --- a/livekit-agents/livekit/agents/stt/stream_adapter.py +++ b/livekit-agents/livekit/agents/stt/stream_adapter.py @@ -3,7 +3,9 @@ import asyncio from collections.abc import AsyncIterable from dataclasses import dataclass -from typing import Any +from typing import Any, Literal + +from livekit import rtc from .. import utils from ..log import logger @@ -17,9 +19,13 @@ ) +SilenceMode = Literal["drop", "zeros", "passthrough"] + + @dataclass class StreamAdapterOptions: - use_streaming: bool = False + use_streaming: bool + silence_mode: SilenceMode class StreamAdapter(STT): @@ -29,6 +35,7 @@ def __init__( stt: STT, vad: VAD, use_streaming: bool = False, + silence_mode: SilenceMode = "zeros", ) -> None: """ Create a new instance of StreamAdapter. @@ -37,6 +44,10 @@ def __init__( stt: The STT to wrap. vad: The VAD to use. use_streaming: Whether to use streaming mode of the wrapped STT. Default is False. + silence_mode: How to handle audio frames during silent periods, only for use_streaming=True: + - "drop": Don't send silent frames to STT + - "zeros": Send zero-filled frames during silence (default) + - "passthrough": Send original frames even during silence """ super().__init__( capabilities=STTCapabilities( @@ -47,7 +58,10 @@ def __init__( ) self._vad = vad self._stt = stt - self._opts = StreamAdapterOptions(use_streaming=use_streaming) + self._opts = StreamAdapterOptions( + use_streaming=use_streaming, + silence_mode=silence_mode, + ) if use_streaming and not stt.capabilities.streaming: raise ValueError( f"STT {stt.label} does not support streaming while use_streaming is enabled" @@ -186,15 +200,31 @@ async def _recognize_streaming( if event.type == VADEventType.START_OF_SPEECH: speaking = True frames = event.frames - elif event.type == VADEventType.INFERENCE_DONE and speaking: - frames = event.frames elif event.type == VADEventType.END_OF_SPEECH: speaking = False - stt_stream.flush() + elif event.type == VADEventType.INFERENCE_DONE: + frames = event.frames + + if not speaking: + if self._opts.silence_mode == "drop": + frames.clear() + elif self._opts.silence_mode == "zeros": + frames = [ + rtc.AudioFrame( + data=b"\x00\x00" * f.samples_per_channel * f.num_channels, + sample_rate=f.sample_rate, + num_channels=f.num_channels, + samples_per_channel=f.samples_per_channel, + ) + for f in frames + ] for f in frames: stt_stream.push_frame(f) + if event.type == VADEventType.END_OF_SPEECH: + stt_stream.flush() + async def _recognize_non_streaming(self, vad_stream: VADStream) -> None: """recognize speech from vad""" async for event in vad_stream: diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 696b1f6aca..0679a1ed75 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -477,10 +477,10 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: self._audio_duration_collector.push(frame.duration) await ws.send_bytes(frame.data.tobytes()) - if has_ended: - self._audio_duration_collector.flush() - await ws.send_str(SpeechStream._FINALIZE_MSG) - has_ended = False + if has_ended: + self._audio_duration_collector.flush() + await ws.send_str(SpeechStream._FINALIZE_MSG) + has_ended = False # tell deepgram we are done sending audio/inputs closing_ws = True diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt_v2.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt_v2.py index 381d4eae0b..2e98618d68 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt_v2.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt_v2.py @@ -309,9 +309,9 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: self._audio_duration_collector.push(frame.duration) await ws.send_bytes(frame.data.tobytes()) - if has_ended: - self._audio_duration_collector.flush() - has_ended = False + if has_ended: + self._audio_duration_collector.flush() + has_ended = False # tell deepgram we are done sending audio/inputs closing_ws = True