diff --git a/common/config_models.py b/common/config_models.py index 79b952b..0be73d9 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -19,6 +19,7 @@ class DefaultLlm(BaseModel): """Default LLM configuration.""" default_model: str + fallback_model: str | None = None default_temperature: float default_max_tokens: int diff --git a/common/global_config.yaml b/common/global_config.yaml index e77b235..091fcab 100644 --- a/common/global_config.yaml +++ b/common/global_config.yaml @@ -1,4 +1,4 @@ -model_name: gemini/gemini-2.0-flash +model_name: gemini/gemini-3-flash dot_global_config_health_check: true example_parent: @@ -8,7 +8,8 @@ example_parent: # LLMs ######################################################## default_llm: - default_model: gemini/gemini-2.0-flash + default_model: gemini/gemini-3-flash + fallback_model: gemini/gemini-2.5-flash default_temperature: 0.5 default_max_tokens: 100000 diff --git a/utils/llm/dspy_inference.py b/utils/llm/dspy_inference.py index 4dd8cf0..695d152 100644 --- a/utils/llm/dspy_inference.py +++ b/utils/llm/dspy_inference.py @@ -3,7 +3,7 @@ import dspy from langfuse.decorators import observe # type: ignore -from litellm.exceptions import ServiceUnavailableError +from litellm.exceptions import RateLimitError, ServiceUnavailableError from loguru import logger as log from tenacity import ( retry, @@ -23,6 +23,7 @@ def __init__( tools: list[Callable[..., Any]] | None = None, observe: bool = True, model_name: str = global_config.default_llm.default_model, + fallback_model_name: str | None = global_config.default_llm.fallback_model, temperature: float = global_config.default_llm.default_temperature, max_tokens: int = global_config.default_llm.default_max_tokens, max_iters: int = 5, @@ -30,13 +31,16 @@ def __init__( if tools is None: tools = [] - api_key = global_config.llm_api_key(model_name) - self.lm = dspy.LM( - model=model_name, - api_key=api_key, - cache=global_config.llm_config.cache_enabled, - temperature=temperature, - max_tokens=max_tokens, + self.lm = self._build_lm(model_name, temperature, max_tokens) + self.fallback_model_name = ( + fallback_model_name + if fallback_model_name and fallback_model_name != model_name + else None + ) + self.fallback_lm = ( + self._build_lm(self.fallback_model_name, temperature, max_tokens) + if self.fallback_model_name is not None + else None ) if observe: # Initialize a LangFuseDSPYCallback and configure the LM instance for generation tracing @@ -58,25 +62,61 @@ def __init__( self.inference_module ) - @observe() @retry( - retry=retry_if_exception_type(ServiceUnavailableError), + retry=retry_if_exception_type((RateLimitError, ServiceUnavailableError)), stop=stop_after_attempt(global_config.llm_config.retry.max_attempts), wait=wait_exponential( multiplier=global_config.llm_config.retry.min_wait_seconds, max=global_config.llm_config.retry.max_wait_seconds, ), before_sleep=lambda retry_state: log.warning( - f"Retrying due to ServiceUnavailableError. Attempt {retry_state.attempt_number}" + "Retrying due to LLM error " + f"{retry_state.outcome.exception().__class__.__name__}. " + f"Attempt {retry_state.attempt_number}" ), ) + async def _run_with_retry( + self, + lm: dspy.LM, + **kwargs: Any, + ) -> Any: + return await self.inference_module_async(**kwargs, lm=lm) + + def _build_lm( + self, + model_name: str, + temperature: float, + max_tokens: int, + ) -> dspy.LM: + api_key = global_config.llm_api_key(model_name) + return dspy.LM( + model=model_name, + api_key=api_key, + cache=global_config.llm_config.cache_enabled, + temperature=temperature, + max_tokens=max_tokens, + ) + + @observe() async def run( self, **kwargs: Any, ) -> Any: try: # user_id is passed if the pred_signature requires it. - result = await self.inference_module_async(**kwargs, lm=self.lm) + result = await self._run_with_retry(self.lm, **kwargs) + except (RateLimitError, ServiceUnavailableError) as e: + if not self.fallback_lm: + log.error(f"{e.__class__.__name__} without fallback: {str(e)}") + raise + log.warning( + f"Primary model unavailable; falling back to {self.fallback_model_name}" + ) + try: + result = await self._run_with_retry(self.fallback_lm, **kwargs) + except (RateLimitError, ServiceUnavailableError) as fallback_error: + log.error(f"Fallback model failed: {fallback_error.__class__.__name__}") + raise except (RuntimeError, ValueError, TypeError) as e: log.error(f"Error in run: {str(e)}") raise