33
44import dspy
55from langfuse .decorators import observe # type: ignore
6- from litellm .exceptions import ServiceUnavailableError
6+ from litellm .exceptions import RateLimitError , ServiceUnavailableError
77from loguru import logger as log
88from tenacity import (
99 retry ,
@@ -23,20 +23,24 @@ def __init__(
2323 tools : list [Callable [..., Any ]] | None = None ,
2424 observe : bool = True ,
2525 model_name : str = global_config .default_llm .default_model ,
26+ fallback_model_name : str | None = global_config .default_llm .fallback_model ,
2627 temperature : float = global_config .default_llm .default_temperature ,
2728 max_tokens : int = global_config .default_llm .default_max_tokens ,
2829 max_iters : int = 5 ,
2930 ) -> None :
3031 if tools is None :
3132 tools = []
3233
33- api_key = global_config .llm_api_key (model_name )
34- self .lm = dspy .LM (
35- model = model_name ,
36- api_key = api_key ,
37- cache = global_config .llm_config .cache_enabled ,
38- temperature = temperature ,
39- max_tokens = max_tokens ,
34+ self .lm = self ._build_lm (model_name , temperature , max_tokens )
35+ self .fallback_model_name = (
36+ fallback_model_name
37+ if fallback_model_name and fallback_model_name != model_name
38+ else None
39+ )
40+ self .fallback_lm = (
41+ self ._build_lm (self .fallback_model_name , temperature , max_tokens )
42+ if self .fallback_model_name is not None
43+ else None
4044 )
4145 if observe :
4246 # Initialize a LangFuseDSPYCallback and configure the LM instance for generation tracing
@@ -58,25 +62,61 @@ def __init__(
5862 self .inference_module
5963 )
6064
61- @observe ()
6265 @retry (
63- retry = retry_if_exception_type (ServiceUnavailableError ),
66+ retry = retry_if_exception_type (( RateLimitError , ServiceUnavailableError ) ),
6467 stop = stop_after_attempt (global_config .llm_config .retry .max_attempts ),
6568 wait = wait_exponential (
6669 multiplier = global_config .llm_config .retry .min_wait_seconds ,
6770 max = global_config .llm_config .retry .max_wait_seconds ,
6871 ),
6972 before_sleep = lambda retry_state : log .warning (
70- f"Retrying due to ServiceUnavailableError. Attempt { retry_state .attempt_number } "
73+ "Retrying due to LLM error "
74+ f"{ retry_state .outcome .exception ().__class__ .__name__ } . "
75+ f"Attempt { retry_state .attempt_number } "
7176 ),
7277 )
78+ async def _run_with_retry (
79+ self ,
80+ lm : dspy .LM ,
81+ ** kwargs : Any ,
82+ ) -> Any :
83+ return await self .inference_module_async (** kwargs , lm = lm )
84+
85+ def _build_lm (
86+ self ,
87+ model_name : str ,
88+ temperature : float ,
89+ max_tokens : int ,
90+ ) -> dspy .LM :
91+ api_key = global_config .llm_api_key (model_name )
92+ return dspy .LM (
93+ model = model_name ,
94+ api_key = api_key ,
95+ cache = global_config .llm_config .cache_enabled ,
96+ temperature = temperature ,
97+ max_tokens = max_tokens ,
98+ )
99+
100+ @observe ()
73101 async def run (
74102 self ,
75103 ** kwargs : Any ,
76104 ) -> Any :
77105 try :
78106 # user_id is passed if the pred_signature requires it.
79- result = await self .inference_module_async (** kwargs , lm = self .lm )
107+ result = await self ._run_with_retry (self .lm , ** kwargs )
108+ except (RateLimitError , ServiceUnavailableError ) as e :
109+ if not self .fallback_lm :
110+ log .error (f"{ e .__class__ .__name__ } without fallback: { str (e )} " )
111+ raise
112+ log .warning (
113+ f"Primary model unavailable; falling back to { self .fallback_model_name } "
114+ )
115+ try :
116+ result = await self ._run_with_retry (self .fallback_lm , ** kwargs )
117+ except (RateLimitError , ServiceUnavailableError ) as fallback_error :
118+ log .error (f"Fallback model failed: { fallback_error .__class__ .__name__ } " )
119+ raise
80120 except (RuntimeError , ValueError , TypeError ) as e :
81121 log .error (f"Error in run: { str (e )} " )
82122 raise
0 commit comments