Skip to content

Commit a785e42

Browse files
committed
update created_at to ns
1 parent a9ea7d9 commit a785e42

File tree

2 files changed

+23
-26
lines changed

2 files changed

+23
-26
lines changed

examples/voice_agents/basic_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ async def entrypoint(ctx: JobContext):
9494
# See more at https://docs.livekit.io/agents/build/turns
9595
turn_detection=MultilingualModel(),
9696
vad=ctx.proc.userdata["vad"],
97-
bargein_detector=BargeinDetector(use_proxy=True),
97+
bargein_detector=BargeinDetector(),
9898
# allow the LLM to generate a response while waiting for the end of turn
9999
# See more at https://docs.livekit.io/agents/build/audio/#preemptive-generation
100100
preemptive_generation=True,

livekit-agents/livekit/agents/inference/bargein.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import time
99
import weakref
1010
from dataclasses import dataclass
11-
from time import perf_counter
11+
from time import perf_counter_ns
1212
from typing import Any, Union
1313

1414
import aiohttp
@@ -24,7 +24,6 @@
2424
APIError,
2525
APIStatusError,
2626
NotGivenOr,
27-
get_job_context,
2827
utils,
2928
)
3029
from ..bargein import (
@@ -284,20 +283,13 @@ async def _send_task() -> None:
284283
data_chan.close()
285284

286285
async def predict(self, waveform: np.ndarray) -> bool:
287-
ctx = get_job_context()
288-
created_at = perf_counter()
286+
created_at = perf_counter_ns()
289287
request = {
290-
"jobId": ctx.job.id,
291-
"workerId": ctx.worker_id,
292288
"waveform": self._model.encode_waveform(waveform),
293289
"threshold": self._opts.threshold,
294290
"min_frames": self._opts.min_frames,
295291
"created_at": created_at,
296292
}
297-
agent_id = os.getenv("LIVEKIT_AGENT_ID")
298-
if agent_id:
299-
request["agentId"] = agent_id
300-
301293
async with utils.http_context.http_session().post(
302294
url=f"{self._opts.base_url}/bargein",
303295
headers={
@@ -306,19 +298,24 @@ async def predict(self, waveform: np.ndarray) -> bool:
306298
json=request,
307299
timeout=aiohttp.ClientTimeout(total=self._opts.inference_timeout),
308300
) as resp:
309-
resp.raise_for_status()
310-
data = await resp.json()
311-
is_bargein: bool | None = data.get("is_bargein")
312-
inference_duration = time.perf_counter() - created_at
313-
if isinstance(is_bargein, bool):
314-
logger.debug(
315-
"bargein prediction",
316-
extra={
317-
"is_bargein": is_bargein,
318-
"duration": inference_duration,
319-
},
320-
)
321-
return is_bargein
301+
try:
302+
resp.raise_for_status()
303+
data = await resp.json()
304+
is_bargein: bool | None = data.get("is_bargein")
305+
inference_duration = (perf_counter_ns() - created_at) / 1e9
306+
if isinstance(is_bargein, bool):
307+
logger.debug(
308+
"bargein prediction",
309+
extra={
310+
"is_bargein": is_bargein,
311+
"duration": inference_duration,
312+
},
313+
)
314+
return is_bargein
315+
except Exception as e:
316+
msg = await resp.text()
317+
logger.error("error during bargein prediction", extra={"response": msg})
318+
raise APIError(f"error during bargein prediction: {e}", body=msg) from e
322319
return False
323320

324321

@@ -414,7 +411,7 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
414411
"num_channels": 1,
415412
"threshold": self._model._opts.threshold,
416413
"min_frames": self._model._opts.min_frames,
417-
"created_at": perf_counter(),
414+
"created_at": perf_counter_ns(),
418415
}
419416
await ws.send_str(json.dumps(msg))
420417
accumulated_samples = 0
@@ -452,7 +449,7 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
452449
pass
453450
elif msg_type == "inference_done":
454451
is_bargein_result = data.get("is_bargein", False)
455-
inference_duration = time.perf_counter() - created_at
452+
inference_duration = (perf_counter_ns() - created_at) / 1e9
456453
logger.debug(
457454
"inference done",
458455
extra={

0 commit comments

Comments
 (0)