Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 85 additions & 4 deletions livekit-agents/livekit/agents/stt/stream_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@


class StreamAdapter(STT):
def __init__(self, *, stt: STT, vad: VAD) -> None:
super().__init__(
capabilities=STTCapabilities(
def __init__(self, *, stt: STT, vad: VAD, force_stream: bool = False) -> None:
if stt.capabilities.streaming:
capabilities = stt.capabilities
else:
capabilities = STTCapabilities(
streaming=True,
interim_results=False,
diarization=False, # diarization requires streaming STT
)
)
super().__init__(capabilities=capabilities)
self._vad = vad
self._stt = stt
self._force_stream = force_stream

# TODO(theomonnom): The segment_id needs to be populated!
self._stt.on("metrics_collected", self._on_metrics_collected)
Expand All @@ -42,6 +45,10 @@ def model(self) -> str:
def provider(self) -> str:
return self._stt.provider

@property
def force_stream(self) -> bool:
return self._force_stream

async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
Expand All @@ -65,6 +72,7 @@ def stream(
wrapped_stt=self._stt,
language=language,
conn_options=conn_options,
force_stream=self._force_stream,
)

def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None:
Expand All @@ -83,17 +91,90 @@ def __init__(
wrapped_stt: STT,
language: NotGivenOr[str],
conn_options: APIConnectOptions,
force_stream: bool = False,
) -> None:
super().__init__(stt=stt, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS)
self._vad = vad
self._wrapped_stt = wrapped_stt
self._wrapped_stt_conn_options = conn_options
self._language = language
self._force_stream = force_stream

async def _metrics_monitor_task(self, event_aiter: AsyncIterable[SpeechEvent]) -> None:
pass # do nothing

async def _run(self) -> None:
# we do expect stt providers to honestly tell us
# weather they are capable of streaming
if self._wrapped_stt.capabilities.streaming and self._force_stream:
await self._run_stream_impl()
else:
await self._run_batch_impl()

async def _run_stream_impl(self):
vad_stream = self._vad.stream()
stt_stream = self._wrapped_stt.stream()

start_of_speech_received = asyncio.Event()

async def _forward_input() -> None:
"""forward input to vad"""
async for input in self._input_ch:
if isinstance(input, self._FlushSentinel):
vad_stream.flush()
stt_stream.flush()
continue
vad_stream.push_frame(input)
stt_stream.push_frame(input)

vad_stream.end_input()
stt_stream.end_input()

async def _handle_vad_stream() -> None:
async for event in vad_stream:
if event.type == VADEventType.START_OF_SPEECH:
start_of_speech_received.set()
self._event_ch.send_nowait(SpeechEvent(type=SpeechEventType.START_OF_SPEECH))
elif event.type == VADEventType.END_OF_SPEECH:
self._event_ch.send_nowait(
SpeechEvent(
type=SpeechEventType.END_OF_SPEECH,
)
)

async def _handle_stt_stream() -> None:
async for event in stt_stream:
status = start_of_speech_received.is_set()

# ignore if vad didn's signal start of speech
if not status:
continue

# we let vad handle these events
if (
event.type == SpeechEventType.START_OF_SPEECH
or event.type == SpeechEventType.END_OF_SPEECH
):
continue

if event.type == SpeechEventType.FINAL_TRANSCRIPT and status:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can receive FINAL_TRANSCRIPT while the user is still speaking, especially when there is a short pause:

[ User speech 1]...........[pause].....[User speech 2]
..........................................[F 1]...............[F2] 

The first FINAL_TRANSCRIPT will turn off start_of_speech_received after VAD has a START_OF_SPEECH for speech 2. The second one will be ignored.

Compared to VAD events, STT events can also be delayed and reliability varies across vendors.

start_of_speech_received.clear()

self._event_ch.send_nowait(event)

tasks = [
asyncio.create_task(_forward_input(), name="forward_input"),
asyncio.create_task(_handle_vad_stream(), name="handle_vad"),
asyncio.create_task(_handle_stt_stream(), name="handle_stt"),
]
try:
await asyncio.gather(*tasks)
finally:
await utils.aio.cancel_and_wait(*tasks)
await vad_stream.aclose()
await stt_stream.aclose()

async def _run_batch_impl(self) -> None:
vad_stream = self._vad.stream()

async def _forward_input() -> None:
Expand Down
Loading