Skip to content

Commit 9b7aa60

Browse files
feat(memreader/LLM): add backup config for openai memreader (#1246)
feat: add backend for openai Co-authored-by: harvey_xiang <harvey_xiang@163.com>
1 parent 88efb42 commit 9b7aa60

File tree

4 files changed

+114
-45
lines changed

4 files changed

+114
-45
lines changed

src/memos/api/config.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -321,23 +321,40 @@ def get_activation_config() -> dict[str, Any]:
321321

322322
@staticmethod
323323
def get_memreader_config() -> dict[str, Any]:
324-
"""Get MemReader configuration for chat/doc extraction (fine-tuned 0.6B model)."""
325-
return {
326-
"backend": "openai",
327-
"config": {
328-
"model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"),
329-
"temperature": 0.6,
330-
"max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "8000")),
331-
"top_p": 0.95,
332-
"top_k": 20,
333-
"api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"),
334-
# Default to OpenAI base URL when env var is not provided to satisfy pydantic
335-
# validation requirements during tests/import.
336-
"api_base": os.getenv("MEMRADER_API_BASE", "https://api.openai.com/v1"),
337-
"remove_think_prefix": True,
338-
},
324+
"""Get MemReader configuration for chat/doc extraction (fine-tuned 0.6B model).
325+
326+
When MEMREADER_GENERAL_MODEL is configured (i.e. a separate stable LLM exists),
327+
the backup client is automatically enabled so that primary failures (self-deployed
328+
model) fall back to the general LLM.
329+
"""
330+
config = {
331+
"model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"),
332+
"temperature": 0.6,
333+
"max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "8000")),
334+
"top_p": 0.95,
335+
"top_k": 20,
336+
"api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"),
337+
# Default to OpenAI base URL when env var is not provided to satisfy pydantic
338+
# validation requirements during tests/import.
339+
"api_base": os.getenv("MEMRADER_API_BASE", "https://api.openai.com/v1"),
340+
"remove_think_prefix": True,
339341
}
340342

343+
general_model = os.getenv("MEMREADER_GENERAL_MODEL")
344+
enable_backup = os.getenv("MEMREADER_ENABLE_BACKUP", "false").lower() == "true"
345+
if general_model and enable_backup:
346+
config["backup_client"] = True
347+
config["backup_model_name_or_path"] = general_model
348+
config["backup_api_key"] = os.getenv(
349+
"MEMREADER_GENERAL_API_KEY", os.getenv("OPENAI_API_KEY", "EMPTY")
350+
)
351+
config["backup_api_base"] = os.getenv(
352+
"MEMREADER_GENERAL_API_BASE",
353+
os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
354+
)
355+
356+
return {"backend": "openai", "config": config}
357+
341358
@staticmethod
342359
def get_memreader_general_llm_config() -> dict[str, Any]:
343360
"""Get general LLM configuration for non-chat/doc tasks.

src/memos/configs/llm.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@ class OpenAILLMConfig(BaseLLMConfig):
2828
default="https://api.openai.com/v1", description="Base URL for OpenAI API"
2929
)
3030
extra_body: Any = Field(default=None, description="extra body")
31+
backup_client: bool = Field(
32+
default=False,
33+
description="Whether to enable backup client for fallback on primary failure",
34+
)
35+
backup_api_key: str | None = Field(
36+
default=None, description="API key for backup OpenAI-compatible endpoint"
37+
)
38+
backup_api_base: str | None = Field(
39+
default=None, description="Base URL for backup OpenAI-compatible endpoint"
40+
)
41+
backup_model_name_or_path: str | None = Field(
42+
default=None, description="Model name for backup endpoint"
43+
)
44+
backup_headers: dict[str, Any] | None = Field(
45+
default=None, description="Default headers for backup client requests"
46+
)
3147

3248

3349
class OpenAIResponsesLLMConfig(BaseLLMConfig):
@@ -42,22 +58,18 @@ class OpenAIResponsesLLMConfig(BaseLLMConfig):
4258
)
4359

4460

45-
class QwenLLMConfig(BaseLLMConfig):
46-
api_key: str = Field(..., description="API key for DashScope (Qwen)")
61+
class QwenLLMConfig(OpenAILLMConfig):
4762
api_base: str = Field(
4863
default="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
4964
description="Base URL for Qwen OpenAI-compatible API",
5065
)
51-
extra_body: Any = Field(default=None, description="extra body")
5266

5367

54-
class DeepSeekLLMConfig(BaseLLMConfig):
55-
api_key: str = Field(..., description="API key for DeepSeek")
68+
class DeepSeekLLMConfig(OpenAILLMConfig):
5669
api_base: str = Field(
5770
default="https://api.deepseek.com",
5871
description="Base URL for DeepSeek OpenAI-compatible API",
5972
)
60-
extra_body: Any = Field(default=None, description="Extra options for API")
6173

6274

6375
class AzureLLMConfig(BaseLLMConfig):

src/memos/llms/openai.py

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,39 @@ def __init__(self, config: OpenAILLMConfig):
2727
self.client = openai.Client(
2828
api_key=config.api_key, base_url=config.api_base, default_headers=config.default_headers
2929
)
30-
logger.info("OpenAI LLM instance initialized")
30+
self.use_backup_client = config.backup_client
31+
if self.use_backup_client:
32+
self.backup_client = openai.Client(
33+
api_key=config.backup_api_key,
34+
base_url=config.backup_api_base,
35+
default_headers=config.backup_headers,
36+
)
37+
logger.info(
38+
f"OpenAI LLM instance initialized with backup "
39+
f"(model={config.backup_model_name_or_path})"
40+
)
41+
else:
42+
self.backup_client = None
43+
logger.info("OpenAI LLM instance initialized")
44+
45+
def _parse_response(self, response) -> str:
46+
"""Extract text content from a chat completion response."""
47+
if not response.choices:
48+
logger.warning("OpenAI response has no choices")
49+
return ""
50+
51+
tool_calls = getattr(response.choices[0].message, "tool_calls", None)
52+
if isinstance(tool_calls, list) and len(tool_calls) > 0:
53+
return self.tool_call_parser(tool_calls)
54+
response_content = response.choices[0].message.content
55+
reasoning_content = getattr(response.choices[0].message, "reasoning_content", None)
56+
if isinstance(reasoning_content, str) and reasoning_content:
57+
reasoning_content = f"<think>{reasoning_content}</think>"
58+
if self.config.remove_think_prefix:
59+
return remove_thinking_tags(response_content)
60+
if reasoning_content:
61+
return reasoning_content + (response_content or "")
62+
return response_content or ""
3163

3264
@timed_with_status(
3365
log_prefix="OpenAI LLM",
@@ -50,29 +82,32 @@ def generate(self, messages: MessageList, **kwargs) -> str:
5082
start_time = time.perf_counter()
5183
logger.info(f"OpenAI LLM Request body: {request_body}")
5284

53-
response = self.client.chat.completions.create(**request_body)
54-
55-
cost_time = time.perf_counter() - start_time
56-
logger.info(
57-
f"Request body: {request_body}, Response from OpenAI: {response.model_dump_json()}, Cost time: {cost_time}"
58-
)
59-
60-
if not response.choices:
61-
logger.warning("OpenAI response has no choices")
62-
return ""
63-
64-
tool_calls = getattr(response.choices[0].message, "tool_calls", None)
65-
if isinstance(tool_calls, list) and len(tool_calls) > 0:
66-
return self.tool_call_parser(tool_calls)
67-
response_content = response.choices[0].message.content
68-
reasoning_content = getattr(response.choices[0].message, "reasoning_content", None)
69-
if isinstance(reasoning_content, str) and reasoning_content:
70-
reasoning_content = f"<think>{reasoning_content}</think>"
71-
if self.config.remove_think_prefix:
72-
return remove_thinking_tags(response_content)
73-
if reasoning_content:
74-
return reasoning_content + (response_content or "")
75-
return response_content or ""
85+
try:
86+
response = self.client.chat.completions.create(**request_body)
87+
cost_time = time.perf_counter() - start_time
88+
logger.info(
89+
f"Request body: {request_body}, Response from OpenAI: "
90+
f"{response.model_dump_json()}, Cost time: {cost_time}"
91+
)
92+
return self._parse_response(response)
93+
except Exception as e:
94+
if not self.use_backup_client:
95+
raise
96+
logger.warning(
97+
f"Primary LLM request failed with {type(e).__name__}: {e}, "
98+
f"falling back to backup client"
99+
)
100+
backup_body = {
101+
**request_body,
102+
"model": self.config.backup_model_name_or_path or request_body["model"],
103+
}
104+
backup_response = self.backup_client.chat.completions.create(**backup_body)
105+
cost_time = time.perf_counter() - start_time
106+
logger.info(
107+
f"Backup LLM request succeeded, Response: "
108+
f"{backup_response.model_dump_json()}, Cost time: {cost_time}"
109+
)
110+
return self._parse_response(backup_response)
76111

77112
@timed_with_status(
78113
log_prefix="OpenAI LLM Stream",

tests/configs/test_llm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ def test_openai_llm_config():
5656
"remove_think_prefix",
5757
"extra_body",
5858
"default_headers",
59+
"backup_client",
60+
"backup_api_key",
61+
"backup_api_base",
62+
"backup_model_name_or_path",
63+
"backup_headers",
5964
],
6065
)
6166

0 commit comments

Comments
 (0)