-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Fallback API for Inference #4099
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f7d368b
6332115
e6099e6
983a592
5cd6166
c9dde41
5ca9661
8e29769
9bc29f0
adb0c19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -5,10 +5,12 @@ | |||
| import json | ||||
| import os | ||||
| import weakref | ||||
| from collections.abc import Mapping, Sequence | ||||
| from dataclasses import dataclass, replace | ||||
| from typing import Any, Literal, TypedDict, Union, overload | ||||
|
|
||||
| import aiohttp | ||||
| from typing_extensions import Required | ||||
|
|
||||
| from .. import tokenize, tts, utils | ||||
| from .._exceptions import APIConnectionError, APIError, APIStatusError, APITimeoutError | ||||
|
|
@@ -42,6 +44,87 @@ | |||
| "inworld/inworld-tts-1", | ||||
| ] | ||||
|
|
||||
| TTSModels = Union[CartesiaModels, ElevenlabsModels, RimeModels, InworldModels] | ||||
|
|
||||
| def parse_model_string(model: str) -> tuple[str, str | None]: | ||||
| """Parse a model string into a model and voice | ||||
| Args: | ||||
| model (str): Model string to parse | ||||
| Returns: | ||||
| tuple[str, str | None]: Model and voice (voice is None if not specified) | ||||
| """ | ||||
| voice: str | None = None | ||||
| if (idx := model.rfind(":")) != -1: | ||||
| voice = model[idx + 1 :] | ||||
| model = model[:idx] | ||||
| return model, voice | ||||
|
|
||||
|
|
||||
| class ConnectionOptions(TypedDict, total=False): | ||||
| """Connection options for fallback attempts.""" | ||||
|
|
||||
| timeout: float | ||||
| """Connection timeout in seconds.""" | ||||
|
|
||||
| retries: int | ||||
| """Number of retries per model.""" | ||||
|
|
||||
|
|
||||
| class FallbackModel(TypedDict, total=False): | ||||
| """A fallback model with optional extra configuration. | ||||
|
|
||||
| Extra fields are passed through to the provider. | ||||
|
|
||||
| Example: | ||||
| >>> FallbackModel(name="cartesia/sonic", voice="") | ||||
| """ | ||||
|
|
||||
| name: Required[str] | ||||
| """Model name (e.g. "cartesia/sonic", "elevenlabs/eleven_flash_v2", "rime/arcana").""" | ||||
|
|
||||
| voice: Required[str | None] | ||||
| """Voice to use for the model.""" | ||||
|
|
||||
| extra_kwargs: dict[str, Any] | ||||
| """Extra configuration for the model.""" | ||||
|
|
||||
|
|
||||
| FallbackModelType = Union[FallbackModel, str] | ||||
|
|
||||
|
|
||||
| class Fallback(TypedDict, total=False): | ||||
| """Configuration for fallback models when the primary model fails.""" | ||||
|
|
||||
| models: Required[Sequence[FallbackModelType]] | ||||
| """Fallback models in priority order.""" | ||||
|
|
||||
| connection: ConnectionOptions | ||||
| """Connection options for fallback attempts.""" | ||||
|
Comment on lines
+95
to
+102
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering if this class is really useful. Maybe the argument inside the TTS/STT can directly accept a Sequence of FallbackModelType and the ConnectionOptions can directly be inside the constructor? Probably nitpicking but maybe we can also re-use
I'm not 100% sure, wdyt?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now, we already have |
||||
|
|
||||
|
|
||||
| FallbackType = Union[Sequence[FallbackModelType], Fallback] | ||||
|
|
||||
|
|
||||
| def _normalize_fallback(fallback: FallbackType) -> Fallback: | ||||
| options: ConnectionOptions = {} | ||||
| models: Sequence[FallbackModelType] | ||||
| if isinstance(fallback, Mapping): | ||||
| models = fallback.get("models", ()) | ||||
| options = fallback.get("connection", options) | ||||
| else: | ||||
| models = fallback | ||||
|
|
||||
| models_list: list[FallbackModel] = [] | ||||
| for m in models: | ||||
| if isinstance(m, str): | ||||
| model, voice = parse_model_string(m) | ||||
| fm: FallbackModel = {"name": model, "voice": voice} | ||||
| else: | ||||
| fm = m | ||||
| models_list.append(fm) | ||||
|
|
||||
| return Fallback(models=models_list, connection=options) | ||||
|
|
||||
|
|
||||
| class CartesiaOptions(TypedDict, total=False): | ||||
| duration: float # max duration of audio in seconds | ||||
|
|
@@ -61,8 +144,6 @@ class InworldOptions(TypedDict, total=False): | |||
| pass | ||||
|
|
||||
|
|
||||
| TTSModels = Union[CartesiaModels, ElevenlabsModels, RimeModels, InworldModels] | ||||
|
|
||||
| TTSEncoding = Literal["pcm_s16le"] | ||||
|
|
||||
| DEFAULT_ENCODING: TTSEncoding = "pcm_s16le" | ||||
|
|
@@ -81,6 +162,7 @@ class _TTSOptions: | |||
| api_key: str | ||||
| api_secret: str | ||||
| extra_kwargs: dict[str, Any] | ||||
| fallback: NotGivenOr[Fallback] | ||||
|
|
||||
|
|
||||
| class TTS(tts.TTS): | ||||
|
|
@@ -97,6 +179,7 @@ def __init__( | |||
| api_key: NotGivenOr[str] = NOT_GIVEN, | ||||
| api_secret: NotGivenOr[str] = NOT_GIVEN, | ||||
| http_session: aiohttp.ClientSession | None = None, | ||||
| fallback: NotGivenOr[FallbackType] = NOT_GIVEN, | ||||
| extra_kwargs: NotGivenOr[CartesiaOptions] = NOT_GIVEN, | ||||
| ) -> None: | ||||
| pass | ||||
|
|
@@ -114,6 +197,7 @@ def __init__( | |||
| api_key: NotGivenOr[str] = NOT_GIVEN, | ||||
| api_secret: NotGivenOr[str] = NOT_GIVEN, | ||||
| http_session: aiohttp.ClientSession | None = None, | ||||
| fallback: NotGivenOr[FallbackType] = NOT_GIVEN, | ||||
| extra_kwargs: NotGivenOr[ElevenlabsOptions] = NOT_GIVEN, | ||||
| ) -> None: | ||||
| pass | ||||
|
|
@@ -131,6 +215,7 @@ def __init__( | |||
| api_key: NotGivenOr[str] = NOT_GIVEN, | ||||
| api_secret: NotGivenOr[str] = NOT_GIVEN, | ||||
| http_session: aiohttp.ClientSession | None = None, | ||||
| fallback: NotGivenOr[FallbackType] = NOT_GIVEN, | ||||
| extra_kwargs: NotGivenOr[RimeOptions] = NOT_GIVEN, | ||||
| ) -> None: | ||||
| pass | ||||
|
|
@@ -148,6 +233,7 @@ def __init__( | |||
| api_key: NotGivenOr[str] = NOT_GIVEN, | ||||
| api_secret: NotGivenOr[str] = NOT_GIVEN, | ||||
| http_session: aiohttp.ClientSession | None = None, | ||||
| fallback: NotGivenOr[FallbackType] = NOT_GIVEN, | ||||
| extra_kwargs: NotGivenOr[InworldOptions] = NOT_GIVEN, | ||||
| ) -> None: | ||||
| pass | ||||
|
|
@@ -165,6 +251,7 @@ def __init__( | |||
| api_key: NotGivenOr[str] = NOT_GIVEN, | ||||
| api_secret: NotGivenOr[str] = NOT_GIVEN, | ||||
| http_session: aiohttp.ClientSession | None = None, | ||||
| fallback: NotGivenOr[FallbackType] = NOT_GIVEN, | ||||
| extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN, | ||||
| ) -> None: | ||||
| pass | ||||
|
|
@@ -181,6 +268,7 @@ def __init__( | |||
| api_key: NotGivenOr[str] = NOT_GIVEN, | ||||
| api_secret: NotGivenOr[str] = NOT_GIVEN, | ||||
| http_session: aiohttp.ClientSession | None = None, | ||||
| fallback: NotGivenOr[FallbackType] = NOT_GIVEN, | ||||
| extra_kwargs: NotGivenOr[ | ||||
| dict[str, Any] | CartesiaOptions | ElevenlabsOptions | RimeOptions | InworldOptions | ||||
| ] = NOT_GIVEN, | ||||
|
|
@@ -232,6 +320,10 @@ def __init__( | |||
| "api_secret is required, either as argument or set LIVEKIT_API_SECRET environmental variable" | ||||
| ) | ||||
|
|
||||
| fallback_model: NotGivenOr[Fallback] = NOT_GIVEN | ||||
| if is_given(fallback): | ||||
| fallback_model = _normalize_fallback(fallback) # type: ignore[arg-type] | ||||
|
|
||||
| self._opts = _TTSOptions( | ||||
| model=model, | ||||
| voice=voice, | ||||
|
|
@@ -242,6 +334,7 @@ def __init__( | |||
| api_key=lk_api_key, | ||||
| api_secret=lk_api_secret, | ||||
| extra_kwargs=dict(extra_kwargs) if is_given(extra_kwargs) else {}, | ||||
| fallback=fallback_model, | ||||
| ) | ||||
| self._session = http_session | ||||
| self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse]( | ||||
|
|
@@ -262,11 +355,8 @@ def from_model_string(cls, model: str) -> TTS: | |||
| Returns: | ||||
| TTS: TTS instance | ||||
| """ | ||||
| voice: NotGivenOr[str] = NOT_GIVEN | ||||
| if (idx := model.rfind(":")) != -1: | ||||
| voice = model[idx + 1 :] | ||||
| model = model[:idx] | ||||
| return cls(model, voice=voice) | ||||
| model, voice = parse_model_string(model) | ||||
| return cls(model=model, voice=voice if voice else NOT_GIVEN) | ||||
|
|
||||
| @property | ||||
| def model(self) -> str: | ||||
|
|
@@ -308,6 +398,8 @@ async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse: | |||
| params["model"] = self._opts.model | ||||
| if self._opts.language: | ||||
| params["language"] = self._opts.language | ||||
| if self._opts.fallback: | ||||
| params["fallback"] = self._opts.fallback | ||||
|
|
||||
| try: | ||||
| await ws.send_str(json.dumps(params)) | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the voice be None? Is this different than omitting it (removing the Required flag)