Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ CACHE_DIR=
REDIS_CACHE=
# password for redis
REDIS_PASSWORD=
# Set to true to disable caching
NO_CACHE=

# For logging finetune metrics to wandb.
WANDB_API_KEY=
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ __pycache__
*.ipyc
.conda
.venv
build/

# Mac
.DS_Store
Expand Down
16 changes: 11 additions & 5 deletions safetytooling/apis/inference/openrouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,15 @@ def _extract_text_completion(self, generated_content: list[ChatMessage]) -> str:
text_parts = [part for part in text_parts if part.strip() != ""]
return "\n\n".join(text_parts) if text_parts else ""

async def _execute_tool_loop(self, messages, model_id, openai_tools, tools, **kwargs):
async def _execute_tool_loop(self, messages, model_id, real_model_id, openai_tools, tools, **kwargs):
"""Handle OpenAI-style tool execution loop and return all generated content."""
current_messages = messages.copy()
all_generated_content = []

while True:
response_data = await self.aclient.chat.completions.create(
messages=current_messages,
model=model_id,
model=real_model_id,
tools=openai_tools,
tool_choice="auto",
**kwargs,
Expand Down Expand Up @@ -256,6 +256,12 @@ async def __call__(
if tools:
openai_tools = convert_tools_to_openai(tools)

# Remove openrouter/ prefix if it exists
if model_id.startswith("openrouter/"):
real_model_id = model_id.split("/", 1)[1]
else:
real_model_id = model_id

start = time.time()
prompt_file = self.create_prompt_history_file(prompt, model_id, self.prompt_history_dir)

Expand All @@ -279,12 +285,12 @@ async def __call__(
# Handle tool execution if tools are provided
if openai_tools:
response_data, generated_content = await self._execute_tool_loop(
prompt.openai_format(), model_id, openai_tools, tools, **kwargs
prompt.openai_format(), model_id, real_model_id, openai_tools, tools, **kwargs
)
else:
response_data = await self.aclient.chat.completions.create(
messages=prompt.openai_format(),
model=model_id,
model=real_model_id,
**kwargs,
)
# Convert single response to ChatMessage
Expand All @@ -308,7 +314,7 @@ async def __call__(
)
):
# sometimes gemini will never return a response
if model_id == "google/gemini-2.0-flash-001":
if real_model_id == "google/gemini-2.0-flash-001":
LOGGER.warn(f"Empty response from {model_id} (returning empty response)")
return [
LLMResponse(
Expand Down
Loading