diff --git a/livekit-agents/livekit/agents/inference/stt.py b/livekit-agents/livekit/agents/inference/stt.py index beed9dc177..5c39466cf0 100644 --- a/livekit-agents/livekit/agents/inference/stt.py +++ b/livekit-agents/livekit/agents/inference/stt.py @@ -9,6 +9,7 @@ from typing import Any, Literal, TypedDict, Union, overload import aiohttp +from typing_extensions import Required from livekit import rtc @@ -69,6 +70,48 @@ class AssemblyaiOptions(TypedDict, total=False): STTLanguages = Literal["multi", "en", "de", "es", "fr", "ja", "pt", "zh", "hi"] +class FallbackModel(TypedDict, total=False): + """A fallback model with optional extra configuration. + + Extra fields are passed through to the provider. + + Example: + >>> FallbackModel(name="deepgram/nova-3", extra_kwargs={"keywords": ["livekit"]}) + """ + + name: Required[str] + """Model name (e.g. "deepgram/nova-3", "assemblyai/universal-streaming", "cartesia/ink-whisper").""" + + extra_kwargs: dict[str, Any] + """Extra configuration for the model.""" + + +FallbackModelType = Union[FallbackModel, str] + + +def _parse_model_string(model: str) -> tuple[str, NotGivenOr[str]]: + language: NotGivenOr[str] = NOT_GIVEN + if (idx := model.rfind(":")) != -1: + language = model[idx + 1 :] + model = model[:idx] + return model, language + + +def _normalize_fallback( + fallback: list[FallbackModelType] | FallbackModelType, +) -> list[FallbackModel]: + def _make_fallback(model: FallbackModelType) -> FallbackModel: + if isinstance(model, str): + name, _ = _parse_model_string(model) + return FallbackModel(name=name) + return model + + if isinstance(fallback, list): + return [_make_fallback(m) for m in fallback] + + return [_make_fallback(fallback)] + + STTModels = Union[ DeepgramModels, CartesiaModels, @@ -77,6 +120,7 @@ class AssemblyaiOptions(TypedDict, total=False): ] STTEncoding = Literal["pcm_s16le"] + DEFAULT_ENCODING: STTEncoding = "pcm_s16le" DEFAULT_SAMPLE_RATE: int = 16000 DEFAULT_BASE_URL = "https://agent-gateway.livekit.cloud/v1" @@ -92,6 +136,8 @@ class STTOptions: api_key: str api_secret: str extra_kwargs: dict[str, Any] + fallback: NotGivenOr[list[FallbackModel]] + conn_options: NotGivenOr[APIConnectOptions] class STT(stt.STT): @@ -108,6 +154,8 @@ def __init__( api_secret: NotGivenOr[str] = NOT_GIVEN, http_session: aiohttp.ClientSession | None = None, extra_kwargs: NotGivenOr[CartesiaOptions] = NOT_GIVEN, + fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN, + conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN, ) -> None: ... @overload @@ -123,6 +171,8 @@ def __init__( api_secret: NotGivenOr[str] = NOT_GIVEN, http_session: aiohttp.ClientSession | None = None, extra_kwargs: NotGivenOr[DeepgramOptions] = NOT_GIVEN, + fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN, + conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN, ) -> None: ... @overload @@ -138,6 +188,8 @@ def __init__( api_secret: NotGivenOr[str] = NOT_GIVEN, http_session: aiohttp.ClientSession | None = None, extra_kwargs: NotGivenOr[AssemblyaiOptions] = NOT_GIVEN, + fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN, + conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN, ) -> None: ... @overload @@ -153,6 +205,8 @@ def __init__( api_secret: NotGivenOr[str] = NOT_GIVEN, http_session: aiohttp.ClientSession | None = None, extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN, + fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN, + conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN, ) -> None: ... def __init__( @@ -169,6 +223,8 @@ def __init__( extra_kwargs: NotGivenOr[ dict[str, Any] | CartesiaOptions | DeepgramOptions | AssemblyaiOptions ] = NOT_GIVEN, + fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN, + conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN, ) -> None: """Livekit Cloud Inference STT @@ -182,6 +238,9 @@ def __init__( api_secret (str, optional): LIVEKIT_API_SECRET, if not provided, read from environment variable. http_session (aiohttp.ClientSession, optional): HTTP session to use. extra_kwargs (dict, optional): Extra kwargs to pass to the STT model. + fallback (FallbackModelType, optional): Fallback models - either a list of model names, + a list of FallbackModel instances. + conn_options (APIConnectOptions, optional): Connection options for request attempts. """ super().__init__( capabilities=stt.STTCapabilities(streaming=True, interim_results=True), @@ -212,6 +271,9 @@ def __init__( raise ValueError( "api_secret is required, either as argument or set LIVEKIT_API_SECRET environmental variable" ) + fallback_models: NotGivenOr[list[FallbackModel]] = NOT_GIVEN + if is_given(fallback): + fallback_models = _normalize_fallback(fallback) # type: ignore[arg-type] self._opts = STTOptions( model=model, @@ -222,6 +284,8 @@ def __init__( api_key=lk_api_key, api_secret=lk_api_secret, extra_kwargs=dict(extra_kwargs) if is_given(extra_kwargs) else {}, + fallback=fallback_models, + conn_options=conn_options if is_given(conn_options) else DEFAULT_API_CONNECT_OPTIONS, ) self._session = http_session @@ -237,12 +301,8 @@ def from_model_string(cls, model: str) -> STT: Returns: STT: STT instance """ - - language: NotGivenOr[str] = NOT_GIVEN - if (idx := model.rfind(":")) != -1: - language = model[idx + 1 :] - model = model[:idx] - return cls(model, language=language) + model_name, language = _parse_model_string(model) + return cls(model=model_name, language=language) @property def model(self) -> str: @@ -459,6 +519,18 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: if self._opts.language: params["settings"]["language"] = self._opts.language + if self._opts.fallback: + models = [ + {"name": m.get("name"), "extra": m.get("extra_kwargs")} for m in self._opts.fallback + ] + params["fallback"] = {"models": models} + + if self._opts.conn_options: + params["connection"] = { + "timeout": self._opts.conn_options.timeout, + "retries": self._opts.conn_options.max_retry, + } + base_url = self._opts.base_url if base_url.startswith(("http://", "https://")): base_url = base_url.replace("http", "ws", 1) diff --git a/livekit-agents/livekit/agents/inference/tts.py b/livekit-agents/livekit/agents/inference/tts.py index c37de348ce..dbc64b5fe1 100644 --- a/livekit-agents/livekit/agents/inference/tts.py +++ b/livekit-agents/livekit/agents/inference/tts.py @@ -9,6 +9,7 @@ from typing import Any, Literal, TypedDict, Union, overload import aiohttp +from typing_extensions import NotRequired from .. import tokenize, tts, utils from .._exceptions import APIConnectionError, APIError, APIStatusError, APITimeoutError @@ -42,6 +43,59 @@ "inworld/inworld-tts-1", ] +TTSModels = Union[CartesiaModels, ElevenlabsModels, RimeModels, InworldModels] + + +def _parse_model_string(model: str) -> tuple[str, str | None]: + """Parse a model string into a model and voice + Args: + model (str): Model string to parse + Returns: + tuple[str, str | None]: Model and voice (voice is None if not specified) + """ + voice: str | None = None + if (idx := model.rfind(":")) != -1: + voice = model[idx + 1 :] + model = model[:idx] + return model, voice + + +class FallbackModel(TypedDict): + """A fallback model with optional extra configuration. + + Extra fields are passed through to the provider. + + Example: + >>> FallbackModel(name="cartesia/sonic", voice="") + """ + + name: str + """Model name (e.g. "cartesia/sonic", "elevenlabs/eleven_flash_v2", "rime/arcana").""" + + voice: str + """Voice to use for the model.""" + + extra_kwargs: NotRequired[dict[str, Any]] + """Extra configuration for the model.""" + + +FallbackModelType = Union[FallbackModel, str] + + +def _normalize_fallback( + fallback: list[FallbackModelType] | FallbackModelType, +) -> list[FallbackModel]: + def _make_fallback(model: FallbackModelType) -> FallbackModel: + if isinstance(model, str): + name, voice = _parse_model_string(model) + return FallbackModel(name=name, voice=voice if voice else "") + return model + + if isinstance(fallback, list): + return [_make_fallback(m) for m in fallback] + + return [_make_fallback(fallback)] + class CartesiaOptions(TypedDict, total=False): duration: float # max duration of audio in seconds @@ -61,8 +115,6 @@ class InworldOptions(TypedDict, total=False): pass -TTSModels = Union[CartesiaModels, ElevenlabsModels, RimeModels, InworldModels] - TTSEncoding = Literal["pcm_s16le"] DEFAULT_ENCODING: TTSEncoding = "pcm_s16le" @@ -81,6 +133,8 @@ class _TTSOptions: api_key: str api_secret: str extra_kwargs: dict[str, Any] + fallback: NotGivenOr[list[FallbackModel]] + conn_options: NotGivenOr[APIConnectOptions] class TTS(tts.TTS): @@ -98,6 +152,8 @@ def __init__( api_secret: NotGivenOr[str] = NOT_GIVEN, http_session: aiohttp.ClientSession | None = None, extra_kwargs: NotGivenOr[CartesiaOptions] = NOT_GIVEN, + fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN, + conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN, ) -> None: pass @@ -115,6 +171,8 @@ def __init__( api_secret: NotGivenOr[str] = NOT_GIVEN, http_session: aiohttp.ClientSession | None = None, extra_kwargs: NotGivenOr[ElevenlabsOptions] = NOT_GIVEN, + fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN, + conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN, ) -> None: pass @@ -132,6 +190,8 @@ def __init__( api_secret: NotGivenOr[str] = NOT_GIVEN, http_session: aiohttp.ClientSession | None = None, extra_kwargs: NotGivenOr[RimeOptions] = NOT_GIVEN, + fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN, + conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN, ) -> None: pass @@ -149,6 +209,8 @@ def __init__( api_secret: NotGivenOr[str] = NOT_GIVEN, http_session: aiohttp.ClientSession | None = None, extra_kwargs: NotGivenOr[InworldOptions] = NOT_GIVEN, + fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN, + conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN, ) -> None: pass @@ -166,6 +228,8 @@ def __init__( api_secret: NotGivenOr[str] = NOT_GIVEN, http_session: aiohttp.ClientSession | None = None, extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN, + fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN, + conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN, ) -> None: pass @@ -184,6 +248,8 @@ def __init__( extra_kwargs: NotGivenOr[ dict[str, Any] | CartesiaOptions | ElevenlabsOptions | RimeOptions | InworldOptions ] = NOT_GIVEN, + fallback: NotGivenOr[list[FallbackModelType] | FallbackModelType] = NOT_GIVEN, + conn_options: NotGivenOr[APIConnectOptions] = NOT_GIVEN, ) -> None: """Livekit Cloud Inference TTS @@ -198,6 +264,9 @@ def __init__( api_secret (str, optional): LIVEKIT_API_SECRET, if not provided, read from environment variable. http_session (aiohttp.ClientSession, optional): HTTP session to use. extra_kwargs (dict, optional): Extra kwargs to pass to the TTS model. + fallback (FallbackModelType, optional): Fallback models - either a list of model names, + a list of FallbackModel instances. + conn_options (APIConnectOptions, optional): Connection options for request attempts. """ sample_rate = sample_rate if is_given(sample_rate) else DEFAULT_SAMPLE_RATE super().__init__( @@ -232,6 +301,10 @@ def __init__( "api_secret is required, either as argument or set LIVEKIT_API_SECRET environmental variable" ) + fallback_models: NotGivenOr[list[FallbackModel]] = NOT_GIVEN + if is_given(fallback): + fallback_models = _normalize_fallback(fallback) # type: ignore[arg-type] + self._opts = _TTSOptions( model=model, voice=voice, @@ -242,6 +315,8 @@ def __init__( api_key=lk_api_key, api_secret=lk_api_secret, extra_kwargs=dict(extra_kwargs) if is_given(extra_kwargs) else {}, + fallback=fallback_models, + conn_options=conn_options if is_given(conn_options) else DEFAULT_API_CONNECT_OPTIONS, ) self._session = http_session self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse]( @@ -262,11 +337,8 @@ def from_model_string(cls, model: str) -> TTS: Returns: TTS: TTS instance """ - voice: NotGivenOr[str] = NOT_GIVEN - if (idx := model.rfind(":")) != -1: - voice = model[idx + 1 :] - model = model[:idx] - return cls(model, voice=voice) + model, voice = _parse_model_string(model) + return cls(model=model, voice=voice if voice else NOT_GIVEN) @property def model(self) -> str: @@ -295,7 +367,7 @@ async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse: raise APIStatusError("LiveKit TTS quota exceeded", status_code=e.status) from e raise APIConnectionError("failed to connect to LiveKit TTS") from e - params = { + params: dict[str, Any] = { "type": "session.create", "sample_rate": str(self._opts.sample_rate), "encoding": self._opts.encoding, @@ -308,6 +380,18 @@ async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse: params["model"] = self._opts.model if self._opts.language: params["language"] = self._opts.language + if self._opts.fallback: + models = [ + {"name": m.get("name"), "voice": m.get("voice"), "extra": m.get("extra_kwargs", {})} + for m in self._opts.fallback + ] + params["fallback"] = {"models": models} + + if self._opts.conn_options: + params["connection"] = { + "timeout": self._opts.conn_options.timeout, + "retries": self._opts.conn_options.max_retry, + } try: await ws.send_str(json.dumps(params)) diff --git a/tests/test_inference_stt_fallback.py b/tests/test_inference_stt_fallback.py new file mode 100644 index 0000000000..65e29f6ef5 --- /dev/null +++ b/tests/test_inference_stt_fallback.py @@ -0,0 +1,259 @@ +import pytest + +from livekit.agents.inference.stt import ( + STT, + FallbackModel, + _normalize_fallback, + _parse_model_string, +) +from livekit.agents.types import ( + DEFAULT_API_CONNECT_OPTIONS, + NOT_GIVEN, + APIConnectOptions, +) + + +def _make_stt(**kwargs): + """Helper to create STT with required credentials.""" + defaults = { + "model": "deepgram", + "api_key": "test-key", + "api_secret": "test-secret", + "base_url": "https://example.livekit.cloud", + } + defaults.update(kwargs) + return STT(**defaults) + + +class TestParseModelString: + def test_simple_model_without_language(self): + """Model string without language suffix returns NOT_GIVEN for language.""" + model, language = _parse_model_string("deepgram") + assert model == "deepgram" + assert language is NOT_GIVEN + + def test_model_with_language_suffix(self): + """Model string with :language suffix extracts the language.""" + model, language = _parse_model_string("deepgram:en") + assert model == "deepgram" + assert language == "en" + + def test_provider_model_format_without_language(self): + """Provider/model format without language suffix.""" + model, language = _parse_model_string("deepgram/nova-3") + assert model == "deepgram/nova-3" + assert language is NOT_GIVEN + + def test_provider_model_format_with_language(self): + """Provider/model format with language suffix.""" + model, language = _parse_model_string("deepgram/nova-3:en") + assert model == "deepgram/nova-3" + assert language == "en" + + @pytest.mark.parametrize( + "model_str,expected_model,expected_lang", + [ + ("cartesia/ink-whisper:de", "cartesia/ink-whisper", "de"), + ("assemblyai:es", "assemblyai", "es"), + ("deepgram/nova-2-medical:ja", "deepgram/nova-2-medical", "ja"), + ("deepgram/nova-3:multi", "deepgram/nova-3", "multi"), + ("cartesia:zh", "cartesia", "zh"), + ], + ) + def test_various_providers_and_languages(self, model_str, expected_model, expected_lang): + """Test various provider/model combinations with different languages.""" + model, language = _parse_model_string(model_str) + assert model == expected_model + assert language == expected_lang + + def test_auto_model(self): + """Auto model without language.""" + model, language = _parse_model_string("auto") + assert model == "auto" + assert language is NOT_GIVEN + + def test_auto_model_with_language(self): + """Auto model with language suffix.""" + model, language = _parse_model_string("auto:pt") + assert model == "auto" + assert language == "pt" + + +class TestNormalizeFallback: + def test_single_string_model(self): + """Single string model becomes a list with one FallbackModel.""" + result = _normalize_fallback("deepgram/nova-3") + assert result == [{"name": "deepgram/nova-3"}] + + def test_single_fallback_model_dict(self): + """Single FallbackModel dict becomes a list with that dict.""" + fallback = FallbackModel(name="deepgram/nova-3") + result = _normalize_fallback(fallback) + assert result == [{"name": "deepgram/nova-3"}] + + def test_list_of_string_models(self): + """List of string models becomes list of FallbackModels.""" + result = _normalize_fallback(["deepgram/nova-3", "cartesia/ink-whisper"]) + assert result == [ + {"name": "deepgram/nova-3"}, + {"name": "cartesia/ink-whisper"}, + ] + + def test_list_of_fallback_model_dicts(self): + """List of FallbackModel dicts is preserved.""" + fallbacks = [ + FallbackModel(name="deepgram/nova-3"), + FallbackModel(name="assemblyai"), + ] + result = _normalize_fallback(fallbacks) + assert result == [ + {"name": "deepgram/nova-3"}, + {"name": "assemblyai"}, + ] + + def test_mixed_list_strings_and_dicts(self): + """Mixed list of strings and FallbackModel dicts.""" + fallbacks = [ + "deepgram/nova-3", + FallbackModel(name="cartesia/ink-whisper"), + "assemblyai", + ] + result = _normalize_fallback(fallbacks) + assert result == [ + {"name": "deepgram/nova-3"}, + {"name": "cartesia/ink-whisper"}, + {"name": "assemblyai"}, + ] + + def test_string_with_language_suffix_discards_language(self): + """Language suffix in string model is discarded.""" + result = _normalize_fallback("deepgram/nova-3:en") + assert result == [{"name": "deepgram/nova-3"}] + + def test_fallback_model_with_extra_kwargs(self): + """FallbackModel with extra_kwargs is preserved.""" + fallback = FallbackModel( + name="deepgram/nova-3", + extra_kwargs={"keywords": [("livekit", 1.5)], "punctuate": True}, + ) + result = _normalize_fallback(fallback) + assert result == [ + { + "name": "deepgram/nova-3", + "extra_kwargs": {"keywords": [("livekit", 1.5)], "punctuate": True}, + } + ] + + def test_list_with_extra_kwargs_preserved(self): + """List with FallbackModels containing extra_kwargs.""" + fallbacks = [ + FallbackModel(name="deepgram/nova-3", extra_kwargs={"punctuate": True}), + "cartesia/ink-whisper", + FallbackModel(name="assemblyai", extra_kwargs={"format_turns": True}), + ] + result = _normalize_fallback(fallbacks) + assert result == [ + {"name": "deepgram/nova-3", "extra_kwargs": {"punctuate": True}}, + {"name": "cartesia/ink-whisper"}, + {"name": "assemblyai", "extra_kwargs": {"format_turns": True}}, + ] + + def test_empty_list(self): + """Empty list returns empty list.""" + result = _normalize_fallback([]) + assert result == [] + + def test_multiple_colons_in_model_string(self): + """Multiple colons in model string - splits on last, discards language.""" + result = _normalize_fallback("some:model:part:fr") + assert result == [{"name": "some:model:part"}] + + +class TestSTTConstructorFallbackAndConnectOptions: + """Tests for STT constructor focusing on fallback and connect_options args.""" + + def test_fallback_not_given(self): + """When fallback is not provided, _opts.fallback is NOT_GIVEN.""" + stt = _make_stt() + assert stt._opts.fallback is NOT_GIVEN + + def test_fallback_single_string(self): + """Single string fallback is normalized to list of FallbackModel.""" + stt = _make_stt(fallback="cartesia/ink-whisper") + assert stt._opts.fallback == [{"name": "cartesia/ink-whisper"}] + + def test_fallback_list_of_strings(self): + """List of string fallbacks is normalized.""" + stt = _make_stt(fallback=["deepgram/nova-3", "assemblyai"]) + assert stt._opts.fallback == [ + {"name": "deepgram/nova-3"}, + {"name": "assemblyai"}, + ] + + def test_fallback_single_fallback_model(self): + """Single FallbackModel is normalized to list.""" + stt = _make_stt(fallback=FallbackModel(name="deepgram/nova-3")) + assert stt._opts.fallback == [{"name": "deepgram/nova-3"}] + + def test_fallback_with_extra_kwargs(self): + """FallbackModel with extra_kwargs is preserved in _opts.""" + stt = _make_stt( + fallback=FallbackModel( + name="deepgram/nova-3", + extra_kwargs={"punctuate": True, "keywords": [("livekit", 1.5)]}, + ) + ) + assert stt._opts.fallback == [ + { + "name": "deepgram/nova-3", + "extra_kwargs": {"punctuate": True, "keywords": [("livekit", 1.5)]}, + } + ] + + def test_fallback_mixed_list(self): + """Mixed list of strings and FallbackModels is normalized.""" + stt = _make_stt( + fallback=[ + "deepgram/nova-3", + FallbackModel(name="cartesia", extra_kwargs={"min_volume": 0.5}), + "assemblyai", + ] + ) + assert stt._opts.fallback == [ + {"name": "deepgram/nova-3"}, + {"name": "cartesia", "extra_kwargs": {"min_volume": 0.5}}, + {"name": "assemblyai"}, + ] + + def test_fallback_string_with_language_discarded(self): + """Language suffix in fallback string is discarded.""" + stt = _make_stt(fallback="deepgram/nova-3:en") + assert stt._opts.fallback == [{"name": "deepgram/nova-3"}] + + def test_connect_options_not_given_uses_default(self): + """When connect_options is not provided, uses DEFAULT_API_CONNECT_OPTIONS.""" + stt = _make_stt() + assert stt._opts.conn_options == DEFAULT_API_CONNECT_OPTIONS + + def test_connect_options_custom_timeout(self): + """Custom connect_options with timeout is stored.""" + custom_opts = APIConnectOptions(timeout=30.0) + stt = _make_stt(conn_options=custom_opts) + assert stt._opts.conn_options == custom_opts + assert stt._opts.conn_options.timeout == 30.0 + + def test_connect_options_custom_max_retry(self): + """Custom conn_options with max_retry is stored.""" + custom_opts = APIConnectOptions(max_retry=5) + stt = _make_stt(conn_options=custom_opts) + assert stt._opts.conn_options == custom_opts + assert stt._opts.conn_options.max_retry == 5 + + def test_connect_options_full_custom(self): + """Fully custom connect_options is stored correctly.""" + custom_opts = APIConnectOptions(timeout=60.0, max_retry=10, retry_interval=2.0) + stt = _make_stt(conn_options=custom_opts) + assert stt._opts.conn_options == custom_opts + assert stt._opts.conn_options.timeout == 60.0 + assert stt._opts.conn_options.max_retry == 10 + assert stt._opts.conn_options.retry_interval == 2.0 diff --git a/tests/test_inference_tts_fallback.py b/tests/test_inference_tts_fallback.py new file mode 100644 index 0000000000..8623d7b9b3 --- /dev/null +++ b/tests/test_inference_tts_fallback.py @@ -0,0 +1,169 @@ +import pytest + +from livekit.agents.inference.tts import ( + TTS, + FallbackModel, + _normalize_fallback, + _parse_model_string, +) + + +def _make_tts(**kwargs): + """Helper to create TTS with required credentials.""" + defaults = { + "model": "cartesia/sonic", + "api_key": "test-key", + "api_secret": "test-secret", + "base_url": "https://example.livekit.cloud", + } + defaults.update(kwargs) + return TTS(**defaults) + + +class TestParseModelString: + def test_simple_model_without_voice(self): + """Model string without voice suffix returns None for voice.""" + model, voice = _parse_model_string("cartesia") + assert model == "cartesia" + assert voice is None + + def test_model_with_voice_suffix(self): + """Model string with :voice suffix extracts the voice.""" + model, voice = _parse_model_string("cartesia:my-voice-id") + assert model == "cartesia" + assert voice == "my-voice-id" + + def test_provider_model_format_without_voice(self): + """Provider/model format without voice suffix.""" + model, voice = _parse_model_string("cartesia/sonic") + assert model == "cartesia/sonic" + assert voice is None + + def test_provider_model_format_with_voice(self): + """Provider/model format with voice suffix.""" + model, voice = _parse_model_string("cartesia/sonic:my-voice-id") + assert model == "cartesia/sonic" + assert voice == "my-voice-id" + + @pytest.mark.parametrize( + "model_str,expected_model,expected_voice", + [ + ("elevenlabs/eleven_flash_v2:voice123", "elevenlabs/eleven_flash_v2", "voice123"), + ("rime:speaker-a", "rime", "speaker-a"), + ("rime/mist:narrator", "rime/mist", "narrator"), + ("inworld/inworld-tts-1:character", "inworld/inworld-tts-1", "character"), + ("cartesia/sonic-turbo:deep-voice", "cartesia/sonic-turbo", "deep-voice"), + ], + ) + def test_various_providers_and_voices(self, model_str, expected_model, expected_voice): + """Test various provider/model combinations with different voices.""" + model, voice = _parse_model_string(model_str) + assert model == expected_model + assert voice == expected_voice + + def test_empty_voice_after_colon(self): + """Empty string after colon still counts as voice.""" + model, voice = _parse_model_string("cartesia/sonic:") + assert model == "cartesia/sonic" + assert voice == "" + + +class TestNormalizeFallback: + def test_single_string_model(self): + """Single string model becomes a list with one FallbackModel.""" + result = _normalize_fallback("cartesia/sonic") + assert result == [{"name": "cartesia/sonic", "voice": ""}] + + def test_single_string_model_with_voice(self): + """Single string model with voice suffix extracts voice.""" + result = _normalize_fallback("cartesia/sonic:my-voice") + assert result == [{"name": "cartesia/sonic", "voice": "my-voice"}] + + def test_single_fallback_model_dict(self): + """Single FallbackModel dict becomes a list with that dict.""" + fallback = FallbackModel(name="cartesia/sonic", voice="narrator") + result = _normalize_fallback(fallback) + assert result == [{"name": "cartesia/sonic", "voice": "narrator"}] + + def test_list_of_string_models(self): + """List of string models becomes list of FallbackModels.""" + result = _normalize_fallback(["cartesia/sonic", "elevenlabs/eleven_flash_v2"]) + assert result == [ + {"name": "cartesia/sonic", "voice": ""}, + {"name": "elevenlabs/eleven_flash_v2", "voice": ""}, + ] + + def test_list_of_string_models_with_voices(self): + """List of string models with voice suffixes.""" + result = _normalize_fallback(["cartesia/sonic:voice1", "elevenlabs:voice2"]) + assert result == [ + {"name": "cartesia/sonic", "voice": "voice1"}, + {"name": "elevenlabs", "voice": "voice2"}, + ] + + def test_list_of_fallback_model_dicts(self): + """List of FallbackModel dicts is preserved.""" + fallbacks = [ + FallbackModel(name="cartesia/sonic", voice="narrator"), + FallbackModel(name="elevenlabs", voice=""), + ] + result = _normalize_fallback(fallbacks) + assert result == [ + {"name": "cartesia/sonic", "voice": "narrator"}, + {"name": "elevenlabs", "voice": ""}, + ] + + def test_mixed_list_strings_and_dicts(self): + """Mixed list of strings and FallbackModel dicts.""" + fallbacks = [ + "cartesia/sonic:voice1", + FallbackModel(name="elevenlabs/eleven_flash_v2", voice="custom"), + "rime/mist", + ] + result = _normalize_fallback(fallbacks) + assert result == [ + {"name": "cartesia/sonic", "voice": "voice1"}, + {"name": "elevenlabs/eleven_flash_v2", "voice": "custom"}, + {"name": "rime/mist", "voice": ""}, + ] + + def test_fallback_model_with_extra_kwargs(self): + """FallbackModel with extra_kwargs is preserved.""" + fallback = FallbackModel( + name="cartesia/sonic", + voice="narrator", + extra_kwargs={"duration": 30.0, "speed": "fast"}, + ) + result = _normalize_fallback(fallback) + assert result == [ + { + "name": "cartesia/sonic", + "voice": "narrator", + "extra_kwargs": {"duration": 30.0, "speed": "fast"}, + } + ] + + def test_list_with_extra_kwargs_preserved(self): + """List with FallbackModels containing extra_kwargs.""" + fallbacks = [ + FallbackModel(name="cartesia/sonic", voice="v1", extra_kwargs={"speed": "slow"}), + "elevenlabs:voice2", + FallbackModel(name="rime/mist", voice="", extra_kwargs={"custom": True}), + ] + result = _normalize_fallback(fallbacks) + assert result == [ + {"name": "cartesia/sonic", "voice": "v1", "extra_kwargs": {"speed": "slow"}}, + {"name": "elevenlabs", "voice": "voice2"}, + {"name": "rime/mist", "voice": "", "extra_kwargs": {"custom": True}}, + ] + + def test_empty_list(self): + """Empty list returns empty list.""" + result = _normalize_fallback([]) + assert result == [] + + def test_fallback_model_with_none_voice(self): + """FallbackModel with explicit None voice.""" + fallback = FallbackModel(name="cartesia/sonic", voice="") + result = _normalize_fallback(fallback) + assert result == [{"name": "cartesia/sonic", "voice": ""}]