88import time
99import weakref
1010from dataclasses import dataclass
11- from time import perf_counter
11+ from time import perf_counter_ns
1212from typing import Any , Union
1313
1414import aiohttp
2424 APIError ,
2525 APIStatusError ,
2626 NotGivenOr ,
27- get_job_context ,
2827 utils ,
2928)
3029from ..bargein import (
@@ -284,20 +283,13 @@ async def _send_task() -> None:
284283 data_chan .close ()
285284
286285 async def predict (self , waveform : np .ndarray ) -> bool :
287- ctx = get_job_context ()
288- created_at = perf_counter ()
286+ created_at = perf_counter_ns ()
289287 request = {
290- "jobId" : ctx .job .id ,
291- "workerId" : ctx .worker_id ,
292288 "waveform" : self ._model .encode_waveform (waveform ),
293289 "threshold" : self ._opts .threshold ,
294290 "min_frames" : self ._opts .min_frames ,
295291 "created_at" : created_at ,
296292 }
297- agent_id = os .getenv ("LIVEKIT_AGENT_ID" )
298- if agent_id :
299- request ["agentId" ] = agent_id
300-
301293 async with utils .http_context .http_session ().post (
302294 url = f"{ self ._opts .base_url } /bargein" ,
303295 headers = {
@@ -306,19 +298,24 @@ async def predict(self, waveform: np.ndarray) -> bool:
306298 json = request ,
307299 timeout = aiohttp .ClientTimeout (total = self ._opts .inference_timeout ),
308300 ) as resp :
309- resp .raise_for_status ()
310- data = await resp .json ()
311- is_bargein : bool | None = data .get ("is_bargein" )
312- inference_duration = time .perf_counter () - created_at
313- if isinstance (is_bargein , bool ):
314- logger .debug (
315- "bargein prediction" ,
316- extra = {
317- "is_bargein" : is_bargein ,
318- "duration" : inference_duration ,
319- },
320- )
321- return is_bargein
301+ try :
302+ resp .raise_for_status ()
303+ data = await resp .json ()
304+ is_bargein : bool | None = data .get ("is_bargein" )
305+ inference_duration = (perf_counter_ns () - created_at ) / 1e9
306+ if isinstance (is_bargein , bool ):
307+ logger .debug (
308+ "bargein prediction" ,
309+ extra = {
310+ "is_bargein" : is_bargein ,
311+ "duration" : inference_duration ,
312+ },
313+ )
314+ return is_bargein
315+ except Exception as e :
316+ msg = await resp .text ()
317+ logger .error ("error during bargein prediction" , extra = {"response" : msg })
318+ raise APIError (f"error during bargein prediction: { e } " , body = msg ) from e
322319 return False
323320
324321
@@ -414,7 +411,7 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
414411 "num_channels" : 1 ,
415412 "threshold" : self ._model ._opts .threshold ,
416413 "min_frames" : self ._model ._opts .min_frames ,
417- "created_at" : perf_counter (),
414+ "created_at" : perf_counter_ns (),
418415 }
419416 await ws .send_str (json .dumps (msg ))
420417 accumulated_samples = 0
@@ -452,7 +449,7 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
452449 pass
453450 elif msg_type == "inference_done" :
454451 is_bargein_result = data .get ("is_bargein" , False )
455- inference_duration = time . perf_counter ( ) - created_at
452+ inference_duration = ( perf_counter_ns ( ) - created_at ) / 1e9
456453 logger .debug (
457454 "inference done" ,
458455 extra = {
0 commit comments