Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions livekit-agents/livekit/agents/inference/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import json
import os
import weakref
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, replace
from typing import Any, Literal, TypedDict, Union, overload

import aiohttp
from typing_extensions import Required

from livekit import rtc

Expand Down Expand Up @@ -69,6 +71,55 @@ class AssemblyaiOptions(TypedDict, total=False):
STTLanguages = Literal["multi", "en", "de", "es", "fr", "ja", "pt", "zh", "hi"]


class ConnectionOptions(TypedDict, total=False):
"""Connection options for fallback attempts."""

timeout: float
"""Connection timeout in seconds."""

retries: int
"""Number of retries per model."""


class FallbackModel(TypedDict, total=False):
"""A fallback model with optional extra configuration.

Extra fields are passed through to the provider.

Example:
>>> FallbackModel(name="deepgram/nova-3", extra_kwargs={"keywords": ["livekit"]})
"""

name: Required[str]
"""Model name (e.g. "deepgram/nova-3", "assemblyai/universal-streaming", "cartesia/ink-whisper")."""

extra_kwargs: dict[str, Any]
"""Extra configuration for the model."""


FallbackModelType = Union[FallbackModel, str]
class Fallback(TypedDict, total=False):
"""Configuration for fallback models when the primary model fails."""

models: Required[Sequence[FallbackModelType]]
"""Fallback models in priority order."""

connection: ConnectionOptions
"""Connection options for fallback attempts."""

def _normalize_fallback(fallback: FallbackType) -> Fallback:
options: ConnectionOptions = {}
models: Sequence[FallbackModelType]
if isinstance(fallback, Mapping):
models = fallback.get("models", ())
options = fallback.get("connection", options)
else:
models = fallback
models_list = [
FallbackModel(name=m) if isinstance(m, str) else m for m in models
]
return Fallback(models=models_list, connection=options)

STTModels = Union[
DeepgramModels,
CartesiaModels,
Expand All @@ -77,6 +128,8 @@ class AssemblyaiOptions(TypedDict, total=False):
]
STTEncoding = Literal["pcm_s16le"]

FallbackType = Union[Sequence[FallbackModelType], Fallback]

DEFAULT_ENCODING: STTEncoding = "pcm_s16le"
DEFAULT_SAMPLE_RATE: int = 16000
DEFAULT_BASE_URL = "https://agent-gateway.livekit.cloud/v1"
Expand All @@ -92,6 +145,7 @@ class STTOptions:
api_key: str
api_secret: str
extra_kwargs: dict[str, Any]
fallback: NotGivenOr[Fallback]


class STT(stt.STT):
Expand All @@ -108,6 +162,7 @@ def __init__(
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
extra_kwargs: NotGivenOr[CartesiaOptions] = NOT_GIVEN,
fallback: NotGivenOr[FallbackType] = NOT_GIVEN,
) -> None: ...

@overload
Expand All @@ -123,6 +178,7 @@ def __init__(
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
extra_kwargs: NotGivenOr[DeepgramOptions] = NOT_GIVEN,
fallback: NotGivenOr[FallbackType] = NOT_GIVEN,
) -> None: ...

@overload
Expand All @@ -138,6 +194,7 @@ def __init__(
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
extra_kwargs: NotGivenOr[AssemblyaiOptions] = NOT_GIVEN,
fallback: NotGivenOr[FallbackType] = NOT_GIVEN,
) -> None: ...

@overload
Expand All @@ -153,6 +210,7 @@ def __init__(
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
fallback: NotGivenOr[FallbackType] = NOT_GIVEN,
) -> None: ...

def __init__(
Expand All @@ -169,6 +227,7 @@ def __init__(
extra_kwargs: NotGivenOr[
dict[str, Any] | CartesiaOptions | DeepgramOptions | AssemblyaiOptions
] = NOT_GIVEN,
fallback: NotGivenOr[FallbackType] = NOT_GIVEN,
) -> None:
"""Livekit Cloud Inference STT

Expand All @@ -182,6 +241,8 @@ def __init__(
api_secret (str, optional): LIVEKIT_API_SECRET, if not provided, read from environment variable.
http_session (aiohttp.ClientSession, optional): HTTP session to use.
extra_kwargs (dict, optional): Extra kwargs to pass to the STT model.
fallback (FallbackType, optional): Fallback models - either a list of model names,
a list of FallbackModel instances or an Fallback object for full configuration.
"""
super().__init__(
capabilities=stt.STTCapabilities(streaming=True, interim_results=True),
Expand Down Expand Up @@ -213,6 +274,10 @@ def __init__(
"api_secret is required, either as argument or set LIVEKIT_API_SECRET environmental variable"
)

fallback_model: NotGivenOr[Fallback] = NOT_GIVEN
if is_given(fallback):
fallback_model = _normalize_fallback(fallback) # type: ignore[arg-type]

self._opts = STTOptions(
model=model,
language=language,
Expand All @@ -222,6 +287,7 @@ def __init__(
api_key=lk_api_key,
api_secret=lk_api_secret,
extra_kwargs=dict(extra_kwargs) if is_given(extra_kwargs) else {},
fallback=fallback_model,
)

self._session = http_session
Expand Down Expand Up @@ -459,6 +525,9 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
if self._opts.language:
params["settings"]["language"] = self._opts.language

if self._opts.fallback:
params["fallback"] = self._opts.fallback

base_url = self._opts.base_url
if base_url.startswith(("http://", "https://")):
base_url = base_url.replace("http", "ws", 1)
Expand Down
106 changes: 99 additions & 7 deletions livekit-agents/livekit/agents/inference/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import json
import os
import weakref
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, replace
from typing import Any, Literal, TypedDict, Union, overload

import aiohttp
from typing_extensions import Required

from .. import tokenize, tts, utils
from .._exceptions import APIConnectionError, APIError, APIStatusError, APITimeoutError
Expand Down Expand Up @@ -42,6 +44,87 @@
"inworld/inworld-tts-1",
]

TTSModels = Union[CartesiaModels, ElevenlabsModels, RimeModels, InworldModels]

def parse_model_string(model: str) -> tuple[str, str | None]:
"""Parse a model string into a model and voice
Args:
model (str): Model string to parse
Returns:
tuple[str, str | None]: Model and voice (voice is None if not specified)
"""
voice: str | None = None
if (idx := model.rfind(":")) != -1:
voice = model[idx + 1 :]
model = model[:idx]
return model, voice


class ConnectionOptions(TypedDict, total=False):
"""Connection options for fallback attempts."""

timeout: float
"""Connection timeout in seconds."""

retries: int
"""Number of retries per model."""


class FallbackModel(TypedDict, total=False):
"""A fallback model with optional extra configuration.

Extra fields are passed through to the provider.

Example:
>>> FallbackModel(name="cartesia/sonic", voice="")
"""

name: Required[str]
"""Model name (e.g. "cartesia/sonic", "elevenlabs/eleven_flash_v2", "rime/arcana")."""

voice: Required[str | None]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the voice be None? Is this different than omitting it (removing the Required flag)

"""Voice to use for the model."""

extra_kwargs: dict[str, Any]
"""Extra configuration for the model."""


FallbackModelType = Union[FallbackModel, str]


class Fallback(TypedDict, total=False):
"""Configuration for fallback models when the primary model fails."""

models: Required[Sequence[FallbackModelType]]
"""Fallback models in priority order."""

connection: ConnectionOptions
"""Connection options for fallback attempts."""
Comment on lines +95 to +102
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if this class is really useful.

Maybe the argument inside the TTS/STT can directly accept a Sequence of FallbackModelType and the ConnectionOptions can directly be inside the constructor?

Probably nitpicking but maybe we can also re-use

class APIConnectOptions:
?

I'm not 100% sure, wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, we already have conn_options as a parameter in stream or _recognize_impl, but it is often used with the default value somewhere behind the scenes (not exposed to the user). I think it is a good idea to expose this at the constructor level for all of them (we can start with inference in this PR) with the existing APIConnectOptions.



FallbackType = Union[Sequence[FallbackModelType], Fallback]


def _normalize_fallback(fallback: FallbackType) -> Fallback:
options: ConnectionOptions = {}
models: Sequence[FallbackModelType]
if isinstance(fallback, Mapping):
models = fallback.get("models", ())
options = fallback.get("connection", options)
else:
models = fallback

models_list: list[FallbackModel] = []
for m in models:
if isinstance(m, str):
model, voice = parse_model_string(m)
fm: FallbackModel = {"name": model, "voice": voice}
else:
fm = m
models_list.append(fm)

return Fallback(models=models_list, connection=options)


class CartesiaOptions(TypedDict, total=False):
duration: float # max duration of audio in seconds
Expand All @@ -61,8 +144,6 @@ class InworldOptions(TypedDict, total=False):
pass


TTSModels = Union[CartesiaModels, ElevenlabsModels, RimeModels, InworldModels]

TTSEncoding = Literal["pcm_s16le"]

DEFAULT_ENCODING: TTSEncoding = "pcm_s16le"
Expand All @@ -81,6 +162,7 @@ class _TTSOptions:
api_key: str
api_secret: str
extra_kwargs: dict[str, Any]
fallback: NotGivenOr[Fallback]


class TTS(tts.TTS):
Expand All @@ -97,6 +179,7 @@ def __init__(
api_key: NotGivenOr[str] = NOT_GIVEN,
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
fallback: NotGivenOr[FallbackType] = NOT_GIVEN,
extra_kwargs: NotGivenOr[CartesiaOptions] = NOT_GIVEN,
) -> None:
pass
Expand All @@ -114,6 +197,7 @@ def __init__(
api_key: NotGivenOr[str] = NOT_GIVEN,
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
fallback: NotGivenOr[FallbackType] = NOT_GIVEN,
extra_kwargs: NotGivenOr[ElevenlabsOptions] = NOT_GIVEN,
) -> None:
pass
Expand All @@ -131,6 +215,7 @@ def __init__(
api_key: NotGivenOr[str] = NOT_GIVEN,
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
fallback: NotGivenOr[FallbackType] = NOT_GIVEN,
extra_kwargs: NotGivenOr[RimeOptions] = NOT_GIVEN,
) -> None:
pass
Expand All @@ -148,6 +233,7 @@ def __init__(
api_key: NotGivenOr[str] = NOT_GIVEN,
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
fallback: NotGivenOr[FallbackType] = NOT_GIVEN,
extra_kwargs: NotGivenOr[InworldOptions] = NOT_GIVEN,
) -> None:
pass
Expand All @@ -165,6 +251,7 @@ def __init__(
api_key: NotGivenOr[str] = NOT_GIVEN,
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
fallback: NotGivenOr[FallbackType] = NOT_GIVEN,
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
) -> None:
pass
Expand All @@ -181,6 +268,7 @@ def __init__(
api_key: NotGivenOr[str] = NOT_GIVEN,
api_secret: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
fallback: NotGivenOr[FallbackType] = NOT_GIVEN,
extra_kwargs: NotGivenOr[
dict[str, Any] | CartesiaOptions | ElevenlabsOptions | RimeOptions | InworldOptions
] = NOT_GIVEN,
Expand Down Expand Up @@ -232,6 +320,10 @@ def __init__(
"api_secret is required, either as argument or set LIVEKIT_API_SECRET environmental variable"
)

fallback_model: NotGivenOr[Fallback] = NOT_GIVEN
if is_given(fallback):
fallback_model = _normalize_fallback(fallback) # type: ignore[arg-type]

self._opts = _TTSOptions(
model=model,
voice=voice,
Expand All @@ -242,6 +334,7 @@ def __init__(
api_key=lk_api_key,
api_secret=lk_api_secret,
extra_kwargs=dict(extra_kwargs) if is_given(extra_kwargs) else {},
fallback=fallback_model,
)
self._session = http_session
self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
Expand All @@ -262,11 +355,8 @@ def from_model_string(cls, model: str) -> TTS:
Returns:
TTS: TTS instance
"""
voice: NotGivenOr[str] = NOT_GIVEN
if (idx := model.rfind(":")) != -1:
voice = model[idx + 1 :]
model = model[:idx]
return cls(model, voice=voice)
model, voice = parse_model_string(model)
return cls(model=model, voice=voice if voice else NOT_GIVEN)

@property
def model(self) -> str:
Expand Down Expand Up @@ -308,6 +398,8 @@ async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse:
params["model"] = self._opts.model
if self._opts.language:
params["language"] = self._opts.language
if self._opts.fallback:
params["fallback"] = self._opts.fallback

try:
await ws.send_str(json.dumps(params))
Expand Down
Loading
Loading