1313from __future__ import annotations
1414
1515import asyncio
16+ import concurrent .futures
17+ import contextlib
1618import os
1719from dataclasses import dataclass
18-
19- from amazon_transcribe .auth import AwsCrtCredentialResolver , CredentialResolver , Credentials
20- from amazon_transcribe .client import TranscribeStreamingClient
21- from amazon_transcribe .exceptions import BadRequestException
22- from amazon_transcribe .model import Result , StartStreamTranscriptionEventStream , TranscriptEvent
23- from awscrt .auth import AwsCredentialsProvider # type: ignore[import-untyped]
20+ from typing import Any
2421
2522from livekit import rtc
2623from livekit .agents import (
3532from .log import logger
3633from .utils import DEFAULT_REGION
3734
35+ try :
36+ from aws_sdk_transcribe_streaming .client import TranscribeStreamingClient # type: ignore
37+ from aws_sdk_transcribe_streaming .config import Config # type: ignore
38+ from aws_sdk_transcribe_streaming .models import ( # type: ignore
39+ AudioEvent ,
40+ AudioStream ,
41+ AudioStreamAudioEvent ,
42+ BadRequestException ,
43+ Result ,
44+ StartStreamTranscriptionInput ,
45+ TranscriptEvent ,
46+ TranscriptResultStream ,
47+ )
48+ from smithy_aws_core .identity .environment import EnvironmentCredentialsResolver
49+ from smithy_core .aio .interfaces .eventstream import (
50+ EventPublisher ,
51+ EventReceiver ,
52+ )
53+
54+ _AWS_SDK_AVAILABLE = True
55+ except ImportError :
56+ _AWS_SDK_AVAILABLE = False
57+
58+
59+ @dataclass
60+ class Credentials :
61+ access_key_id : str
62+ secret_access_key : str
63+ session_token : str | None = None
64+
3865
3966@dataclass
4067class STTOptions :
@@ -76,6 +103,12 @@ def __init__(
76103 ):
77104 super ().__init__ (capabilities = stt .STTCapabilities (streaming = True , interim_results = True ))
78105
106+ if not _AWS_SDK_AVAILABLE :
107+ raise ImportError (
108+ "The 'aws_sdk_transcribe_streaming' package is not installed. "
109+ "This implementation requires Python 3.12+ and the 'aws_sdk_transcribe_streaming' dependency."
110+ )
111+
79112 if not is_given (region ):
80113 region = os .getenv ("AWS_REGION" ) or DEFAULT_REGION
81114
@@ -129,7 +162,10 @@ def stream(
129162 conn_options : APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS ,
130163 ) -> SpeechStream :
131164 return SpeechStream (
132- stt = self , conn_options = conn_options , opts = self ._config , credentials = self ._credentials
165+ stt = self ,
166+ conn_options = conn_options ,
167+ opts = self ._config ,
168+ credentials = self ._credentials ,
133169 )
134170
135171
@@ -145,36 +181,25 @@ def __init__(
145181 self ._opts = opts
146182 self ._credentials = credentials
147183
148- def _credential_resolver (self ) -> CredentialResolver :
149- if self ._credentials is None :
150- return AwsCrtCredentialResolver (None ) # type: ignore
151-
152- credentials = self ._credentials
153-
154- class CustomAwsCrtCredentialResolver (CredentialResolver ):
155- def __init__ (self ) -> None :
156- self ._crt_resolver = AwsCredentialsProvider .new_static (
157- credentials .access_key_id ,
158- credentials .secret_access_key ,
159- credentials .session_token ,
160- )
161-
162- async def get_credentials (self ) -> Credentials | None :
163- credentials = await asyncio .wrap_future (self ._crt_resolver .get_credentials ())
164- return credentials # type: ignore[no-any-return]
165-
166- return CustomAwsCrtCredentialResolver ()
167-
168184 async def _run (self ) -> None :
169185 while True :
170- client = TranscribeStreamingClient (
171- region = self ._opts .region ,
172- credential_resolver = self ._credential_resolver (),
186+ config_kwargs : dict [str , Any ] = {"region" : self ._opts .region }
187+ if self ._credentials :
188+ config_kwargs ["aws_access_key_id" ] = self ._credentials .access_key_id
189+ config_kwargs ["aws_secret_access_key" ] = self ._credentials .secret_access_key
190+ config_kwargs ["aws_session_token" ] = self ._credentials .session_token
191+ else :
192+ config_kwargs ["aws_credentials_identity_resolver" ] = (
193+ EnvironmentCredentialsResolver ()
194+ )
195+
196+ client : TranscribeStreamingClient = TranscribeStreamingClient (
197+ config = Config (** config_kwargs )
173198 )
174199
175200 live_config = {
176201 "language_code" : self ._opts .language ,
177- "media_sample_rate_hz " : self ._opts .sample_rate ,
202+ "media_sample_rate_hertz " : self ._opts .sample_rate ,
178203 "media_encoding" : self ._opts .encoding ,
179204 "vocabulary_name" : self ._opts .vocabulary_name ,
180205 "session_id" : self ._opts .session_id ,
@@ -183,30 +208,59 @@ async def _run(self) -> None:
183208 "show_speaker_label" : self ._opts .show_speaker_label ,
184209 "enable_channel_identification" : self ._opts .enable_channel_identification ,
185210 "number_of_channels" : self ._opts .number_of_channels ,
186- "enable_partial_results_stabilization" : self ._opts .enable_partial_results_stabilization , # noqa: E501
211+ "enable_partial_results_stabilization" : self ._opts .enable_partial_results_stabilization ,
187212 "partial_results_stability" : self ._opts .partial_results_stability ,
188213 "language_model_name" : self ._opts .language_model_name ,
189214 }
190215 filtered_config = {k : v for k , v in live_config .items () if v and is_given (v )}
191- stream = await client .start_stream_transcription (** filtered_config ) # type: ignore
192-
193- async def input_generator (stream : StartStreamTranscriptionEventStream ) -> None :
194- async for frame in self ._input_ch :
195- if isinstance (frame , rtc .AudioFrame ):
196- await stream .input_stream .send_audio_event (audio_chunk = frame .data .tobytes ())
197- await stream .input_stream .end_stream () # type: ignore
198-
199- async def handle_transcript_events (stream : StartStreamTranscriptionEventStream ) -> None :
200- async for event in stream .output_stream :
201- if isinstance (event , TranscriptEvent ):
202- self ._process_transcript_event (event )
203-
204- tasks = [
205- asyncio .create_task (input_generator (stream )),
206- asyncio .create_task (handle_transcript_events (stream )),
207- ]
216+
208217 try :
209- await asyncio .gather (* tasks )
218+ stream = await client .start_stream_transcription (
219+ input = StartStreamTranscriptionInput (** filtered_config )
220+ )
221+
222+ # Get the output stream
223+ _ , output_stream = await stream .await_output ()
224+
225+ async def input_generator (
226+ audio_stream : EventPublisher [AudioStream ],
227+ ) -> None :
228+ try :
229+ async for frame in self ._input_ch :
230+ if isinstance (frame , rtc .AudioFrame ):
231+ await audio_stream .send (
232+ AudioStreamAudioEvent (
233+ value = AudioEvent (audio_chunk = frame .data .tobytes ())
234+ )
235+ )
236+ # Send empty frame to close
237+ await audio_stream .send (
238+ AudioStreamAudioEvent (value = AudioEvent (audio_chunk = b"" ))
239+ )
240+ finally :
241+ with contextlib .suppress (Exception ):
242+ await audio_stream .close ()
243+
244+ async def handle_transcript_events (
245+ output_stream : EventReceiver [TranscriptResultStream ],
246+ ) -> None :
247+ try :
248+ async for event in output_stream :
249+ if isinstance (event .value , TranscriptEvent ):
250+ self ._process_transcript_event (event .value )
251+ except concurrent .futures .InvalidStateError :
252+ logger .warning (
253+ "AWS Transcribe stream closed unexpectedly (InvalidStateError)"
254+ )
255+ pass
256+
257+ tasks = [
258+ asyncio .create_task (input_generator (stream .input_stream )),
259+ asyncio .create_task (handle_transcript_events (output_stream )),
260+ ]
261+ gather_future = asyncio .gather (* tasks )
262+
263+ await asyncio .shield (gather_future )
210264 except BadRequestException as e :
211265 if e .message and e .message .startswith ("Your request timed out" ):
212266 # AWS times out after 15s of inactivity, this tends to happen
@@ -217,17 +271,31 @@ async def handle_transcript_events(stream: StartStreamTranscriptionEventStream)
217271 else :
218272 raise e
219273 finally :
220- await utils .aio .gracefully_cancel (* tasks )
274+ # Close input stream first
275+ await utils .aio .gracefully_cancel (tasks [0 ])
276+
277+ # Wait for output stream to close cleanly
278+ try :
279+ await asyncio .wait_for (tasks [1 ], timeout = 3.0 )
280+ except (asyncio .TimeoutError , asyncio .CancelledError ):
281+ await utils .aio .gracefully_cancel (tasks [1 ])
282+
283+ # Ensure gather future is retrieved to avoid "exception never retrieved"
284+ with contextlib .suppress (Exception ):
285+ await gather_future
221286
222287 def _process_transcript_event (self , transcript_event : TranscriptEvent ) -> None :
288+ if not transcript_event .transcript or not transcript_event .transcript .results :
289+ return
290+
223291 stream = transcript_event .transcript .results
224292 for resp in stream :
225- if resp .start_time and resp .start_time == 0.0 :
293+ if resp .start_time is not None and resp .start_time == 0.0 :
226294 self ._event_ch .send_nowait (
227295 stt .SpeechEvent (type = stt .SpeechEventType .START_OF_SPEECH )
228296 )
229297
230- if resp .end_time and resp .end_time > 0.0 :
298+ if resp .end_time is not None and resp .end_time > 0.0 :
231299 if resp .is_partial :
232300 self ._event_ch .send_nowait (
233301 stt .SpeechEvent (
0 commit comments