Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 121 additions & 53 deletions livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,11 @@
from __future__ import annotations

import asyncio
import concurrent.futures
import contextlib
import os
from dataclasses import dataclass

from amazon_transcribe.auth import AwsCrtCredentialResolver, CredentialResolver, Credentials
from amazon_transcribe.client import TranscribeStreamingClient
from amazon_transcribe.exceptions import BadRequestException
from amazon_transcribe.model import Result, StartStreamTranscriptionEventStream, TranscriptEvent
from awscrt.auth import AwsCredentialsProvider # type: ignore[import-untyped]
from typing import Any

from livekit import rtc
from livekit.agents import (
Expand All @@ -35,6 +32,36 @@
from .log import logger
from .utils import DEFAULT_REGION

try:
from aws_sdk_transcribe_streaming.client import TranscribeStreamingClient # type: ignore
from aws_sdk_transcribe_streaming.config import Config # type: ignore
from aws_sdk_transcribe_streaming.models import ( # type: ignore
AudioEvent,
AudioStream,
AudioStreamAudioEvent,
BadRequestException,
Result,
StartStreamTranscriptionInput,
TranscriptEvent,
TranscriptResultStream,
)
from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver
from smithy_core.aio.interfaces.eventstream import (
EventPublisher,
EventReceiver,
)

_AWS_SDK_AVAILABLE = True
except ImportError:
_AWS_SDK_AVAILABLE = False


@dataclass
class Credentials:
access_key_id: str
secret_access_key: str
session_token: str | None = None


@dataclass
class STTOptions:
Expand Down Expand Up @@ -76,6 +103,12 @@ def __init__(
):
super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True))

if not _AWS_SDK_AVAILABLE:
raise ImportError(
"The 'aws_sdk_transcribe_streaming' package is not installed. "
"This implementation requires Python 3.12+ and the 'aws_sdk_transcribe_streaming' dependency."
)

if not is_given(region):
region = os.getenv("AWS_REGION") or DEFAULT_REGION

Expand Down Expand Up @@ -129,7 +162,10 @@ def stream(
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> SpeechStream:
return SpeechStream(
stt=self, conn_options=conn_options, opts=self._config, credentials=self._credentials
stt=self,
conn_options=conn_options,
opts=self._config,
credentials=self._credentials,
)


Expand All @@ -145,36 +181,25 @@ def __init__(
self._opts = opts
self._credentials = credentials

def _credential_resolver(self) -> CredentialResolver:
if self._credentials is None:
return AwsCrtCredentialResolver(None) # type: ignore

credentials = self._credentials

class CustomAwsCrtCredentialResolver(CredentialResolver):
def __init__(self) -> None:
self._crt_resolver = AwsCredentialsProvider.new_static(
credentials.access_key_id,
credentials.secret_access_key,
credentials.session_token,
)

async def get_credentials(self) -> Credentials | None:
credentials = await asyncio.wrap_future(self._crt_resolver.get_credentials())
return credentials # type: ignore[no-any-return]

return CustomAwsCrtCredentialResolver()

async def _run(self) -> None:
while True:
client = TranscribeStreamingClient(
region=self._opts.region,
credential_resolver=self._credential_resolver(),
config_kwargs: dict[str, Any] = {"region": self._opts.region}
if self._credentials:
config_kwargs["aws_access_key_id"] = self._credentials.access_key_id
config_kwargs["aws_secret_access_key"] = self._credentials.secret_access_key
config_kwargs["aws_session_token"] = self._credentials.session_token
else:
config_kwargs["aws_credentials_identity_resolver"] = (
EnvironmentCredentialsResolver()
)

client: TranscribeStreamingClient = TranscribeStreamingClient(
config=Config(**config_kwargs)
)

live_config = {
"language_code": self._opts.language,
"media_sample_rate_hz": self._opts.sample_rate,
"media_sample_rate_hertz": self._opts.sample_rate,
"media_encoding": self._opts.encoding,
"vocabulary_name": self._opts.vocabulary_name,
"session_id": self._opts.session_id,
Expand All @@ -183,30 +208,59 @@ async def _run(self) -> None:
"show_speaker_label": self._opts.show_speaker_label,
"enable_channel_identification": self._opts.enable_channel_identification,
"number_of_channels": self._opts.number_of_channels,
"enable_partial_results_stabilization": self._opts.enable_partial_results_stabilization, # noqa: E501
"enable_partial_results_stabilization": self._opts.enable_partial_results_stabilization,
"partial_results_stability": self._opts.partial_results_stability,
"language_model_name": self._opts.language_model_name,
}
filtered_config = {k: v for k, v in live_config.items() if v and is_given(v)}
stream = await client.start_stream_transcription(**filtered_config) # type: ignore

async def input_generator(stream: StartStreamTranscriptionEventStream) -> None:
async for frame in self._input_ch:
if isinstance(frame, rtc.AudioFrame):
await stream.input_stream.send_audio_event(audio_chunk=frame.data.tobytes())
await stream.input_stream.end_stream() # type: ignore

async def handle_transcript_events(stream: StartStreamTranscriptionEventStream) -> None:
async for event in stream.output_stream:
if isinstance(event, TranscriptEvent):
self._process_transcript_event(event)

tasks = [
asyncio.create_task(input_generator(stream)),
asyncio.create_task(handle_transcript_events(stream)),
]

try:
await asyncio.gather(*tasks)
stream = await client.start_stream_transcription(
input=StartStreamTranscriptionInput(**filtered_config)
)

# Get the output stream
_, output_stream = await stream.await_output()

async def input_generator(
audio_stream: EventPublisher[AudioStream],
) -> None:
try:
async for frame in self._input_ch:
if isinstance(frame, rtc.AudioFrame):
await audio_stream.send(
AudioStreamAudioEvent(
value=AudioEvent(audio_chunk=frame.data.tobytes())
)
)
# Send empty frame to close
await audio_stream.send(
AudioStreamAudioEvent(value=AudioEvent(audio_chunk=b""))
)
finally:
with contextlib.suppress(Exception):
await audio_stream.close()

async def handle_transcript_events(
output_stream: EventReceiver[TranscriptResultStream],
) -> None:
try:
async for event in output_stream:
if isinstance(event.value, TranscriptEvent):
self._process_transcript_event(event.value)
except concurrent.futures.InvalidStateError:
logger.warning(
"AWS Transcribe stream closed unexpectedly (InvalidStateError)"
)
pass

tasks = [
asyncio.create_task(input_generator(stream.input_stream)),
asyncio.create_task(handle_transcript_events(output_stream)),
]
gather_future = asyncio.gather(*tasks)

await asyncio.shield(gather_future)
except BadRequestException as e:
if e.message and e.message.startswith("Your request timed out"):
# AWS times out after 15s of inactivity, this tends to happen
Expand All @@ -217,17 +271,31 @@ async def handle_transcript_events(stream: StartStreamTranscriptionEventStream)
else:
raise e
finally:
await utils.aio.gracefully_cancel(*tasks)
# Close input stream first
await utils.aio.gracefully_cancel(tasks[0])

# Wait for output stream to close cleanly
try:
await asyncio.wait_for(tasks[1], timeout=3.0)
except (asyncio.TimeoutError, asyncio.CancelledError):
await utils.aio.gracefully_cancel(tasks[1])

# Ensure gather future is retrieved to avoid "exception never retrieved"
with contextlib.suppress(Exception):
await gather_future

def _process_transcript_event(self, transcript_event: TranscriptEvent) -> None:
if not transcript_event.transcript or not transcript_event.transcript.results:
return

stream = transcript_event.transcript.results
for resp in stream:
if resp.start_time and resp.start_time == 0.0:
if resp.start_time is not None and resp.start_time == 0.0:
self._event_ch.send_nowait(
stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
)

if resp.end_time and resp.end_time > 0.0:
if resp.end_time is not None and resp.end_time > 0.0:
if resp.is_partial:
self._event_ch.send_nowait(
stt.SpeechEvent(
Expand Down
9 changes: 5 additions & 4 deletions livekit-plugins/livekit-plugins-aws/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@ classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3 :: Only",
]
dependencies = [
"livekit-agents>=1.3.5",
"aioboto3>=14.1.0",
"amazon-transcribe>=0.6.4",
"aws_sdk_transcribe_streaming>=0.2.0; python_version >= '3.12'",
]

[project.optional-dependencies]
realtime = [
"aws-sdk-bedrock-runtime==0.0.2; python_version >= '3.12'",
"aws-sdk-signers==0.0.3; python_version >= '3.12'",
"aws-sdk-bedrock-runtime>=0.2.0; python_version >= '3.12'",
"aws-sdk-signers>=0.0.3; python_version >= '3.12'",
"boto3>1.35.10",
]

Expand All @@ -47,4 +48,4 @@ path = "livekit/plugins/aws/version.py"
packages = ["livekit"]

[tool.hatch.build.targets.sdist]
include = ["/livekit"]
include = ["/livekit"]
Loading
Loading