diff --git a/common/config_models.py b/common/config_models.py index 79b952b..e16db00 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -19,8 +19,10 @@ class DefaultLlm(BaseModel): """Default LLM configuration.""" default_model: str + fallback_model: str | None = None default_temperature: float default_max_tokens: int + default_request_timeout: int = 120 class RetryConfig(BaseModel): diff --git a/common/global_config.yaml b/common/global_config.yaml index e77b235..b51fe57 100644 --- a/common/global_config.yaml +++ b/common/global_config.yaml @@ -9,8 +9,10 @@ example_parent: ######################################################## default_llm: default_model: gemini/gemini-2.0-flash + fallback_model: null default_temperature: 0.5 default_max_tokens: 100000 + default_request_timeout: 300 llm_config: cache_enabled: false diff --git a/tests/healthcheck/test_env_var_loading.py b/tests/healthcheck/test_env_var_loading.py index d558397..12c6061 100644 --- a/tests/healthcheck/test_env_var_loading.py +++ b/tests/healthcheck/test_env_var_loading.py @@ -31,7 +31,7 @@ def test_env_var_loading_precedence(monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "system_openai_key") # 2. Create a temporary .env file - dot_env_content = "DEV_ENV=dotenv\n" "OPENAI_API_KEY=dotenv_openai_key\n" + dot_env_content = "DEV_ENV=dotenv\nOPENAI_API_KEY=dotenv_openai_key\n" with open(dot_env_path, "w") as f: f.write(dot_env_content) @@ -41,12 +41,12 @@ def test_env_var_loading_precedence(monkeypatch): # 4. Assert that the variables are loaded with the correct precedence assert reloaded_config.DEV_ENV == "dotenv", "Should load from .env first" - assert ( - reloaded_config.ANTHROPIC_API_KEY == "system_anthropic_key" - ), "Should fall back to system env" - assert ( - reloaded_config.OPENAI_API_KEY == "dotenv_openai_key" - ), "Should load from .env" + assert reloaded_config.ANTHROPIC_API_KEY == "system_anthropic_key", ( + "Should fall back to system env" + ) + assert reloaded_config.OPENAI_API_KEY == "dotenv_openai_key", ( + "Should load from .env" + ) finally: # Clean up and restore the original .env file if it existed diff --git a/tests/healthcheck/test_pydantic_type_coercion.py b/tests/healthcheck/test_pydantic_type_coercion.py index ab895a8..2ad38a8 100644 --- a/tests/healthcheck/test_pydantic_type_coercion.py +++ b/tests/healthcheck/test_pydantic_type_coercion.py @@ -39,51 +39,51 @@ def test_pydantic_type_coercion(monkeypatch): config = common_module.global_config # Verify integer coercion - assert isinstance( - config.default_llm.default_max_tokens, int - ), "default_max_tokens should be int" - assert ( - config.default_llm.default_max_tokens == 50000 - ), "default_max_tokens should be 50000" - - assert isinstance( - config.llm_config.retry.max_attempts, int - ), "max_attempts should be int" + assert isinstance(config.default_llm.default_max_tokens, int), ( + "default_max_tokens should be int" + ) + assert config.default_llm.default_max_tokens == 50000, ( + "default_max_tokens should be 50000" + ) + + assert isinstance(config.llm_config.retry.max_attempts, int), ( + "max_attempts should be int" + ) assert config.llm_config.retry.max_attempts == 5, "max_attempts should be 5" - assert isinstance( - config.llm_config.retry.min_wait_seconds, int - ), "min_wait_seconds should be int" + assert isinstance(config.llm_config.retry.min_wait_seconds, int), ( + "min_wait_seconds should be int" + ) assert config.llm_config.retry.min_wait_seconds == 2, "min_wait_seconds should be 2" - assert isinstance( - config.llm_config.retry.max_wait_seconds, int - ), "max_wait_seconds should be int" - assert ( - config.llm_config.retry.max_wait_seconds == 10 - ), "max_wait_seconds should be 10" + assert isinstance(config.llm_config.retry.max_wait_seconds, int), ( + "max_wait_seconds should be int" + ) + assert config.llm_config.retry.max_wait_seconds == 10, ( + "max_wait_seconds should be 10" + ) # Verify float coercion - assert isinstance( - config.default_llm.default_temperature, float - ), "default_temperature should be float" - assert ( - config.default_llm.default_temperature == 0.7 - ), "default_temperature should be 0.7" + assert isinstance(config.default_llm.default_temperature, float), ( + "default_temperature should be float" + ) + assert config.default_llm.default_temperature == 0.7, ( + "default_temperature should be 0.7" + ) # Verify boolean coercion - assert isinstance( - config.llm_config.cache_enabled, bool - ), "cache_enabled should be bool" + assert isinstance(config.llm_config.cache_enabled, bool), ( + "cache_enabled should be bool" + ) assert config.llm_config.cache_enabled is True, "cache_enabled should be True" assert isinstance(config.logging.verbose, bool), "verbose should be bool" assert config.logging.verbose is False, "verbose should be False" assert isinstance(config.logging.format.show_time, bool), "show_time should be bool" - assert ( - config.logging.format.show_time is True - ), "show_time should be True (from '1')" + assert config.logging.format.show_time is True, ( + "show_time should be True (from '1')" + ) assert isinstance(config.logging.levels.debug, bool), "debug should be bool" assert config.logging.levels.debug is True, "debug should be True" diff --git a/tests/utils/llm/test_dspy_inference_robustness.py b/tests/utils/llm/test_dspy_inference_robustness.py new file mode 100644 index 0000000..8e3e73f --- /dev/null +++ b/tests/utils/llm/test_dspy_inference_robustness.py @@ -0,0 +1,142 @@ +import asyncio +from unittest.mock import AsyncMock, patch +from tests.test_template import TestTemplate +from utils.llm.dspy_inference import DSPYInference +from litellm.exceptions import RateLimitError, ServiceUnavailableError, Timeout +import dspy + +class MockSignature(dspy.Signature): + """A mock DSPy signature for testing purposes.""" + input = dspy.InputField() + output = dspy.OutputField() + +class TestDSPYInferenceRobustness(TestTemplate): + """Tests for DSPYInference robustness features (retry, fallback, timeout).""" + + def test_retry_on_rate_limit(self): + """ + Tests that the DSPYInference class correctly retries operations when encountering + a RateLimitError from the LLM provider. + """ + async def _test(): + with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"), \ + patch("common.global_config.global_config.default_llm.default_request_timeout", 1): + + inference = DSPYInference(pred_signature=MockSignature, observe=False) + + error = RateLimitError("Rate limit", llm_provider="openai", model="gpt-4") + + mock_method = AsyncMock(side_effect=[ + error, + dspy.Prediction(output="Success") + ]) + inference.inference_module_async = mock_method + + result = await inference.run(input="test") + + assert result.output == "Success" + assert mock_method.call_count == 2 + + asyncio.run(_test()) + + def test_retry_on_timeout(self): + """ + Tests that the DSPYInference class correctly retries operations when encountering + a Timeout error. + """ + async def _test(): + with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"), \ + patch("common.global_config.global_config.default_llm.default_request_timeout", 1): + + inference = DSPYInference(pred_signature=MockSignature, observe=False) + + error = Timeout("Timeout", llm_provider="openai", model="gpt-4") + + mock_method = AsyncMock(side_effect=[ + error, + dspy.Prediction(output="Success") + ]) + inference.inference_module_async = mock_method + + result = await inference.run(input="test") + + assert result.output == "Success" + assert mock_method.call_count == 2 + + asyncio.run(_test()) + + def test_fallback_logic(self): + """ + Tests the model fallback mechanism. + """ + async def _test(): + with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"): + + # Setup with fallback + inference = DSPYInference( + pred_signature=MockSignature, + observe=False, + model_name="primary-model", + fallback_model="fallback-model" + ) + + async def side_effect(*args, **kwargs): + lm = kwargs.get('lm') + if lm.model == "primary-model": + raise ServiceUnavailableError("Down", llm_provider="openai", model="primary-model") + elif lm.model == "fallback-model": + return dspy.Prediction(output="Fallback Success") + else: + raise ValueError(f"Unknown model: {lm.model}") + + inference.inference_module_async = AsyncMock(side_effect=side_effect) + + result = await inference.run(input="test") + + assert result.output == "Fallback Success" + + # Primary model should have been retried max_attempts times (default 3) + # Fallback model called once + # Total 4 + assert inference.inference_module_async.call_count == 4 + + asyncio.run(_test()) + + def test_fallback_failure(self): + """ + Tests the scenario where both the primary and fallback models fail. + """ + async def _test(): + with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"): + + # Setup with fallback where fallback also fails + inference = DSPYInference( + pred_signature=MockSignature, + observe=False, + model_name="primary-model", + fallback_model="fallback-model" + ) + + async def side_effect(*args, **kwargs): + lm = kwargs.get('lm') + if lm.model == "primary-model": + raise ServiceUnavailableError("Down", llm_provider="openai", model="primary-model") + elif lm.model == "fallback-model": + raise ServiceUnavailableError("Also Down", llm_provider="openai", model="fallback-model") + else: + raise ValueError(f"Unknown model: {lm.model}") + + inference.inference_module_async = AsyncMock(side_effect=side_effect) + + # Execute and expect exception + try: + await inference.run(input="test") + assert False, "Should have raised ServiceUnavailableError" + except ServiceUnavailableError: + pass + + # Primary called 3 times, Fallback called 3 times + # Total 6 + assert inference.inference_module_async.call_count == 6 + + asyncio.run(_test()) diff --git a/utils/llm/dspy_inference.py b/utils/llm/dspy_inference.py index bcc50b1..9a7e43d 100644 --- a/utils/llm/dspy_inference.py +++ b/utils/llm/dspy_inference.py @@ -10,7 +10,7 @@ retry_if_exception_type, ) from utils.llm.dspy_langfuse import LangFuseDSPYCallback -from litellm.exceptions import ServiceUnavailableError +from litellm.exceptions import ServiceUnavailableError, RateLimitError, Timeout from langfuse.decorators import observe # type: ignore @@ -21,8 +21,10 @@ def __init__( tools: list[Callable[..., Any]] | None = None, observe: bool = True, model_name: str = global_config.default_llm.default_model, + fallback_model: str | None = global_config.default_llm.fallback_model, temperature: float = global_config.default_llm.default_temperature, max_tokens: int = global_config.default_llm.default_max_tokens, + request_timeout: int = global_config.default_llm.default_request_timeout, max_iters: int = 5, ) -> None: if tools is None: @@ -35,7 +37,26 @@ def __init__( cache=global_config.llm_config.cache_enabled, temperature=temperature, max_tokens=max_tokens, + timeout=request_timeout, ) + + self.fallback_lm = None + if fallback_model: + try: + fallback_api_key = global_config.llm_api_key(fallback_model) + self.fallback_lm = dspy.LM( + model=fallback_model, + api_key=fallback_api_key, + cache=global_config.llm_config.cache_enabled, + temperature=temperature, + max_tokens=max_tokens, + timeout=request_timeout, + ) + except Exception as e: + log.warning( + f"Failed to initialize fallback model {fallback_model}: {e}" + ) + if observe: # Initialize a LangFuseDSPYCallback and configure the LM instance for generation tracing self.callback = LangFuseDSPYCallback(pred_signature) @@ -56,26 +77,40 @@ def __init__( self.inference_module ) - @observe() @retry( - retry=retry_if_exception_type(ServiceUnavailableError), + retry=retry_if_exception_type( + (ServiceUnavailableError, RateLimitError, Timeout) + ), stop=stop_after_attempt(global_config.llm_config.retry.max_attempts), wait=wait_exponential( multiplier=global_config.llm_config.retry.min_wait_seconds, max=global_config.llm_config.retry.max_wait_seconds, ), before_sleep=lambda retry_state: log.warning( - f"Retrying due to ServiceUnavailableError. Attempt {retry_state.attempt_number}" + f"Retrying due to LLM Error ({retry_state.outcome.exception()}). Attempt {retry_state.attempt_number}" ), + reraise=True, ) + async def _run_inference(self, lm, **kwargs) -> Any: + return await self.inference_module_async(**kwargs, lm=lm) + + @observe() async def run( self, **kwargs: Any, ) -> Any: try: # user_id is passed if the pred_signature requires it. - result = await self.inference_module_async(**kwargs, lm=self.lm) + result = await self._run_inference(lm=self.lm, **kwargs) except Exception as e: - log.error(f"Error in run: {str(e)}") - raise + if self.fallback_lm: + log.warning(f"Primary model failed: {e}. Switching to fallback model.") + try: + result = await self._run_inference(lm=self.fallback_lm, **kwargs) + except Exception as fallback_error: + log.error(f"Fallback model also failed: {fallback_error}") + raise fallback_error + else: + log.error(f"Error in run: {str(e)}") + raise return result diff --git a/utils/llm/dspy_langfuse.py b/utils/llm/dspy_langfuse.py index 26132f6..992c1b3 100644 --- a/utils/llm/dspy_langfuse.py +++ b/utils/llm/dspy_langfuse.py @@ -374,26 +374,30 @@ def on_tool_start( # noqa inputs: dict[str, Any], ) -> None: """Called when a tool execution starts.""" - tool_name = getattr(instance, "__name__", None) or getattr( - instance, "name", None - ) or str(type(instance).__name__) - + tool_name = ( + getattr(instance, "__name__", None) + or getattr(instance, "name", None) + or str(type(instance).__name__) + ) + # Skip internal DSPy tools if tool_name in self.INTERNAL_TOOLS: self.current_tool_span.set(None) return - + # Extract tool arguments tool_args = inputs.get("args", {}) if not tool_args: # Try to get kwargs directly - tool_args = {k: v for k, v in inputs.items() if k not in ["call_id", "instance"]} - + tool_args = { + k: v for k, v in inputs.items() if k not in ["call_id", "instance"] + } + log.debug(f"Tool call started: {tool_name} with args: {tool_args}") - + trace_id = langfuse_context.get_current_trace_id() parent_observation_id = langfuse_context.get_current_observation_id() - + if trace_id: # Create a span for the tool call tool_span = self.langfuse.span( @@ -416,12 +420,12 @@ def on_tool_end( # noqa ) -> None: """Called when a tool execution ends.""" tool_span = self.current_tool_span.get(None) - + if tool_span: level: Literal["DEFAULT", "WARNING", "ERROR"] = "DEFAULT" status_message: Optional[str] = None output_value: Any = None - + if exception: level = "ERROR" status_message = str(exception) @@ -438,12 +442,12 @@ def on_tool_end( # noqa output_value = str(outputs) except Exception as e: output_value = {"serialization_error": str(e), "raw": str(outputs)} - + tool_span.end( output=output_value, level=level, status_message=status_message, ) self.current_tool_span.set(None) - + log.debug(f"Tool call ended with output: {str(output_value)[:100]}...")