Skip to content

Commit 9385f0b

Browse files
committed
Merge remote-tracking branch 'origin/main' into chore/ruff-format
2 parents 92e82b0 + e312512 commit 9385f0b

10 files changed

Lines changed: 1400 additions & 1154 deletions

.github/workflows/a_test_target_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
echo "Log Level: ${{ github.event.inputs.log_level }}"
4141
echo "Environment: ${{ github.event.inputs.environment }}"
4242
- name: Set up Python
43-
uses: actions/setup-python@v5
43+
uses: actions/setup-python@v6
4444
with:
4545
python-version: ${{ env.PYTHON_VERSION }}
4646
- name: Install uv

.github/workflows/agents_validate.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
- name: Read Python version
3636
run: echo "PYTHON_VERSION=$(cat .python-version | tr -d '\n')" >> $GITHUB_ENV
3737
- name: Set up Python
38-
uses: actions/setup-python@v5
38+
uses: actions/setup-python@v6
3939
with:
4040
python-version: ${{ env.PYTHON_VERSION }}
4141
- name: Install uv

.github/workflows/docs-lint.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
fetch-depth: 0
2727

2828
- name: Setup Bun
29-
uses: oven-sh/setup-bun@v1
29+
uses: oven-sh/setup-bun@v2
3030
with:
3131
bun-version: latest
3232

.github/workflows/linter_require_ruff.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
- uses: actions/checkout@v4
1818
- name: Read Python version
1919
run: echo "PYTHON_VERSION=$(cat .python-version | tr -d '\n')" >> $GITHUB_ENV
20-
- uses: actions/setup-python@v5
20+
- uses: actions/setup-python@v6
2121
with:
2222
python-version: ${{ env.PYTHON_VERSION }}
2323
- name: Install uv

.github/workflows/nightly_tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
- name: Read Python version
2323
run: echo "PYTHON_VERSION=$(cat .python-version | tr -d '\n')" >> $GITHUB_ENV
2424
- name: Set up Python
25-
uses: actions/setup-python@v5
25+
uses: actions/setup-python@v6
2626
with:
2727
python-version: ${{ env.PYTHON_VERSION }}
2828
- name: Install uv

.github/workflows/vulture_dead_code.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
echo "Log Level: ${{ github.event.inputs.log_level }}"
3131
echo "Environment: ${{ github.event.inputs.environment }}"
3232
- name: Set up Python
33-
uses: actions/setup-python@v5
33+
uses: actions/setup-python@v6
3434
with:
3535
python-version: ${{ env.PYTHON_VERSION }}
3636
- name: Install uv

common/config_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class DefaultLlm(BaseModel):
1919
"""Default LLM configuration."""
2020

2121
default_model: str
22+
fallback_model: str | None = None
2223
default_temperature: float
2324
default_max_tokens: int
2425

common/global_config.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
model_name: gemini/gemini-2.0-flash
1+
model_name: gemini/gemini-3-flash
22
dot_global_config_health_check: true
33

44
example_parent:
@@ -8,7 +8,8 @@ example_parent:
88
# LLMs
99
########################################################
1010
default_llm:
11-
default_model: gemini/gemini-2.0-flash
11+
default_model: gemini/gemini-3-flash
12+
fallback_model: gemini/gemini-2.5-flash
1213
default_temperature: 0.5
1314
default_max_tokens: 100000
1415

utils/llm/dspy_inference.py

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import dspy
55
from langfuse.decorators import observe # type: ignore
6-
from litellm.exceptions import ServiceUnavailableError
6+
from litellm.exceptions import RateLimitError, ServiceUnavailableError
77
from loguru import logger as log
88
from tenacity import (
99
retry,
@@ -23,20 +23,24 @@ def __init__(
2323
tools: list[Callable[..., Any]] | None = None,
2424
observe: bool = True,
2525
model_name: str = global_config.default_llm.default_model,
26+
fallback_model_name: str | None = global_config.default_llm.fallback_model,
2627
temperature: float = global_config.default_llm.default_temperature,
2728
max_tokens: int = global_config.default_llm.default_max_tokens,
2829
max_iters: int = 5,
2930
) -> None:
3031
if tools is None:
3132
tools = []
3233

33-
api_key = global_config.llm_api_key(model_name)
34-
self.lm = dspy.LM(
35-
model=model_name,
36-
api_key=api_key,
37-
cache=global_config.llm_config.cache_enabled,
38-
temperature=temperature,
39-
max_tokens=max_tokens,
34+
self.lm = self._build_lm(model_name, temperature, max_tokens)
35+
self.fallback_model_name = (
36+
fallback_model_name
37+
if fallback_model_name and fallback_model_name != model_name
38+
else None
39+
)
40+
self.fallback_lm = (
41+
self._build_lm(self.fallback_model_name, temperature, max_tokens)
42+
if self.fallback_model_name is not None
43+
else None
4044
)
4145
if observe:
4246
# Initialize a LangFuseDSPYCallback and configure the LM instance for generation tracing
@@ -58,25 +62,61 @@ def __init__(
5862
self.inference_module
5963
)
6064

61-
@observe()
6265
@retry(
63-
retry=retry_if_exception_type(ServiceUnavailableError),
66+
retry=retry_if_exception_type((RateLimitError, ServiceUnavailableError)),
6467
stop=stop_after_attempt(global_config.llm_config.retry.max_attempts),
6568
wait=wait_exponential(
6669
multiplier=global_config.llm_config.retry.min_wait_seconds,
6770
max=global_config.llm_config.retry.max_wait_seconds,
6871
),
6972
before_sleep=lambda retry_state: log.warning(
70-
f"Retrying due to ServiceUnavailableError. Attempt {retry_state.attempt_number}"
73+
"Retrying due to LLM error "
74+
f"{retry_state.outcome.exception().__class__.__name__}. "
75+
f"Attempt {retry_state.attempt_number}"
7176
),
7277
)
78+
async def _run_with_retry(
79+
self,
80+
lm: dspy.LM,
81+
**kwargs: Any,
82+
) -> Any:
83+
return await self.inference_module_async(**kwargs, lm=lm)
84+
85+
def _build_lm(
86+
self,
87+
model_name: str,
88+
temperature: float,
89+
max_tokens: int,
90+
) -> dspy.LM:
91+
api_key = global_config.llm_api_key(model_name)
92+
return dspy.LM(
93+
model=model_name,
94+
api_key=api_key,
95+
cache=global_config.llm_config.cache_enabled,
96+
temperature=temperature,
97+
max_tokens=max_tokens,
98+
)
99+
100+
@observe()
73101
async def run(
74102
self,
75103
**kwargs: Any,
76104
) -> Any:
77105
try:
78106
# user_id is passed if the pred_signature requires it.
79-
result = await self.inference_module_async(**kwargs, lm=self.lm)
107+
result = await self._run_with_retry(self.lm, **kwargs)
108+
except (RateLimitError, ServiceUnavailableError) as e:
109+
if not self.fallback_lm:
110+
log.error(f"{e.__class__.__name__} without fallback: {str(e)}")
111+
raise
112+
log.warning(
113+
f"Primary model unavailable; falling back to {self.fallback_model_name}"
114+
)
115+
try:
116+
result = await self._run_with_retry(self.fallback_lm, **kwargs)
117+
except (RateLimitError, ServiceUnavailableError) as fallback_error:
118+
log.error(f"Fallback model failed: {fallback_error.__class__.__name__}")
119+
raise
80120
except (RuntimeError, ValueError, TypeError) as e:
81121
log.error(f"Error in run: {str(e)}")
82122
raise

0 commit comments

Comments
 (0)