Skip to content

Commit 391bbdd

Browse files
authored
Replaced deprecated amazon-transcribe SDK with new aws-sdk-transcribe-streaming (#4111)
1 parent 13580b3 commit 391bbdd

File tree

3 files changed

+208
-131
lines changed

3 files changed

+208
-131
lines changed

livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/stt.py

Lines changed: 121 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,11 @@
1313
from __future__ import annotations
1414

1515
import asyncio
16+
import concurrent.futures
17+
import contextlib
1618
import os
1719
from dataclasses import dataclass
18-
19-
from amazon_transcribe.auth import AwsCrtCredentialResolver, CredentialResolver, Credentials
20-
from amazon_transcribe.client import TranscribeStreamingClient
21-
from amazon_transcribe.exceptions import BadRequestException
22-
from amazon_transcribe.model import Result, StartStreamTranscriptionEventStream, TranscriptEvent
23-
from awscrt.auth import AwsCredentialsProvider # type: ignore[import-untyped]
20+
from typing import Any
2421

2522
from livekit import rtc
2623
from livekit.agents import (
@@ -35,6 +32,36 @@
3532
from .log import logger
3633
from .utils import DEFAULT_REGION
3734

35+
try:
36+
from aws_sdk_transcribe_streaming.client import TranscribeStreamingClient # type: ignore
37+
from aws_sdk_transcribe_streaming.config import Config # type: ignore
38+
from aws_sdk_transcribe_streaming.models import ( # type: ignore
39+
AudioEvent,
40+
AudioStream,
41+
AudioStreamAudioEvent,
42+
BadRequestException,
43+
Result,
44+
StartStreamTranscriptionInput,
45+
TranscriptEvent,
46+
TranscriptResultStream,
47+
)
48+
from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver
49+
from smithy_core.aio.interfaces.eventstream import (
50+
EventPublisher,
51+
EventReceiver,
52+
)
53+
54+
_AWS_SDK_AVAILABLE = True
55+
except ImportError:
56+
_AWS_SDK_AVAILABLE = False
57+
58+
59+
@dataclass
60+
class Credentials:
61+
access_key_id: str
62+
secret_access_key: str
63+
session_token: str | None = None
64+
3865

3966
@dataclass
4067
class STTOptions:
@@ -76,6 +103,12 @@ def __init__(
76103
):
77104
super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True))
78105

106+
if not _AWS_SDK_AVAILABLE:
107+
raise ImportError(
108+
"The 'aws_sdk_transcribe_streaming' package is not installed. "
109+
"This implementation requires Python 3.12+ and the 'aws_sdk_transcribe_streaming' dependency."
110+
)
111+
79112
if not is_given(region):
80113
region = os.getenv("AWS_REGION") or DEFAULT_REGION
81114

@@ -129,7 +162,10 @@ def stream(
129162
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
130163
) -> SpeechStream:
131164
return SpeechStream(
132-
stt=self, conn_options=conn_options, opts=self._config, credentials=self._credentials
165+
stt=self,
166+
conn_options=conn_options,
167+
opts=self._config,
168+
credentials=self._credentials,
133169
)
134170

135171

@@ -145,36 +181,25 @@ def __init__(
145181
self._opts = opts
146182
self._credentials = credentials
147183

148-
def _credential_resolver(self) -> CredentialResolver:
149-
if self._credentials is None:
150-
return AwsCrtCredentialResolver(None) # type: ignore
151-
152-
credentials = self._credentials
153-
154-
class CustomAwsCrtCredentialResolver(CredentialResolver):
155-
def __init__(self) -> None:
156-
self._crt_resolver = AwsCredentialsProvider.new_static(
157-
credentials.access_key_id,
158-
credentials.secret_access_key,
159-
credentials.session_token,
160-
)
161-
162-
async def get_credentials(self) -> Credentials | None:
163-
credentials = await asyncio.wrap_future(self._crt_resolver.get_credentials())
164-
return credentials # type: ignore[no-any-return]
165-
166-
return CustomAwsCrtCredentialResolver()
167-
168184
async def _run(self) -> None:
169185
while True:
170-
client = TranscribeStreamingClient(
171-
region=self._opts.region,
172-
credential_resolver=self._credential_resolver(),
186+
config_kwargs: dict[str, Any] = {"region": self._opts.region}
187+
if self._credentials:
188+
config_kwargs["aws_access_key_id"] = self._credentials.access_key_id
189+
config_kwargs["aws_secret_access_key"] = self._credentials.secret_access_key
190+
config_kwargs["aws_session_token"] = self._credentials.session_token
191+
else:
192+
config_kwargs["aws_credentials_identity_resolver"] = (
193+
EnvironmentCredentialsResolver()
194+
)
195+
196+
client: TranscribeStreamingClient = TranscribeStreamingClient(
197+
config=Config(**config_kwargs)
173198
)
174199

175200
live_config = {
176201
"language_code": self._opts.language,
177-
"media_sample_rate_hz": self._opts.sample_rate,
202+
"media_sample_rate_hertz": self._opts.sample_rate,
178203
"media_encoding": self._opts.encoding,
179204
"vocabulary_name": self._opts.vocabulary_name,
180205
"session_id": self._opts.session_id,
@@ -183,30 +208,59 @@ async def _run(self) -> None:
183208
"show_speaker_label": self._opts.show_speaker_label,
184209
"enable_channel_identification": self._opts.enable_channel_identification,
185210
"number_of_channels": self._opts.number_of_channels,
186-
"enable_partial_results_stabilization": self._opts.enable_partial_results_stabilization, # noqa: E501
211+
"enable_partial_results_stabilization": self._opts.enable_partial_results_stabilization,
187212
"partial_results_stability": self._opts.partial_results_stability,
188213
"language_model_name": self._opts.language_model_name,
189214
}
190215
filtered_config = {k: v for k, v in live_config.items() if v and is_given(v)}
191-
stream = await client.start_stream_transcription(**filtered_config) # type: ignore
192-
193-
async def input_generator(stream: StartStreamTranscriptionEventStream) -> None:
194-
async for frame in self._input_ch:
195-
if isinstance(frame, rtc.AudioFrame):
196-
await stream.input_stream.send_audio_event(audio_chunk=frame.data.tobytes())
197-
await stream.input_stream.end_stream() # type: ignore
198-
199-
async def handle_transcript_events(stream: StartStreamTranscriptionEventStream) -> None:
200-
async for event in stream.output_stream:
201-
if isinstance(event, TranscriptEvent):
202-
self._process_transcript_event(event)
203-
204-
tasks = [
205-
asyncio.create_task(input_generator(stream)),
206-
asyncio.create_task(handle_transcript_events(stream)),
207-
]
216+
208217
try:
209-
await asyncio.gather(*tasks)
218+
stream = await client.start_stream_transcription(
219+
input=StartStreamTranscriptionInput(**filtered_config)
220+
)
221+
222+
# Get the output stream
223+
_, output_stream = await stream.await_output()
224+
225+
async def input_generator(
226+
audio_stream: EventPublisher[AudioStream],
227+
) -> None:
228+
try:
229+
async for frame in self._input_ch:
230+
if isinstance(frame, rtc.AudioFrame):
231+
await audio_stream.send(
232+
AudioStreamAudioEvent(
233+
value=AudioEvent(audio_chunk=frame.data.tobytes())
234+
)
235+
)
236+
# Send empty frame to close
237+
await audio_stream.send(
238+
AudioStreamAudioEvent(value=AudioEvent(audio_chunk=b""))
239+
)
240+
finally:
241+
with contextlib.suppress(Exception):
242+
await audio_stream.close()
243+
244+
async def handle_transcript_events(
245+
output_stream: EventReceiver[TranscriptResultStream],
246+
) -> None:
247+
try:
248+
async for event in output_stream:
249+
if isinstance(event.value, TranscriptEvent):
250+
self._process_transcript_event(event.value)
251+
except concurrent.futures.InvalidStateError:
252+
logger.warning(
253+
"AWS Transcribe stream closed unexpectedly (InvalidStateError)"
254+
)
255+
pass
256+
257+
tasks = [
258+
asyncio.create_task(input_generator(stream.input_stream)),
259+
asyncio.create_task(handle_transcript_events(output_stream)),
260+
]
261+
gather_future = asyncio.gather(*tasks)
262+
263+
await asyncio.shield(gather_future)
210264
except BadRequestException as e:
211265
if e.message and e.message.startswith("Your request timed out"):
212266
# AWS times out after 15s of inactivity, this tends to happen
@@ -217,17 +271,31 @@ async def handle_transcript_events(stream: StartStreamTranscriptionEventStream)
217271
else:
218272
raise e
219273
finally:
220-
await utils.aio.gracefully_cancel(*tasks)
274+
# Close input stream first
275+
await utils.aio.gracefully_cancel(tasks[0])
276+
277+
# Wait for output stream to close cleanly
278+
try:
279+
await asyncio.wait_for(tasks[1], timeout=3.0)
280+
except (asyncio.TimeoutError, asyncio.CancelledError):
281+
await utils.aio.gracefully_cancel(tasks[1])
282+
283+
# Ensure gather future is retrieved to avoid "exception never retrieved"
284+
with contextlib.suppress(Exception):
285+
await gather_future
221286

222287
def _process_transcript_event(self, transcript_event: TranscriptEvent) -> None:
288+
if not transcript_event.transcript or not transcript_event.transcript.results:
289+
return
290+
223291
stream = transcript_event.transcript.results
224292
for resp in stream:
225-
if resp.start_time and resp.start_time == 0.0:
293+
if resp.start_time is not None and resp.start_time == 0.0:
226294
self._event_ch.send_nowait(
227295
stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
228296
)
229297

230-
if resp.end_time and resp.end_time > 0.0:
298+
if resp.end_time is not None and resp.end_time > 0.0:
231299
if resp.is_partial:
232300
self._event_ch.send_nowait(
233301
stt.SpeechEvent(

livekit-plugins/livekit-plugins-aws/pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,19 @@ classifiers = [
2020
"Programming Language :: Python :: 3",
2121
"Programming Language :: Python :: 3.9",
2222
"Programming Language :: Python :: 3.10",
23+
"Programming Language :: Python :: 3.12",
2324
"Programming Language :: Python :: 3 :: Only",
2425
]
2526
dependencies = [
2627
"livekit-agents>=1.3.5",
2728
"aioboto3>=14.1.0",
28-
"amazon-transcribe>=0.6.4",
29+
"aws_sdk_transcribe_streaming>=0.2.0; python_version >= '3.12'",
2930
]
3031

3132
[project.optional-dependencies]
3233
realtime = [
33-
"aws-sdk-bedrock-runtime==0.0.2; python_version >= '3.12'",
34-
"aws-sdk-signers==0.0.3; python_version >= '3.12'",
34+
"aws-sdk-bedrock-runtime>=0.2.0; python_version >= '3.12'",
35+
"aws-sdk-signers>=0.0.3; python_version >= '3.12'",
3536
"boto3>1.35.10",
3637
]
3738

@@ -47,4 +48,4 @@ path = "livekit/plugins/aws/version.py"
4748
packages = ["livekit"]
4849

4950
[tool.hatch.build.targets.sdist]
50-
include = ["/livekit"]
51+
include = ["/livekit"]

0 commit comments

Comments
 (0)