66from collections .abc import AsyncIterator
77from dataclasses import dataclass
88from enum import Enum , unique
9- from typing import Literal , Union
9+ from typing import Generic , Literal , TypeVar , Union
10+
11+ from pydantic import BaseModel , ConfigDict , Field
1012
1113from livekit import rtc
1214
13- from .utils import aio
15+ from ._exceptions import APIConnectionError , APIError
16+ from .log import logger
17+ from .types import APIConnectOptions
18+ from .utils import aio , log_exceptions
1419
1520
1621@unique
@@ -38,10 +43,25 @@ class BargeinEvent:
3843 """Time taken to perform the inference, in seconds."""
3944
4045
41- class BargeinDetector (ABC , rtc .EventEmitter [Literal ["bargein_detected" ]]):
42- def __init__ (self ) -> None :
46+ class BargeinError (BaseModel ):
47+ model_config = ConfigDict (arbitrary_types_allowed = True )
48+ type : Literal ["bargein_error" ] = "bargein_error"
49+ timestamp : float
50+ label : str
51+ error : Exception = Field (..., exclude = True )
52+ recoverable : bool
53+
54+
55+ TEvent = TypeVar ("TEvent" )
56+
57+
58+ class BargeinDetector (
59+ ABC , rtc .EventEmitter [Union [Literal ["bargein_detected" , "error" ], TEvent ]], Generic [TEvent ]
60+ ):
61+ def __init__ (self , * , sample_rate : int ) -> None :
4362 super ().__init__ ()
4463 self ._label = f"{ type (self ).__module__ } .{ type (self ).__name__ } "
64+ self ._sample_rate = sample_rate
4565
4666 @property
4767 def model (self ) -> str :
@@ -55,6 +75,17 @@ def provider(self) -> str:
5575 def label (self ) -> str :
5676 return self ._label
5777
78+ def _emit_error (self , api_error : Exception , recoverable : bool ) -> None :
79+ self .emit (
80+ "error" ,
81+ BargeinError (
82+ timestamp = time .time (),
83+ label = self ._label ,
84+ error = api_error ,
85+ recoverable = recoverable ,
86+ ),
87+ )
88+
5889 @abstractmethod
5990 def stream (self ) -> BargeinStream : ...
6091
@@ -75,7 +106,7 @@ class _OverlapSpeechEndedSentinel:
75106 class _FlushSentinel :
76107 pass
77108
78- def __init__ (self , bargein_detector : BargeinDetector ) -> None :
109+ def __init__ (self , bargein_detector : BargeinDetector , conn_options : APIConnectOptions ) -> None :
79110 self ._bargein_detector = bargein_detector
80111 self ._last_activity_time = time .perf_counter ()
81112 self ._input_ch = aio .Chan [
@@ -91,9 +122,60 @@ def __init__(self, bargein_detector: BargeinDetector) -> None:
91122 self ._event_ch = aio .Chan [BargeinEvent ]()
92123 self ._task = asyncio .create_task (self ._main_task ())
93124 self ._task .add_done_callback (lambda _ : self ._event_ch .close ())
125+ self ._num_retries = 0
126+ self ._conn_options = conn_options
127+ self ._sample_rate = bargein_detector ._sample_rate
128+ self ._resampler : rtc .AudioResampler | None = None
94129
95130 @abstractmethod
96- async def _main_task (self ) -> None : ...
131+ async def _run (self ) -> None : ...
132+
133+ @log_exceptions (logger = logger )
134+ async def _main_task (self ) -> None :
135+ max_retries = self ._conn_options .max_retry
136+
137+ while self ._num_retries <= max_retries :
138+ try :
139+ return await self ._run ()
140+ except APIError as e :
141+ if max_retries == 0 :
142+ self ._emit_error (e , recoverable = False )
143+ raise
144+ elif self ._num_retries == max_retries :
145+ self ._emit_error (e , recoverable = False )
146+ raise APIConnectionError (
147+ f"failed to detect bargein after { self ._num_retries } attempts" ,
148+ ) from e
149+ else :
150+ self ._emit_error (e , recoverable = True )
151+
152+ retry_interval = self ._conn_options ._interval_for_retry (self ._num_retries )
153+ logger .warning (
154+ f"failed to detect bargein, retrying in { retry_interval } s" ,
155+ exc_info = e ,
156+ extra = {
157+ "bargein_detector" : self ._bargein_detector ._label ,
158+ "attempt" : self ._num_retries ,
159+ },
160+ )
161+ await asyncio .sleep (retry_interval )
162+
163+ self ._num_retries += 1
164+
165+ except Exception as e :
166+ self ._emit_error (e , recoverable = False )
167+ raise
168+
169+ def _emit_error (self , api_error : Exception , recoverable : bool ) -> None :
170+ self ._bargein_detector .emit (
171+ "error" ,
172+ BargeinError (
173+ timestamp = time .time (),
174+ label = self ._bargein_detector ._label ,
175+ error = api_error ,
176+ recoverable = recoverable ,
177+ ),
178+ )
97179
98180 def start_agent_speech (self ) -> None :
99181 """Mark the start of the agent's speech"""
@@ -130,7 +212,28 @@ def push_frame(
130212 """Push some audio frame to be analyzed"""
131213 self ._check_input_not_ended ()
132214 self ._check_not_closed ()
133- self ._input_ch .send_nowait (frame )
215+
216+ if not isinstance (frame , rtc .AudioFrame ):
217+ self ._input_ch .send_nowait (frame )
218+ return
219+
220+ if self ._sample_rate != frame .sample_rate :
221+ if not self ._resampler :
222+ self ._resampler = rtc .AudioResampler (
223+ input_rate = frame .sample_rate ,
224+ output_rate = self ._sample_rate ,
225+ num_channels = 1 ,
226+ quality = rtc .AudioResamplerQuality .LOW ,
227+ )
228+ elif self ._resampler ._input_rate != frame .sample_rate :
229+ raise ValueError ("the sample rate of the input frames must be consistent" )
230+
231+ if self ._resampler :
232+ frames = self ._resampler .push (frame )
233+ for frame in frames :
234+ self ._input_ch .send_nowait (frame )
235+ else :
236+ self ._input_ch .send_nowait (frame )
134237
135238 def flush (self ) -> None :
136239 """Mark the end of the current segment"""
0 commit comments