Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions ai_lab_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class LaboratoryWorkflow:
def __init__(self, research_topic, openai_api_key, max_steps=100, num_papers_lit_review=5, agent_model_backbone=f"{DEFAULT_LLM_BACKBONE}", notes=list(), human_in_loop_flag=None, compile_pdf=True, mlesolver_max_steps=3, papersolver_max_steps=5):
def __init__(self, research_topic, openai_api_key, max_steps=100, num_papers_lit_review=5, agent_model_backbone=f"{DEFAULT_LLM_BACKBONE}", notes=list(), human_in_loop_flag=None, compile_pdf=True, mlesolver_max_steps=3, papersolver_max_steps=5, openai_base_url="https://api.openai.com/v1"):
"""
Initialize laboratory workflow
@param research_topic: (str) description of research idea to explore
Expand All @@ -28,6 +28,7 @@ def __init__(self, research_topic, openai_api_key, max_steps=100, num_papers_lit
self.research_topic = research_topic
self.model_backbone = agent_model_backbone
self.num_papers_lit_review = num_papers_lit_review
self.openai_base_url = openai_base_url

self.print_cost = True
self.review_override = True # should review be overridden?
Expand Down Expand Up @@ -569,6 +570,13 @@ def parse_arguments():
help='Provide the OpenAI API key.'
)

parser.add_argument(
'--base-url',
type=str,
default="https://api.openai.com/v1",
help='Base URL for OpenAI API.'
)

parser.add_argument(
'--compile-latex',
type=str,
Expand Down Expand Up @@ -646,6 +654,11 @@ def parse_arguments():
if not api_key and not deepseek_api_key:
raise ValueError("API key must be provided via --api-key / -deepseek-api-key or the OPENAI_API_KEY / DEEPSEEK_API_KEY environment variable.")

base_url = os.getenv('OPENAI_BASE_URL') or args.base_url
if args.base_url is not None and os.getenv('OPENAI_BASE_URL') is None:
os.environ["OPENAI_BASE_URL"] = args.base_url


##########################################################
# Research question that the agents are going to explore #
##########################################################
Expand All @@ -662,7 +675,7 @@ def parse_arguments():
"note": "Please use gpt-4o-mini for your experiments."},

{"phases": ["running experiments"],
"note": f'Use the following code to inference gpt-4o-mini: \nfrom openai import OpenAI\nos.environ["OPENAI_API_KEY"] = "{api_key}"\nclient = OpenAI()\ncompletion = client.chat.completions.create(\nmodel="gpt-4o-mini-2024-07-18", messages=messages)\nanswer = completion.choices[0].message.content\n'},
"note": f'Use the following code to inference gpt-4o-mini: \nfrom openai import OpenAI\nos.environ["OPENAI_API_KEY"] = "{api_key}"\nopenai.base_url = {base_url}\nclient = OpenAI()\ncompletion = client.chat.completions.create(\nmodel="gpt-4o-mini-2024-07-18", messages=messages)\nanswer = completion.choices[0].message.content\n'},

{"phases": ["running experiments"],
"note": f"You have access to only gpt-4o-mini using the OpenAI API, please use the following key {api_key} but do not use too many inferences. Do not use openai.ChatCompletion.create or any openai==0.28 commands. Instead use the provided inference code."},
Expand Down Expand Up @@ -725,6 +738,7 @@ def parse_arguments():
num_papers_lit_review=num_papers_lit_review,
papersolver_max_steps=papersolver_max_steps,
mlesolver_max_steps=mlesolver_max_steps,
openai_base_url=base_url
)

lab.perform_research()
Expand Down
2 changes: 2 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def query_model(model_str, prompt, system_prompt, openai_api_key=None, anthropic
os.environ["OPENAI_API_KEY"] = openai_api_key
if anthropic_api_key is not None:
os.environ["ANTHROPIC_API_KEY"] = anthropic_api_key
base_url = os.getenv('OPENAI_BASE_URL', 'https://api.openai.com/v1')
openai.base_url = base_url
for _ in range(tries):
try:
if model_str == "gpt-4o-mini" or model_str == "gpt4omini" or model_str == "gpt-4omini" or model_str == "gpt4o-mini":
Expand Down