From bf68bd492fea1a1b11dee92efbfca09db6bbf3c2 Mon Sep 17 00:00:00 2001 From: Adam Newgas Date: Tue, 30 Sep 2025 11:03:16 +0000 Subject: [PATCH] OpenRouter chat model should support models that start "openrouter/" --- .env.example | 2 ++ .gitignore | 1 + safetytooling/apis/inference/openrouter.py | 16 +++++++++++----- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/.env.example b/.env.example index e2711f8..3538656 100644 --- a/.env.example +++ b/.env.example @@ -40,6 +40,8 @@ CACHE_DIR= REDIS_CACHE= # password for redis REDIS_PASSWORD= +# Set to true to disable caching +NO_CACHE= # For logging finetune metrics to wandb. WANDB_API_KEY= \ No newline at end of file diff --git a/.gitignore b/.gitignore index bf0f947..67c0c00 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ __pycache__ *.ipyc .conda .venv +build/ # Mac .DS_Store diff --git a/safetytooling/apis/inference/openrouter.py b/safetytooling/apis/inference/openrouter.py index 6841649..14822f5 100644 --- a/safetytooling/apis/inference/openrouter.py +++ b/safetytooling/apis/inference/openrouter.py @@ -117,7 +117,7 @@ def _extract_text_completion(self, generated_content: list[ChatMessage]) -> str: text_parts = [part for part in text_parts if part.strip() != ""] return "\n\n".join(text_parts) if text_parts else "" - async def _execute_tool_loop(self, messages, model_id, openai_tools, tools, **kwargs): + async def _execute_tool_loop(self, messages, model_id, real_model_id, openai_tools, tools, **kwargs): """Handle OpenAI-style tool execution loop and return all generated content.""" current_messages = messages.copy() all_generated_content = [] @@ -125,7 +125,7 @@ async def _execute_tool_loop(self, messages, model_id, openai_tools, tools, **kw while True: response_data = await self.aclient.chat.completions.create( messages=current_messages, - model=model_id, + model=real_model_id, tools=openai_tools, tool_choice="auto", **kwargs, @@ -256,6 +256,12 @@ async def __call__( if tools: openai_tools = convert_tools_to_openai(tools) + # Remove openrouter/ prefix if it exists + if model_id.startswith("openrouter/"): + real_model_id = model_id.split("/", 1)[1] + else: + real_model_id = model_id + start = time.time() prompt_file = self.create_prompt_history_file(prompt, model_id, self.prompt_history_dir) @@ -279,12 +285,12 @@ async def __call__( # Handle tool execution if tools are provided if openai_tools: response_data, generated_content = await self._execute_tool_loop( - prompt.openai_format(), model_id, openai_tools, tools, **kwargs + prompt.openai_format(), model_id, real_model_id, openai_tools, tools, **kwargs ) else: response_data = await self.aclient.chat.completions.create( messages=prompt.openai_format(), - model=model_id, + model=real_model_id, **kwargs, ) # Convert single response to ChatMessage @@ -308,7 +314,7 @@ async def __call__( ) ): # sometimes gemini will never return a response - if model_id == "google/gemini-2.0-flash-001": + if real_model_id == "google/gemini-2.0-flash-001": LOGGER.warn(f"Empty response from {model_id} (returning empty response)") return [ LLMResponse(