Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions ai_scientist/generate_ideas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import backoff
import requests

from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS
from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, validate_model_choice

S2_API_KEY = os.getenv("S2_API_KEY")

Expand Down Expand Up @@ -507,10 +507,9 @@ def check_idea_novelty(
)
parser.add_argument(
"--model",
type=str,
type=validate_model_choice,
default="gpt-4o-2024-05-13",
choices=AVAILABLE_LLMS,
help="Model to use for AI Scientist.",
help="Model to use (AVAILABLE_LLMS or openrouter/<provider>/<model>).",
)
parser.add_argument(
"--skip-idea-generation",
Expand Down
82 changes: 54 additions & 28 deletions ai_scientist/llm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import json
import os
import re
Expand All @@ -9,6 +10,28 @@
from google.generativeai.types import GenerationConfig

MAX_NUM_TOKENS = 4096
OPENROUTER_PREFIX = "openrouter/"


def is_openrouter_model(model_name: str) -> bool:
return model_name == "llama3.1-405b" or model_name.startswith(OPENROUTER_PREFIX)


def normalize_openrouter_model(model_name: str) -> str:
if model_name.startswith(OPENROUTER_PREFIX):
return model_name[len(OPENROUTER_PREFIX) :]
if model_name == "llama3.1-405b":
return "meta-llama/llama-3.1-405b-instruct"
return model_name


def validate_model_choice(model: str) -> str:
if model in AVAILABLE_LLMS or model.startswith(OPENROUTER_PREFIX):
return model
raise argparse.ArgumentTypeError(
"Model must be in AVAILABLE_LLMS or start with openrouter/<provider>/<model>."
)


AVAILABLE_LLMS = [
# Anthropic models
Expand All @@ -33,7 +56,7 @@
"o1-mini-2024-09-12",
"o3-mini",
"o3-mini-2025-01-31",
# OpenRouter models
# OpenRouter models (use openrouter/<provider>/<model> for others)
"llama3.1-405b",
# Anthropic Claude models via Amazon Bedrock
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
Expand Down Expand Up @@ -77,10 +100,11 @@ def get_batch_responses_from_llm(
if msg_history is None:
msg_history = []

if 'gpt' in model:
if is_openrouter_model(model):
new_msg_history = msg_history + [{"role": "user", "content": msg}]
api_model = normalize_openrouter_model(model)
response = client.chat.completions.create(
model=model,
model=api_model,
messages=[
{"role": "system", "content": system_message},
*new_msg_history,
Expand All @@ -89,16 +113,15 @@ def get_batch_responses_from_llm(
max_tokens=MAX_NUM_TOKENS,
n=n_responses,
stop=None,
seed=0,
)
content = [r.message.content for r in response.choices]
new_msg_history = [
new_msg_history + [{"role": "assistant", "content": c}] for c in content
]
elif model == "llama-3-1-405b-instruct":
elif "gpt" in model:
new_msg_history = msg_history + [{"role": "user", "content": msg}]
response = client.chat.completions.create(
model="meta-llama/llama-3.1-405b-instruct",
model=model,
messages=[
{"role": "system", "content": system_message},
*new_msg_history,
Expand All @@ -107,6 +130,7 @@ def get_batch_responses_from_llm(
max_tokens=MAX_NUM_TOKENS,
n=n_responses,
stop=None,
seed=0,
)
content = [r.message.content for r in response.choices]
new_msg_history = [
Expand Down Expand Up @@ -152,7 +176,23 @@ def get_response_from_llm(
if msg_history is None:
msg_history = []

if "claude" in model:
if is_openrouter_model(model):
new_msg_history = msg_history + [{"role": "user", "content": msg}]
api_model = normalize_openrouter_model(model)
response = client.chat.completions.create(
model=api_model,
messages=[
{"role": "system", "content": system_message},
*new_msg_history,
],
temperature=temperature,
max_tokens=MAX_NUM_TOKENS,
n=1,
stop=None,
)
content = response.choices[0].message.content
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
elif "claude" in model:
new_msg_history = msg_history + [
{
"role": "user",
Expand Down Expand Up @@ -214,21 +254,6 @@ def get_response_from_llm(
)
content = response.choices[0].message.content
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
elif model in ["meta-llama/llama-3.1-405b-instruct", "llama-3-1-405b-instruct"]:
new_msg_history = msg_history + [{"role": "user", "content": msg}]
response = client.chat.completions.create(
model="meta-llama/llama-3.1-405b-instruct",
messages=[
{"role": "system", "content": system_message},
*new_msg_history,
],
temperature=temperature,
max_tokens=MAX_NUM_TOKENS,
n=1,
stop=None,
)
content = response.choices[0].message.content
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
elif model in ["deepseek-chat", "deepseek-coder"]:
new_msg_history = msg_history + [{"role": "user", "content": msg}]
response = client.chat.completions.create(
Expand Down Expand Up @@ -326,6 +351,13 @@ def create_client(model):
client_model = model.split("/")[-1]
print(f"Using Vertex AI with model {client_model}.")
return anthropic.AnthropicVertex(), client_model
elif is_openrouter_model(model):
normalized_model = normalize_openrouter_model(model)
print(f"Using OpenRouter API with {normalized_model}.")
return openai.OpenAI(
api_key=os.environ["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1",
), f"{OPENROUTER_PREFIX}{normalized_model}"
elif 'gpt' in model or "o1" in model or "o3" in model:
print(f"Using OpenAI API with model {model}.")
return openai.OpenAI(), model
Expand All @@ -335,12 +367,6 @@ def create_client(model):
api_key=os.environ["DEEPSEEK_API_KEY"],
base_url="https://api.deepseek.com"
), model
elif model == "llama3.1-405b":
print(f"Using OpenAI API with {model}.")
return openai.OpenAI(
api_key=os.environ["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1"
), "meta-llama/llama-3.1-405b-instruct"
elif "gemini" in model:
print(f"Using OpenAI API with {model}.")
return openai.OpenAI(
Expand Down
12 changes: 6 additions & 6 deletions ai_scientist/perform_writeup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Optional, Tuple

from ai_scientist.generate_ideas import search_for_papers
from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS
from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, OPENROUTER_PREFIX, is_openrouter_model, normalize_openrouter_model, validate_model_choice


# GENERATE LATEX
Expand Down Expand Up @@ -523,10 +523,9 @@ def perform_writeup(
parser.add_argument("--no-writing", action="store_true", help="Only generate")
parser.add_argument(
"--model",
type=str,
type=validate_model_choice,
default="gpt-4o-2024-05-13",
choices=AVAILABLE_LLMS,
help="Model to use for AI Scientist.",
help="Model to use (AVAILABLE_LLMS or openrouter/<provider>/<model>).",
)
parser.add_argument(
"--engine",
Expand Down Expand Up @@ -558,8 +557,9 @@ def perform_writeup(
io = InputOutput(yes=True, chat_history_file=f"{folder_name}/{idea_name}_aider.txt")
if args.model == "deepseek-coder-v2-0724":
main_model = Model("deepseek/deepseek-coder")
elif args.model == "llama3.1-405b":
main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct")
elif is_openrouter_model(args.model):
normalized_model = normalize_openrouter_model(args.model)
main_model = Model(f"{OPENROUTER_PREFIX}{normalized_model}")
else:
main_model = Model(model)
coder = Coder.create(
Expand Down
29 changes: 19 additions & 10 deletions launch_scientist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from datetime import datetime

from ai_scientist.generate_ideas import generate_ideas, check_idea_novelty
from ai_scientist.llm import create_client, AVAILABLE_LLMS
from ai_scientist.llm import create_client, OPENROUTER_PREFIX, is_openrouter_model, normalize_openrouter_model, validate_model_choice
from ai_scientist.perform_experiments import perform_experiments
from ai_scientist.perform_review import perform_review, load_paper, perform_improvement
from ai_scientist.perform_writeup import perform_writeup, generate_latex
Expand Down Expand Up @@ -47,10 +47,9 @@ def parse_arguments():
)
parser.add_argument(
"--model",
type=str,
type=validate_model_choice,
default="claude-3-5-sonnet-20240620",
choices=AVAILABLE_LLMS,
help="Model to use for AI Scientist.",
help="Model to use (AVAILABLE_LLMS or openrouter/<provider>/<model>).",
)
parser.add_argument(
"--writeup",
Expand Down Expand Up @@ -129,6 +128,7 @@ def worker(
writeup,
improvement,
gpu_id,
engine,
):
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
print(f"Worker {gpu_id} started.")
Expand All @@ -145,7 +145,8 @@ def worker(
client_model,
writeup,
improvement,
log_file=True,
engine,
True,
)
print(f"Completed idea: {idea['Name']}, Success: {success}")
print(f"Worker {gpu_id} finished.")
Expand All @@ -161,6 +162,7 @@ def do_idea(
writeup,
improvement,
log_file=False,
engine,
):
## CREATE PROJECT FOLDER
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
Expand Down Expand Up @@ -202,8 +204,9 @@ def do_idea(
main_model = Model("deepseek/deepseek-coder")
elif model == "deepseek-reasoner":
main_model = Model("deepseek/deepseek-reasoner")
elif model == "llama3.1-405b":
main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct")
elif is_openrouter_model(model):
normalized_model = normalize_openrouter_model(model)
main_model = Model(f"{OPENROUTER_PREFIX}{normalized_model}")
else:
main_model = Model(model)
coder = Coder.create(
Expand Down Expand Up @@ -238,8 +241,9 @@ def do_idea(
main_model = Model("deepseek/deepseek-coder")
elif model == "deepseek-reasoner":
main_model = Model("deepseek/deepseek-reasoner")
elif model == "llama3.1-405b":
main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct")
elif is_openrouter_model(model):
normalized_model = normalize_openrouter_model(model)
main_model = Model(f"{OPENROUTER_PREFIX}{normalized_model}")
else:
main_model = Model(model)
coder = Coder.create(
Expand All @@ -251,7 +255,7 @@ def do_idea(
edit_format="diff",
)
try:
perform_writeup(idea, folder_name, coder, client, client_model, engine=args.engine)
perform_writeup(idea, folder_name, coder, client, client_model, engine=engine)
except Exception as e:
print(f"Failed to perform writeup: {e}")
return False
Expand Down Expand Up @@ -356,6 +360,9 @@ def do_idea(
model=client_model,
engine=args.engine,
)
else:
for idea in ideas:
idea["novel"] = True

with open(osp.join(base_dir, "ideas.json"), "w") as f:
json.dump(ideas, f, indent=4)
Expand Down Expand Up @@ -384,6 +391,7 @@ def do_idea(
args.writeup,
args.improvement,
gpu_id,
args.engine,
),
)
p.start()
Expand Down Expand Up @@ -411,6 +419,7 @@ def do_idea(
client_model,
args.writeup,
args.improvement,
args.engine,
)
print(f"Completed idea: {idea['Name']}, Success: {success}")
except Exception as e:
Expand Down