Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/config_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class DefaultLlm(BaseModel):
"""Default LLM configuration."""

default_model: str
fallback_model: str | None = None
default_temperature: float
default_max_tokens: int

Expand Down
5 changes: 3 additions & 2 deletions common/global_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
model_name: gemini/gemini-2.0-flash
model_name: gemini/gemini-3-flash
dot_global_config_health_check: true

example_parent:
Expand All @@ -8,7 +8,8 @@ example_parent:
# LLMs
########################################################
default_llm:
default_model: gemini/gemini-2.0-flash
default_model: gemini/gemini-3-flash
fallback_model: gemini/gemini-2.5-flash
default_temperature: 0.5
default_max_tokens: 100000

Expand Down
64 changes: 52 additions & 12 deletions utils/llm/dspy_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import dspy
from langfuse.decorators import observe # type: ignore
from litellm.exceptions import ServiceUnavailableError
from litellm.exceptions import RateLimitError, ServiceUnavailableError
from loguru import logger as log
from tenacity import (
retry,
Expand All @@ -23,20 +23,24 @@ def __init__(
tools: list[Callable[..., Any]] | None = None,
observe: bool = True,
model_name: str = global_config.default_llm.default_model,
fallback_model_name: str | None = global_config.default_llm.fallback_model,
temperature: float = global_config.default_llm.default_temperature,
max_tokens: int = global_config.default_llm.default_max_tokens,
max_iters: int = 5,
) -> None:
if tools is None:
tools = []

api_key = global_config.llm_api_key(model_name)
self.lm = dspy.LM(
model=model_name,
api_key=api_key,
cache=global_config.llm_config.cache_enabled,
temperature=temperature,
max_tokens=max_tokens,
self.lm = self._build_lm(model_name, temperature, max_tokens)
self.fallback_model_name = (
fallback_model_name
if fallback_model_name and fallback_model_name != model_name
else None
)
self.fallback_lm = (
self._build_lm(self.fallback_model_name, temperature, max_tokens)
if self.fallback_model_name is not None
else None
Comment thread
Miyamura80 marked this conversation as resolved.
)
Comment thread
Miyamura80 marked this conversation as resolved.
Comment thread
Miyamura80 marked this conversation as resolved.
if observe:
# Initialize a LangFuseDSPYCallback and configure the LM instance for generation tracing
Expand All @@ -58,25 +62,61 @@ def __init__(
self.inference_module
)

@observe()
@retry(
retry=retry_if_exception_type(ServiceUnavailableError),
retry=retry_if_exception_type((RateLimitError, ServiceUnavailableError)),
stop=stop_after_attempt(global_config.llm_config.retry.max_attempts),
wait=wait_exponential(
multiplier=global_config.llm_config.retry.min_wait_seconds,
max=global_config.llm_config.retry.max_wait_seconds,
),
before_sleep=lambda retry_state: log.warning(
f"Retrying due to ServiceUnavailableError. Attempt {retry_state.attempt_number}"
"Retrying due to LLM error "
f"{retry_state.outcome.exception().__class__.__name__}. "
f"Attempt {retry_state.attempt_number}"
),
)
async def _run_with_retry(
self,
lm: dspy.LM,
**kwargs: Any,
) -> Any:
return await self.inference_module_async(**kwargs, lm=lm)

def _build_lm(
self,
model_name: str,
temperature: float,
max_tokens: int,
) -> dspy.LM:
api_key = global_config.llm_api_key(model_name)
return dspy.LM(
model=model_name,
api_key=api_key,
cache=global_config.llm_config.cache_enabled,
temperature=temperature,
max_tokens=max_tokens,
)

@observe()
async def run(
self,
**kwargs: Any,
) -> Any:
try:
# user_id is passed if the pred_signature requires it.
result = await self.inference_module_async(**kwargs, lm=self.lm)
result = await self._run_with_retry(self.lm, **kwargs)
except (RateLimitError, ServiceUnavailableError) as e:
if not self.fallback_lm:
log.error(f"{e.__class__.__name__} without fallback: {str(e)}")
raise
log.warning(
f"Primary model unavailable; falling back to {self.fallback_model_name}"
)
try:
result = await self._run_with_retry(self.fallback_lm, **kwargs)
except (RateLimitError, ServiceUnavailableError) as fallback_error:
log.error(f"Fallback model failed: {fallback_error.__class__.__name__}")
raise
except (RuntimeError, ValueError, TypeError) as e:
log.error(f"Error in run: {str(e)}")
raise
Expand Down