From db98bb23748fe72073a65d4ffe6bb9ac50672667 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 17 Jan 2026 17:35:07 +0000 Subject: [PATCH 1/3] feat: enhance dspy inference robustness with retry, fallback, and timeout - Update `LlmConfig` and `DefaultLlm` in `common/config_models.py` to support `fallback_model` and `default_request_timeout`. - Update `common/global_config.yaml` with default values. - Refactor `utils/llm/dspy_inference.py` to: - Implement retry logic for `ServiceUnavailableError`, `RateLimitError`, and `Timeout` using `tenacity` with exponential backoff. - Implement fallback to a secondary model if the primary fails. - Pass `timeout` to the LLM provider. - Add `tests/utils/llm/test_dspy_inference_robustness.py` to verify retry and fallback mechanisms. --- common/config_models.py | 2 + common/global_config.yaml | 2 + .../llm/test_dspy_inference_robustness.py | 127 ++++++++++++++++++ utils/llm/dspy_inference.py | 45 ++++++- 4 files changed, 169 insertions(+), 7 deletions(-) create mode 100644 tests/utils/llm/test_dspy_inference_robustness.py diff --git a/common/config_models.py b/common/config_models.py index 79b952b..e16db00 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -19,8 +19,10 @@ class DefaultLlm(BaseModel): """Default LLM configuration.""" default_model: str + fallback_model: str | None = None default_temperature: float default_max_tokens: int + default_request_timeout: int = 120 class RetryConfig(BaseModel): diff --git a/common/global_config.yaml b/common/global_config.yaml index e77b235..b51fe57 100644 --- a/common/global_config.yaml +++ b/common/global_config.yaml @@ -9,8 +9,10 @@ example_parent: ######################################################## default_llm: default_model: gemini/gemini-2.0-flash + fallback_model: null default_temperature: 0.5 default_max_tokens: 100000 + default_request_timeout: 300 llm_config: cache_enabled: false diff --git a/tests/utils/llm/test_dspy_inference_robustness.py b/tests/utils/llm/test_dspy_inference_robustness.py new file mode 100644 index 0000000..f60c6b7 --- /dev/null +++ b/tests/utils/llm/test_dspy_inference_robustness.py @@ -0,0 +1,127 @@ +import asyncio +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +from utils.llm.dspy_inference import DSPYInference +from litellm.exceptions import RateLimitError, ServiceUnavailableError, Timeout +import dspy + +class MockSignature(dspy.Signature): + input = dspy.InputField() + output = dspy.OutputField() + +def run_async(coro): + return asyncio.run(coro) + +def test_retry_on_rate_limit(): + async def _test(): + with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"), \ + patch("common.global_config.global_config.default_llm.default_request_timeout", 1): + + inference = DSPYInference(pred_signature=MockSignature, observe=False) + + error = RateLimitError("Rate limit", llm_provider="openai", model="gpt-4") + + mock_method = AsyncMock(side_effect=[ + error, + dspy.Prediction(output="Success") + ]) + inference.inference_module_async = mock_method + + result = await inference.run(input="test") + + assert result.output == "Success" + assert mock_method.call_count == 2 + + run_async(_test()) + +def test_retry_on_timeout(): + async def _test(): + with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"), \ + patch("common.global_config.global_config.default_llm.default_request_timeout", 1): + + inference = DSPYInference(pred_signature=MockSignature, observe=False) + + error = Timeout("Timeout", llm_provider="openai", model="gpt-4") + + mock_method = AsyncMock(side_effect=[ + error, + dspy.Prediction(output="Success") + ]) + inference.inference_module_async = mock_method + + result = await inference.run(input="test") + + assert result.output == "Success" + assert mock_method.call_count == 2 + + run_async(_test()) + +def test_fallback_logic(): + async def _test(): + with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"): + + # Setup with fallback + inference = DSPYInference( + pred_signature=MockSignature, + observe=False, + model_name="primary-model", + fallback_model="fallback-model" + ) + + async def side_effect(*args, **kwargs): + lm = kwargs.get('lm') + if lm.model == "primary-model": + raise ServiceUnavailableError("Down", llm_provider="openai", model="primary-model") + elif lm.model == "fallback-model": + return dspy.Prediction(output="Fallback Success") + else: + raise ValueError(f"Unknown model: {lm.model}") + + inference.inference_module_async = AsyncMock(side_effect=side_effect) + + result = await inference.run(input="test") + + assert result.output == "Fallback Success" + + # Primary model should have been retried max_attempts times (default 3) + # Fallback model called once + # Total 4 + assert inference.inference_module_async.call_count == 4 + + run_async(_test()) + +def test_fallback_failure(): + async def _test(): + with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"): + + # Setup with fallback where fallback also fails + inference = DSPYInference( + pred_signature=MockSignature, + observe=False, + model_name="primary-model", + fallback_model="fallback-model" + ) + + async def side_effect(*args, **kwargs): + lm = kwargs.get('lm') + if lm.model == "primary-model": + raise ServiceUnavailableError("Down", llm_provider="openai", model="primary-model") + elif lm.model == "fallback-model": + raise ServiceUnavailableError("Also Down", llm_provider="openai", model="fallback-model") + else: + raise ValueError(f"Unknown model: {lm.model}") + + inference.inference_module_async = AsyncMock(side_effect=side_effect) + + # Execute and expect exception + try: + await inference.run(input="test") + assert False, "Should have raised ServiceUnavailableError" + except ServiceUnavailableError: + pass + + # Primary called 3 times, Fallback called 3 times + # Total 6 + assert inference.inference_module_async.call_count == 6 + + run_async(_test()) diff --git a/utils/llm/dspy_inference.py b/utils/llm/dspy_inference.py index bcc50b1..2c655b4 100644 --- a/utils/llm/dspy_inference.py +++ b/utils/llm/dspy_inference.py @@ -10,7 +10,7 @@ retry_if_exception_type, ) from utils.llm.dspy_langfuse import LangFuseDSPYCallback -from litellm.exceptions import ServiceUnavailableError +from litellm.exceptions import ServiceUnavailableError, RateLimitError, Timeout from langfuse.decorators import observe # type: ignore @@ -21,8 +21,10 @@ def __init__( tools: list[Callable[..., Any]] | None = None, observe: bool = True, model_name: str = global_config.default_llm.default_model, + fallback_model: str | None = global_config.default_llm.fallback_model, temperature: float = global_config.default_llm.default_temperature, max_tokens: int = global_config.default_llm.default_max_tokens, + request_timeout: int = global_config.default_llm.default_request_timeout, max_iters: int = 5, ) -> None: if tools is None: @@ -35,7 +37,24 @@ def __init__( cache=global_config.llm_config.cache_enabled, temperature=temperature, max_tokens=max_tokens, + timeout=request_timeout, ) + + self.fallback_lm = None + if fallback_model: + try: + fallback_api_key = global_config.llm_api_key(fallback_model) + self.fallback_lm = dspy.LM( + model=fallback_model, + api_key=fallback_api_key, + cache=global_config.llm_config.cache_enabled, + temperature=temperature, + max_tokens=max_tokens, + timeout=request_timeout, + ) + except Exception as e: + log.warning(f"Failed to initialize fallback model {fallback_model}: {e}") + if observe: # Initialize a LangFuseDSPYCallback and configure the LM instance for generation tracing self.callback = LangFuseDSPYCallback(pred_signature) @@ -56,26 +75,38 @@ def __init__( self.inference_module ) - @observe() @retry( - retry=retry_if_exception_type(ServiceUnavailableError), + retry=retry_if_exception_type((ServiceUnavailableError, RateLimitError, Timeout)), stop=stop_after_attempt(global_config.llm_config.retry.max_attempts), wait=wait_exponential( multiplier=global_config.llm_config.retry.min_wait_seconds, max=global_config.llm_config.retry.max_wait_seconds, ), before_sleep=lambda retry_state: log.warning( - f"Retrying due to ServiceUnavailableError. Attempt {retry_state.attempt_number}" + f"Retrying due to LLM Error ({retry_state.outcome.exception()}). Attempt {retry_state.attempt_number}" ), + reraise=True, ) + async def _run_inference(self, lm, **kwargs) -> Any: + return await self.inference_module_async(**kwargs, lm=lm) + + @observe() async def run( self, **kwargs: Any, ) -> Any: try: # user_id is passed if the pred_signature requires it. - result = await self.inference_module_async(**kwargs, lm=self.lm) + result = await self._run_inference(lm=self.lm, **kwargs) except Exception as e: - log.error(f"Error in run: {str(e)}") - raise + if self.fallback_lm: + log.warning(f"Primary model failed: {e}. Switching to fallback model.") + try: + result = await self._run_inference(lm=self.fallback_lm, **kwargs) + except Exception as fallback_error: + log.error(f"Fallback model also failed: {fallback_error}") + raise fallback_error + else: + log.error(f"Error in run: {str(e)}") + raise return result From 832afaf127fd2aa1c1464c8e6063347fd22e287f Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 17 Jan 2026 18:24:35 +0000 Subject: [PATCH 2/3] feat: enhance dspy inference robustness with retry, fallback, and timeout - Update `LlmConfig` and `DefaultLlm` in `common/config_models.py` to support `fallback_model` and `default_request_timeout`. - Update `common/global_config.yaml` with default values. - Refactor `utils/llm/dspy_inference.py` to: - Implement retry logic for `ServiceUnavailableError`, `RateLimitError`, and `Timeout` using `tenacity` with exponential backoff. - Implement fallback to a secondary model if the primary fails. - Pass `timeout` to the LLM provider. - Add `tests/utils/llm/test_dspy_inference_robustness.py` to verify retry and fallback mechanisms. --- tests/healthcheck/test_env_var_loading.py | 14 ++--- .../test_pydantic_type_coercion.py | 62 +++++++++---------- .../llm/test_dspy_inference_robustness.py | 43 ++++++++++++- utils/llm/dspy_inference.py | 8 ++- utils/llm/dspy_langfuse.py | 30 +++++---- 5 files changed, 102 insertions(+), 55 deletions(-) diff --git a/tests/healthcheck/test_env_var_loading.py b/tests/healthcheck/test_env_var_loading.py index d558397..12c6061 100644 --- a/tests/healthcheck/test_env_var_loading.py +++ b/tests/healthcheck/test_env_var_loading.py @@ -31,7 +31,7 @@ def test_env_var_loading_precedence(monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "system_openai_key") # 2. Create a temporary .env file - dot_env_content = "DEV_ENV=dotenv\n" "OPENAI_API_KEY=dotenv_openai_key\n" + dot_env_content = "DEV_ENV=dotenv\nOPENAI_API_KEY=dotenv_openai_key\n" with open(dot_env_path, "w") as f: f.write(dot_env_content) @@ -41,12 +41,12 @@ def test_env_var_loading_precedence(monkeypatch): # 4. Assert that the variables are loaded with the correct precedence assert reloaded_config.DEV_ENV == "dotenv", "Should load from .env first" - assert ( - reloaded_config.ANTHROPIC_API_KEY == "system_anthropic_key" - ), "Should fall back to system env" - assert ( - reloaded_config.OPENAI_API_KEY == "dotenv_openai_key" - ), "Should load from .env" + assert reloaded_config.ANTHROPIC_API_KEY == "system_anthropic_key", ( + "Should fall back to system env" + ) + assert reloaded_config.OPENAI_API_KEY == "dotenv_openai_key", ( + "Should load from .env" + ) finally: # Clean up and restore the original .env file if it existed diff --git a/tests/healthcheck/test_pydantic_type_coercion.py b/tests/healthcheck/test_pydantic_type_coercion.py index ab895a8..2ad38a8 100644 --- a/tests/healthcheck/test_pydantic_type_coercion.py +++ b/tests/healthcheck/test_pydantic_type_coercion.py @@ -39,51 +39,51 @@ def test_pydantic_type_coercion(monkeypatch): config = common_module.global_config # Verify integer coercion - assert isinstance( - config.default_llm.default_max_tokens, int - ), "default_max_tokens should be int" - assert ( - config.default_llm.default_max_tokens == 50000 - ), "default_max_tokens should be 50000" - - assert isinstance( - config.llm_config.retry.max_attempts, int - ), "max_attempts should be int" + assert isinstance(config.default_llm.default_max_tokens, int), ( + "default_max_tokens should be int" + ) + assert config.default_llm.default_max_tokens == 50000, ( + "default_max_tokens should be 50000" + ) + + assert isinstance(config.llm_config.retry.max_attempts, int), ( + "max_attempts should be int" + ) assert config.llm_config.retry.max_attempts == 5, "max_attempts should be 5" - assert isinstance( - config.llm_config.retry.min_wait_seconds, int - ), "min_wait_seconds should be int" + assert isinstance(config.llm_config.retry.min_wait_seconds, int), ( + "min_wait_seconds should be int" + ) assert config.llm_config.retry.min_wait_seconds == 2, "min_wait_seconds should be 2" - assert isinstance( - config.llm_config.retry.max_wait_seconds, int - ), "max_wait_seconds should be int" - assert ( - config.llm_config.retry.max_wait_seconds == 10 - ), "max_wait_seconds should be 10" + assert isinstance(config.llm_config.retry.max_wait_seconds, int), ( + "max_wait_seconds should be int" + ) + assert config.llm_config.retry.max_wait_seconds == 10, ( + "max_wait_seconds should be 10" + ) # Verify float coercion - assert isinstance( - config.default_llm.default_temperature, float - ), "default_temperature should be float" - assert ( - config.default_llm.default_temperature == 0.7 - ), "default_temperature should be 0.7" + assert isinstance(config.default_llm.default_temperature, float), ( + "default_temperature should be float" + ) + assert config.default_llm.default_temperature == 0.7, ( + "default_temperature should be 0.7" + ) # Verify boolean coercion - assert isinstance( - config.llm_config.cache_enabled, bool - ), "cache_enabled should be bool" + assert isinstance(config.llm_config.cache_enabled, bool), ( + "cache_enabled should be bool" + ) assert config.llm_config.cache_enabled is True, "cache_enabled should be True" assert isinstance(config.logging.verbose, bool), "verbose should be bool" assert config.logging.verbose is False, "verbose should be False" assert isinstance(config.logging.format.show_time, bool), "show_time should be bool" - assert ( - config.logging.format.show_time is True - ), "show_time should be True (from '1')" + assert config.logging.format.show_time is True, ( + "show_time should be True (from '1')" + ) assert isinstance(config.logging.levels.debug, bool), "debug should be bool" assert config.logging.levels.debug is True, "debug should be True" diff --git a/tests/utils/llm/test_dspy_inference_robustness.py b/tests/utils/llm/test_dspy_inference_robustness.py index f60c6b7..7e5e58f 100644 --- a/tests/utils/llm/test_dspy_inference_robustness.py +++ b/tests/utils/llm/test_dspy_inference_robustness.py @@ -1,18 +1,28 @@ import asyncio -import pytest -from unittest.mock import MagicMock, AsyncMock, patch +from unittest.mock import AsyncMock, patch from utils.llm.dspy_inference import DSPYInference from litellm.exceptions import RateLimitError, ServiceUnavailableError, Timeout import dspy class MockSignature(dspy.Signature): + """A mock DSPy signature for testing purposes.""" input = dspy.InputField() output = dspy.OutputField() def run_async(coro): + """Helper to run async test functions.""" return asyncio.run(coro) def test_retry_on_rate_limit(): + """ + Tests that the DSPYInference class correctly retries operations when encountering + a RateLimitError from the LLM provider. + + It mocks the underlying async inference call to raise a RateLimitError on the first + attempt and succeed on the second. It verifies that: + 1. The operation eventually succeeds. + 2. The inference method was called twice (initial attempt + 1 retry). + """ async def _test(): with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"), \ patch("common.global_config.global_config.default_llm.default_request_timeout", 1): @@ -35,6 +45,15 @@ async def _test(): run_async(_test()) def test_retry_on_timeout(): + """ + Tests that the DSPYInference class correctly retries operations when encountering + a Timeout error. + + It mocks the underlying async inference call to raise a Timeout on the first + attempt and succeed on the second. It verifies that: + 1. The operation eventually succeeds. + 2. The inference method was called twice. + """ async def _test(): with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"), \ patch("common.global_config.global_config.default_llm.default_request_timeout", 1): @@ -57,6 +76,17 @@ async def _test(): run_async(_test()) def test_fallback_logic(): + """ + Tests the model fallback mechanism. + + It configures a primary model and a fallback model. It mocks the inference call + to simulate the primary model failing with a ServiceUnavailableError (which triggers retries). + After the primary model's retries are exhausted, the system should switch to the fallback model. + + It verifies that: + 1. The operation succeeds using the fallback model. + 2. The total call count reflects the primary model's retries + the successful fallback attempt. + """ async def _test(): with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"): @@ -91,6 +121,15 @@ async def side_effect(*args, **kwargs): run_async(_test()) def test_fallback_failure(): + """ + Tests the scenario where both the primary and fallback models fail. + + It mocks both models to raise ServiceUnavailableError. + + It verifies that: + 1. The ServiceUnavailableError is ultimately raised to the caller. + 2. Both primary and fallback models were attempted (with their respective retries). + """ async def _test(): with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"): diff --git a/utils/llm/dspy_inference.py b/utils/llm/dspy_inference.py index 2c655b4..9a7e43d 100644 --- a/utils/llm/dspy_inference.py +++ b/utils/llm/dspy_inference.py @@ -53,7 +53,9 @@ def __init__( timeout=request_timeout, ) except Exception as e: - log.warning(f"Failed to initialize fallback model {fallback_model}: {e}") + log.warning( + f"Failed to initialize fallback model {fallback_model}: {e}" + ) if observe: # Initialize a LangFuseDSPYCallback and configure the LM instance for generation tracing @@ -76,7 +78,9 @@ def __init__( ) @retry( - retry=retry_if_exception_type((ServiceUnavailableError, RateLimitError, Timeout)), + retry=retry_if_exception_type( + (ServiceUnavailableError, RateLimitError, Timeout) + ), stop=stop_after_attempt(global_config.llm_config.retry.max_attempts), wait=wait_exponential( multiplier=global_config.llm_config.retry.min_wait_seconds, diff --git a/utils/llm/dspy_langfuse.py b/utils/llm/dspy_langfuse.py index 26132f6..992c1b3 100644 --- a/utils/llm/dspy_langfuse.py +++ b/utils/llm/dspy_langfuse.py @@ -374,26 +374,30 @@ def on_tool_start( # noqa inputs: dict[str, Any], ) -> None: """Called when a tool execution starts.""" - tool_name = getattr(instance, "__name__", None) or getattr( - instance, "name", None - ) or str(type(instance).__name__) - + tool_name = ( + getattr(instance, "__name__", None) + or getattr(instance, "name", None) + or str(type(instance).__name__) + ) + # Skip internal DSPy tools if tool_name in self.INTERNAL_TOOLS: self.current_tool_span.set(None) return - + # Extract tool arguments tool_args = inputs.get("args", {}) if not tool_args: # Try to get kwargs directly - tool_args = {k: v for k, v in inputs.items() if k not in ["call_id", "instance"]} - + tool_args = { + k: v for k, v in inputs.items() if k not in ["call_id", "instance"] + } + log.debug(f"Tool call started: {tool_name} with args: {tool_args}") - + trace_id = langfuse_context.get_current_trace_id() parent_observation_id = langfuse_context.get_current_observation_id() - + if trace_id: # Create a span for the tool call tool_span = self.langfuse.span( @@ -416,12 +420,12 @@ def on_tool_end( # noqa ) -> None: """Called when a tool execution ends.""" tool_span = self.current_tool_span.get(None) - + if tool_span: level: Literal["DEFAULT", "WARNING", "ERROR"] = "DEFAULT" status_message: Optional[str] = None output_value: Any = None - + if exception: level = "ERROR" status_message = str(exception) @@ -438,12 +442,12 @@ def on_tool_end( # noqa output_value = str(outputs) except Exception as e: output_value = {"serialization_error": str(e), "raw": str(outputs)} - + tool_span.end( output=output_value, level=level, status_message=status_message, ) self.current_tool_span.set(None) - + log.debug(f"Tool call ended with output: {str(output_value)[:100]}...") From fb470cd8f6cc4bcdf69c11dc8c45b6580f5f8369 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 17 Jan 2026 20:00:12 +0000 Subject: [PATCH 3/3] feat: enhance dspy inference robustness with retry, fallback, and timeout - Update `LlmConfig` and `DefaultLlm` in `common/config_models.py` to support `fallback_model` and `default_request_timeout`. - Update `common/global_config.yaml` with default values. - Refactor `utils/llm/dspy_inference.py` to: - Implement retry logic for `ServiceUnavailableError`, `RateLimitError`, and `Timeout` using `tenacity` with exponential backoff. - Implement fallback to a secondary model if the primary fails. - Pass `timeout` to the LLM provider. - Add `tests/utils/llm/test_dspy_inference_robustness.py` to verify retry and fallback mechanisms. --- .../llm/test_dspy_inference_robustness.py | 286 ++++++++---------- 1 file changed, 131 insertions(+), 155 deletions(-) diff --git a/tests/utils/llm/test_dspy_inference_robustness.py b/tests/utils/llm/test_dspy_inference_robustness.py index 7e5e58f..8e3e73f 100644 --- a/tests/utils/llm/test_dspy_inference_robustness.py +++ b/tests/utils/llm/test_dspy_inference_robustness.py @@ -1,5 +1,6 @@ import asyncio from unittest.mock import AsyncMock, patch +from tests.test_template import TestTemplate from utils.llm.dspy_inference import DSPYInference from litellm.exceptions import RateLimitError, ServiceUnavailableError, Timeout import dspy @@ -9,158 +10,133 @@ class MockSignature(dspy.Signature): input = dspy.InputField() output = dspy.OutputField() -def run_async(coro): - """Helper to run async test functions.""" - return asyncio.run(coro) - -def test_retry_on_rate_limit(): - """ - Tests that the DSPYInference class correctly retries operations when encountering - a RateLimitError from the LLM provider. - - It mocks the underlying async inference call to raise a RateLimitError on the first - attempt and succeed on the second. It verifies that: - 1. The operation eventually succeeds. - 2. The inference method was called twice (initial attempt + 1 retry). - """ - async def _test(): - with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"), \ - patch("common.global_config.global_config.default_llm.default_request_timeout", 1): - - inference = DSPYInference(pred_signature=MockSignature, observe=False) - - error = RateLimitError("Rate limit", llm_provider="openai", model="gpt-4") - - mock_method = AsyncMock(side_effect=[ - error, - dspy.Prediction(output="Success") - ]) - inference.inference_module_async = mock_method - - result = await inference.run(input="test") - - assert result.output == "Success" - assert mock_method.call_count == 2 - - run_async(_test()) - -def test_retry_on_timeout(): - """ - Tests that the DSPYInference class correctly retries operations when encountering - a Timeout error. - - It mocks the underlying async inference call to raise a Timeout on the first - attempt and succeed on the second. It verifies that: - 1. The operation eventually succeeds. - 2. The inference method was called twice. - """ - async def _test(): - with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"), \ - patch("common.global_config.global_config.default_llm.default_request_timeout", 1): - - inference = DSPYInference(pred_signature=MockSignature, observe=False) - - error = Timeout("Timeout", llm_provider="openai", model="gpt-4") - - mock_method = AsyncMock(side_effect=[ - error, - dspy.Prediction(output="Success") - ]) - inference.inference_module_async = mock_method - - result = await inference.run(input="test") - - assert result.output == "Success" - assert mock_method.call_count == 2 - - run_async(_test()) - -def test_fallback_logic(): - """ - Tests the model fallback mechanism. - - It configures a primary model and a fallback model. It mocks the inference call - to simulate the primary model failing with a ServiceUnavailableError (which triggers retries). - After the primary model's retries are exhausted, the system should switch to the fallback model. - - It verifies that: - 1. The operation succeeds using the fallback model. - 2. The total call count reflects the primary model's retries + the successful fallback attempt. - """ - async def _test(): - with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"): - - # Setup with fallback - inference = DSPYInference( - pred_signature=MockSignature, - observe=False, - model_name="primary-model", - fallback_model="fallback-model" - ) - - async def side_effect(*args, **kwargs): - lm = kwargs.get('lm') - if lm.model == "primary-model": - raise ServiceUnavailableError("Down", llm_provider="openai", model="primary-model") - elif lm.model == "fallback-model": - return dspy.Prediction(output="Fallback Success") - else: - raise ValueError(f"Unknown model: {lm.model}") - - inference.inference_module_async = AsyncMock(side_effect=side_effect) - - result = await inference.run(input="test") - - assert result.output == "Fallback Success" - - # Primary model should have been retried max_attempts times (default 3) - # Fallback model called once - # Total 4 - assert inference.inference_module_async.call_count == 4 - - run_async(_test()) - -def test_fallback_failure(): - """ - Tests the scenario where both the primary and fallback models fail. - - It mocks both models to raise ServiceUnavailableError. - - It verifies that: - 1. The ServiceUnavailableError is ultimately raised to the caller. - 2. Both primary and fallback models were attempted (with their respective retries). - """ - async def _test(): - with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"): - - # Setup with fallback where fallback also fails - inference = DSPYInference( - pred_signature=MockSignature, - observe=False, - model_name="primary-model", - fallback_model="fallback-model" - ) - - async def side_effect(*args, **kwargs): - lm = kwargs.get('lm') - if lm.model == "primary-model": - raise ServiceUnavailableError("Down", llm_provider="openai", model="primary-model") - elif lm.model == "fallback-model": - raise ServiceUnavailableError("Also Down", llm_provider="openai", model="fallback-model") - else: - raise ValueError(f"Unknown model: {lm.model}") - - inference.inference_module_async = AsyncMock(side_effect=side_effect) - - # Execute and expect exception - try: - await inference.run(input="test") - assert False, "Should have raised ServiceUnavailableError" - except ServiceUnavailableError: - pass - - # Primary called 3 times, Fallback called 3 times - # Total 6 - assert inference.inference_module_async.call_count == 6 - - run_async(_test()) +class TestDSPYInferenceRobustness(TestTemplate): + """Tests for DSPYInference robustness features (retry, fallback, timeout).""" + + def test_retry_on_rate_limit(self): + """ + Tests that the DSPYInference class correctly retries operations when encountering + a RateLimitError from the LLM provider. + """ + async def _test(): + with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"), \ + patch("common.global_config.global_config.default_llm.default_request_timeout", 1): + + inference = DSPYInference(pred_signature=MockSignature, observe=False) + + error = RateLimitError("Rate limit", llm_provider="openai", model="gpt-4") + + mock_method = AsyncMock(side_effect=[ + error, + dspy.Prediction(output="Success") + ]) + inference.inference_module_async = mock_method + + result = await inference.run(input="test") + + assert result.output == "Success" + assert mock_method.call_count == 2 + + asyncio.run(_test()) + + def test_retry_on_timeout(self): + """ + Tests that the DSPYInference class correctly retries operations when encountering + a Timeout error. + """ + async def _test(): + with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"), \ + patch("common.global_config.global_config.default_llm.default_request_timeout", 1): + + inference = DSPYInference(pred_signature=MockSignature, observe=False) + + error = Timeout("Timeout", llm_provider="openai", model="gpt-4") + + mock_method = AsyncMock(side_effect=[ + error, + dspy.Prediction(output="Success") + ]) + inference.inference_module_async = mock_method + + result = await inference.run(input="test") + + assert result.output == "Success" + assert mock_method.call_count == 2 + + asyncio.run(_test()) + + def test_fallback_logic(self): + """ + Tests the model fallback mechanism. + """ + async def _test(): + with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"): + + # Setup with fallback + inference = DSPYInference( + pred_signature=MockSignature, + observe=False, + model_name="primary-model", + fallback_model="fallback-model" + ) + + async def side_effect(*args, **kwargs): + lm = kwargs.get('lm') + if lm.model == "primary-model": + raise ServiceUnavailableError("Down", llm_provider="openai", model="primary-model") + elif lm.model == "fallback-model": + return dspy.Prediction(output="Fallback Success") + else: + raise ValueError(f"Unknown model: {lm.model}") + + inference.inference_module_async = AsyncMock(side_effect=side_effect) + + result = await inference.run(input="test") + + assert result.output == "Fallback Success" + + # Primary model should have been retried max_attempts times (default 3) + # Fallback model called once + # Total 4 + assert inference.inference_module_async.call_count == 4 + + asyncio.run(_test()) + + def test_fallback_failure(self): + """ + Tests the scenario where both the primary and fallback models fail. + """ + async def _test(): + with patch("common.global_config.global_config.llm_api_key", return_value="fake-key"): + + # Setup with fallback where fallback also fails + inference = DSPYInference( + pred_signature=MockSignature, + observe=False, + model_name="primary-model", + fallback_model="fallback-model" + ) + + async def side_effect(*args, **kwargs): + lm = kwargs.get('lm') + if lm.model == "primary-model": + raise ServiceUnavailableError("Down", llm_provider="openai", model="primary-model") + elif lm.model == "fallback-model": + raise ServiceUnavailableError("Also Down", llm_provider="openai", model="fallback-model") + else: + raise ValueError(f"Unknown model: {lm.model}") + + inference.inference_module_async = AsyncMock(side_effect=side_effect) + + # Execute and expect exception + try: + await inference.run(input="test") + assert False, "Should have raised ServiceUnavailableError" + except ServiceUnavailableError: + pass + + # Primary called 3 times, Fallback called 3 times + # Total 6 + assert inference.inference_module_async.call_count == 6 + + asyncio.run(_test())