diff --git a/ai_lab_repo.py b/ai_lab_repo.py index dbe9541..0e26f17 100755 --- a/ai_lab_repo.py +++ b/ai_lab_repo.py @@ -11,7 +11,7 @@ class LaboratoryWorkflow: - def __init__(self, research_topic, openai_api_key, max_steps=100, num_papers_lit_review=5, agent_model_backbone=f"{DEFAULT_LLM_BACKBONE}", notes=list(), human_in_loop_flag=None, compile_pdf=True, mlesolver_max_steps=3, papersolver_max_steps=5): + def __init__(self, research_topic, openai_api_key, max_steps=100, num_papers_lit_review=5, agent_model_backbone=f"{DEFAULT_LLM_BACKBONE}", notes=list(), human_in_loop_flag=None, compile_pdf=True, mlesolver_max_steps=3, papersolver_max_steps=5, openai_base_url="https://api.openai.com/v1"): """ Initialize laboratory workflow @param research_topic: (str) description of research idea to explore @@ -28,6 +28,7 @@ def __init__(self, research_topic, openai_api_key, max_steps=100, num_papers_lit self.research_topic = research_topic self.model_backbone = agent_model_backbone self.num_papers_lit_review = num_papers_lit_review + self.openai_base_url = openai_base_url self.print_cost = True self.review_override = True # should review be overridden? @@ -569,6 +570,13 @@ def parse_arguments(): help='Provide the OpenAI API key.' ) + parser.add_argument( + '--base-url', + type=str, + default="https://api.openai.com/v1", + help='Base URL for OpenAI API.' + ) + parser.add_argument( '--compile-latex', type=str, @@ -646,6 +654,11 @@ def parse_arguments(): if not api_key and not deepseek_api_key: raise ValueError("API key must be provided via --api-key / -deepseek-api-key or the OPENAI_API_KEY / DEEPSEEK_API_KEY environment variable.") + base_url = os.getenv('OPENAI_BASE_URL') or args.base_url + if args.base_url is not None and os.getenv('OPENAI_BASE_URL') is None: + os.environ["OPENAI_BASE_URL"] = args.base_url + + ########################################################## # Research question that the agents are going to explore # ########################################################## @@ -662,7 +675,7 @@ def parse_arguments(): "note": "Please use gpt-4o-mini for your experiments."}, {"phases": ["running experiments"], - "note": f'Use the following code to inference gpt-4o-mini: \nfrom openai import OpenAI\nos.environ["OPENAI_API_KEY"] = "{api_key}"\nclient = OpenAI()\ncompletion = client.chat.completions.create(\nmodel="gpt-4o-mini-2024-07-18", messages=messages)\nanswer = completion.choices[0].message.content\n'}, + "note": f'Use the following code to inference gpt-4o-mini: \nfrom openai import OpenAI\nos.environ["OPENAI_API_KEY"] = "{api_key}"\nopenai.base_url = {base_url}\nclient = OpenAI()\ncompletion = client.chat.completions.create(\nmodel="gpt-4o-mini-2024-07-18", messages=messages)\nanswer = completion.choices[0].message.content\n'}, {"phases": ["running experiments"], "note": f"You have access to only gpt-4o-mini using the OpenAI API, please use the following key {api_key} but do not use too many inferences. Do not use openai.ChatCompletion.create or any openai==0.28 commands. Instead use the provided inference code."}, @@ -725,6 +738,7 @@ def parse_arguments(): num_papers_lit_review=num_papers_lit_review, papersolver_max_steps=papersolver_max_steps, mlesolver_max_steps=mlesolver_max_steps, + openai_base_url=base_url ) lab.perform_research() diff --git a/inference.py b/inference.py index d87ad9e..a36d72a 100755 --- a/inference.py +++ b/inference.py @@ -40,6 +40,8 @@ def query_model(model_str, prompt, system_prompt, openai_api_key=None, anthropic os.environ["OPENAI_API_KEY"] = openai_api_key if anthropic_api_key is not None: os.environ["ANTHROPIC_API_KEY"] = anthropic_api_key + base_url = os.getenv('OPENAI_BASE_URL', 'https://api.openai.com/v1') + openai.base_url = base_url for _ in range(tries): try: if model_str == "gpt-4o-mini" or model_str == "gpt4omini" or model_str == "gpt-4omini" or model_str == "gpt4o-mini":