Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion aide/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,12 @@ def parse_exec_result(self, node: Node, exec_result: ExecutionResult):
),
)

print("Response:", response)
# manually set the metric to None if it's not in the response
if "metric" not in response:
response["metric"] = None
# if the metric isn't a float then fill the metric with the worst metric
if not isinstance(response["metric"], float):
elif not isinstance(response["metric"], float):
response["metric"] = None

node.analysis = response["summary"]
Expand Down
22 changes: 22 additions & 0 deletions aide/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,28 @@
"gpt-4o-2024-08-06": {"input": 2.5 / 1000000, "output": 10 / 1000000},
"o3-mini-2025-01-31": {"input": 1.1 / 1000000, "output": 4.4 / 1000000},
"o3-2025-04-16": {"input": 10 / 1000000, "output": 40 / 1000000},
"o4-mini-2025-04-16": {"input": 1.1 / 1000000, "output": 4.4 / 1000000},
"gpt-4.1-2025-04-14": {"input": 2 / 1000000, "output": 8 / 1000000},
"gpt-4.1-mini-2025-04-14": {"input": 0.4 / 1000000, "output": 1.6 / 1000000},
"claude-opus-4-20250514": {"input": 15 / 1000000, "output": 75 / 1000000},
"claude-sonnet-4-20250514": {"input": 3 / 1000000, "output": 15 / 1000000},
"claude-3-7-sonnet-20250219": {"input": 3 / 1000000, "output": 15 / 1000000},
"claude-3-7-sonnet-20250219-think": {"input": 3 / 1000000, "output": 15 / 1000000},
"claude-3-5-sonnet-20241022": {"input": 3 / 1000000, "output": 15 / 1000000},
"claude-3-5-sonnet-20241022-think": {"input": 3 / 1000000, "output": 15 / 1000000},
"gemini-2.5-flash-preview-05-20": {
"input": 0.15 / 1000000,
"output": 3.5 / 1000000,
},
"gemini-2.5-pro-preview-06-05": {"input": 1.25 / 1000000, "output": 10 / 1000000},
"deepseek-reasoner": {"input": 0.55 / 1000000, "output": 2.19 / 1000000},
"deepseek-chat": {"input": 0.27 / 1000000, "output": 1.1 / 1000000},
"Llama-4-Maverick-17B-128E-Instruct-FP8": {
"input": 0 / 1000000,
"output": 0 / 1000000,
},
"Llama-3.3-8B-Instruct": {"input": 0 / 1000000, "output": 0 / 1000000},
"Llama-3.3-70B-Instruct": {"input": 0 / 1000000, "output": 0 / 1000000},
}


Expand All @@ -25,6 +44,8 @@ def determine_provider(model: str) -> str:
or model.startswith("o1-")
or model.startswith("o3-")
or model.startswith("o4-")
or model.startswith("deepseek-")
or model.startswith("Llama-")
):
return "openai"
elif model.startswith("claude-"):
Expand All @@ -41,6 +62,7 @@ def determine_provider(model: str) -> str:
"anthropic": backend_anthropic.query,
"gdm": backend_gdm.query,
"openrouter": backend_openrouter.query,
# "meta": backend_meta.query,
}


Expand Down
97 changes: 83 additions & 14 deletions aide/backend/backend_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Backend for Anthropic API."""

import time
import logging
import time

import anthropic
from .utils import FunctionSpec, OutputType, backoff_create, opt_messages_to_list
from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create
from funcy import notnone, once, select_values
import anthropic

logger = logging.getLogger("aide")

Expand All @@ -18,6 +18,11 @@
anthropic.InternalServerError,
)

ANTHROPIC_MODEL_ALIASES = {
"claude-3.5-sonnet": "claude-3-5-sonnet-20241022",
"claude-3.7-sonnet": "claude-3-7-sonnet-20250219",
}


@once
def _setup_anthropic_client():
Expand All @@ -32,45 +37,109 @@ def query(
convert_system_to_user: bool = False,
**model_kwargs,
) -> tuple[OutputType, float, int, int, dict]:
"""
Query Anthropic's API, optionally with tool use (Anthropic's equivalent to function calling).
"""
_setup_anthropic_client()

filtered_kwargs: dict = select_values(notnone, model_kwargs) # type: ignore
if "max_tokens" not in filtered_kwargs:
filtered_kwargs["max_tokens"] = 4096 # default for Claude models
if "claude-3-5-sonnet" in filtered_kwargs["model"]:
filtered_kwargs["max_tokens"] = 8192
else:
filtered_kwargs["max_tokens"] = 16000 # default for Claude models

model_name = filtered_kwargs.get("model", "")
logger.debug(f"Anthropic query called with model='{model_name}'")

if func_spec is not None:
raise NotImplementedError(
"Anthropic does not support function calling for now."
)
if model_name in ANTHROPIC_MODEL_ALIASES:
model_name = ANTHROPIC_MODEL_ALIASES[model_name]
filtered_kwargs["model"] = model_name
logger.debug(f"Using aliased model name: {model_name}")

# Anthropic doesn't allow not having a user messages
if func_spec is not None and func_spec.name == "submit_review":
filtered_kwargs["tools"] = [func_spec.as_anthropic_tool_dict]
# Force tool use
filtered_kwargs["tool_choice"] = func_spec.anthropic_tool_choice_dict

# Anthropic doesn't allow not having user messages
# if we only have system msg -> use it as user msg
if system_message is not None and user_message is None:
system_message, user_message = user_message, system_message

# Anthropic passes the system messages as a separate argument
# Anthropic passes system messages as a separate argument
if system_message is not None:
filtered_kwargs["system"] = system_message

messages = opt_messages_to_list(None, user_message)

think = False
if (
"claude-sonnet-4" in model_name or "claude-opus-4" in model_name
) and func_spec is not None:
if model_name.endswith("-think"):
think = True
print("interleaved thinking enabled...")
filtered_kwargs["extra_headers"] = {
"anthropic-beta": "interleaved-thinking-2025-05-14"
}

if (
"claude-sonnet-4" in model_name
or "claude-opus-4" in model_name
or "claude-3-7-sonnet" in model_name
) and func_spec is None:
if model_name.endswith("-think"):
think = True
print("extended thinking enabled...")
filtered_kwargs["thinking"] = {"type": "enabled", "budget_tokens": 10000}
filtered_kwargs["temperature"] = 1
if model_name.endswith("-think"):
# remove trailing '-think' from model name
filtered_kwargs["model"] = model_name[:-6]
t0 = time.time()
message = backoff_create(
_client.messages.create,
ANTHROPIC_TIMEOUT_EXCEPTIONS,
retry_exceptions=ANTHROPIC_TIMEOUT_EXCEPTIONS,
messages=messages,
**filtered_kwargs,
)
req_time = time.time() - t0

assert len(message.content) == 1 and message.content[0].type == "text"
# Handle tool calls if present
if (
func_spec is not None
and "tools" in filtered_kwargs
and len(message.content) > 0
and message.content[0].type == "tool_use"
):
block = message.content[0] # This is a "ToolUseBlock"
# block has attributes: type, id, name, input
assert (
block.name == func_spec.name
), f"Function name mismatch: expected {func_spec.name}, got {block.name}"
output = block.input # Anthropic calls the parameters "input"

# handle thinking if enabled
elif think:
for block in message.content:
if block.type == "thinking":
continue #! skip thinking blocks for now
elif block.type == "text":
output = block.text
else:
# For non-tool responses, ensure we have text content
assert len(message.content) == 1, "Expected single content item"
assert (
message.content[0].type == "text"
), f"Expected text response, got {message.content[0].type}"
output = message.content[0].text

output: str = message.content[0].text
in_tokens = message.usage.input_tokens
out_tokens = message.usage.output_tokens

info = {
"stop_reason": message.stop_reason,
"model": message.model,
}

return output, req_time, in_tokens, out_tokens, info
136 changes: 86 additions & 50 deletions aide/backend/backend_gdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@
import logging
import os

import google.api_core.exceptions
import google.generativeai as genai
from google.generativeai.generative_models import generation_types
from google.api_core import exceptions
from google import genai
from google.genai.types import (
HarmCategory,
HarmBlockThreshold,
SafetySetting,
GenerateContentConfig,
Tool,
)

# from google.generativeai.generative_models import generation_types

from funcy import once
from funcy import notnone, once, select_values
from .utils import FunctionSpec, OutputType, backoff_create

logger = logging.getLogger("aide")
Expand All @@ -17,39 +25,36 @@
generation_config = None # type: ignore

GDM_TIMEOUT_EXCEPTIONS = (
google.api_core.exceptions.RetryError,
google.api_core.exceptions.TooManyRequests,
google.api_core.exceptions.ResourceExhausted,
google.api_core.exceptions.InternalServerError,
exceptions.RetryError,
exceptions.TooManyRequests,
exceptions.ResourceExhausted,
exceptions.InternalServerError,
)
SAFETY_SETTINGS = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
safety_settings = [
SafetySetting(
category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
),
SafetySetting(
category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
),
]


@once
def _setup_gdm_client(model_name: str, temperature: float):
global gdm_model
def _setup_gdm_client():
global _client
global generation_config

genai.configure(api_key=os.environ["GEMINI_API_KEY"])
gdm_model = genai.GenerativeModel(model_name)
generation_config = genai.GenerationConfig(temperature=temperature)
_client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])


def query(
Expand All @@ -59,36 +64,67 @@ def query(
convert_system_to_user: bool = False,
**model_kwargs,
) -> tuple[OutputType, float, int, int, dict]:
model = model_kwargs.pop("model")
temperature = model_kwargs.pop("temperature", None)

_setup_gdm_client(model, temperature)
_setup_gdm_client()
filtered_kwargs: dict = select_values(notnone, model_kwargs) # type: ignore
if "max_tokens" not in filtered_kwargs:
filtered_kwargs["max_tokens"] = 65536 # default for Claude models
model = filtered_kwargs.get("model", "")
temperature = filtered_kwargs.get("temperature", None)
max_output_tokens = filtered_kwargs.get("max_output_tokens", None)

if func_spec is not None:
raise NotImplementedError(
"GDM supports function calling but we won't use it for now."
tools = [Tool(function_declarations=[func_spec.as_gdm_tool_dict])]
generation_config = GenerateContentConfig(
tools=tools,
temperature=temperature,
max_output_tokens=max_output_tokens,
safety_settings=safety_settings,
)
else:
generation_config = GenerateContentConfig(
temperature=temperature,
max_output_tokens=max_output_tokens,
safety_settings=safety_settings,
)

# GDM gemini api doesnt support system messages outside of the beta
messages = [
{"role": "user", "parts": message}
for message in [system_message, user_message]
if message
]
parts = []
if system_message:
parts.append({"text": system_message})
if user_message:
parts.append({"text": user_message})
messages = [{"role": "user", "parts": parts}] if parts else []

t0 = time.time()
response: generation_types.GenerateContentResponse = backoff_create(
gdm_model.generate_content,
response = backoff_create(
_client.models.generate_content,
retry_exceptions=GDM_TIMEOUT_EXCEPTIONS,
model=model,
contents=messages,
generation_config=generation_config,
safety_settings=SAFETY_SETTINGS,
config=generation_config,
)
req_time = time.time() - t0

if response.prompt_feedback.block_reason:
output = str(response.prompt_feedback)
# Check if the model responded with a function call
# iterate over parts of the response
function_call = None
for part in response.candidates[0].content.parts:
if part.function_call:
function_call = part.function_call
break
if func_spec is not None:
assert function_call is not None, "Function call not found in response"
func_name = function_call.name
assert (
func_name == func_spec.name
), f"Function name mismatch: expected {func_spec.name}, got {func_name}"
func_args = {key: value for key, value in function_call.args.items()}
func_args["function_name"] = func_name
output = func_args
else:
# if response.prompt_feedback and response.prompt_feedback.block_reason:
# output = response.prompt_feedback
# print(output)
# else:
# # Otherwise, return the text content
output = response.text
in_tokens = response.usage_metadata.prompt_token_count
out_tokens = response.usage_metadata.candidates_token_count
Expand Down
Loading