diff --git a/livekit-agents/livekit/agents/stt/stream_adapter.py b/livekit-agents/livekit/agents/stt/stream_adapter.py index 0a84cd7a76..96e741edc3 100644 --- a/livekit-agents/livekit/agents/stt/stream_adapter.py +++ b/livekit-agents/livekit/agents/stt/stream_adapter.py @@ -16,16 +16,19 @@ class StreamAdapter(STT): - def __init__(self, *, stt: STT, vad: VAD) -> None: - super().__init__( - capabilities=STTCapabilities( + def __init__(self, *, stt: STT, vad: VAD, force_stream: bool = False) -> None: + if stt.capabilities.streaming: + capabilities = stt.capabilities + else: + capabilities = STTCapabilities( streaming=True, interim_results=False, diarization=False, # diarization requires streaming STT ) - ) + super().__init__(capabilities=capabilities) self._vad = vad self._stt = stt + self._force_stream = force_stream # TODO(theomonnom): The segment_id needs to be populated! self._stt.on("metrics_collected", self._on_metrics_collected) @@ -42,6 +45,10 @@ def model(self) -> str: def provider(self) -> str: return self._stt.provider + @property + def force_stream(self) -> bool: + return self._force_stream + async def _recognize_impl( self, buffer: utils.AudioBuffer, @@ -65,6 +72,7 @@ def stream( wrapped_stt=self._stt, language=language, conn_options=conn_options, + force_stream=self._force_stream, ) def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None: @@ -83,17 +91,90 @@ def __init__( wrapped_stt: STT, language: NotGivenOr[str], conn_options: APIConnectOptions, + force_stream: bool = False, ) -> None: super().__init__(stt=stt, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS) self._vad = vad self._wrapped_stt = wrapped_stt self._wrapped_stt_conn_options = conn_options self._language = language + self._force_stream = force_stream async def _metrics_monitor_task(self, event_aiter: AsyncIterable[SpeechEvent]) -> None: pass # do nothing async def _run(self) -> None: + # we do expect stt providers to honestly tell us + # weather they are capable of streaming + if self._wrapped_stt.capabilities.streaming and self._force_stream: + await self._run_stream_impl() + else: + await self._run_batch_impl() + + async def _run_stream_impl(self): + vad_stream = self._vad.stream() + stt_stream = self._wrapped_stt.stream() + + start_of_speech_received = asyncio.Event() + + async def _forward_input() -> None: + """forward input to vad""" + async for input in self._input_ch: + if isinstance(input, self._FlushSentinel): + vad_stream.flush() + stt_stream.flush() + continue + vad_stream.push_frame(input) + stt_stream.push_frame(input) + + vad_stream.end_input() + stt_stream.end_input() + + async def _handle_vad_stream() -> None: + async for event in vad_stream: + if event.type == VADEventType.START_OF_SPEECH: + start_of_speech_received.set() + self._event_ch.send_nowait(SpeechEvent(type=SpeechEventType.START_OF_SPEECH)) + elif event.type == VADEventType.END_OF_SPEECH: + self._event_ch.send_nowait( + SpeechEvent( + type=SpeechEventType.END_OF_SPEECH, + ) + ) + + async def _handle_stt_stream() -> None: + async for event in stt_stream: + status = start_of_speech_received.is_set() + + # ignore if vad didn's signal start of speech + if not status: + continue + + # we let vad handle these events + if ( + event.type == SpeechEventType.START_OF_SPEECH + or event.type == SpeechEventType.END_OF_SPEECH + ): + continue + + if event.type == SpeechEventType.FINAL_TRANSCRIPT and status: + start_of_speech_received.clear() + + self._event_ch.send_nowait(event) + + tasks = [ + asyncio.create_task(_forward_input(), name="forward_input"), + asyncio.create_task(_handle_vad_stream(), name="handle_vad"), + asyncio.create_task(_handle_stt_stream(), name="handle_stt"), + ] + try: + await asyncio.gather(*tasks) + finally: + await utils.aio.cancel_and_wait(*tasks) + await vad_stream.aclose() + await stt_stream.aclose() + + async def _run_batch_impl(self) -> None: vad_stream = self._vad.stream() async def _forward_input() -> None: