diff --git a/examples/other/elevenlab_scribe_v2.py b/examples/other/elevenlab_scribe_v2.py index 4a7b8e0370..9a2527d679 100644 --- a/examples/other/elevenlab_scribe_v2.py +++ b/examples/other/elevenlab_scribe_v2.py @@ -2,32 +2,32 @@ from dotenv import load_dotenv -from livekit.agents import Agent, AgentSession, JobContext, JobProcess, WorkerOptions, cli -from livekit.plugins import elevenlabs, openai, silero +from livekit.agents import Agent, AgentServer, AgentSession, JobContext, JobProcess, cli, stt +from livekit.plugins import elevenlabs, silero logger = logging.getLogger("realtime-scribe-v2") logger.setLevel(logging.INFO) load_dotenv() +server = AgentServer() -async def entrypoint(ctx: JobContext): - stt = elevenlabs.STT( - use_realtime=True, - server_vad={ - "vad_silence_threshold_secs": 0.5, - "vad_threshold": 0.5, - "min_speech_duration_ms": 100, - "min_silence_duration_ms": 300, - }, - ) +@server.rtc_session() +async def entrypoint(ctx: JobContext): session = AgentSession( - allow_interruptions=True, vad=ctx.proc.userdata["vad"], - stt=stt, - llm=openai.LLM(model="gpt-4.1-mini"), - tts=elevenlabs.TTS(model="eleven_turbo_v2_5"), + stt=stt.StreamAdapter( # use local VAD with the STT + stt=elevenlabs.STT( + use_realtime=True, + server_vad=None, # disable server-side VAD + language_code="en", + ), + vad=ctx.proc.userdata["vad"], + use_streaming=True, + ), + llm="openai/gpt-4.1-mini", + tts="elevenlabs", ) await session.start( agent=Agent(instructions="You are a somewhat helpful assistant."), room=ctx.room @@ -40,5 +40,8 @@ def prewarm(proc: JobProcess): proc.userdata["vad"] = silero.VAD.load() +server.setup_fnc = prewarm + + if __name__ == "__main__": - cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, prewarm_fnc=prewarm)) + cli.run_app(server) diff --git a/examples/voice_agents/stream_stt_with_vad.py b/examples/voice_agents/stream_stt_with_vad.py new file mode 100644 index 0000000000..a203f14a02 --- /dev/null +++ b/examples/voice_agents/stream_stt_with_vad.py @@ -0,0 +1,55 @@ +import logging + +from dotenv import load_dotenv + +from livekit.agents import ( + Agent, + AgentServer, + AgentSession, + JobContext, + JobProcess, + cli, + stt, +) +from livekit.plugins import deepgram, silero + +logger = logging.getLogger("stream-stt-with-vad") + +# This example shows how to use a streaming STT with a VAD. +# Only the audio frames which are detected as speech by the VAD will be sent to the STT. +# This requires the STT to support streaming and flush, e.g. deepgram, cartesia, etc., +# check the `STT.capabilities` for more details. + +load_dotenv() + +server = AgentServer() + + +@server.rtc_session() +async def entrypoint(ctx: JobContext): + session = AgentSession( + vad=ctx.proc.userdata["vad"], + stt=stt.StreamAdapter( + stt=deepgram.STT(), + vad=ctx.proc.userdata["vad"], + use_streaming=True, # use streaming mode of the wrapped STT with VAD + ), + llm="openai/gpt-4.1-mini", + tts="elevenlabs", + ) + await session.start( + agent=Agent(instructions="You are a somewhat helpful assistant."), room=ctx.room + ) + + await session.say("Hello, how can I help you?") + + +def prewarm(proc: JobProcess): + proc.userdata["vad"] = silero.VAD.load() + + +server.setup_fnc = prewarm + + +if __name__ == "__main__": + cli.run_app(server) diff --git a/livekit-agents/livekit/agents/stt/stream_adapter.py b/livekit-agents/livekit/agents/stt/stream_adapter.py index 0a84cd7a76..4529569f5a 100644 --- a/livekit-agents/livekit/agents/stt/stream_adapter.py +++ b/livekit-agents/livekit/agents/stt/stream_adapter.py @@ -2,11 +2,15 @@ import asyncio from collections.abc import AsyncIterable -from typing import Any +from dataclasses import dataclass +from typing import Any, Literal + +from livekit import rtc from .. import utils +from ..log import logger from ..types import DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, APIConnectOptions, NotGivenOr -from ..vad import VAD, VADEventType +from ..vad import VAD, VADEventType, VADStream from .stt import STT, RecognizeStream, SpeechEvent, SpeechEventType, STTCapabilities # already a retry mechanism in STT.recognize, don't retry in stream adapter @@ -15,17 +19,58 @@ ) +SilenceMode = Literal["drop", "zeros", "passthrough"] + + +@dataclass +class StreamAdapterOptions: + use_streaming: bool + silence_mode: SilenceMode + + class StreamAdapter(STT): - def __init__(self, *, stt: STT, vad: VAD) -> None: + def __init__( + self, + *, + stt: STT, + vad: VAD, + use_streaming: bool = False, + silence_mode: SilenceMode = "zeros", + ) -> None: + """ + Create a new instance of StreamAdapter. + + Args: + stt: The STT to wrap. + vad: The VAD to use. + use_streaming: Whether to use streaming mode of the wrapped STT. Default is False. + silence_mode: How to handle audio frames during silent periods, only for use_streaming=True: + - "drop": Don't send silent frames to STT + - "zeros": Send zero-filled frames during silence (default) + - "passthrough": Send original frames even during silence + """ super().__init__( capabilities=STTCapabilities( streaming=True, - interim_results=False, - diarization=False, # diarization requires streaming STT + interim_results=use_streaming, + diarization=stt.capabilities.diarization and use_streaming, ) ) self._vad = vad self._stt = stt + self._opts = StreamAdapterOptions( + use_streaming=use_streaming, + silence_mode=silence_mode, + ) + if use_streaming and not stt.capabilities.streaming: + raise ValueError( + f"STT {stt.label} does not support streaming while use_streaming is enabled" + ) + if use_streaming and not stt.capabilities.flush: + logger.warning( + f"STT {stt.label} does not support flush while use_streaming is enabled, " + "this may cause incomplete transcriptions." + ) # TODO(theomonnom): The segment_id needs to be populated! self._stt.on("metrics_collected", self._on_metrics_collected) @@ -65,6 +110,7 @@ def stream( wrapped_stt=self._stt, language=language, conn_options=conn_options, + opts=self._opts, ) def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None: @@ -83,12 +129,14 @@ def __init__( wrapped_stt: STT, language: NotGivenOr[str], conn_options: APIConnectOptions, + opts: StreamAdapterOptions, ) -> None: super().__init__(stt=stt, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS) self._vad = vad self._wrapped_stt = wrapped_stt self._wrapped_stt_conn_options = conn_options self._language = language + self._opts = opts async def _metrics_monitor_task(self, event_aiter: AsyncIterable[SpeechEvent]) -> None: pass # do nothing @@ -106,43 +154,104 @@ async def _forward_input() -> None: vad_stream.end_input() - async def _recognize() -> None: - """recognize speech from vad""" - async for event in vad_stream: - if event.type == VADEventType.START_OF_SPEECH: - self._event_ch.send_nowait(SpeechEvent(SpeechEventType.START_OF_SPEECH)) - elif event.type == VADEventType.END_OF_SPEECH: - self._event_ch.send_nowait( - SpeechEvent( - type=SpeechEventType.END_OF_SPEECH, - ) - ) + async def _forward_stream_output(stream: RecognizeStream) -> None: + async for event in stream: + self._event_ch.send_nowait(event) - merged_frames = utils.merge_frames(event.frames) - t_event = await self._wrapped_stt.recognize( - buffer=merged_frames, - language=self._language, - conn_options=self._wrapped_stt_conn_options, - ) + stt_stream: RecognizeStream | None = None + forward_input_task = asyncio.create_task(_forward_input(), name="forward_input") + tasks = [] + if not self._opts.use_streaming: + tasks.append( + asyncio.create_task( + self._recognize_non_streaming(vad_stream), name="recognize_non_streaming" + ), + ) + else: + stt_stream = self._wrapped_stt.stream( + language=self._language, conn_options=self._wrapped_stt_conn_options + ) + tasks += [ + asyncio.create_task( + _forward_stream_output(stt_stream), name="forward_stream_output" + ), + asyncio.create_task( + self._recognize_streaming(vad_stream, stt_stream), + name="recognize_streaming", + ), + ] + + try: + await asyncio.gather(*tasks, forward_input_task) + finally: + await utils.aio.cancel_and_wait(forward_input_task) + await vad_stream.aclose() + if stt_stream is not None: + stt_stream.end_input() + await stt_stream.aclose() + await utils.aio.cancel_and_wait(*tasks) - if len(t_event.alternatives) == 0: - continue - elif not t_event.alternatives[0].text: - continue + async def _recognize_streaming( + self, vad_stream: VADStream, stt_stream: RecognizeStream + ) -> None: + speaking = False + async for event in vad_stream: + frames = [] + if event.type == VADEventType.START_OF_SPEECH: + speaking = True + frames = event.frames + elif event.type == VADEventType.END_OF_SPEECH: + speaking = False + elif event.type == VADEventType.INFERENCE_DONE: + frames = event.frames - self._event_ch.send_nowait( - SpeechEvent( - type=SpeechEventType.FINAL_TRANSCRIPT, - alternatives=[t_event.alternatives[0]], + if not speaking: + if self._opts.silence_mode == "drop": + frames.clear() + elif self._opts.silence_mode == "zeros": + frames = [ + rtc.AudioFrame( + data=b"\x00\x00" * f.samples_per_channel * f.num_channels, + sample_rate=f.sample_rate, + num_channels=f.num_channels, + samples_per_channel=f.samples_per_channel, ) + for f in frames + ] + + for f in frames: + stt_stream.push_frame(f) + + if event.type == VADEventType.END_OF_SPEECH: + stt_stream.flush() + + async def _recognize_non_streaming(self, vad_stream: VADStream) -> None: + """recognize speech from vad""" + async for event in vad_stream: + if event.type == VADEventType.START_OF_SPEECH: + self._event_ch.send_nowait(SpeechEvent(SpeechEventType.START_OF_SPEECH)) + elif event.type == VADEventType.END_OF_SPEECH: + self._event_ch.send_nowait( + SpeechEvent( + type=SpeechEventType.END_OF_SPEECH, ) + ) - tasks = [ - asyncio.create_task(_forward_input(), name="forward_input"), - asyncio.create_task(_recognize(), name="recognize"), - ] - try: - await asyncio.gather(*tasks) - finally: - await utils.aio.cancel_and_wait(*tasks) - await vad_stream.aclose() + merged_frames = utils.merge_frames(event.frames) + t_event = await self._wrapped_stt.recognize( + buffer=merged_frames, + language=self._language, + conn_options=self._wrapped_stt_conn_options, + ) + + if len(t_event.alternatives) == 0: + continue + elif not t_event.alternatives[0].text: + continue + + self._event_ch.send_nowait( + SpeechEvent( + type=SpeechEventType.FINAL_TRANSCRIPT, + alternatives=[t_event.alternatives[0]], + ) + ) diff --git a/livekit-agents/livekit/agents/stt/stt.py b/livekit-agents/livekit/agents/stt/stt.py index ba23e9ca55..aa77b765a6 100644 --- a/livekit-agents/livekit/agents/stt/stt.py +++ b/livekit-agents/livekit/agents/stt/stt.py @@ -72,6 +72,7 @@ class STTCapabilities: streaming: bool interim_results: bool diarization: bool = False + flush: bool = False class STTError(BaseModel): diff --git a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py index 239389ec8e..fabefdbbe0 100644 --- a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py +++ b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py @@ -76,7 +76,7 @@ def __init__( buffer_size_seconds: float = 0.05, ): super().__init__( - capabilities=stt.STTCapabilities(streaming=True, interim_results=False), + capabilities=stt.STTCapabilities(streaming=True, interim_results=False, flush=True), ) assemblyai_api_key = api_key if is_given(api_key) else os.environ.get("ASSEMBLYAI_API_KEY") if assemblyai_api_key is None: @@ -171,6 +171,7 @@ def update_options( class SpeechStream(stt.SpeechStream): # Used to close websocket _CLOSE_MSG: str = json.dumps({"type": "Terminate"}) + _FLUSH_MSG: str = json.dumps({"type": "ForceEndpoint"}) def __init__( self, @@ -241,6 +242,9 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: self._speech_duration += frame.duration await ws.send_bytes(frame.data.tobytes()) + if isinstance(data, self._FlushSentinel): + await ws.send_str(SpeechStream._FLUSH_MSG) + closing_ws = True await ws.send_str(SpeechStream._CLOSE_MSG) diff --git a/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/stt.py b/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/stt.py index ca88ecd2d8..411ac1856f 100644 --- a/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/stt.py +++ b/livekit-plugins/livekit-plugins-cartesia/livekit/plugins/cartesia/stt.py @@ -89,7 +89,9 @@ def __init__( Raises: ValueError: If no API key is provided or found in environment variables. """ - super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=False)) + super().__init__( + capabilities=stt.STTCapabilities(streaming=True, interim_results=False, flush=True) + ) cartesia_api_key = api_key or os.environ.get("CARTESIA_API_KEY") if not cartesia_api_key: @@ -249,6 +251,9 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: self._speech_duration += frame.duration await ws.send_bytes(frame.data.tobytes()) + if isinstance(data, self._FlushSentinel): + await ws.send_str("finalize") + closing_ws = True await ws.send_str("finalize") diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 4e5918ab84..0679a1ed75 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -131,7 +131,10 @@ def __init__( super().__init__( capabilities=stt.STTCapabilities( - streaming=True, interim_results=interim_results, diarization=enable_diarization + streaming=True, + interim_results=interim_results, + diarization=enable_diarization, + flush=True, ) ) @@ -474,10 +477,10 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: self._audio_duration_collector.push(frame.duration) await ws.send_bytes(frame.data.tobytes()) - if has_ended: - self._audio_duration_collector.flush() - await ws.send_str(SpeechStream._FINALIZE_MSG) - has_ended = False + if has_ended: + self._audio_duration_collector.flush() + await ws.send_str(SpeechStream._FINALIZE_MSG) + has_ended = False # tell deepgram we are done sending audio/inputs closing_ws = True diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt_v2.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt_v2.py index eade7e7eb5..2e98618d68 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt_v2.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt_v2.py @@ -96,7 +96,9 @@ def __init__( the DEEPGRAM_API_KEY environmental variable. """ # noqa: E501 - super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True)) + super().__init__( + capabilities=stt.STTCapabilities(streaming=True, interim_results=True, flush=True) + ) deepgram_api_key = api_key if is_given(api_key) else os.environ.get("DEEPGRAM_API_KEY") if not deepgram_api_key: @@ -307,9 +309,9 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: self._audio_duration_collector.push(frame.duration) await ws.send_bytes(frame.data.tobytes()) - if has_ended: - self._audio_duration_collector.flush() - has_ended = False + if has_ended: + self._audio_duration_collector.flush() + has_ended = False # tell deepgram we are done sending audio/inputs closing_ws = True diff --git a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py index f1d48c6f49..d1b3b4d9bb 100644 --- a/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py +++ b/livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py @@ -76,7 +76,7 @@ def __init__( tag_audio_events: bool = True, use_realtime: bool = False, sample_rate: STTRealtimeSampleRates = 16000, - server_vad: NotGivenOr[VADOptions] = NOT_GIVEN, + server_vad: NotGivenOr[VADOptions | None] = NOT_GIVEN, http_session: aiohttp.ClientSession | None = None, ) -> None: """ @@ -90,11 +90,16 @@ def __init__( Only supported for Scribe v1 model. Default is True. use_realtime (bool): Whether to use "scribe_v2_realtime" model for streaming mode. Default is False. sample_rate (STTRealtimeSampleRates): Audio sample rate in Hz. Default is 16000. - server_vad (NotGivenOr[VADOptions]): Server-side VAD options, only supported for Scribe v2 realtime model. + server_vad (NotGivenOr[VADOptions | None]): Server-side VAD options, only supported for Scribe v2 realtime model. + If None, use the "manual" commit strategy. http_session (aiohttp.ClientSession | None): Custom HTTP session for API requests. Optional. """ # noqa: E501 - super().__init__(capabilities=STTCapabilities(streaming=use_realtime, interim_results=True)) + super().__init__( + capabilities=STTCapabilities( + streaming=use_realtime, interim_results=True, flush=use_realtime + ) + ) if not use_realtime and is_given(server_vad): logger.warning("Server-side VAD is only supported for Scribe v2 realtime model") @@ -217,7 +222,7 @@ def update_options( self, *, tag_audio_events: NotGivenOr[bool] = NOT_GIVEN, - server_vad: NotGivenOr[VADOptions] = NOT_GIVEN, + server_vad: NotGivenOr[VADOptions | None] = NOT_GIVEN, ) -> None: if is_given(tag_audio_events): self._opts.tag_audio_events = tag_audio_events @@ -267,7 +272,7 @@ def __init__( def update_options( self, *, - server_vad: NotGivenOr[VADOptions] = NOT_GIVEN, + server_vad: NotGivenOr[VADOptions | None] = NOT_GIVEN, ) -> None: if is_given(server_vad): self._opts.server_vad = server_vad @@ -300,10 +305,12 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: async for data in self._input_ch: # Write audio bytes to buffer and get 50ms frames frames: list[rtc.AudioFrame] = [] + commit = False if isinstance(data, rtc.AudioFrame): frames.extend(audio_bstream.write(data.data.tobytes())) elif isinstance(data, self._FlushSentinel): frames.extend(audio_bstream.flush()) + commit = True for frame in frames: audio_b64 = base64.b64encode(frame.data.tobytes()).decode("utf-8") @@ -317,6 +324,17 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None: } ) ) + if commit: + await ws.send_str( + json.dumps( + { + "message_type": "input_audio_chunk", + "audio_base_64": "", + "commit": True, + "sample_rate": self._opts.sample_rate, + } + ) + ) closing_ws = True diff --git a/livekit-plugins/livekit-plugins-gladia/livekit/plugins/gladia/stt.py b/livekit-plugins/livekit-plugins-gladia/livekit/plugins/gladia/stt.py index 182c2f505b..373e195e8e 100644 --- a/livekit-plugins/livekit-plugins-gladia/livekit/plugins/gladia/stt.py +++ b/livekit-plugins/livekit-plugins-gladia/livekit/plugins/gladia/stt.py @@ -273,7 +273,9 @@ def __init__( ValueError: If no API key is provided or found in environment variables. """ super().__init__( - capabilities=stt.STTCapabilities(streaming=True, interim_results=interim_results) + capabilities=stt.STTCapabilities( + streaming=True, interim_results=interim_results, flush=True + ) ) self._base_url = base_url @@ -963,10 +965,10 @@ async def _send_audio_task(self) -> None: message = json.dumps({"type": "audio_chunk", "data": {"chunk": chunk_b64}}) await self._ws.send_str(message) - if has_ended: - self._audio_duration_collector.flush() - await self._ws.send_str(json.dumps({"type": "stop_recording"})) - has_ended = False + if has_ended: + self._audio_duration_collector.flush() + await self._ws.send_str(json.dumps({"type": "stop_recording"})) + has_ended = False # Tell Gladia we're done sending audio when the stream ends if self._ws: diff --git a/livekit-plugins/livekit-plugins-sarvam/livekit/plugins/sarvam/stt.py b/livekit-plugins/livekit-plugins-sarvam/livekit/plugins/sarvam/stt.py index 3d8036138e..e20c71f76d 100644 --- a/livekit-plugins/livekit-plugins-sarvam/livekit/plugins/sarvam/stt.py +++ b/livekit-plugins/livekit-plugins-sarvam/livekit/plugins/sarvam/stt.py @@ -173,7 +173,9 @@ def __init__( http_session: aiohttp.ClientSession | None = None, prompt: str | None = None, ) -> None: - super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True)) + super().__init__( + capabilities=stt.STTCapabilities(streaming=True, interim_results=True, flush=True) + ) self._api_key = api_key or os.environ.get("SARVAM_API_KEY") if not self._api_key: diff --git a/livekit-plugins/livekit-plugins-soniox/livekit/plugins/soniox/stt.py b/livekit-plugins/livekit-plugins-soniox/livekit/plugins/soniox/stt.py index 9b00c0991a..9938ec7478 100644 --- a/livekit-plugins/livekit-plugins-soniox/livekit/plugins/soniox/stt.py +++ b/livekit-plugins/livekit-plugins-soniox/livekit/plugins/soniox/stt.py @@ -98,6 +98,7 @@ class STTOptions: enable_language_identification: bool = True client_reference_id: str | None = None + enable_endpoint_detection: bool = True class STT(stt.STT): @@ -116,8 +117,9 @@ def __init__( api_key: str | None = None, base_url: str = BASE_URL, http_session: aiohttp.ClientSession | None = None, - vad: vad.VAD | None = None, params: STTOptions | None = None, + # deprecated + vad: vad.VAD | None = None, ): """Initialize instance of Soniox Speech-to-Text API service. @@ -126,17 +128,21 @@ def __init__( base_url: Base URL for Soniox Speech-to-Text API, default to BASE_URL defined in this module. http_session: Optional aiohttp.ClientSession to use for requests. - vad: If passed, enable Voice Activity Detection (VAD) for audio frames. params: Additional configuration parameters, such as model, language hints, context and speaker diarization. """ - super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True)) + super().__init__( + capabilities=stt.STTCapabilities(streaming=True, interim_results=True, flush=True) + ) self._api_key = api_key or os.getenv("SONIOX_API_KEY") self._base_url = base_url self._http_session = http_session - self._vad_stream = vad.stream() if vad else None self._params = params or STTOptions() + if vad is not None: + logger.warning( + "`vad` is deprecated. Use `stt.StreamAdapter(..., use_streaming=True)` instead." + ) @property def model(self) -> str: @@ -198,8 +204,6 @@ def _ensure_session(self) -> aiohttp.ClientSession: async def _connect_ws(self): """Open a WebSocket connection to the Soniox Speech-to-Text API and send the initial configuration.""" - # If VAD was passed, disable endpoint detection, otherwise enable it. - enable_endpoint_detection = not self._stt._vad_stream context = self._stt._params.context if isinstance(context, ContextObject): @@ -211,7 +215,7 @@ async def _connect_ws(self): "model": self._stt._params.model, "audio_format": "pcm_s16le", "num_channels": self._stt._params.num_channels or 1, - "enable_endpoint_detection": enable_endpoint_detection, + "enable_endpoint_detection": self._stt._params.enable_endpoint_detection, "sample_rate": self._stt._params.sample_rate, "language_hints": self._stt._params.language_hints, "context": context, @@ -238,7 +242,6 @@ async def _run(self) -> None: # Create task for audio processing, voice turn detection and message handling. tasks = [ asyncio.create_task(self._prepare_audio_task()), - asyncio.create_task(self._handle_vad_task()), asyncio.create_task(self._send_audio_task()), asyncio.create_task(self._recv_messages_task()), asyncio.create_task(self._keepalive_task()), @@ -302,23 +305,18 @@ async def _keepalive_task(self): logger.error(f"Error while sending keep alive message: {e}") async def _prepare_audio_task(self): - """Read audio frames, process VAD, and enqueue PCM data for sending.""" + """Read audio frames and enqueue PCM data for sending.""" if not self._ws: logger.error("WebSocket connection to Soniox Speech-to-Text API is not established") return async for data in self._input_ch: - if self._stt._vad_stream: - # If VAD is enabled, push the audio frame to the VAD stream. - if isinstance(data, self._FlushSentinel): - self._stt._vad_stream.flush() - else: - self._stt._vad_stream.push_frame(data) - if isinstance(data, rtc.AudioFrame): # Get the raw bytes from the audio frame. pcm_data = data.data.tobytes() self.audio_queue.put_nowait(pcm_data) + else: + self.audio_queue.put_nowait(FINALIZE_MESSAGE) async def _send_audio_task(self): """Take queued audio data and transmit it over the WebSocket.""" @@ -340,16 +338,6 @@ async def _send_audio_task(self): logger.error(f"Error while sending audio data: {e}") break - async def _handle_vad_task(self): - """Listen for VAD events to trigger finalize or keepalive messages.""" - if not self._stt._vad_stream: - logger.debug("VAD stream is not enabled, skipping VAD task") - return - - async for event in self._stt._vad_stream: - if event.type == vad.VADEventType.END_OF_SPEECH: - self.audio_queue.put_nowait(FINALIZE_MESSAGE) - async def _recv_messages_task(self): """Receive transcription messages, handle tokens, errors, and dispatch events.""" diff --git a/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/stt.py b/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/stt.py index 55a6eb6b57..d01b1b8978 100644 --- a/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/stt.py +++ b/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/stt.py @@ -241,7 +241,7 @@ def __init__( super().__init__( capabilities=stt.STTCapabilities( - streaming=True, interim_results=True, diarization=enable_diarization + streaming=True, interim_results=True, diarization=enable_diarization, flush=True ), ) @@ -496,6 +496,9 @@ def _evt_on_end_of_utterance(message: dict[str, Any]) -> None: self._speech_duration += frame.duration await self._client.send_audio(frame.data.tobytes()) + if isinstance(data, self._FlushSentinel): + await self._client.force_end_of_utterance() + # TODO - handle the closing of the stream? def _handle_transcript(self, message: dict[str, Any], is_final: bool) -> None: