diff --git a/.gitignore b/.gitignore index ea6f4be..e01df78 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,11 @@ research_dir/* state_saves/* __pycache__/* Figure*.png -testrun.py \ No newline at end of file +testrun.py +/data/ +/.vscode +/settings/user_settings.json +/AgentLaboratoryWebUI +settings/task_note_llm_config.json +*.pyc +/.claude diff --git a/README.md b/README.md index 11710d7..42ccb97 100755 --- a/README.md +++ b/README.md @@ -33,95 +33,229 @@ ### πŸ‘Ύ Currently supported models -* **OpenAI**: o1, o1-preview, o1-mini, gpt-4o +* **OpenAI**: + * GPT-5 series: gpt-5.2, gpt-5.2-pro, gpt-5-mini + * GPT-4.1 series: gpt-4.1, gpt-4.1-mini, gpt-4.1-nano + * GPT-4o series: gpt-4o, gpt-4o-mini + * o-series (reasoning): o1, o1-preview, o1-mini, o3-mini, o4-mini * **DeepSeek**: deepseek-chat (deepseek-v3) +* **Anthropic**: + * Claude 4.5: claude-4.5-opus, claude-4.5-sonnet, claude-4.5-haiku + * Claude 4.1: claude-4.1-opus + * Claude 4: claude-4-opus, claude-4-sonnet + * Claude 3.x: claude-3-5-sonnet, claude-3-5-haiku, claude-3-7-sonnet +* **Google**: + * Gemini 3.0: gemini-3.0-pro, gemini-3.0-flash + * Gemini 2.5: gemini-2.5-pro, gemini-2.5-flash, gemini-2.5-flash-lite + * Gemini 2.0: gemini-2.0-flash +* **Ollama**: Any model that you can find in the [Ollama Website](https://ollama.com/search) To select a specific llm set the flag `--llm-backend="llm_model"` for example `--llm-backend="gpt-4o"` or `--llm-backend="deepseek-chat"`. Please feel free to add a PR supporting new models according to your need! ## πŸ–₯️ Installation -### Python venv option - -* We recommend using python 3.12 +> [!IMPORTANT] +> We recommend using python 3.12 1. **Clone the GitHub Repository**: Begin by cloning the repository using the command: -```bash -git clone git@github.com:SamuelSchmidgall/AgentLaboratory.git -``` + ```bash + git clone git@github.com:SamuelSchmidgall/AgentLaboratory.git + ``` 2. **Set up and Activate Python Environment** -```bash -python -m venv venv_agent_lab -``` -- Now activate this environment: -```bash -source venv_agent_lab/bin/activate -``` + + Python venv option + ```bash + python -m venv venv_agent_lab + source venv_agent_lab/bin/activate + ``` + + Conda option + ```bash + conda create -n agent_lab python=3.12 + conda activate agent_lab + ``` 3. **Install required libraries** -```bash -pip install -r requirements.txt -``` - -4. **Install pdflatex [OPTIONAL]** -```bash -sudo apt install pdflatex -``` -- This enables latex source to be compiled by the agents. -- **[IMPORTANT]** If this step cannot be run due to not having sudo access, pdf compiling can be turned off via running Agent Laboratory via setting the `--compile-latex` flag to false: `--compile-latex "false"` - - - -5. **Now run Agent Laboratory!** - -`python ai_lab_repo.py --api-key "API_KEY_HERE" --llm-backend "o1-mini" --research-topic "YOUR RESEARCH IDEA"` - -or, if you don't have pdflatex installed - -`python ai_lab_repo.py --api-key "API_KEY_HERE" --llm-backend "o1-mini" --research-topic "YOUR RESEARCH IDEA" --compile-latex "false"` - -### Co-Pilot mode - -To run Agent Laboratory in copilot mode, simply set the copilot-mode flag to `"true"` - -`python ai_lab_repo.py --api-key "API_KEY_HERE" --llm-backend "o1-mini" --research-topic "YOUR RESEARCH IDEA" --copilot-mode "true"` + ```bash + pip install -r requirements.txt + ``` + +4. **Install Higher Version of Gradio** + ```bash + pip install gradio==4.44.1 + ``` + +> [!NOTE] +> This is only required for the current version of web interface. +> We will move to `Flask` in the future for capability of the package. + +5. **Install pdflatex [OPTIONAL]** + + ## Instructions for Ubuntu: + ```bash + sudo apt install pdflatex + ``` + + If you find the package is not available, + you can install it via the following commands: + ```bash + sudo apt-get install texlive-latex-base + + sudo apt-get install texlive-fonts-recommended + sudo apt-get install texlive-fonts-extra + + sudo apt-get install texlive-latex-extra + sudo apt-get install texlive-science + ``` + + ## Instructions for Windows Users + + To have a TeX engine on Windows, you need to install the MikTeX software. Follow these steps: + + 1. Go to the [MikTeX download page](https://miktex.org/download). + 2. Download the installer for your version of Windows. + 3. Run the installer and follow the on-screen instructions to complete the installation. + 4. Once installed, you can verify the installation by opening a command prompt and typing `latex --version`. + + - This enables latex source to be compiled by the agents. +> [!IMPORTANT] +> If this step cannot be run due to not having sudo access, +pdf compiling can be turned off via running Agent Laboratory +via setting the `--compile-latex` flag to false: `--compile-latex "false"`. +Or you can disable by unchecked the `Compile LaTeX` option in the web interface. + +## πŸš€ Quick Start + +1. **Set up the configuration file** + + - You can set up the configuration file by editing the `config.py` file. + - See the [configuration file](./config.py) for more details. + +2. **Now run Agent Laboratory!** + + #### Basic Usage of Agent Laboratory in Web Interface + ```bash + python config_gradio.py + ``` + + #### Basic Usage of Agent Laboratory in CLI + + ##### 1. A simple command to run Agent Laboratory + ```bash + python ai_lab_repo.py --api-key "API_KEY_HERE" --llm-backend "o1-mini" --research-topic "YOUR RESEARCH IDEA" + ``` + + ##### 2. Available Configuration Options + + **API Keys:** + - `--api-key`: OpenAI API key or set to "ollama" for Ollama usage **(required)** + - `--deepseek-api-key`: DeepSeek API key + - `--google-api-key`: Google API key + - `--anthropic-api-key`: Anthropic API key + + **LLM Settings:** + - `--llm-backend`: Backend LLM to use (default: "o1-mini"), please ensure your model string is correct, here is some common models: + - OpenAI: "o1", "o1-preview", "o1-mini", "gpt-4o" + - DeepSeek: "deepseek-chat" (deepseek-v3) + - Anthropic: "claude-3-5-sonnet-latest", "claude-3-5-haiku-latest" + - Google: "gemini-2.0-flash", "gemini-2.0-flash" + - Ollama: Any model that you can find in the [Ollama Website](https://ollama.com/search) + - `--ollama-max-tokens`: Max tokens for OLLAMA (default: 2048), + + **Research Parameters:** + - `--research-topic`: Your research topic/idea or a open-ended question to ask, this **must be provided** + - `--language`: Operating language (default: "English") which will instruct the agents to perform research in your preferred language (Not fully supported yet) + - `--num-papers-lit-review`: Number of papers for literature review (default: 5) + - `--mlesolver-max-steps`: Steps for MLE solver (default: 3) + - `--papersolver-max-steps`: Steps for paper solver (default: 5) + + **Operation Modes:** + - `--copilot-mode`: Enable human interaction mode (default: "false"), you need check terminal for input in this mode + - `--compile-latex`: Enable LaTeX PDF compilation (default: "true"), **please ensure you have pdflatex installed** + + **State Management:** + - `--load-existing`: Load from existing state (default: "false") + - `--load-existing-path`: Path to load state from (e.g., "state_saves/results_interpretation.pkl") + +
+ πŸ“š Example Usage + + Basic run without PDF compilation: + ```bash + python ai_lab_repo.py --api-key "API_KEY_HERE" --llm-backend "o1-mini" --research-topic "YOUR RESEARCH IDEA" --compile-latex "false" + ``` + + Run in copilot mode: + ```bash + python ai_lab_repo.py --api-key "API_KEY_HERE" --llm-backend "o1-mini" --research-topic "YOUR RESEARCH IDEA" --copilot-mode "true" + ``` + + Run with custom solver steps and language: + ```bash + python ai_lab_repo.py --api-key "API_KEY_HERE" --llm-backend "o1-mini" --research-topic "YOUR RESEARCH IDEA" --mlesolver-max-steps "5" --papersolver-max-steps "7" --language "Spanish" + ``` + + Load from the existing state: + ```bash + python ai_lab_repo.py --api-key "API_KEY_HERE" --load-existing "true" --research-topic "YOUR RESEARCH IDEA" --load-existing-path "state_saves/results_interpretation.pkl" + ``` +
+ +> [!NOTE] +> You must at least provide an API key for use. +> Even when you run a local Ollama, you must provide an "ollama" string as the API key. + +> [!TIP] +> - Set the `--ollama-max-tokens` to the model real context length (Ex: 128000 for `qwen2.5:32b`) for much better performance. +> - Use the model that supports `tools` as the Agent Laboratory will instruct the model to output formatted code or actions (This is kinda needed for the current version of Agent Laboratory). ----- + ## Tips for better research outcomes #### [Tip #1] πŸ“ Make sure to write extensive notes! πŸ“ -**Writing extensive notes is important** for helping your agent understand what you're looking to accomplish in your project, as well as any style preferences. Notes can include any experiments you want the agents to perform, providing API keys, certain plots or figures you want included, or anything you want the agent to know when performing research. +**Writing extensive notes is important** for helping your agent understand what you're looking to accomplish in your project, +as well as any style preferences. Notes can include any experiments you want the agents to perform, providing API keys, certain plots or figures you want included, or anything you want the agent to know when performing research. -This is also your opportunity to let the agent know **what compute resources it has access to**, e.g. GPUs (how many, what type of GPU, how many GBs), CPUs (how many cores, what type of CPUs), storage limitations, and hardware specs. +This is also your opportunity to let the agent know **what compute resources it has access to**, +e.g. GPUs (how many, what type of GPU, how many GBs), CPUs (how many cores, what type of CPUs), storage limitations, and hardware specs. -In order to add notes, you must modify the task_notes_LLM structure inside of `ai_lab_repo.py`. Provided below is an example set of notes used for some of our experiments. +In order to add notes, you must modify the TASK_NOTE_LLM structure inside of `config.py`. +Provided below is an example set of notes used for some of our experiments. ``` -task_notes_LLM = [ +TASK_NOTE_LLM = [ {"phases": ["plan formulation"], "note": f"You should come up with a plan for TWO experiments."}, - {"phases": ["plan formulation", "data preparation", "running experiments"], + {"phases": ["plan formulation", "data preparation", "running experiments"], "note": "Please use gpt-4o-mini for your experiments."}, {"phases": ["running experiments"], - "note": f'Use the following code to inference gpt-4o-mini: \nfrom openai import OpenAI\nos.environ["OPENAI_API_KEY"] = "{api_key}"\nclient = OpenAI()\ncompletion = client.chat.completions.create(\nmodel="gpt-4o-mini-2024-07-18", messages=messages)\nanswer = completion.choices[0].message.content\n'}, + "note": 'Use the following code to inference gpt-4o-mini: \nfrom openai import OpenAI\nos.environ["OPENAI_API_KEY"] = "{{api_key}}"\nclient = OpenAI()\ncompletion = client.chat.completions.create(\nmodel="gpt-4o-mini-2024-07-18", messages=messages)\nanswer = completion.choices[0].message.content\n'}, {"phases": ["running experiments"], - "note": f"You have access to only gpt-4o-mini using the OpenAI API, please use the following key {api_key} but do not use too many inferences. Do not use openai.ChatCompletion.create or any openai==0.28 commands. Instead use the provided inference code."}, + "note": "You have access to only gpt-4o-mini using the OpenAI API, please use the following key {{api_key}} but do not use too many inferences. Do not use openai.ChatCompletion.create or any openai==0.28 commands. Instead use the provided inference code."}, {"phases": ["running experiments"], "note": "I would recommend using a small dataset (approximately only 100 data points) to run experiments in order to save time. Do not use much more than this unless you have to or are running the final tests."}, {"phases": ["data preparation", "running experiments"], - "note": "You are running on a MacBook laptop. You can use 'mps' with PyTorch"}, + "note": "You are running on a Ubuntu System. You can use 'cuda' with PyTorch"}, {"phases": ["data preparation", "running experiments"], "note": "Generate figures with very colorful and artistic design."}, - ] + + {"phases": ["literature review", "plan formulation", + "data preparation", "running experiments", + "results interpretation", "report writing", + "report refinement"], + "note": "You should always write in the following language to converse and to write the report {{language}}"} +] ``` -------- @@ -138,9 +272,22 @@ When resources are limited, **optimize by fine-tuning smaller models** on your s #### [Tip #3] βœ… You can load previous saves from checkpoints βœ… -**If you lose progress, internet connection, or if a subtask fails, you can always load from a previous state.** All of your progress is saved by default in the `state_saves` variable, which stores each individual checkpoint. Just pass the following arguments when running `ai_lab_repo.py` +**If you lose progress, internet connection, or if a subtask fails, you can always load from a previous state.** +All of your progress is saved by default in the `state_saves` variable, which stores each individual checkpoint. + +##### **For Web Interface** + +You can check out the `Resume Previous Research` section to load from a previous state. +By checking the `Load Existing Research State` flag and then select the stage you want to load from, you can easily load from a previous state. +If the state is not up-to-date, you can always click the `Refresh Saved States` button to refresh the saved states. + +##### **For CLI** -`python ai_lab_repo.py --api-key "API_KEY_HERE" --research-topic "YOUR RESEARCH IDEA" --llm-backend "o1-mini" --load-existing True --load-existing-path "state_saves/LOAD_PATH"` +Just pass the following arguments when running `ai_lab_repo.py` + +```bash +python ai_lab_repo.py --api-key "API_KEY_HERE" --research-topic "YOUR RESEARCH IDEA" --llm-backend "o1-mini" --load-existing True --load-existing-path "state_saves/LOAD_PATH" +``` ----- @@ -150,8 +297,11 @@ When resources are limited, **optimize by fine-tuning smaller models** on your s If you are running Agent Laboratory in a language other than English, no problem, just make sure to provide a language flag to the agents to perform research in your preferred language. Note that we have not extensively studied running Agent Laboratory in other languages, so be sure to report any problems you encounter. -For example, if you are running in Chinese: +##### **For Web Interface** +You can select the language in the dropdown menu. If the language you want is not available, you can edit the `config_gradio.py` file to add the language you want. +##### **For CLI** +If you are running in Chinese, you can run the following command: `python ai_lab_repo.py --api-key "API_KEY_HERE" --research-topic "YOUR RESEARCH IDEA (in your language)" --llm-backend "o1-mini" --language "δΈ­ζ–‡"` ---- diff --git a/agents.py b/agents.py index c6fd4cd..53195df 100755 --- a/agents.py +++ b/agents.py @@ -1,3 +1,5 @@ +import json + from utils import * from tools import * from inference import * diff --git a/ai_lab_repo.py b/ai_lab_repo.py index dbe9541..50aaadb 100755 --- a/ai_lab_repo.py +++ b/ai_lab_repo.py @@ -1,6 +1,7 @@ from agents import * from copy import copy from common_imports import * +from config import TASK_NOTE_LLM, CONFIG_HUMAN_IN_THE_LOOP, CONFIG_AGENT_MODELS from mlesolver import MLESolver from torch.backends.mkl import verbose @@ -11,7 +12,7 @@ class LaboratoryWorkflow: - def __init__(self, research_topic, openai_api_key, max_steps=100, num_papers_lit_review=5, agent_model_backbone=f"{DEFAULT_LLM_BACKBONE}", notes=list(), human_in_loop_flag=None, compile_pdf=True, mlesolver_max_steps=3, papersolver_max_steps=5): + def __init__(self, research_topic, openai_api_key, max_steps=100, num_papers_lit_review=5, agent_model_backbone=None, notes=None, human_in_loop_flag=None, compile_pdf=True, mlesolver_max_steps=3, papersolver_max_steps=5): """ Initialize laboratory workflow @param research_topic: (str) description of research idea to explore @@ -21,12 +22,15 @@ def __init__(self, research_topic, openai_api_key, max_steps=100, num_papers_lit @param notes: (list) notes for agent to follow during tasks """ + if notes is None: + notes = [] + self.notes = notes self.max_steps = max_steps self.compile_pdf = compile_pdf self.openai_api_key = openai_api_key self.research_topic = research_topic - self.model_backbone = agent_model_backbone + self.model_backbone = os.getenv('LLM_BACKEND') if os.getenv('LLM_BACKEND') is not None else DEFAULT_LLM_BACKBONE self.num_papers_lit_review = num_papers_lit_review self.print_cost = True @@ -35,6 +39,18 @@ def __init__(self, research_topic, openai_api_key, max_steps=100, num_papers_lit self.arxiv_paper_exp_time = 3 self.reference_papers = list() + + if agent_model_backbone is None: + agent_model_backbone = { + "literature review": self.model_backbone, + "plan formulation": self.model_backbone, + "data preparation": self.model_backbone, + "running experiments": self.model_backbone, + "results interpretation": self.model_backbone, + "report writing": self.model_backbone, + "report refinement": self.model_backbone, + } + ########################################## ####### COMPUTE BUDGET PARAMETERS ######## ########################################## @@ -61,9 +77,13 @@ def __init__(self, research_topic, openai_api_key, max_steps=100, num_papers_lit for subtask in subtasks: self.phase_models[subtask] = agent_model_backbone elif type(agent_model_backbone) == dict: - # todo: check if valid - self.phase_models = agent_model_backbone - + # Load models for each phase if key exists otherwise use the default model + for phase, subtasks in self.phases: + for subtask in subtasks: + if subtask in agent_model_backbone: + self.phase_models[subtask] = agent_model_backbone[subtask] + else: + self.phase_models[subtask] = self.model_backbone self.human_in_loop_flag = human_in_loop_flag @@ -147,7 +167,11 @@ def perform_research(self): if type(self.phase_models) == dict: if subtask in self.phase_models: self.set_model(self.phase_models[subtask]) - else: self.set_model(f"{DEFAULT_LLM_BACKBONE}") + elif os.getenv('LLM_BACKEND') is not None: + self.set_model(os.getenv('LLM_BACKEND')) + else: + print(f"Warning: Model for subtask {subtask} not found in phase_models dictionary or passed with argument. Using default model.") + self.set_model(f"{DEFAULT_LLM_BACKBONE}") if (subtask not in self.phase_status or not self.phase_status[subtask]) and subtask == "literature review": repeat = True while repeat: repeat = self.literature_review() @@ -225,7 +249,9 @@ def report_refinement(self): raise Exception("Model did not respond") response = response.lower().strip()[0] if response == "n": - if verbose: print("*"*40, "\n", "REVIEW COMPLETE", "\n", "*"*40) + if verbose: + print("*"*40, "\n", "REVIEW COMPLETE", "\n", "*"*40) + return False elif response == "y": self.set_agent_attr("reviewer_response", f"Provided are reviews from a set of three reviewers: {reviews}.") @@ -243,7 +269,7 @@ def report_writing(self): # instantiate mle-solver from papersolver import PaperSolver self.reference_papers = [] - solver = PaperSolver(notes=report_notes, max_steps=self.papersolver_max_steps, plan=lab.phd.plan, exp_code=lab.phd.results_code, exp_results=lab.phd.exp_results, insights=lab.phd.interpretation, lit_review=lab.phd.lit_review, ref_papers=self.reference_papers, topic=research_topic, openai_api_key=self.openai_api_key, llm_str=self.model_backbone["report writing"], compile_pdf=compile_pdf) + solver = PaperSolver(notes=report_notes, max_steps=self.papersolver_max_steps, plan=lab.phd.plan, exp_code=lab.phd.results_code, exp_results=lab.phd.exp_results, insights=lab.phd.interpretation, lit_review=lab.phd.lit_review, ref_papers=self.reference_papers, topic=research_topic, openai_api_key=self.openai_api_key, llm_str=self.phase_models["report writing"], compile_pdf=compile_pdf) # run initialization for solver solver.initial_solve() # run solver for N mle optimization steps @@ -307,7 +333,7 @@ def running_experiments(self): experiment_notes = [_note["note"] for _note in self.ml_engineer.notes if "running experiments" in _note["phases"]] experiment_notes = f"Notes for the task objective: {experiment_notes}\n" if len(experiment_notes) > 0 else "" # instantiate mle-solver - solver = MLESolver(dataset_code=self.ml_engineer.dataset_code, notes=experiment_notes, insights=self.ml_engineer.lit_review_sum, max_steps=self.mlesolver_max_steps, plan=self.ml_engineer.plan, openai_api_key=self.openai_api_key, llm_str=self.model_backbone["running experiments"]) + solver = MLESolver(dataset_code=self.ml_engineer.dataset_code, notes=experiment_notes, insights=self.ml_engineer.lit_review_sum, max_steps=self.mlesolver_max_steps, plan=self.ml_engineer.plan, openai_api_key=self.openai_api_key, llm_str=self.phase_models["running experiments"]) # run initialization for solver solver.initial_solve() # run solver for N mle optimization steps @@ -544,6 +570,18 @@ def parse_arguments(): help='Provide the DeepSeek API key.' ) + parser.add_argument( + '--google-api-key', + type=str, + help='Provide the Google API key.' + ) + + parser.add_argument( + '--anthropic-api-key', + type=str, + help='Provide the Anthropic API key.' + ) + parser.add_argument( '--load-existing', type=str, @@ -611,126 +649,168 @@ def parse_arguments(): help='Total number of paper-solver steps' ) + parser.add_argument( + '--ollama-max-tokens', + type=str, + default="2048", + help='Total number of tokens to use for OLLAMA' + ) + + parser.add_argument( + '--task-note-llm-config-file', + type=str, + help='Provide path to the task note LLM config file.' + ) return parser.parse_args() if __name__ == "__main__": - args = parse_arguments() - - llm_backend = args.llm_backend - human_mode = args.copilot_mode.lower() == "true" - compile_pdf = args.compile_latex.lower() == "true" - load_existing = args.load_existing.lower() == "true" - try: - num_papers_lit_review = int(args.num_papers_lit_review.lower()) - except Exception: - raise Exception("args.num_papers_lit_review must be a valid integer!") - try: - papersolver_max_steps = int(args.papersolver_max_steps.lower()) - except Exception: - raise Exception("args.papersolver_max_steps must be a valid integer!") try: - mlesolver_max_steps = int(args.mlesolver_max_steps.lower()) - except Exception: - raise Exception("args.papersolver_max_steps must be a valid integer!") - - - api_key = os.getenv('OPENAI_API_KEY') or args.api_key - deepseek_api_key = os.getenv('DEEPSEEK_API_KEY') or args.deepseek_api_key - if args.api_key is not None and os.getenv('OPENAI_API_KEY') is None: - os.environ["OPENAI_API_KEY"] = args.api_key - if args.deepseek_api_key is not None and os.getenv('DEEPSEEK_API_KEY') is None: - os.environ["DEEPSEEK_API_KEY"] = args.deepseek_api_key - - if not api_key and not deepseek_api_key: - raise ValueError("API key must be provided via --api-key / -deepseek-api-key or the OPENAI_API_KEY / DEEPSEEK_API_KEY environment variable.") - - ########################################################## - # Research question that the agents are going to explore # - ########################################################## - if human_mode or args.research_topic is None: - research_topic = input("Please name an experiment idea for AgentLaboratory to perform: ") - else: - research_topic = args.research_topic - - task_notes_LLM = [ - {"phases": ["plan formulation"], - "note": f"You should come up with a plan for TWO experiments."}, - - {"phases": ["plan formulation", "data preparation", "running experiments"], - "note": "Please use gpt-4o-mini for your experiments."}, - - {"phases": ["running experiments"], - "note": f'Use the following code to inference gpt-4o-mini: \nfrom openai import OpenAI\nos.environ["OPENAI_API_KEY"] = "{api_key}"\nclient = OpenAI()\ncompletion = client.chat.completions.create(\nmodel="gpt-4o-mini-2024-07-18", messages=messages)\nanswer = completion.choices[0].message.content\n'}, - - {"phases": ["running experiments"], - "note": f"You have access to only gpt-4o-mini using the OpenAI API, please use the following key {api_key} but do not use too many inferences. Do not use openai.ChatCompletion.create or any openai==0.28 commands. Instead use the provided inference code."}, - - {"phases": ["running experiments"], - "note": "I would recommend using a small dataset (approximately only 100 data points) to run experiments in order to save time. Do not use much more than this unless you have to or are running the final tests."}, - - {"phases": ["data preparation", "running experiments"], - "note": "You are running on a MacBook laptop. You can use 'mps' with PyTorch"}, - - {"phases": ["data preparation", "running experiments"], - "note": "Generate figures with very colorful and artistic design."}, - ] - - task_notes_LLM.append( - {"phases": ["literature review", "plan formulation", "data preparation", "running experiments", "results interpretation", "report writing", "report refinement"], - "note": f"You should always write in the following language to converse and to write the report {args.language}"}, - ) + args = parse_arguments() + + llm_backend = args.llm_backend + human_mode = args.copilot_mode.lower() == "true" + compile_pdf = args.compile_latex.lower() == "true" + load_existing = args.load_existing.lower() == "true" + try: + num_papers_lit_review = int(args.num_papers_lit_review.lower()) + except Exception: + raise Exception("args.num_papers_lit_review must be a valid integer!") + try: + papersolver_max_steps = int(args.papersolver_max_steps.lower()) + except Exception: + raise Exception("args.papersolver_max_steps must be a valid integer!") + try: + mlesolver_max_steps = int(args.mlesolver_max_steps.lower()) + except Exception: + raise Exception("args.papersolver_max_steps must be a valid integer!") + + # If using ollama, set the max tokens + if args.api_key is not None: + if args.api_key == "ollama": + try: + ollama_max_tokens = int(args.ollama_max_tokens.lower()) + os.environ["OLLAMA_MAX_TOKENS"] = str(ollama_max_tokens) + except Exception: + raise Exception("args.ollama_max_tokens must be a valid integer!") + + api_key = os.getenv('OPENAI_API_KEY') or args.api_key + deepseek_api_key = os.getenv('DEEPSEEK_API_KEY') or args.deepseek_api_key + google_api_key = os.getenv('GOOGLE_API_KEY') or args.google_api_key + anthropic_api_key = os.getenv('ANTHROPIC_API_KEY') or args.anthropic_api_key + if args.api_key is not None and os.getenv('OPENAI_API_KEY') is None: + os.environ["OPENAI_API_KEY"] = args.api_key + if args.deepseek_api_key is not None and os.getenv('DEEPSEEK_API_KEY') is None: + os.environ["DEEPSEEK_API_KEY"] = args.deepseek_api_key + if args.google_api_key is not None and os.getenv('GOOGLE_API_KEY') is None: + os.environ["GOOGLE_API_KEY"] = args.google_api_key + if args.anthropic_api_key is not None and os.getenv('ANTHROPIC_API_KEY') is None: + os.environ["ANTHROPIC_API_KEY"] = args.anthropic_api_key + + if not api_key and not deepseek_api_key and not google_api_key and not anthropic_api_key: + raise ValueError( + "API key must be provided via --api-key / -deepseek-api-key / --google-api-key / --anthropic-api-key argument " + "or the OPENAI_API_KEY / DEEPSEEK_API_KEY / GOOGLE_API_KEY / ANTHROPIC_API_KEY environment variable." + ) + + # Store the backend LLM to use for the agents + if not llm_backend: + raise ValueError("Please provide a valid LLM backend to use for the agents.") + + os.environ["LLM_BACKEND"] = llm_backend + print(f"Using {llm_backend} as the backend LLM for the agents.") + + ########################################################## + # Research question that the agents are going to explore # + ########################################################## + if human_mode or args.research_topic is None: + research_topic = input("Please name an experiment idea for AgentLaboratory to perform: ") + else: + research_topic = args.research_topic + + if args.task_note_llm_config_file is not None: + try: + with open(args.task_note_llm_config_file, "r") as f: + task_note_json = json.load(f) + + # Verify that the JSON file is in the correct format + if not validate_task_note_config(task_note_json): + raise ValueError("The task note LLM config file is not in the correct format.") + except Exception as e: + print(f"[Warning] Error loading the task note LLM config file: {e}") + # Use the default task note JSON + task_note_json = TASK_NOTE_LLM + else: + task_note_json = TASK_NOTE_LLM - #################################################### - ### Stages where human input will be requested ### - #################################################### - human_in_loop = { - "literature review": human_mode, - "plan formulation": human_mode, - "data preparation": human_mode, - "running experiments": human_mode, - "results interpretation": human_mode, - "report writing": human_mode, - "report refinement": human_mode, - } - - ################################################### - ### LLM Backend used for the different phases ### - ################################################### - agent_models = { - "literature review": llm_backend, - "plan formulation": llm_backend, - "data preparation": llm_backend, - "running experiments": llm_backend, - "report writing": llm_backend, - "results interpretation": llm_backend, - "paper refinement": llm_backend, - } - - if load_existing: - load_path = args.load_existing_path - if load_path is None: - raise ValueError("Please provide path to load existing state.") - with open(load_path, "rb") as f: - lab = pickle.load(f) - else: - lab = LaboratoryWorkflow( + task_notes_LLM = build_task_note( + task_note_json, research_topic=research_topic, - notes=task_notes_LLM, - agent_model_backbone=agent_models, - human_in_loop_flag=human_in_loop, - openai_api_key=api_key, - compile_pdf=compile_pdf, - num_papers_lit_review=num_papers_lit_review, - papersolver_max_steps=papersolver_max_steps, - mlesolver_max_steps=mlesolver_max_steps, + api_key=api_key, + deepseek_api_key=deepseek_api_key, + google_api_key=google_api_key, + anthropic_api_key=anthropic_api_key, + language=args.language, + llm_backend=llm_backend ) - lab.perform_research() - - - - - - + #################################################### + ### Stages where human input will be requested ### + #################################################### + human_in_loop = { + "literature review": human_mode, + "plan formulation": human_mode, + "data preparation": human_mode, + "running experiments": human_mode, + "results interpretation": human_mode, + "report writing": human_mode, + "report refinement": human_mode, + } + for phase, mode in human_in_loop.items(): + if phase not in CONFIG_HUMAN_IN_THE_LOOP: + continue + if type(CONFIG_HUMAN_IN_THE_LOOP[phase]) == bool: + human_in_loop[phase] = CONFIG_HUMAN_IN_THE_LOOP[phase] + + ################################################### + ### LLM Backend used for the different phases ### + ################################################### + agent_models = { + "literature review": llm_backend, + "plan formulation": llm_backend, + "data preparation": llm_backend, + "running experiments": llm_backend, + "report writing": llm_backend, + "results interpretation": llm_backend, + "report refinement": llm_backend, + } + for phase, model in agent_models.items(): + if CONFIG_AGENT_MODELS.get(phase) is None: + continue + if type(CONFIG_AGENT_MODELS[phase]) == str: + agent_models[phase] = CONFIG_AGENT_MODELS.get(phase) + + if load_existing: + load_path = args.load_existing_path + if load_path is None: + raise ValueError("Please provide path to load existing state.") + with open(load_path, "rb") as f: + lab = pickle.load(f) + else: + lab = LaboratoryWorkflow( + research_topic=research_topic, + notes=task_notes_LLM, + agent_model_backbone=agent_models, + human_in_loop_flag=human_in_loop, + openai_api_key=api_key, + compile_pdf=compile_pdf, + num_papers_lit_review=num_papers_lit_review, + papersolver_max_steps=papersolver_max_steps, + mlesolver_max_steps=mlesolver_max_steps, + ) + lab.perform_research() + except Exception as e: + input(f"An error occurred: {e}\nPress enter to exit.") + + input("The research project has been completed. Press enter to exit.") diff --git a/app.py b/app.py new file mode 100644 index 0000000..08cf72f --- /dev/null +++ b/app.py @@ -0,0 +1,340 @@ +import json +import os +import subprocess +import sys +from flask import Flask, request, jsonify, render_template +from flask_cors import CORS + +from settings_manager import SettingsManager +from config import TASK_NOTE_LLM +from utils import validate_task_note_config +from model_registry import ModelRegistry + + +# Check if the WebUI repository is cloned +def check_webui_cloned(): + if not os.path.exists("AgentLaboratoryWebUI"): + # Ask the user if they want to clone the repository + print("The WebUI repository is not cloned.") + answer = input("Would you like to clone it now from https://github.com/whats2000/AgentLaboratoryWebUI.git? (y/n) ") + + if answer.lower() != "y": + print("Error: The WebUI repository is not cloned. Please clone it manually.") + sys.exit(1) + + print("Cloning the WebUI repository...") + subprocess.run(["git", "clone", "https://github.com/whats2000/AgentLaboratoryWebUI.git"], check=True) + +# Check if Node.js is installed +def check_node_installed(): + try: + subprocess.run(["node", "--version"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + except (FileNotFoundError, subprocess.CalledProcessError): + print("Error: Node.js is not installed. Please install it from https://nodejs.org/") + sys.exit(1) + +# Check if Yarn is installed +def check_yarn_installed(): + try: + subprocess.run(["yarn", "--version"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + except (FileNotFoundError, subprocess.CalledProcessError): + # Ask the user if they want to install Yarn + print("Yarn is not installed.") + answer = input("Would you like to install it now? (y/n) ") + + if answer.lower() != "y": + print("Error: Yarn is not installed. Please install it manually.") + sys.exit(1) + + print("Installing Yarn...") + subprocess.run(["npm", "install", "-g", "yarn"], check=True) + +# Build the WebUI +def build_webui(): + webui_path = os.path.join(os.getcwd(), "AgentLaboratoryWebUI") + if not os.path.exists(os.path.join(webui_path, "dist")): + print("Building the WebUI...") + subprocess.run(["yarn", "install"], check=True, cwd=webui_path) + subprocess.run(["yarn", "build"], check=True, cwd=webui_path) + +# Run the checks and build the WebUI +check_webui_cloned() +check_node_installed() +check_yarn_installed() +build_webui() + +# Initialize the Flask app +app = Flask( + __name__, + static_url_path='', + static_folder='AgentLaboratoryWebUI/dist', + template_folder='AgentLaboratoryWebUI/dist' +) +CORS(app) + +# Define default values +DEFAULT_SETTINGS = { + "research_topic": "", + "api_key": "", + "deepseek_api_key": "", + "google_api_key": "", + "anthropic_api_key": "", + "llm_backend": "gpt-4o", + "custom_llm_backend": "", + "ollama_max_tokens": 2048, + "language": "English", + "copilot_mode": False, + "compile_latex": True, + "num_papers_lit_review": 5, + "mlesolver_max_steps": 3, + "papersolver_max_steps": 5, +} + +settings_manager = SettingsManager() + +def save_user_settings_from_dict(settings: dict): + """Save settings using the SettingsManager.""" + settings_manager.save_settings(settings) + +def load_user_settings(): + """Load settings using the SettingsManager.""" + return settings_manager.load_settings() + +def get_existing_saves(): + """ + Retrieve list of existing save files from the 'state_saves' directory. + """ + saves_dir = 'state_saves' + try: + os.makedirs(saves_dir, exist_ok=True) + saves = [f for f in os.listdir(saves_dir) if f.endswith('.pkl')] + return saves if saves else ["No saved states found"] + except Exception as e: + print(f"Error retrieving saves: {e}") + return ["No saved states found"] + +def run_research_process(data: dict) -> str: + """ + Execute the research process based on the provided settings. + This is adapted from your original function. + """ + # Unpack parameters from the incoming JSON payload. + research_topic = data.get('research_topic', '') + api_key = data.get('api_key', '') + llm_backend = data.get('llm_backend', 'o1-mini') + custom_llm_backend = data.get('custom_llm_backend', '') + ollama_max_tokens = data.get('ollama_max_tokens', 2048) + language = data.get('language', 'English') + copilot_mode = data.get('copilot_mode', False) + compile_latex = data.get('compile_latex', True) + num_papers_lit_review = data.get('num_papers_lit_review', 5) + mlesolver_max_steps = data.get('mlesolver_max_steps', 3) + papersolver_max_steps = data.get('papersolver_max_steps', 5) + deepseek_api_key = data.get('deepseek_api_key', '') + google_api_key = data.get('google_api_key', '') + anthropic_api_key = data.get('anthropic_api_key', '') + load_existing = data.get('load_existing', False) + load_existing_path = data.get('load_existing_path', '') + + # Choose backend based on the API key value. + if api_key.strip().lower() == "ollama": + chosen_backend = custom_llm_backend.strip() if custom_llm_backend.strip() else llm_backend + else: + chosen_backend = llm_backend + + # Prepare the command arguments. + cmd = [ + sys.executable, 'ai_lab_repo.py', + '--research-topic', research_topic, + '--llm-backend', chosen_backend, + '--language', language, + '--copilot-mode', str(copilot_mode).lower(), + '--compile-latex', str(compile_latex).lower(), + '--num-papers-lit-review', str(num_papers_lit_review), + '--mlesolver-max-steps', str(mlesolver_max_steps), + '--papersolver-max-steps', str(papersolver_max_steps) + ] + + # Append optional API keys if provided. + if api_key: + cmd.extend(['--api-key', api_key]) + if deepseek_api_key: + cmd.extend(['--deepseek-api-key', deepseek_api_key]) + if google_api_key: + cmd.extend(['--google-api-key', google_api_key]) + if anthropic_api_key: + cmd.extend(['--anthropic-api-key', anthropic_api_key]) + + # Require at least one valid API key. + if not (api_key or deepseek_api_key or google_api_key or anthropic_api_key): + return "**Error starting research process:** No valid API key provided. At least one API key is required." + + # Handle Ollama-specific requirements. + if api_key.strip().lower() == "ollama": + if not custom_llm_backend.strip(): + return "**Error starting research process:** Custom LLM Backend is required for Ollama. Enter a custom model string or select a standard model." + if not ollama_max_tokens: + return "**Error starting research process:** Custom Max Tokens for Ollama is required. Enter a valid integer value." + cmd.extend(['--ollama-max-tokens', str(int(ollama_max_tokens))]) + + # If loading an existing research state, add the flags. + if load_existing and load_existing_path and load_existing_path != "No saved states found": + cmd.extend([ + '--load-existing', 'True', + '--load-existing-path', os.path.join('state_saves', load_existing_path) + ]) + + # Append task note config if the config file exists. + if os.path.exists('settings/task_note_llm_config.json'): + cmd.extend(['--task-note-llm-config-file', 'settings/task_note_llm_config.json']) + + # Create a displayable command string. + command_str = ' '.join( + [arg if (arg == sys.executable or arg == "ai_lab_repo.py" or arg.startswith("--")) + else f'"{arg}"' + for arg in cmd] + ) + markdown_status = f"**Command created:**\n```\n{command_str}\n```\n" + + # Attempt to open a new terminal window and run the command. + try: + if sys.platform == 'win32': + subprocess.Popen(['start', 'cmd', '/k'] + cmd, shell=True) + elif sys.platform == 'darwin': + subprocess.Popen(['open', '-a', 'Terminal'] + cmd) + else: + subprocess.Popen(['x-terminal-emulator', '-e'] + cmd) + markdown_status += "\n**Research process started in a new terminal window.**" + except Exception as e: + markdown_status += f"\n**Error starting research process:** {e}" + + return markdown_status + +# Update the WebUI repository +def update_webui(): + """Pull the latest changes from the WebUI repository and rebuild if needed.""" + webui_path = os.path.join(os.getcwd(), "AgentLaboratoryWebUI") + + try: + # Check if we can pull updates + print("Checking for WebUI updates...") + result = subprocess.run( + ["git", "pull"], + check=True, + cwd=webui_path, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + + # If the output contains "Already up to date", no rebuild is needed + if "Already up to date" in result.stdout: + return {"status": "WebUI is already up to date", "updated": False} + + # If we got here, changes were pulled, rebuild the UI + print("Rebuilding WebUI after update...") + subprocess.run(["yarn", "install"], check=True, cwd=webui_path) + subprocess.run(["yarn", "build"], check=True, cwd=webui_path) + return {"status": "WebUI has been updated and rebuilt successfully", "updated": True} + + except subprocess.CalledProcessError as e: + error_message = f"Error updating WebUI: {e.stderr}" + print(error_message) + return {"status": error_message, "updated": False, "error": True} + except Exception as e: + error_message = f"Unexpected error updating WebUI: {str(e)}" + print(error_message) + return {"status": error_message, "updated": False, "error": True} + +@app.route("/") +@app.route("/config") +@app.route("/monitor") +def hello(): + return render_template("index.html") + +# Endpoint to start the research process. +@app.route('/api/research', methods=['POST']) +def api_research(): + data = request.get_json() + result = run_research_process(data) + return jsonify({"status": result}) + +# Endpoint to load or update settings. +@app.route('/api/settings', methods=['GET', 'POST']) +def api_settings(): + if request.method == 'GET': + settings = load_user_settings() + # Merge with defaults to ensure all keys are present. + merged_settings = DEFAULT_SETTINGS.copy() + merged_settings.update(settings or {}) + return jsonify(merged_settings) + elif request.method == 'POST': + settings = request.get_json() + save_user_settings_from_dict(settings) + return jsonify({"status": "Settings saved"}) + +# Endpoint to retrieve saved research states. +@app.route('/api/saves', methods=['GET']) +def api_saves(): + saves = get_existing_saves() + return jsonify({"saves": saves}) + +# Endpoint to retrieve available models from the registry. +@app.route('/api/models', methods=['GET']) +def api_models(): + reg = ModelRegistry(auto_refresh=False) + return jsonify({"models": reg.list_models()}) + +# Endpoint to update the WebUI +@app.route('/api/updateWebUI', methods=['POST']) +def api_update_webui(): + result = update_webui() + return jsonify(result) + +# Endpoint to manage the task note LLM configuration +@app.route('/api/task_note_config', methods=['GET', 'POST']) +def api_task_note_config(): + if request.method == 'GET': + config_file = os.path.join('settings', 'task_note_llm_config.json') + if os.path.exists(config_file): + try: + with open(config_file, 'r') as f: + config = json.load(f) + return jsonify(config) + except Exception as e: + return jsonify({ + "status": "error", + "message": f"Error loading task note config: {str(e)}" + }), 500 + else: + # Return default task note config from config.py + return jsonify(TASK_NOTE_LLM) + elif request.method == 'POST': + config = request.get_json() + + # Validate configuration before saving + if not validate_task_note_config(config): + return jsonify({ + "status": "error", + "message": "Invalid task note LLM configuration format" + }), 400 + + try: + os.makedirs('settings', exist_ok=True) + config_file = os.path.join('settings', 'task_note_llm_config.json') + # Save as task note JSON format + with open(config_file, 'w') as f: + json.dump(config, f, indent=2) + + return jsonify({ + "status": "success", + "message": "Task note LLM config saved successfully and set as current config" + }) + except Exception as e: + return jsonify({ + "status": "error", + "message": f"Error saving task note config: {str(e)}" + }), 500 + +if __name__ == '__main__': + app.run(debug=True, host='0.0.0.0', port=5000) diff --git a/common_imports.py b/common_imports.py index 7d968f3..8210852 100755 --- a/common_imports.py +++ b/common_imports.py @@ -40,6 +40,8 @@ import torch.optim as optim import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset, random_split +import torchvision +import torchaudio import tensorflow as tf #import keras diff --git a/config.py b/config.py new file mode 100644 index 0000000..48adbb1 --- /dev/null +++ b/config.py @@ -0,0 +1,103 @@ +""" +Configuration Guide for API Base URLs + +Note: You can set `OLLAMA_API_BASE_URL` to your Ollama API URL if you are using it. + You still need to set the `api_key` to the `ollama` when using it. (The `OpenAI API Key` field in the UI) + Because we use it to identify if we are using ollama providers + Then you can set any model string in the `args.llm_backend` flag or the `Custom LLM Backend (For Ollama)` field in the UI. + +Read more about Ollama: https://ollama.com/blog/openai-compatibility +""" +GOOGLE_GENERATIVE_API_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/" +DEEPSEEK_API_BASE_URL = "https://api.deepseek.com/v1" +OLLAMA_API_BASE_URL = "http://localhost:11434/v1/" + +""" +TASK_NOTE_LLM Configuration Guide + +# Phase Configuration +- phases need to one of the following: +- ["literature review", "plan formulation", + "data preparation", "running experiments", + "results interpretation", "report writing", + "report refinement"] + +--- +# Note Configuration +There are some variables that you can use in the note, you can use them by putting them in double curly braces. +Example: "You should write the report in {{language}}" + +Here are the available variables for common use: +- research_topic: The research topic of the task +- api_key: OpenAI API Key +- deepseek_api_key: Deepseek API Key +- google_api_key: Google API Key +- anthropic_api_key: Anthropic API Key +- language: The language to use for the report +- llm_backend: The backend to use for the LLM +""" +TASK_NOTE_LLM = [ + {"phases": ["plan formulation"], + "note": f"You should come up with a plan for TWO experiments."}, + + {"phases": ["plan formulation", "data preparation", "running experiments"], + "note": "Please use gpt-4o-mini for your experiments."}, + + {"phases": ["running experiments"], + "note": 'Use the following code to inference gpt-4o-mini: \nfrom openai import OpenAI\nos.environ["OPENAI_API_KEY"] = "{{api_key}}"\nclient = OpenAI()\ncompletion = client.chat.completions.create(\nmodel="gpt-4o-mini-2024-07-18", messages=messages)\nanswer = completion.choices[0].message.content\n'}, + + {"phases": ["running experiments"], + "note": "You have access to only gpt-4o-mini using the OpenAI API, please use the following key {{api_key}} but do not use too many inferences. Do not use openai.ChatCompletion.create or any openai==0.28 commands. Instead use the provided inference code."}, + + {"phases": ["running experiments"], + "note": "I would recommend using a small dataset (approximately only 100 data points) to run experiments in order to save time. Do not use much more than this unless you have to or are running the final tests."}, + + {"phases": ["data preparation", "running experiments"], + "note": "You are running on a Ubuntu System. You can use 'cuda' with PyTorch"}, + + {"phases": ["data preparation", "running experiments"], + "note": "Generate figures with very colorful and artistic design."}, + + {"phases": ["literature review", "plan formulation", + "data preparation", "running experiments", + "results interpretation", "report writing", + "report refinement"], + "note": "You should always write in the following language to converse and to write the report {{language}}"} +] + +""" +Human-in-the-Loop Configuration Guide + +You can configure Stages where human input will be requested. +- If the value is `True`, the stage will be in human mode. +- If the value is `False`, the stage will be in AI mode. +- If you set to `None`, the stage will take the configuration from the `args.copilot_mode` flag. + (The `Enable Human-in-Loop Mode` checkbox in the UI) +""" +CONFIG_HUMAN_IN_THE_LOOP = { + "literature review": None, + "plan formulation": None, + "data preparation": None, + "running experiments": None, + "results interpretation": None, + "report writing": None, + "report refinement": None, +} + +""" +Agent Models Configuration Guide + +You can configure the LLM Backend used for the different phases. +- If the value is a string, the stage will use the specified backend. +- If the value is `None`, the stage will take the configuration from the `args.llm_backend` flag. + (Or whatever model you select or set in the UI) +""" +CONFIG_AGENT_MODELS = { + "literature review": None, + "plan formulation": None, + "data preparation": None, + "running experiments": None, + "results interpretation": None, + "report writing": None, + "report refinement": None, +} diff --git a/config_gradio.py b/config_gradio.py new file mode 100644 index 0000000..d68a04e --- /dev/null +++ b/config_gradio.py @@ -0,0 +1,373 @@ +import os +import subprocess +import sys +from typing import Any +import gradio as gr + +from settings_manager import SettingsManager + +# Define default values +DEFAULT_SETTINGS = { + "research_topic": "", + "api_key": "", + "deepseek_api_key": "", + "google_api_key": "", + "anthropic_api_key": "", + "llm_backend": "o1-mini", + "custom_llm_backend": "", + "ollama_max_tokens": 2048, + "language": "English", + "copilot_mode": False, + "compile_latex": True, + "num_papers_lit_review": 5, + "mlesolver_max_steps": 3, + "papersolver_max_steps": 5, +} + +settings_manager = SettingsManager() + +def save_user_settings( + research_topic, api_key, llm_backend, custom_llm_backend, + ollama_max_tokens, language, copilot_mode, compile_latex, + num_papers_lit_review, mlesolver_max_steps, papersolver_max_steps, + deepseek_api_key, google_api_key, anthropic_api_key +): + """Save current UI settings""" + settings = { + "research_topic": research_topic, + "api_key": api_key, + "deepseek_api_key": deepseek_api_key, + "google_api_key": google_api_key, + "anthropic_api_key": anthropic_api_key, + "llm_backend": llm_backend, + "custom_llm_backend": custom_llm_backend, + "ollama_max_tokens": ollama_max_tokens, + "language": language, + "copilot_mode": copilot_mode, + "compile_latex": compile_latex, + "num_papers_lit_review": num_papers_lit_review, + "mlesolver_max_steps": mlesolver_max_steps, + "papersolver_max_steps": papersolver_max_steps, + } + settings_manager.save_settings(settings) + +def load_user_settings(): + """Load saved UI settings""" + settings = settings_manager.load_settings() + return settings + +def get_existing_saves() -> list: + """Retrieve list of existing save files from state_saves directory.""" + saves_dir = 'state_saves' + try: + os.makedirs(saves_dir, exist_ok=True) + # List all .pkl files in the directory + saves = [f for f in os.listdir(saves_dir) if f.endswith('.pkl')] + return saves if saves else ["No saved states found"] + except Exception as e: + print(f"Error retrieving saves: {e}") + return ["No saved states found"] + + +def refresh_saves_dropdown(): + """ + IMPORTANT PART: + Return a *new* gr.Dropdown component populated with fresh choices. + This replaces the existing dropdown instead of attempting to update it. + """ + new_saves = get_existing_saves() + return gr.Dropdown( + choices=new_saves, + label="Select Saved Research State", + interactive=True + ) + + +def run_research_process( + research_topic: str, + api_key: str, + llm_backend: str, + custom_llm_backend: str, + ollama_max_tokens: Any, + language: str, + copilot_mode: bool, + compile_latex: bool, + num_papers_lit_review: Any, # Gradio numbers may come as float + mlesolver_max_steps: Any, + papersolver_max_steps: Any, + deepseek_api_key: str = "", + google_api_key: str = "", + anthropic_api_key: str = "", + load_existing: bool = False, + load_existing_path: str = "" +) -> str: + # Determine which LLM backend to use: + if api_key.strip().lower() == "ollama": + chosen_backend = custom_llm_backend.strip() if custom_llm_backend.strip() else llm_backend + else: + chosen_backend = llm_backend + + # Prepare the command arguments + cmd = [ + sys.executable, 'ai_lab_repo.py', + '--research-topic', research_topic, + '--llm-backend', chosen_backend, + '--language', language, + '--copilot-mode', str(copilot_mode).lower(), + '--compile-latex', str(compile_latex).lower(), + '--num-papers-lit-review', str(num_papers_lit_review), + '--mlesolver-max-steps', str(mlesolver_max_steps), + '--papersolver-max-steps', str(papersolver_max_steps) + ] + + # Add optional API keys if provided + if api_key: + cmd.extend(['--api-key', api_key]) + if deepseek_api_key: + cmd.extend(['--deepseek-api-key', deepseek_api_key]) + if google_api_key: + cmd.extend(['--google-api-key', google_api_key]) + if anthropic_api_key: + cmd.extend(['--anthropic-api-key', anthropic_api_key]) + + # Valid API keys are required for the research process to start + if not api_key and not deepseek_api_key and not google_api_key and not anthropic_api_key: + return "**Error starting research process:** No valid API key provided. At least one API key is required." + + # For Ollama, require a custom LLM backend and add the custom max tokens + if api_key.strip().lower() == "ollama": + if not custom_llm_backend.strip(): + return "**Error starting research process:** Custom LLM Backend is required for Ollama. Enter a custom model string or select a standard model." + if not ollama_max_tokens: + return "**Error starting research process:** Custom Max Tokens for Ollama is required. Enter a valid integer value." + # Append custom max tokens for Ollama + cmd.extend(['--ollama-max-tokens', str(int(ollama_max_tokens))]) + + # Add load existing flags if selected + if load_existing and load_existing_path and load_existing_path != "No saved states found": + cmd.extend([ + '--load-existing', 'True', + '--load-existing-path', os.path.join('state_saves', load_existing_path) + ]) + + # Create a string version of the command for display purposes. + command_str = ' '.join([ + arg if (arg == sys.executable or arg == "ai_lab_repo.py" or arg.startswith("--")) + else f'"{arg}"' + for arg in cmd + ]) + + # Build the Markdown status message with the created command + markdown_status = f"""**Command created:** +
+ Click to view the command created +
{command_str}
+
+""" + + # Now attempt to open a new terminal window with the research process + try: + if sys.platform == 'win32': + subprocess.Popen(['start', 'cmd', '/k'] + cmd, shell=True) + elif sys.platform == 'darwin': + subprocess.Popen(['open', '-a', 'Terminal'] + cmd) + else: + subprocess.Popen(['x-terminal-emulator', '-e'] + cmd) + markdown_status += "\n**Research process started in a new terminal window.**" + except Exception as e: + markdown_status += f"\n**Error starting research process:** {e}" + + return markdown_status + + +def create_gradio_config() -> gr.Blocks: + # Populate backend options from model registry + from model_registry import ModelRegistry + _registry = ModelRegistry(auto_refresh=False) + llm_backend_options = _registry.list_models() + languages = [ + "English", "Chinese-Simplified", "Chinese-Traditional", + "Japanese", "Korean", "Filipino", "French", + "Slovak", "Portuguese", "Spanish", "Turkish", "Hindi", "Bengali", + "Vietnamese", "Russian", "Arabic", "Farsi", "Italian" + ] + + with gr.Blocks() as demo: + gr.Markdown("# Agent Laboratory Configuration") + + with gr.Row(): + with gr.Column(): + gr.Markdown("## Basic Configuration") + research_topic = gr.Textbox( + label="Research Topic", + placeholder="Enter your research idea...", + lines=3 + ) + api_key = gr.Textbox( + label="OpenAI API Key", + type="password", + placeholder="Enter your OpenAI API key (for Ollama, set API key to 'ollama', must be set)" + ) + deepseek_api_key = gr.Textbox( + label="DeepSeek API Key (Optional)", + type="password", + placeholder="Enter your DeepSeek API key if using DeepSeek model" + ) + google_api_key = gr.Textbox( + label="Google API Key (Optional)", + type="password", + placeholder="Enter your Google API key if using Google models" + ) + anthropic_api_key = gr.Textbox( + label="Anthropic API Key (Optional)", + type="password", + placeholder="Enter your Anthropic API key if using Anthropic models" + ) + + with gr.Row(): + # Dropdown for standard LLM backend options + llm_backend = gr.Dropdown( + choices=llm_backend_options, + label="LLM Backend", + value="o1-mini" + ) + language = gr.Dropdown( + choices=languages, + label="Language", + value="English" + ) + + # Custom LLM Backend textbox for Ollama. + # This is optional and only used when API key is set to "ollama". + custom_llm_backend = gr.Textbox( + label="Custom LLM Backend (For Ollama)", + placeholder="Enter your custom model string (optional)", + value="" + ) + # Custom max tokens for Ollama + ollama_max_tokens = gr.Number( + label="Custom Max Tokens for Ollama", + value=2048, + precision=0, + info="Set the maximum tokens for the Ollama model (only used if API key is 'ollama')" + ) + + with gr.Column(): + gr.Markdown("## Advanced Configuration") + with gr.Accordion(label="Instructions for Use:", open=True): + gr.Markdown( + """ + - Fill in the research configuration. + - Optionally load a previous research state. + - **For standard models:** Select the desired backend from the dropdown. + - **For Ollama:** Set the API key to `ollama` and, if needed, enter your custom model string and max tokens in the **Custom LLM Backend** and **Custom Max Tokens for Ollama** fields. + - If the custom field is left empty when using Ollama, the dropdown value will be used. + - Click **Start Research in Terminal**. + - A new terminal window will open with the research process. + """ + ) + + # Configuration for the research process + with gr.Row(): + copilot_mode = gr.Checkbox(label="Enable Human-in-Loop Mode") + compile_latex = gr.Checkbox(label="Compile LaTeX", value=True) + + with gr.Row(): + num_papers_lit_review = gr.Number( + label="Papers in Literature Review", + value=5, precision=0, minimum=1, maximum=20 + ) + mlesolver_max_steps = gr.Number( + label="MLE Solver Max Steps", + value=3, precision=0, minimum=1, maximum=10 + ) + papersolver_max_steps = gr.Number( + label="Paper Solver Max Steps", + value=5, precision=0, minimum=1, maximum=10 + ) + + # Saved States Section + with gr.Accordion("Resume Previous Research", open=False): + load_existing = gr.Checkbox(label="Load Existing Research State") + existing_saves = gr.Dropdown( + choices=get_existing_saves(), + label="Select Saved Research State", + interactive=True + ) + refresh_saves_btn = gr.Button("Refresh Saved States") + + submit_btn = gr.Button("Start Research in Terminal", variant="primary") + + with gr.Accordion(label="Status", open=True): + # Connect submit button to the research process. + # Output is gr.Markdown so that the returned Markdown is rendered. + submit_btn.click( + fn=run_research_process, + inputs=[ + research_topic, api_key, llm_backend, custom_llm_backend, ollama_max_tokens, language, + copilot_mode, compile_latex, num_papers_lit_review, + mlesolver_max_steps, papersolver_max_steps, + deepseek_api_key, google_api_key, anthropic_api_key, + load_existing, existing_saves, + ], + outputs=gr.Markdown() + ) + + # Instead of returning just a list, return a new Dropdown from refresh_saves_dropdown() + refresh_saves_btn.click( + fn=refresh_saves_dropdown, + inputs=None, + outputs=existing_saves + ) + + # Load saved settings when initializing components + saved_settings = load_user_settings() + + # Define default values + default_settings = DEFAULT_SETTINGS.copy() + # Update default settings with saved settings + default_settings.update(saved_settings) + + # Update component default values with settings + research_topic.value = default_settings["research_topic"] + api_key.value = default_settings["api_key"] + deepseek_api_key.value = default_settings["deepseek_api_key"] + google_api_key.value = default_settings["google_api_key"] + anthropic_api_key.value = default_settings["anthropic_api_key"] + llm_backend.value = default_settings["llm_backend"] + custom_llm_backend.value = default_settings["custom_llm_backend"] + ollama_max_tokens.value = default_settings["ollama_max_tokens"] + language.value = default_settings["language"] + copilot_mode.value = default_settings["copilot_mode"] + compile_latex.value = default_settings["compile_latex"] + num_papers_lit_review.value = default_settings["num_papers_lit_review"] + mlesolver_max_steps.value = default_settings["mlesolver_max_steps"] + papersolver_max_steps.value = default_settings["papersolver_max_steps"] + + # Add change handlers to save settings when values change + for component in [ + research_topic, api_key, llm_backend, custom_llm_backend, + ollama_max_tokens, language, copilot_mode, compile_latex, + num_papers_lit_review, mlesolver_max_steps, papersolver_max_steps, + deepseek_api_key, google_api_key, anthropic_api_key + ]: + component.change( + fn=save_user_settings, + inputs=[ + research_topic, api_key, llm_backend, custom_llm_backend, + ollama_max_tokens, language, copilot_mode, compile_latex, + num_papers_lit_review, mlesolver_max_steps, papersolver_max_steps, + deepseek_api_key, google_api_key, anthropic_api_key + ] + ) + + return demo + + +def main(): + demo = create_gradio_config() + demo.launch() + + +if __name__ == "__main__": + main() diff --git a/configs/models_pricing.json b/configs/models_pricing.json new file mode 100644 index 0000000..78b525c --- /dev/null +++ b/configs/models_pricing.json @@ -0,0 +1,337 @@ +{ + "last_updated": "2026-03-28T14:49:47.830867+00:00", + "models": { + "gpt-4o": { + "provider": "openai", + "api_model_name": "gpt-4o-2024-08-06", + "aliases": [ + "gpt4o" + ], + "cost_per_million_input": 2.5, + "cost_per_million_output": 10.0 + }, + "gpt-4o-mini": { + "provider": "openai", + "api_model_name": "gpt-4o-mini-2024-07-18", + "aliases": [ + "gpt4omini", + "gpt-4omini", + "gpt4o-mini" + ], + "cost_per_million_input": 0.15, + "cost_per_million_output": 0.6 + }, + "gpt-4.1": { + "provider": "openai", + "api_model_name": "gpt-4.1", + "aliases": [ + "gpt-4-1" + ], + "cost_per_million_input": 2.0, + "cost_per_million_output": 8.0 + }, + "gpt-4.1-mini": { + "provider": "openai", + "api_model_name": "gpt-4.1-mini", + "aliases": [ + "gpt-4-1-mini" + ], + "cost_per_million_input": 0.4, + "cost_per_million_output": 1.6 + }, + "gpt-4.1-nano": { + "provider": "openai", + "api_model_name": "gpt-4.1-nano", + "aliases": [ + "gpt-4-1-nano" + ], + "cost_per_million_input": 0.1, + "cost_per_million_output": 0.4 + }, + "gpt-5.2": { + "provider": "openai", + "api_model_name": "gpt-5.2", + "aliases": [ + "gpt5.2", + "gpt-5-2" + ], + "cost_per_million_input": 0.875, + "cost_per_million_output": 7.0 + }, + "gpt-5.2-pro": { + "provider": "openai", + "api_model_name": "gpt-5.2-pro", + "aliases": [ + "gpt5.2-pro", + "gpt-5-2-pro" + ], + "cost_per_million_input": 10.5, + "cost_per_million_output": 84.0 + }, + "gpt-5-mini": { + "provider": "openai", + "api_model_name": "gpt-5-mini", + "aliases": [ + "gpt5-mini", + "gpt5mini" + ], + "cost_per_million_input": 0.125, + "cost_per_million_output": 1.0 + }, + "o1": { + "provider": "openai", + "api_model_name": "o1-2024-12-17", + "aliases": [], + "cost_per_million_input": 15.0, + "cost_per_million_output": 60.0 + }, + "o1-preview": { + "provider": "openai", + "api_model_name": "o1-preview-2024-12-17", + "aliases": [], + "cost_per_million_input": 15.0, + "cost_per_million_output": 60.0 + }, + "o1-mini": { + "provider": "openai", + "api_model_name": "o1-mini-2024-09-12", + "aliases": [], + "cost_per_million_input": 0.55, + "cost_per_million_output": 2.2 + }, + "o3-mini": { + "provider": "openai", + "api_model_name": "o3-mini-2025-01-31", + "aliases": [], + "cost_per_million_input": 1.1, + "cost_per_million_output": 4.4 + }, + "o4-mini": { + "provider": "openai", + "api_model_name": "o4-mini", + "aliases": [], + "cost_per_million_input": 1.1, + "cost_per_million_output": 4.4 + }, + "claude-4.5-opus": { + "provider": "anthropic", + "api_model_name": "claude-4.5-opus", + "aliases": [], + "cost_per_million_input": 5.0, + "cost_per_million_output": 25.0 + }, + "claude-4.5-sonnet": { + "provider": "anthropic", + "api_model_name": "claude-4.5-sonnet", + "aliases": [], + "cost_per_million_input": 3.0, + "cost_per_million_output": 15.0 + }, + "claude-4.5-haiku": { + "provider": "anthropic", + "api_model_name": "claude-4.5-haiku", + "aliases": [], + "cost_per_million_input": 1.0, + "cost_per_million_output": 5.0 + }, + "claude-4.1-opus": { + "provider": "anthropic", + "api_model_name": "claude-4.1-opus", + "aliases": [], + "cost_per_million_input": 15.0, + "cost_per_million_output": 75.0 + }, + "claude-4-opus": { + "provider": "anthropic", + "api_model_name": "claude-4-opus", + "aliases": [], + "cost_per_million_input": 15.0, + "cost_per_million_output": 75.0 + }, + "claude-4-sonnet": { + "provider": "anthropic", + "api_model_name": "claude-4-sonnet", + "aliases": [], + "cost_per_million_input": 3.0, + "cost_per_million_output": 15.0 + }, + "claude-3-7-sonnet": { + "provider": "anthropic", + "api_model_name": "claude-3-7-sonnet", + "aliases": [], + "cost_per_million_input": 3.0, + "cost_per_million_output": 15.0 + }, + "claude-3-5-sonnet": { + "provider": "anthropic", + "api_model_name": "claude-3-5-sonnet", + "aliases": [], + "cost_per_million_input": 3.0, + "cost_per_million_output": 15.0 + }, + "claude-3-5-haiku": { + "provider": "anthropic", + "api_model_name": "claude-3-5-haiku", + "aliases": [], + "cost_per_million_input": 0.8, + "cost_per_million_output": 4.0 + }, + "deepseek-chat": { + "provider": "deepseek", + "api_model_name": "deepseek-chat", + "aliases": [], + "cost_per_million_input": 0.014, + "cost_per_million_output": 0.028 + }, + "gemini-3.0-pro": { + "provider": "google", + "api_model_name": "gemini-3-pro-preview", + "aliases": [ + "gemini-3-pro", + "gemini-3.0-pro-preview" + ], + "cost_per_million_input": 2.0, + "cost_per_million_output": 12.0 + }, + "gemini-3.0-flash": { + "provider": "google", + "api_model_name": "gemini-3-flash-preview", + "aliases": [ + "gemini-3-flash", + "gemini-3.0-flash-preview" + ], + "cost_per_million_input": 0.5, + "cost_per_million_output": 3.0 + }, + "gemini-2.5-pro": { + "provider": "google", + "api_model_name": "gemini-2.5-pro", + "aliases": [], + "cost_per_million_input": 1.0, + "cost_per_million_output": 10.0 + }, + "gemini-2.5-flash": { + "provider": "google", + "api_model_name": "gemini-2.5-flash", + "aliases": [], + "cost_per_million_input": 0.3, + "cost_per_million_output": 2.5 + }, + "gemini-2.5-flash-lite": { + "provider": "google", + "api_model_name": "gemini-2.5-flash-lite", + "aliases": [], + "cost_per_million_input": 0.1, + "cost_per_million_output": 0.4 + }, + "gemini-2.0-flash": { + "provider": "google", + "api_model_name": "gemini-2.0-flash", + "aliases": [], + "cost_per_million_input": 0.1, + "cost_per_million_output": 0.4 + }, + "gpt-5.4": { + "provider": "openai", + "api_model_name": "gpt-5.4", + "aliases": [ + "gpt5.4", + "gpt-5-4" + ], + "cost_per_million_input": 2.5, + "cost_per_million_output": 15.0 + }, + "gpt-5.4-mini": { + "provider": "openai", + "api_model_name": "gpt-5.4-mini", + "aliases": [ + "gpt5.4-mini", + "gpt-5-4-mini" + ], + "cost_per_million_input": 0.75, + "cost_per_million_output": 4.5 + }, + "gpt-5.4-nano": { + "provider": "openai", + "api_model_name": "gpt-5.4-nano", + "aliases": [ + "gpt5.4-nano", + "gpt-5-4-nano" + ], + "cost_per_million_input": 0.2, + "cost_per_million_output": 1.25 + }, + "gpt-5.4-pro": { + "provider": "openai", + "api_model_name": "gpt-5.4-pro", + "aliases": [ + "gpt5.4-pro", + "gpt-5-4-pro" + ], + "cost_per_million_input": 30.0, + "cost_per_million_output": 180.0 + }, + "o3": { + "provider": "openai", + "api_model_name": "o3", + "aliases": [], + "cost_per_million_input": 2.0, + "cost_per_million_output": 8.0 + }, + "o3-pro": { + "provider": "openai", + "api_model_name": "o3-pro", + "aliases": [], + "cost_per_million_input": 20.0, + "cost_per_million_output": 80.0 + }, + "claude-4.6-opus": { + "provider": "anthropic", + "api_model_name": "claude-opus-4-6", + "aliases": [ + "claude-opus-4.6", + "claude-opus-4-6" + ], + "cost_per_million_input": 5.0, + "cost_per_million_output": 25.0 + }, + "claude-4.6-sonnet": { + "provider": "anthropic", + "api_model_name": "claude-sonnet-4-6", + "aliases": [ + "claude-sonnet-4.6", + "claude-sonnet-4-6" + ], + "cost_per_million_input": 3.0, + "cost_per_million_output": 15.0 + }, + "gemini-3.1-pro": { + "provider": "google", + "api_model_name": "gemini-3.1-pro-preview", + "aliases": [ + "gemini-3-1-pro", + "gemini-3.1-pro-preview" + ], + "cost_per_million_input": 2.0, + "cost_per_million_output": 12.0 + }, + "gemini-3.1-flash-lite": { + "provider": "google", + "api_model_name": "gemini-3.1-flash-lite-preview", + "aliases": [ + "gemini-3-1-flash-lite" + ], + "cost_per_million_input": 0.25, + "cost_per_million_output": 1.5 + }, + "deepseek-r1": { + "provider": "deepseek", + "api_model_name": "deepseek-reasoner", + "aliases": [ + "deepseek-reasoner" + ], + "cost_per_million_input": 0.55, + "cost_per_million_output": 2.0 + } + } +} \ No newline at end of file diff --git a/docs/superpowers/plans/2026-03-28-model-pricing-registry.md b/docs/superpowers/plans/2026-03-28-model-pricing-registry.md new file mode 100644 index 0000000..602ccf6 --- /dev/null +++ b/docs/superpowers/plans/2026-03-28-model-pricing-registry.md @@ -0,0 +1,1530 @@ +# Model Pricing Registry Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Replace all hardcoded model pricing, aliases, and routing logic with a data-driven registry backed by `configs/models_pricing.json` and a web-based fetch utility. + +**Architecture:** A `ModelRegistry` class loads model definitions from a JSON config file and provides lookup/routing/cost methods. A `ModelFetcher` utility discovers models via provider APIs and scrapes pricing from docs pages. On startup, stale data (>7 days) triggers an auto-refresh attempt with graceful fallback. `inference.py`'s ~250-line if/elif chain is replaced by ~15 lines of registry-driven code. + +**Tech Stack:** Python, requests, beautifulsoup4, json, tiktoken (existing) + +--- + +## File Structure + +| File | Action | Responsibility | +|------|--------|---------------| +| `configs/models_pricing.json` | Create | Cached model/pricing data, single source of truth | +| `model_registry.py` | Create | Load JSON, resolve aliases, route to provider, track costs | +| `model_fetcher.py` | Create | API model discovery + pricing page scraping per provider | +| `update_models.py` | Create | CLI entry point for manual refresh | +| `inference.py` | Modify | Replace if/elif chain + costmaps with registry lookups | +| `provider.py` | Modify | Add provider dispatcher function | +| `AgentLaboratoryWebUI/config_gradio.py` | Modify | Populate dropdown from registry | +| `AgentLaboratoryWebUI/app.py` | Modify | Use registry for model list | +| `requirements.txt` | Modify | Add requests, beautifulsoup4 | +| `tests/test_model_registry.py` | Create | Unit tests for registry | +| `tests/test_model_fetcher.py` | Create | Unit tests for fetcher | + +--- + +### Task 1: Create `configs/models_pricing.json` with all current models + +**Files:** +- Create: `configs/models_pricing.json` + +- [ ] **Step 1: Create the configs directory** + +```bash +mkdir -p configs +``` + +- [ ] **Step 2: Create `configs/models_pricing.json` with all current model data** + +This JSON contains every model currently hardcoded in `inference.py` lines 16-75 and 104-335. All pricing comes from the existing `costmap_in`/`costmap_out` dicts. The `api_model_name` comes from the model name strings passed to provider calls. The `aliases` come from the if/elif conditions. + +```json +{ + "last_updated": "2026-03-28T00:00:00Z", + "models": { + "gpt-4o": { + "provider": "openai", + "api_model_name": "gpt-4o-2024-08-06", + "aliases": ["gpt4o", "gpt-4o"], + "cost_per_million_input": 2.50, + "cost_per_million_output": 10.00 + }, + "gpt-4o-mini": { + "provider": "openai", + "api_model_name": "gpt-4o-mini-2024-07-18", + "aliases": ["gpt4omini", "gpt-4omini", "gpt4o-mini"], + "cost_per_million_input": 0.150, + "cost_per_million_output": 0.60 + }, + "gpt-4.1": { + "provider": "openai", + "api_model_name": "gpt-4.1", + "aliases": ["gpt-4-1"], + "cost_per_million_input": 3.00, + "cost_per_million_output": 12.00 + }, + "gpt-4.1-mini": { + "provider": "openai", + "api_model_name": "gpt-4.1-mini", + "aliases": ["gpt-4-1-mini"], + "cost_per_million_input": 0.80, + "cost_per_million_output": 3.20 + }, + "gpt-4.1-nano": { + "provider": "openai", + "api_model_name": "gpt-4.1-nano", + "aliases": ["gpt-4-1-nano"], + "cost_per_million_input": 0.20, + "cost_per_million_output": 0.80 + }, + "gpt-5.2": { + "provider": "openai", + "api_model_name": "gpt-5.2", + "aliases": ["gpt5.2", "gpt-5-2"], + "cost_per_million_input": 1.75, + "cost_per_million_output": 14.00 + }, + "gpt-5.2-pro": { + "provider": "openai", + "api_model_name": "gpt-5.2-pro", + "aliases": ["gpt5.2-pro", "gpt-5-2-pro"], + "cost_per_million_input": 21.00, + "cost_per_million_output": 168.00 + }, + "gpt-5-mini": { + "provider": "openai", + "api_model_name": "gpt-5-mini", + "aliases": ["gpt5-mini", "gpt5mini"], + "cost_per_million_input": 0.25, + "cost_per_million_output": 2.00 + }, + "o1": { + "provider": "openai", + "api_model_name": "o1-2024-12-17", + "aliases": [], + "cost_per_million_input": 15.00, + "cost_per_million_output": 60.00 + }, + "o1-preview": { + "provider": "openai", + "api_model_name": "o1-preview-2024-12-17", + "aliases": [], + "cost_per_million_input": 15.00, + "cost_per_million_output": 60.00 + }, + "o1-mini": { + "provider": "openai", + "api_model_name": "o1-mini-2024-09-12", + "aliases": [], + "cost_per_million_input": 1.10, + "cost_per_million_output": 4.40 + }, + "o3-mini": { + "provider": "openai", + "api_model_name": "o3-mini-2025-01-31", + "aliases": [], + "cost_per_million_input": 1.10, + "cost_per_million_output": 4.40 + }, + "o4-mini": { + "provider": "openai", + "api_model_name": "o4-mini", + "aliases": [], + "cost_per_million_input": 4.00, + "cost_per_million_output": 16.00 + }, + "claude-4.5-opus": { + "provider": "anthropic", + "api_model_name": "claude-4.5-opus", + "aliases": [], + "cost_per_million_input": 5.00, + "cost_per_million_output": 25.00 + }, + "claude-4.5-sonnet": { + "provider": "anthropic", + "api_model_name": "claude-4.5-sonnet", + "aliases": [], + "cost_per_million_input": 3.00, + "cost_per_million_output": 15.00 + }, + "claude-4.5-haiku": { + "provider": "anthropic", + "api_model_name": "claude-4.5-haiku", + "aliases": [], + "cost_per_million_input": 1.00, + "cost_per_million_output": 5.00 + }, + "claude-4.1-opus": { + "provider": "anthropic", + "api_model_name": "claude-4.1-opus", + "aliases": [], + "cost_per_million_input": 15.00, + "cost_per_million_output": 75.00 + }, + "claude-4-opus": { + "provider": "anthropic", + "api_model_name": "claude-4-opus", + "aliases": [], + "cost_per_million_input": 15.00, + "cost_per_million_output": 75.00 + }, + "claude-4-sonnet": { + "provider": "anthropic", + "api_model_name": "claude-4-sonnet", + "aliases": [], + "cost_per_million_input": 3.00, + "cost_per_million_output": 15.00 + }, + "claude-3-7-sonnet": { + "provider": "anthropic", + "api_model_name": "claude-3-7-sonnet", + "aliases": [], + "cost_per_million_input": 3.00, + "cost_per_million_output": 15.00 + }, + "claude-3-5-sonnet": { + "provider": "anthropic", + "api_model_name": "claude-3-5-sonnet", + "aliases": [], + "cost_per_million_input": 3.00, + "cost_per_million_output": 15.00 + }, + "claude-3-5-haiku": { + "provider": "anthropic", + "api_model_name": "claude-3-5-haiku", + "aliases": [], + "cost_per_million_input": 0.80, + "cost_per_million_output": 4.00 + }, + "deepseek-chat": { + "provider": "deepseek", + "api_model_name": "deepseek-chat", + "aliases": [], + "cost_per_million_input": 0.27, + "cost_per_million_output": 1.10 + }, + "gemini-3.0-pro": { + "provider": "google", + "api_model_name": "gemini-3-pro-preview", + "aliases": ["gemini-3-pro", "gemini-3.0-pro-preview"], + "cost_per_million_input": 2.00, + "cost_per_million_output": 12.00 + }, + "gemini-3.0-flash": { + "provider": "google", + "api_model_name": "gemini-3-flash-preview", + "aliases": ["gemini-3-flash", "gemini-3.0-flash-preview"], + "cost_per_million_input": 0.50, + "cost_per_million_output": 3.00 + }, + "gemini-2.5-pro": { + "provider": "google", + "api_model_name": "gemini-2.5-pro", + "aliases": [], + "cost_per_million_input": 1.25, + "cost_per_million_output": 10.00 + }, + "gemini-2.5-flash": { + "provider": "google", + "api_model_name": "gemini-2.5-flash", + "aliases": [], + "cost_per_million_input": 0.30, + "cost_per_million_output": 2.50 + }, + "gemini-2.5-flash-lite": { + "provider": "google", + "api_model_name": "gemini-2.5-flash-lite", + "aliases": [], + "cost_per_million_input": 0.10, + "cost_per_million_output": 0.40 + }, + "gemini-2.0-flash": { + "provider": "google", + "api_model_name": "gemini-2.0-flash", + "aliases": [], + "cost_per_million_input": null, + "cost_per_million_output": null + } + } +} +``` + +- [ ] **Step 3: Commit** + +```bash +git add configs/models_pricing.json +git commit -m "feat: add configs/models_pricing.json with all current model data" +``` + +--- + +### Task 2: Create `model_registry.py` + +**Files:** +- Create: `model_registry.py` +- Create: `tests/test_model_registry.py` + +- [ ] **Step 1: Write tests for ModelRegistry** + +Create `tests/test_model_registry.py`: + +```python +import json +import os +import pytest +from unittest.mock import patch +from datetime import datetime, timezone + +# We'll import after creating the module +# from model_registry import ModelRegistry, ModelNotFoundError + + +SAMPLE_CONFIG = { + "last_updated": "2026-03-28T00:00:00Z", + "models": { + "gpt-4o": { + "provider": "openai", + "api_model_name": "gpt-4o-2024-08-06", + "aliases": ["gpt4o"], + "cost_per_million_input": 2.50, + "cost_per_million_output": 10.00 + }, + "claude-3-5-sonnet": { + "provider": "anthropic", + "api_model_name": "claude-3-5-sonnet-20241022", + "aliases": ["claude-3-5-sonnet-latest"], + "cost_per_million_input": 3.00, + "cost_per_million_output": 15.00 + }, + "deepseek-chat": { + "provider": "deepseek", + "api_model_name": "deepseek-chat", + "aliases": [], + "cost_per_million_input": 0.27, + "cost_per_million_output": 1.10 + }, + "gemini-2.5-flash": { + "provider": "google", + "api_model_name": "gemini-2.5-flash", + "aliases": [], + "cost_per_million_input": 0.30, + "cost_per_million_output": 2.50 + }, + "null-cost-model": { + "provider": "openai", + "api_model_name": "null-cost-model", + "aliases": [], + "cost_per_million_input": null, + "cost_per_million_output": null + } + } +} + + +@pytest.fixture +def sample_config_path(tmp_path): + config_path = tmp_path / "configs" / "models_pricing.json" + config_path.parent.mkdir(parents=True) + config_path.write_text(json.dumps(SAMPLE_CONFIG)) + return str(config_path) + + +@pytest.fixture +def registry(sample_config_path): + from model_registry import ModelRegistry + return ModelRegistry(config_path=sample_config_path, auto_refresh=False) + + +class TestModelResolution: + def test_resolve_canonical_name(self, registry): + model = registry.get_model("gpt-4o") + assert model["provider"] == "openai" + assert model["api_model_name"] == "gpt-4o-2024-08-06" + + def test_resolve_alias(self, registry): + model = registry.get_model("gpt4o") + assert model["api_model_name"] == "gpt-4o-2024-08-06" + + def test_unknown_model_raises(self, registry): + from model_registry import ModelNotFoundError + with pytest.raises(ModelNotFoundError): + registry.get_model("nonexistent-model") + + def test_resolve_alias_method(self, registry): + assert registry.resolve_alias("gpt4o") == "gpt-4o" + + def test_resolve_canonical_returns_same(self, registry): + assert registry.resolve_alias("gpt-4o") == "gpt-4o" + + def test_anthropic_startswith_matching(self, registry): + """Anthropic models with suffixes like -latest or -20241022 should match.""" + model = registry.get_model("claude-3-5-sonnet-20241022") + assert model["provider"] == "anthropic" + + +class TestProviderRouting: + def test_get_provider(self, registry): + assert registry.get_provider("gpt-4o") == "openai" + assert registry.get_provider("claude-3-5-sonnet") == "anthropic" + assert registry.get_provider("deepseek-chat") == "deepseek" + assert registry.get_provider("gemini-2.5-flash") == "google" + + def test_get_api_model_name(self, registry): + assert registry.get_api_model_name("gpt-4o") == "gpt-4o-2024-08-06" + + def test_get_base_url(self, registry): + assert registry.get_base_url("gpt-4o") is None + assert "deepseek" in registry.get_base_url("deepseek-chat") + assert "generativelanguage" in registry.get_base_url("gemini-2.5-flash") + + +class TestCostEstimation: + def test_get_cost_input(self, registry): + assert registry.get_cost_input("gpt-4o") == 2.50 / 1_000_000 + + def test_get_cost_output(self, registry): + assert registry.get_cost_output("gpt-4o") == 10.00 / 1_000_000 + + def test_null_cost_returns_none(self, registry): + assert registry.get_cost_input("null-cost-model") is None + assert registry.get_cost_output("null-cost-model") is None + + def test_curr_cost_est_empty(self, registry): + assert registry.curr_cost_est() == 0.0 + + def test_curr_cost_est_with_tokens(self, registry): + registry.tokens_in["gpt-4o"] = 1000 + registry.tokens_out["gpt-4o"] = 500 + cost = registry.curr_cost_est() + expected = 1000 * (2.50 / 1_000_000) + 500 * (10.00 / 1_000_000) + assert abs(cost - expected) < 1e-10 + + +class TestListModels: + def test_list_all(self, registry): + models = registry.list_models() + assert "gpt-4o" in models + assert "claude-3-5-sonnet" in models + + def test_list_by_provider(self, registry): + models = registry.list_models(provider="openai") + assert "gpt-4o" in models + assert "claude-3-5-sonnet" not in models + + +class TestFallback: + def test_missing_file_uses_defaults(self, tmp_path): + from model_registry import ModelRegistry + bad_path = str(tmp_path / "nonexistent" / "models_pricing.json") + reg = ModelRegistry(config_path=bad_path, auto_refresh=False) + # Should have loaded DEFAULT_MODELS + assert len(reg.list_models()) > 0 + + def test_corrupt_file_uses_defaults(self, tmp_path): + from model_registry import ModelRegistry + bad_file = tmp_path / "bad.json" + bad_file.write_text("not json{{{") + reg = ModelRegistry(config_path=str(bad_file), auto_refresh=False) + assert len(reg.list_models()) > 0 + + +class TestStaleness: + def test_is_stale_when_old(self, sample_config_path): + from model_registry import ModelRegistry + # Set last_updated to 30 days ago + with open(sample_config_path) as f: + data = json.load(f) + data["last_updated"] = "2026-02-01T00:00:00Z" + with open(sample_config_path, "w") as f: + json.dump(data, f) + reg = ModelRegistry(config_path=sample_config_path, auto_refresh=False) + assert reg.is_stale() is True + + def test_is_not_stale_when_fresh(self, sample_config_path): + from model_registry import ModelRegistry + with open(sample_config_path) as f: + data = json.load(f) + data["last_updated"] = datetime.now(timezone.utc).isoformat() + with open(sample_config_path, "w") as f: + json.dump(data, f) + reg = ModelRegistry(config_path=sample_config_path, auto_refresh=False) + assert reg.is_stale() is False +``` + +- [ ] **Step 2: Run tests to verify they fail** + +```bash +cd d:/GitHub/AgentLaboratory && uv run pytest tests/test_model_registry.py -v +``` + +Expected: FAIL β€” `ModuleNotFoundError: No module named 'model_registry'` + +- [ ] **Step 3: Implement `model_registry.py`** + +Create `model_registry.py` in the project root: + +```python +import json +import os +from datetime import datetime, timezone, timedelta + +from config import GOOGLE_GENERATIVE_API_BASE_URL, DEEPSEEK_API_BASE_URL, OLLAMA_API_BASE_URL + +STALENESS_DAYS = 7 +DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configs", "models_pricing.json") + +# Base URLs per provider (non-OpenAI providers that need a custom base_url) +PROVIDER_BASE_URLS = { + "deepseek": DEEPSEEK_API_BASE_URL, + "google": GOOGLE_GENERATIVE_API_BASE_URL, +} + + +class ModelNotFoundError(Exception): + def __init__(self, model_name, available_models): + self.model_name = model_name + self.available_models = available_models + super().__init__( + f"Model '{model_name}' not found. Available models: {', '.join(sorted(available_models))}" + ) + + +class ModelRegistry: + def __init__(self, config_path=None, auto_refresh=True): + self.config_path = config_path or DEFAULT_CONFIG_PATH + self.models = {} + self.last_updated = None + self.tokens_in = {} + self.tokens_out = {} + self._alias_map = {} + self._load() + if auto_refresh and self.is_stale(): + self._try_refresh() + + def _load(self): + try: + with open(self.config_path, "r") as f: + data = json.load(f) + self.models = data.get("models", {}) + self.last_updated = data.get("last_updated") + except (FileNotFoundError, json.JSONDecodeError, OSError): + print(f"Warning: Could not load {self.config_path}, using default models.") + self.models = DEFAULT_MODELS + self.last_updated = None + self._save() + self._build_alias_map() + + def _build_alias_map(self): + self._alias_map = {} + for name, info in self.models.items(): + self._alias_map[name] = name + for alias in info.get("aliases", []): + self._alias_map[alias] = name + + def _save(self): + data = { + "last_updated": self.last_updated or datetime.now(timezone.utc).isoformat(), + "models": self.models, + } + os.makedirs(os.path.dirname(self.config_path), exist_ok=True) + with open(self.config_path, "w") as f: + json.dump(data, f, indent=2) + + def _try_refresh(self): + try: + from model_fetcher import update_models_pricing + updated = update_models_pricing(self.config_path) + if updated: + self._load() + except Exception as e: + print(f"Warning: Auto-refresh failed ({e}), using cached data.") + + def is_stale(self): + if self.last_updated is None: + return True + try: + last = datetime.fromisoformat(self.last_updated) + if last.tzinfo is None: + last = last.replace(tzinfo=timezone.utc) + return datetime.now(timezone.utc) - last > timedelta(days=STALENESS_DAYS) + except (ValueError, TypeError): + return True + + def resolve_alias(self, name): + if name in self._alias_map: + return self._alias_map[name] + # Anthropic models: match by startswith (e.g. claude-3-5-sonnet-20241022) + for canonical, info in self.models.items(): + if info.get("provider") == "anthropic" and name.startswith(canonical): + return canonical + raise ModelNotFoundError(name, list(self.models.keys())) + + def get_model(self, name_or_alias): + canonical = self.resolve_alias(name_or_alias) + return self.models[canonical] + + def get_provider(self, name_or_alias): + return self.get_model(name_or_alias)["provider"] + + def get_api_model_name(self, name_or_alias): + model = self.get_model(name_or_alias) + # For anthropic, if user passed a full versioned name, use it directly + if model["provider"] == "anthropic" and name_or_alias != self.resolve_alias(name_or_alias): + if name_or_alias.startswith(self.resolve_alias(name_or_alias)): + return name_or_alias + return model["api_model_name"] + + def get_base_url(self, name_or_alias): + provider = self.get_provider(name_or_alias) + return PROVIDER_BASE_URLS.get(provider) + + def get_cost_input(self, name_or_alias): + cost = self.get_model(name_or_alias).get("cost_per_million_input") + if cost is None: + return None + return cost / 1_000_000 + + def get_cost_output(self, name_or_alias): + cost = self.get_model(name_or_alias).get("cost_per_million_output") + if cost is None: + return None + return cost / 1_000_000 + + def list_models(self, provider=None): + if provider: + return [name for name, info in self.models.items() if info["provider"] == provider] + return list(self.models.keys()) + + def get_canonical_for_cost(self, name_or_alias): + """Return canonical name used as key in tokens_in/tokens_out dicts.""" + return self.resolve_alias(name_or_alias) + + def curr_cost_est(self): + total = 0.0 + for model_name, count in self.tokens_in.items(): + cost = self.get_cost_input(model_name) + if cost is not None: + total += cost * count + for model_name, count in self.tokens_out.items(): + cost = self.get_cost_output(model_name) + if cost is not None: + total += cost * count + return total + + +# Hardcoded fallback β€” all models from configs/models_pricing.json +# Used only when the JSON file is missing or corrupt +DEFAULT_MODELS = { + "gpt-4o": { + "provider": "openai", + "api_model_name": "gpt-4o-2024-08-06", + "aliases": ["gpt4o"], + "cost_per_million_input": 2.50, + "cost_per_million_output": 10.00, + }, + "gpt-4o-mini": { + "provider": "openai", + "api_model_name": "gpt-4o-mini-2024-07-18", + "aliases": ["gpt4omini", "gpt-4omini", "gpt4o-mini"], + "cost_per_million_input": 0.150, + "cost_per_million_output": 0.60, + }, + "o1-mini": { + "provider": "openai", + "api_model_name": "o1-mini-2024-09-12", + "aliases": [], + "cost_per_million_input": 1.10, + "cost_per_million_output": 4.40, + }, + "o3-mini": { + "provider": "openai", + "api_model_name": "o3-mini-2025-01-31", + "aliases": [], + "cost_per_million_input": 1.10, + "cost_per_million_output": 4.40, + }, + "claude-3-5-sonnet": { + "provider": "anthropic", + "api_model_name": "claude-3-5-sonnet", + "aliases": [], + "cost_per_million_input": 3.00, + "cost_per_million_output": 15.00, + }, + "gemini-2.5-flash": { + "provider": "google", + "api_model_name": "gemini-2.5-flash", + "aliases": [], + "cost_per_million_input": 0.30, + "cost_per_million_output": 2.50, + }, + "deepseek-chat": { + "provider": "deepseek", + "api_model_name": "deepseek-chat", + "aliases": [], + "cost_per_million_input": 0.27, + "cost_per_million_output": 1.10, + }, +} +``` + +- [ ] **Step 4: Run tests to verify they pass** + +```bash +cd d:/GitHub/AgentLaboratory && uv run pytest tests/test_model_registry.py -v +``` + +Expected: All tests PASS. + +- [ ] **Step 5: Commit** + +```bash +git add model_registry.py tests/test_model_registry.py +git commit -m "feat: add ModelRegistry with alias resolution, provider routing, and cost tracking" +``` + +--- + +### Task 3: Add provider dispatcher to `provider.py` + +**Files:** +- Modify: `provider.py:92-127` (add after `AnthropicProvider`) + +- [ ] **Step 1: Add `get_provider_response` dispatcher function** + +Append to `provider.py` after the `AnthropicProvider` class: + +```python +def get_provider_response(provider, api_key, model_name, user_prompt, system_prompt, temperature=None, base_url=None): + """Dispatch to the correct provider based on provider string.""" + if provider == "anthropic": + return AnthropicProvider.get_response( + api_key=api_key, + model_name=model_name, + user_prompt=user_prompt, + system_prompt=system_prompt, + temperature=temperature, + ) + else: + # openai, google, deepseek, ollama all use OpenAI-compatible API + return OpenaiProvider.get_response( + api_key=api_key, + model_name=model_name, + user_prompt=user_prompt, + system_prompt=system_prompt, + temperature=temperature, + base_url=base_url, + ) +``` + +- [ ] **Step 2: Commit** + +```bash +git add provider.py +git commit -m "feat: add get_provider_response dispatcher to provider.py" +``` + +--- + +### Task 4: Rewrite `inference.py` to use `ModelRegistry` + +**Files:** +- Modify: `inference.py` (full rewrite of lines 1-378) + +- [ ] **Step 1: Rewrite `inference.py`** + +Replace the entire file with: + +```python +import os +import tiktoken +import time + +from config import OLLAMA_API_BASE_URL +from model_registry import ModelRegistry +from provider import get_provider_response +from utils import remove_thinking_process + +# Global registry instance +registry = ModelRegistry() + +encoding = tiktoken.get_encoding("cl100k_base") + + +def curr_cost_est(): + return registry.curr_cost_est() + + +def query_model(model_str, prompt, system_prompt, + openai_api_key=None, anthropic_api_key=None, + tries=5, timeout=5.0, + temp=None, print_cost=True, version="1.5"): + # Override the API keys if provided in the function call + if openai_api_key is not None: + os.environ["OPENAI_API_KEY"] = openai_api_key + if anthropic_api_key is not None: + os.environ["ANTHROPIC_API_KEY"] = anthropic_api_key + + preloaded_openai_api = os.getenv('OPENAI_API_KEY') + preload_anthropic_api = os.getenv('ANTHROPIC_API_KEY') + preload_google_api = os.getenv('GOOGLE_API_KEY') + preload_deepseek_api = os.getenv('DEEPSEEK_API_KEY') + + if (preloaded_openai_api is None and + preload_anthropic_api is None and + preload_google_api is None and + preload_deepseek_api is None): + raise Exception("No API key provided in query_model function") + + # Handle Ollama passthrough + if preloaded_openai_api == "ollama": + return _query_ollama(model_str, prompt, system_prompt, tries, timeout, temp) + + for _ in range(tries): + try: + # Resolve model via registry + canonical = registry.get_canonical_for_cost(model_str) + provider = registry.get_provider(model_str) + api_model_name = registry.get_api_model_name(model_str) + base_url = registry.get_base_url(model_str) + + # Determine API key based on provider + api_key_map = { + "openai": os.getenv('OPENAI_API_KEY'), + "anthropic": os.getenv('ANTHROPIC_API_KEY'), + "google": os.getenv('GOOGLE_API_KEY'), + "deepseek": os.getenv('DEEPSEEK_API_KEY'), + } + api_key = api_key_map.get(provider) + if api_key is None: + raise Exception(f"No API key set for provider '{provider}'") + + answer = get_provider_response( + provider=provider, + api_key=api_key, + model_name=api_model_name, + user_prompt=prompt, + system_prompt=system_prompt, + temperature=temp, + base_url=base_url, + ) + + answer = remove_thinking_process(answer) + + # Cost estimation + try: + try: + model_encoding = tiktoken.encoding_for_model(canonical) + except KeyError: + model_encoding = tiktoken.encoding_for_model("gpt-4o") + if canonical not in registry.tokens_in: + registry.tokens_in[canonical] = 0 + registry.tokens_out[canonical] = 0 + registry.tokens_in[canonical] += len(model_encoding.encode(system_prompt + prompt)) + registry.tokens_out[canonical] += len(model_encoding.encode(answer)) + if print_cost: + print(f"Current experiment cost = ${curr_cost_est()}, ** Approximate values, may not reflect true cost") + except Exception as e: + if print_cost: + print(f"Cost approximation has an error? {e}") + + return answer + except Exception as e: + print("Inference Exception:", e) + time.sleep(timeout) + continue + raise Exception("Max retries: timeout") + + +def _query_ollama(model_str, prompt, system_prompt, tries, timeout, temp): + """Handle Ollama models β€” bypass registry, pass model string directly.""" + from provider import OpenaiProvider + for _ in range(tries): + try: + answer = OpenaiProvider.get_response( + api_key="ollama", + model_name=model_str, + user_prompt=prompt, + system_prompt=system_prompt, + temperature=temp, + base_url=OLLAMA_API_BASE_URL, + ) + return remove_thinking_process(answer) + except Exception as e: + print("Inference Exception:", e) + time.sleep(timeout) + continue + raise Exception("Max retries: timeout") +``` + +- [ ] **Step 2: Verify the project still works** + +```bash +cd d:/GitHub/AgentLaboratory && uv run python -c "from inference import query_model, curr_cost_est; print('Import OK')" +``` + +Expected: `Import OK` + +- [ ] **Step 3: Commit** + +```bash +git add inference.py +git commit -m "refactor: rewrite inference.py to use ModelRegistry instead of hardcoded if/elif chain" +``` + +--- + +### Task 5: Create `model_fetcher.py` + +**Files:** +- Create: `model_fetcher.py` +- Create: `tests/test_model_fetcher.py` + +- [ ] **Step 1: Write tests for model_fetcher** + +Create `tests/test_model_fetcher.py`: + +```python +import json +import pytest +from unittest.mock import patch, MagicMock +from datetime import datetime, timezone + + +SAMPLE_OPENAI_MODELS_RESPONSE = { + "data": [ + {"id": "gpt-4o", "object": "model"}, + {"id": "gpt-4o-mini", "object": "model"}, + {"id": "text-embedding-ada-002", "object": "model"}, + {"id": "dall-e-3", "object": "model"}, + ] +} + + +SAMPLE_ANTHROPIC_MODELS_RESPONSE = { + "data": [ + {"id": "claude-3-5-sonnet-20241022", "type": "model"}, + {"id": "claude-3-5-haiku-20241022", "type": "model"}, + ] +} + + +SAMPLE_GOOGLE_MODELS_RESPONSE = { + "models": [ + {"name": "models/gemini-2.5-flash", "displayName": "Gemini 2.5 Flash"}, + {"name": "models/embedding-001", "displayName": "Embedding 001"}, + ] +} + + +class TestFetchModelsFromAPI: + @patch("model_fetcher.requests.get") + def test_fetch_openai_models(self, mock_get): + from model_fetcher import fetch_models_from_api + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = SAMPLE_OPENAI_MODELS_RESPONSE + mock_get.return_value = mock_resp + + models = fetch_models_from_api("openai", api_key="test-key") + # Should filter out embedding and dall-e models + assert "gpt-4o" in models + assert "gpt-4o-mini" in models + assert "text-embedding-ada-002" not in models + assert "dall-e-3" not in models + + @patch("model_fetcher.requests.get") + def test_fetch_returns_empty_on_failure(self, mock_get): + from model_fetcher import fetch_models_from_api + mock_get.side_effect = Exception("Network error") + models = fetch_models_from_api("openai", api_key="test-key") + assert models == [] + + @patch("model_fetcher.requests.get") + def test_fetch_google_models(self, mock_get): + from model_fetcher import fetch_models_from_api + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = SAMPLE_GOOGLE_MODELS_RESPONSE + mock_get.return_value = mock_resp + + models = fetch_models_from_api("google", api_key="test-key") + assert "gemini-2.5-flash" in models + assert "embedding-001" not in models + + +class TestMergeModels: + def test_merge_adds_new_model(self): + from model_fetcher import merge_discovered_models + existing = { + "last_updated": "2026-03-01T00:00:00Z", + "models": {} + } + discovered = {"openai": ["gpt-4o-new"]} + result = merge_discovered_models(existing, discovered) + assert "gpt-4o-new" in result["models"] + assert result["models"]["gpt-4o-new"]["provider"] == "openai" + assert result["models"]["gpt-4o-new"]["cost_per_million_input"] is None + + def test_merge_keeps_existing(self): + from model_fetcher import merge_discovered_models + existing = { + "last_updated": "2026-03-01T00:00:00Z", + "models": { + "gpt-4o": { + "provider": "openai", + "api_model_name": "gpt-4o", + "aliases": [], + "cost_per_million_input": 2.50, + "cost_per_million_output": 10.00 + } + } + } + discovered = {"openai": ["gpt-4o"]} + result = merge_discovered_models(existing, discovered) + # Existing model should keep its pricing + assert result["models"]["gpt-4o"]["cost_per_million_input"] == 2.50 + + def test_merge_does_not_remove_existing(self): + from model_fetcher import merge_discovered_models + existing = { + "last_updated": "2026-03-01T00:00:00Z", + "models": { + "gpt-4o": { + "provider": "openai", + "api_model_name": "gpt-4o", + "aliases": [], + "cost_per_million_input": 2.50, + "cost_per_million_output": 10.00 + } + } + } + discovered = {"openai": ["gpt-4o-new"]} + result = merge_discovered_models(existing, discovered) + assert "gpt-4o" in result["models"] + + +class TestUpdateModelsPricing: + @patch("model_fetcher.fetch_pricing") + @patch("model_fetcher.fetch_models_from_api") + def test_update_writes_file(self, mock_fetch_models, mock_fetch_pricing, tmp_path): + from model_fetcher import update_models_pricing + config_path = str(tmp_path / "models_pricing.json") + existing = { + "last_updated": "2026-01-01T00:00:00Z", + "models": {} + } + with open(config_path, "w") as f: + json.dump(existing, f) + + mock_fetch_models.return_value = ["gpt-4o"] + mock_fetch_pricing.return_value = {} + + result = update_models_pricing(config_path, force=True) + assert result is True + + with open(config_path) as f: + data = json.load(f) + assert data["last_updated"] != "2026-01-01T00:00:00Z" +``` + +- [ ] **Step 2: Run tests to verify they fail** + +```bash +cd d:/GitHub/AgentLaboratory && uv run pytest tests/test_model_fetcher.py -v +``` + +Expected: FAIL β€” `ModuleNotFoundError: No module named 'model_fetcher'` + +- [ ] **Step 3: Implement `model_fetcher.py`** + +Create `model_fetcher.py`: + +```python +import json +import os +import re +from datetime import datetime, timezone, timedelta + +import requests +from bs4 import BeautifulSoup + +from config import GOOGLE_GENERATIVE_API_BASE_URL, DEEPSEEK_API_BASE_URL + +FETCH_TIMEOUT = 10 + +# Model ID prefixes to exclude from API discovery (embeddings, image gen, etc.) +EXCLUDED_PREFIXES = ( + "text-embedding", "embedding", "dall-e", "tts-", "whisper", + "davinci", "babbage", "curie", "ada", +) + +# Provider API endpoints for model discovery +PROVIDER_API_ENDPOINTS = { + "openai": "https://api.openai.com/v1/models", + "anthropic": "https://api.anthropic.com/v1/models", + "google": f"{GOOGLE_GENERATIVE_API_BASE_URL}models", + "deepseek": f"{DEEPSEEK_API_BASE_URL}/models", +} + +# Provider pricing page URLs +PROVIDER_PRICING_URLS = { + "openai": "https://platform.openai.com/docs/pricing", + "anthropic": "https://docs.anthropic.com/en/docs/about-claude/models", + "google": "https://ai.google.dev/gemini-api/docs/pricing", + "deepseek": "https://api-docs.deepseek.com/quick_start/pricing", +} + + +def fetch_models_from_api(provider, api_key): + """Fetch available model IDs from a provider's API. Returns list of model ID strings.""" + endpoint = PROVIDER_API_ENDPOINTS.get(provider) + if not endpoint: + return [] + + try: + headers = {} + params = {} + if provider == "openai" or provider == "deepseek": + headers["Authorization"] = f"Bearer {api_key}" + elif provider == "anthropic": + headers["x-api-key"] = api_key + headers["anthropic-version"] = "2023-06-01" + elif provider == "google": + params["key"] = api_key + + resp = requests.get(endpoint, headers=headers, params=params, timeout=FETCH_TIMEOUT) + if resp.status_code != 200: + print(f"Warning: {provider} API returned status {resp.status_code}") + return [] + + data = resp.json() + + if provider == "google": + # Google returns {"models": [{"name": "models/gemini-2.5-flash", ...}]} + raw_ids = [m["name"].replace("models/", "") for m in data.get("models", [])] + else: + # OpenAI/Anthropic/DeepSeek return {"data": [{"id": "model-name", ...}]} + raw_ids = [m["id"] for m in data.get("data", [])] + + # Filter out non-chat models + return [mid for mid in raw_ids if not mid.startswith(EXCLUDED_PREFIXES)] + + except Exception as e: + print(f"Warning: Failed to fetch models from {provider}: {e}") + return [] + + +def fetch_pricing(provider): + """Scrape pricing page for a provider. Returns dict of {model_name: {input: float, output: float}}. + Returns empty dict on failure.""" + url = PROVIDER_PRICING_URLS.get(provider) + if not url: + return {} + + try: + resp = requests.get(url, timeout=FETCH_TIMEOUT, headers={"User-Agent": "AgentLaboratory/1.0"}) + if resp.status_code != 200: + print(f"Warning: {provider} pricing page returned status {resp.status_code}") + return {} + + soup = BeautifulSoup(resp.text, "html.parser") + + if provider == "openai": + return _parse_openai_pricing(soup) + elif provider == "anthropic": + return _parse_anthropic_pricing(soup) + elif provider == "google": + return _parse_google_pricing(soup) + elif provider == "deepseek": + return _parse_deepseek_pricing(soup) + + except Exception as e: + print(f"Warning: Failed to fetch pricing for {provider}: {e}") + return {} + + +def _parse_price_str(text): + """Extract a numeric price from a string like '$2.50' or '$0.30 / 1M tokens'.""" + match = re.search(r'\$?([\d.]+)', text.strip()) + if match: + return float(match.group(1)) + return None + + +def _parse_openai_pricing(soup): + """Parse OpenAI pricing page. Returns {model: {input: float, output: float}}.""" + pricing = {} + # Look for tables with pricing data + for table in soup.find_all("table"): + rows = table.find_all("tr") + for row in rows: + cells = row.find_all(["td", "th"]) + if len(cells) >= 3: + model_name = cells[0].get_text(strip=True).lower() + input_price = _parse_price_str(cells[1].get_text(strip=True)) + output_price = _parse_price_str(cells[2].get_text(strip=True)) + if input_price is not None and output_price is not None: + pricing[model_name] = {"input": input_price, "output": output_price} + return pricing + + +def _parse_anthropic_pricing(soup): + """Parse Anthropic pricing page.""" + pricing = {} + for table in soup.find_all("table"): + rows = table.find_all("tr") + for row in rows: + cells = row.find_all(["td", "th"]) + if len(cells) >= 3: + model_name = cells[0].get_text(strip=True).lower() + input_price = _parse_price_str(cells[1].get_text(strip=True)) + output_price = _parse_price_str(cells[2].get_text(strip=True)) + if input_price is not None and output_price is not None: + pricing[model_name] = {"input": input_price, "output": output_price} + return pricing + + +def _parse_google_pricing(soup): + """Parse Google pricing page.""" + pricing = {} + for table in soup.find_all("table"): + rows = table.find_all("tr") + for row in rows: + cells = row.find_all(["td", "th"]) + if len(cells) >= 3: + model_name = cells[0].get_text(strip=True).lower() + input_price = _parse_price_str(cells[1].get_text(strip=True)) + output_price = _parse_price_str(cells[2].get_text(strip=True)) + if input_price is not None and output_price is not None: + pricing[model_name] = {"input": input_price, "output": output_price} + return pricing + + +def _parse_deepseek_pricing(soup): + """Parse DeepSeek pricing page.""" + pricing = {} + for table in soup.find_all("table"): + rows = table.find_all("tr") + for row in rows: + cells = row.find_all(["td", "th"]) + if len(cells) >= 3: + model_name = cells[0].get_text(strip=True).lower() + input_price = _parse_price_str(cells[1].get_text(strip=True)) + output_price = _parse_price_str(cells[2].get_text(strip=True)) + if input_price is not None and output_price is not None: + pricing[model_name] = {"input": input_price, "output": output_price} + return pricing + + +def merge_discovered_models(existing_data, discovered_by_provider, pricing_by_provider=None): + """Merge newly discovered models into existing config data. + + Args: + existing_data: Current JSON data dict + discovered_by_provider: dict of {provider: [model_id, ...]} + pricing_by_provider: dict of {provider: {model_name: {input: float, output: float}}} + + Returns: + Updated data dict + """ + if pricing_by_provider is None: + pricing_by_provider = {} + + models = existing_data.get("models", {}) + + for provider, model_ids in discovered_by_provider.items(): + provider_pricing = pricing_by_provider.get(provider, {}) + for model_id in model_ids: + if model_id not in models: + # New model β€” add with pricing if available + price_info = provider_pricing.get(model_id, {}) + models[model_id] = { + "provider": provider, + "api_model_name": model_id, + "aliases": [], + "cost_per_million_input": price_info.get("input"), + "cost_per_million_output": price_info.get("output"), + } + if price_info: + print(f" Added new model: {model_id} (with pricing)") + else: + print(f" Added new model: {model_id} (no pricing available)") + else: + # Existing model β€” update pricing if we scraped new data + price_info = provider_pricing.get(model_id, {}) + if price_info: + models[model_id]["cost_per_million_input"] = price_info.get("input", models[model_id].get("cost_per_million_input")) + models[model_id]["cost_per_million_output"] = price_info.get("output", models[model_id].get("cost_per_million_output")) + + existing_data["models"] = models + existing_data["last_updated"] = datetime.now(timezone.utc).isoformat() + return existing_data + + +def update_models_pricing(config_path, force=False): + """Main entry point: refresh models_pricing.json. + + Args: + config_path: Path to models_pricing.json + force: If True, skip staleness check + + Returns: + True if update was performed, False if skipped + """ + # Load existing data + try: + with open(config_path) as f: + existing_data = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + existing_data = {"last_updated": None, "models": {}} + + # Check staleness + if not force: + last = existing_data.get("last_updated") + if last: + try: + last_dt = datetime.fromisoformat(last) + if last_dt.tzinfo is None: + last_dt = last_dt.replace(tzinfo=timezone.utc) + if datetime.now(timezone.utc) - last_dt < timedelta(days=7): + print("Model pricing data is fresh, skipping update.") + return False + except (ValueError, TypeError): + pass + + print("Updating model pricing data...") + + # Gather API keys + api_keys = { + "openai": os.getenv("OPENAI_API_KEY"), + "anthropic": os.getenv("ANTHROPIC_API_KEY"), + "google": os.getenv("GOOGLE_API_KEY"), + "deepseek": os.getenv("DEEPSEEK_API_KEY"), + } + + # Discover models via APIs + discovered = {} + for provider, key in api_keys.items(): + if key and key != "ollama": + print(f" Fetching models from {provider}...") + models = fetch_models_from_api(provider, key) + if models: + discovered[provider] = models + print(f" Found {len(models)} models from {provider}") + + # Scrape pricing + pricing = {} + for provider in ["openai", "anthropic", "google", "deepseek"]: + print(f" Fetching pricing from {provider}...") + provider_pricing = fetch_pricing(provider) + if provider_pricing: + pricing[provider] = provider_pricing + print(f" Got pricing for {len(provider_pricing)} models from {provider}") + + # Merge + updated_data = merge_discovered_models(existing_data, discovered, pricing) + + # Write + os.makedirs(os.path.dirname(config_path), exist_ok=True) + with open(config_path, "w") as f: + json.dump(updated_data, f, indent=2) + + print(f"Model pricing data updated at {config_path}") + return True +``` + +- [ ] **Step 4: Run tests** + +```bash +cd d:/GitHub/AgentLaboratory && uv run pytest tests/test_model_fetcher.py -v +``` + +Expected: All tests PASS. + +- [ ] **Step 5: Commit** + +```bash +git add model_fetcher.py tests/test_model_fetcher.py +git commit -m "feat: add model_fetcher with API discovery and pricing scraping" +``` + +--- + +### Task 6: Create `update_models.py` CLI entry point + +**Files:** +- Create: `update_models.py` + +- [ ] **Step 1: Create `update_models.py`** + +```python +"""CLI tool to refresh model pricing data. + +Usage: + python update_models.py # Update if stale (>7 days) + python update_models.py --force # Force update regardless of freshness +""" +import argparse +import os +import sys + +from model_fetcher import update_models_pricing +from model_registry import DEFAULT_CONFIG_PATH + + +def main(): + parser = argparse.ArgumentParser(description="Update model pricing data") + parser.add_argument("--force", action="store_true", help="Force update regardless of freshness") + parser.add_argument("--config", default=DEFAULT_CONFIG_PATH, help="Path to models_pricing.json") + args = parser.parse_args() + + try: + updated = update_models_pricing(args.config, force=args.force) + if updated: + print("Done.") + else: + print("No update needed. Use --force to override.") + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() +``` + +- [ ] **Step 2: Verify it runs** + +```bash +cd d:/GitHub/AgentLaboratory && uv run python update_models.py --help +``` + +Expected: Shows help text with `--force` and `--config` options. + +- [ ] **Step 3: Commit** + +```bash +git add update_models.py +git commit -m "feat: add update_models.py CLI for manual pricing refresh" +``` + +--- + +### Task 7: Update `requirements.txt` + +**Files:** +- Modify: `requirements.txt` + +- [ ] **Step 1: Add dependencies** + +Add these lines to `requirements.txt`: + +``` +requests +beautifulsoup4 +``` + +- [ ] **Step 2: Commit** + +```bash +git add requirements.txt +git commit -m "chore: add requests and beautifulsoup4 to requirements.txt" +``` + +--- + +### Task 8: Update UI files to use registry + +**Files:** +- Modify: `AgentLaboratoryWebUI/config_gradio.py:185-198` +- Modify: `AgentLaboratoryWebUI/app.py:75-90` + +- [ ] **Step 1: Update `config_gradio.py`** + +Replace the hardcoded `llm_backend_options` list (lines 185-198) with: + +```python +from model_registry import ModelRegistry + +_registry = ModelRegistry(auto_refresh=False) +llm_backend_options = _registry.list_models() +``` + +This replaces the static list: +```python +# OLD (remove): +# llm_backend_options = [ +# "o1", "o1-preview", "o1-mini", "o3-mini", +# "gpt-4o", "gpt-4o-mini", +# "deepseek-chat", +# ... +# ] +``` + +- [ ] **Step 2: Update `app.py`** + +In the Flask app, add an endpoint or use registry for model listing. The key change is replacing any hardcoded model list references with `ModelRegistry.list_models()`. + +Find any place in `app.py` where model names are listed or validated and replace with registry lookups. + +- [ ] **Step 3: Verify UI imports work** + +```bash +cd d:/GitHub/AgentLaboratory && uv run python -c "from AgentLaboratoryWebUI.config_gradio import llm_backend_options; print(llm_backend_options[:5])" +``` + +Expected: Prints first 5 model names from the registry. + +- [ ] **Step 4: Commit** + +```bash +git add AgentLaboratoryWebUI/config_gradio.py AgentLaboratoryWebUI/app.py +git commit -m "refactor: use ModelRegistry for UI model dropdowns instead of hardcoded lists" +``` + +--- + +### Task 9: Integration test β€” end-to-end validation + +**Files:** +- No new files β€” validation only + +- [ ] **Step 1: Verify imports chain works** + +```bash +cd d:/GitHub/AgentLaboratory && uv run python -c " +from inference import query_model, curr_cost_est +from model_registry import ModelRegistry +reg = ModelRegistry(auto_refresh=False) +print('Models:', len(reg.list_models())) +print('OpenAI models:', reg.list_models(provider='openai')) +print('Resolve gpt4o:', reg.resolve_alias('gpt4o')) +print('Provider for claude-3-5-sonnet:', reg.get_provider('claude-3-5-sonnet')) +print('API name for gpt-4o:', reg.get_api_model_name('gpt-4o')) +print('Cost est:', curr_cost_est()) +print('All OK') +" +``` + +Expected: Prints model info and `All OK`. + +- [ ] **Step 2: Verify update_models.py runs (dry run)** + +```bash +cd d:/GitHub/AgentLaboratory && uv run python update_models.py +``` + +Expected: Either "No update needed" (if JSON is fresh) or attempts to fetch and prints progress. + +- [ ] **Step 3: Final commit if any fixups needed** + +```bash +git add -A && git commit -m "fix: integration fixups for model registry" +``` + +Only run this if fixups were needed. diff --git a/docs/superpowers/specs/2026-03-28-model-pricing-registry-design.md b/docs/superpowers/specs/2026-03-28-model-pricing-registry-design.md new file mode 100644 index 0000000..519e07c --- /dev/null +++ b/docs/superpowers/specs/2026-03-28-model-pricing-registry-design.md @@ -0,0 +1,160 @@ +# Model Pricing Registry β€” Design Spec + +## Problem + +Model pricing, aliases, and routing logic are hardcoded across multiple files (`inference.py`, `config_gradio.py`, `app.py`). Adding or updating a model requires editing 3+ files and touching a ~250-line if/elif chain. Pricing drifts out of date silently. + +## Solution + +Replace all hardcoded model/pricing data with a single JSON file (`configs/models_pricing.json`) backed by a registry class and a web-based fetch utility that auto-refreshes pricing and discovers new models. + +## Data Model β€” `configs/models_pricing.json` + +```json +{ + "last_updated": "2026-03-28T12:00:00Z", + "models": { + "gpt-4o": { + "provider": "openai", + "api_model_name": "gpt-4o-2024-08-06", + "aliases": ["gpt4o"], + "cost_per_million_input": 2.50, + "cost_per_million_output": 10.00 + } + } +} +``` + +Fields: +- **`provider`**: One of `openai`, `anthropic`, `google`, `deepseek`, `ollama`. Determines which provider class handles the request. +- **`api_model_name`**: Actual model name sent to the API (handles version suffix mapping). +- **`aliases`**: Alternative user-facing names that resolve to this model. +- **`cost_per_million_input/output`**: USD pricing. `null` if unknown. + +## Architecture + +### New files + +| File | Purpose | +|------|---------| +| `model_registry.py` | `ModelRegistry` class β€” loads JSON, resolves aliases, routes to provider, calculates costs | +| `model_fetcher.py` | Fetch utility β€” API model discovery + pricing page scraping per provider | +| `update_models.py` | CLI entry point: `python update_models.py [--force]` | +| `configs/models_pricing.json` | Cached model/pricing data (checked into repo with defaults) | + +### Modified files + +| File | Change | +|------|--------| +| `inference.py` | Replace if/elif chain + costmaps with `ModelRegistry` lookups (~250 lines β†’ ~15) | +| `provider.py` | Add `get_response_by_provider(provider_name, ...)` dispatcher | +| `config_gradio.py` | Populate UI dropdown from `ModelRegistry.list_models()` | +| `app.py` | Same β€” use registry for defaults and model lists | + +### Flow + +``` +App startup + -> ModelRegistry loads configs/models_pricing.json + -> Checks last_updated; if stale (>7 days): + -> model_fetcher attempts refresh + -> Success: updates JSON + reloads + -> Failure: warns, continues with cached data + -> inference.py uses registry for all lookups + +Manual: python update_models.py [--force] + -> model_fetcher fetches all providers + -> Writes updated configs/models_pricing.json +``` + +## ModelRegistry API + +```python +class ModelRegistry: + def __init__(self, config_path="configs/models_pricing.json"): + """Load JSON, check freshness, attempt refresh if stale.""" + + def get_model(self, name_or_alias: str) -> dict: + """Resolve alias, return model entry. Raises ModelNotFoundError if unknown.""" + + def get_cost_input(self, model_name: str) -> float | None: + """Return per-token input cost, or None if unknown.""" + + def get_cost_output(self, model_name: str) -> float | None: + """Return per-token output cost, or None if unknown.""" + + def list_models(self, provider: str = None) -> list[str]: + """List available model names, optionally filtered by provider.""" + + def resolve_alias(self, name: str) -> str: + """Return canonical model name from alias.""" + + def get_provider(self, model_name: str) -> str: + """Return provider string for routing.""" + + def get_api_model_name(self, model_name: str) -> str: + """Return the actual API model name to send.""" + + def curr_cost_est(self) -> float: + """Calculate cumulative cost from tracked tokens.""" +``` + +## Fetch Strategy + +### Model Discovery (via API) + +| Provider | Endpoint | Auth | +|----------|----------|------| +| OpenAI | `GET /v1/models` | API key | +| Anthropic | `GET /v1/models` | API key | +| Google | `GET /v1beta/models` | API key | +| DeepSeek | `GET /v1/models` | API key | + +Returns available model IDs. Filter to relevant ones (skip embeddings, fine-tunes). Skip provider if no API key is set. + +### Pricing Scraping + +| Provider | Source URL | Strategy | +|----------|-----------|----------| +| OpenAI | `https://platform.openai.com/docs/pricing` | Parse pricing table HTML | +| Anthropic | `https://docs.anthropic.com/en/docs/about-claude/models` | Parse model comparison table | +| Google | `https://ai.google.dev/gemini-api/docs/pricing` | Parse pricing table | +| DeepSeek | `https://api-docs.deepseek.com/quick_start/pricing` | Parse pricing table | + +Each provider gets its own parser function in `model_fetcher.py`. + +### Merge Logic + +- **New model via API, no pricing scraped**: Add with `cost: null`, log notice. +- **Existing model, updated pricing**: Update values. +- **Existing model not in API response**: Keep (may be deprecated but usable). +- **Update `last_updated`**: On any successful fetch (even partial). + +## Error Handling + +| Scenario | Behavior | +|----------|----------| +| JSON file missing or corrupt | Fall back to `DEFAULT_MODELS` hardcoded in `model_registry.py` (contains all currently supported models with current pricing), write fresh JSON | +| Fetch timeout (>10s) | Abort fetch, warn, continue with cached data | +| Scraping fails (page format changed) | Log warning per provider, keep existing JSON values for that provider | +| Unknown model requested | `ModelNotFoundError` with list of available models | +| Model has `null` pricing | Skip in cost calculation, print notice | +| Ollama (`OPENAI_API_KEY == "ollama"`) | Bypass registry, pass model string directly | +| No API keys set for any provider | Skip all fetching, use cached JSON silently | + +## Refresh Policy + +- **Auto**: On startup, if `last_updated` is >7 days old, attempt fetch. Non-blocking with 10s timeout. +- **Manual**: `python update_models.py --force` bypasses staleness check. + +## Dependencies + +- `requests` β€” HTTP client for API calls and page fetches +- `beautifulsoup4` β€” HTML parsing for pricing pages +- Both added to `requirements.txt` + +## Testing + +- JSON ships pre-populated with current pricing as defaults +- Fetcher functions independently testable per provider +- Registry testable with mock JSON (no network) diff --git a/inference.py b/inference.py index 74b6d0f..8f2bce1 100755 --- a/inference.py +++ b/inference.py @@ -1,182 +1,95 @@ -import time, tiktoken -from openai import OpenAI -import openai -import os, anthropic, json +import os +import tiktoken +import time -TOKENS_IN = dict() -TOKENS_OUT = dict() +from config import OLLAMA_API_BASE_URL +from model_registry import ModelRegistry +from provider import get_provider_response +from utils import remove_thinking_process + +# Global registry instance +registry = ModelRegistry() encoding = tiktoken.get_encoding("cl100k_base") + def curr_cost_est(): - costmap_in = { - "gpt-4o": 2.50 / 1000000, - "gpt-4o-mini": 0.150 / 1000000, - "o1-preview": 15.00 / 1000000, - "o1-mini": 3.00 / 1000000, - "claude-3-5-sonnet": 3.00 / 1000000, - "deepseek-chat": 1.00 / 1000000, - "o1": 15.00 / 1000000, - } - costmap_out = { - "gpt-4o": 10.00/ 1000000, - "gpt-4o-mini": 0.6 / 1000000, - "o1-preview": 60.00 / 1000000, - "o1-mini": 12.00 / 1000000, - "claude-3-5-sonnet": 12.00 / 1000000, - "deepseek-chat": 5.00 / 1000000, - "o1": 60.00 / 1000000, - } - return sum([costmap_in[_]*TOKENS_IN[_] for _ in TOKENS_IN]) + sum([costmap_out[_]*TOKENS_OUT[_] for _ in TOKENS_OUT]) - -def query_model(model_str, prompt, system_prompt, openai_api_key=None, anthropic_api_key=None, tries=5, timeout=5.0, temp=None, print_cost=True, version="1.5"): - preloaded_api = os.getenv('OPENAI_API_KEY') - if openai_api_key is None and preloaded_api is not None: - openai_api_key = preloaded_api - if openai_api_key is None and anthropic_api_key is None: - raise Exception("No API key provided in query_model function") + return registry.curr_cost_est() + + +def query_model(model_str, prompt, system_prompt, + openai_api_key=None, anthropic_api_key=None, + tries=5, timeout=5.0, + temp=None, print_cost=True, version="1.5"): + # Override the API keys if provided in the function call if openai_api_key is not None: - openai.api_key = openai_api_key os.environ["OPENAI_API_KEY"] = openai_api_key if anthropic_api_key is not None: os.environ["ANTHROPIC_API_KEY"] = anthropic_api_key + + preloaded_openai_api = os.getenv('OPENAI_API_KEY') + preload_anthropic_api = os.getenv('ANTHROPIC_API_KEY') + preload_google_api = os.getenv('GOOGLE_API_KEY') + preload_deepseek_api = os.getenv('DEEPSEEK_API_KEY') + + if (preloaded_openai_api is None and + preload_anthropic_api is None and + preload_google_api is None and + preload_deepseek_api is None): + raise Exception("No API key provided in query_model function") + + # Handle Ollama passthrough + if preloaded_openai_api == "ollama": + return _query_ollama(model_str, prompt, system_prompt, tries, timeout, temp) + for _ in range(tries): try: - if model_str == "gpt-4o-mini" or model_str == "gpt4omini" or model_str == "gpt-4omini" or model_str == "gpt4o-mini": - model_str = "gpt-4o-mini" - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt}] - if version == "0.28": - if temp is None: - completion = openai.ChatCompletion.create( - model=f"{model_str}", # engine = "deployment_name". - messages=messages - ) - else: - completion = openai.ChatCompletion.create( - model=f"{model_str}", # engine = "deployment_name". - messages=messages, temperature=temp - ) - else: - client = OpenAI() - if temp is None: - completion = client.chat.completions.create( - model="gpt-4o-mini-2024-07-18", messages=messages, ) - else: - completion = client.chat.completions.create( - model="gpt-4o-mini-2024-07-18", messages=messages, temperature=temp) - answer = completion.choices[0].message.content - elif model_str == "claude-3.5-sonnet": - client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]) - message = client.messages.create( - model="claude-3-5-sonnet-latest", - system=system_prompt, - messages=[{"role": "user", "content": prompt}]) - answer = json.loads(message.to_json())["content"][0]["text"] - elif model_str == "gpt4o" or model_str == "gpt-4o": - model_str = "gpt-4o" - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt}] - if version == "0.28": - if temp is None: - completion = openai.ChatCompletion.create( - model=f"{model_str}", # engine = "deployment_name". - messages=messages - ) - else: - completion = openai.ChatCompletion.create( - model=f"{model_str}", # engine = "deployment_name". - messages=messages, temperature=temp) - else: - client = OpenAI() - if temp is None: - completion = client.chat.completions.create( - model="gpt-4o-2024-08-06", messages=messages, ) - else: - completion = client.chat.completions.create( - model="gpt-4o-2024-08-06", messages=messages, temperature=temp) - answer = completion.choices[0].message.content - elif model_str == "deepseek-chat": - model_str = "deepseek-chat" - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt}] - if version == "0.28": - raise Exception("Please upgrade your OpenAI version to use DeepSeek client") - else: - deepseek_client = OpenAI( - api_key=os.getenv('DEEPSEEK_API_KEY'), - base_url="https://api.deepseek.com/v1" - ) - if temp is None: - completion = deepseek_client.chat.completions.create( - model="deepseek-chat", - messages=messages) - else: - completion = deepseek_client.chat.completions.create( - model="deepseek-chat", - messages=messages, - temperature=temp) - answer = completion.choices[0].message.content - elif model_str == "o1-mini": - model_str = "o1-mini" - messages = [ - {"role": "user", "content": system_prompt + prompt}] - if version == "0.28": - completion = openai.ChatCompletion.create( - model=f"{model_str}", # engine = "deployment_name". - messages=messages) - else: - client = OpenAI() - completion = client.chat.completions.create( - model="o1-mini-2024-09-12", messages=messages) - answer = completion.choices[0].message.content - elif model_str == "o1": - model_str = "o1" - messages = [ - {"role": "user", "content": system_prompt + prompt}] - if version == "0.28": - completion = openai.ChatCompletion.create( - model="o1-2024-12-17", # engine = "deployment_name". - messages=messages) - else: - client = OpenAI() - completion = client.chat.completions.create( - model="o1-2024-12-17", messages=messages) - answer = completion.choices[0].message.content - elif model_str == "o1-preview": - model_str = "o1-preview" - messages = [ - {"role": "user", "content": system_prompt + prompt}] - if version == "0.28": - completion = openai.ChatCompletion.create( - model=f"{model_str}", # engine = "deployment_name". - messages=messages) - else: - client = OpenAI() - completion = client.chat.completions.create( - model="o1-preview", messages=messages) - answer = completion.choices[0].message.content + # Resolve model via registry + canonical = registry.get_canonical_for_cost(model_str) + provider = registry.get_provider(model_str) + api_model_name = registry.get_api_model_name(model_str) + base_url = registry.get_base_url(model_str) + + # Determine API key based on provider + api_key_map = { + "openai": os.getenv('OPENAI_API_KEY'), + "anthropic": os.getenv('ANTHROPIC_API_KEY'), + "google": os.getenv('GOOGLE_API_KEY'), + "deepseek": os.getenv('DEEPSEEK_API_KEY'), + } + api_key = api_key_map.get(provider) + if api_key is None: + raise Exception(f"No API key set for provider '{provider}'") + answer = get_provider_response( + provider=provider, + api_key=api_key, + model_name=api_model_name, + user_prompt=prompt, + system_prompt=system_prompt, + temperature=temp, + base_url=base_url, + ) + + answer = remove_thinking_process(answer) + + # Cost estimation try: - if model_str in ["o1-preview", "o1-mini", "claude-3.5-sonnet", "o1"]: - encoding = tiktoken.encoding_for_model("gpt-4o") - elif model_str in ["deepseek-chat"]: - encoding = tiktoken.encoding_for_model("cl100k_base") - else: - encoding = tiktoken.encoding_for_model(model_str) - if model_str not in TOKENS_IN: - TOKENS_IN[model_str] = 0 - TOKENS_OUT[model_str] = 0 - TOKENS_IN[model_str] += len(encoding.encode(system_prompt + prompt)) - TOKENS_OUT[model_str] += len(encoding.encode(answer)) + try: + model_encoding = tiktoken.encoding_for_model(canonical) + except KeyError: + model_encoding = tiktoken.encoding_for_model("gpt-4o") + if canonical not in registry.tokens_in: + registry.tokens_in[canonical] = 0 + registry.tokens_out[canonical] = 0 + registry.tokens_in[canonical] += len(model_encoding.encode(system_prompt + prompt)) + registry.tokens_out[canonical] += len(model_encoding.encode(answer)) if print_cost: print(f"Current experiment cost = ${curr_cost_est()}, ** Approximate values, may not reflect true cost") except Exception as e: if print_cost: print(f"Cost approximation has an error? {e}") + return answer except Exception as e: print("Inference Exception:", e) @@ -185,4 +98,22 @@ def query_model(model_str, prompt, system_prompt, openai_api_key=None, anthropic raise Exception("Max retries: timeout") -#print(query_model(model_str="o1-mini", prompt="hi", system_prompt="hey")) \ No newline at end of file +def _query_ollama(model_str, prompt, system_prompt, tries, timeout, temp): + """Handle Ollama models β€” bypass registry, pass model string directly.""" + from provider import OpenaiProvider + for _ in range(tries): + try: + answer = OpenaiProvider.get_response( + api_key="ollama", + model_name=model_str, + user_prompt=prompt, + system_prompt=system_prompt, + temperature=temp, + base_url=OLLAMA_API_BASE_URL, + ) + return remove_thinking_process(answer) + except Exception as e: + print("Inference Exception:", e) + time.sleep(timeout) + continue + raise Exception("Max retries: timeout") diff --git a/mlesolver.py b/mlesolver.py index cfc4896..0eeb52d 100755 --- a/mlesolver.py +++ b/mlesolver.py @@ -103,7 +103,7 @@ def docstring(self) -> str: "============= CODE EDITING TOOL =============\n" "You also have access to a code editing tool. \n" "This tool allows you to replace lines indexed n through m (n:m) of the current code with as many lines of new code as you want to add. This removal is inclusive meaning that line n and m and everything between n and m is removed. This will be the primary way that you interact with code. \n" - "You can edit code using the following command: ```EDIT N M\n\n``` EDIT is the word EDIT, N is the first line index you want to replace and M the the last line index you want to replace (everything inbetween will also be removed), and will be the new code that is replacing the old code. Before changing the existing code to be your new code, your new code will be tested and if it returns an error it will not replace the existing code. Your changes should significantly change the functionality of the code." + "You can edit code using the following command: ```EDIT N M\n\n``` EDIT is the word EDIT, N is the first line index you want to replace and M the last line index you want to replace (everything inbetween will also be removed), and will be the new code that is replacing the old code. Before changing the existing code to be your new code, your new code will be tested and if it returns an error it will not replace the existing code. Your changes should significantly change the functionality of the code." ) def execute_command(self, *args) -> str: @@ -154,7 +154,7 @@ def get_score(outlined_plan, code, code_return, REWARD_MODEL_LLM, attempts=3, op try: # todo: have a reward function here sys = ( - f"You are a professor agent who is serving as an expert reward model that can read a research plan, research code, and code output and are able to determine how well a model followed the plan, built the code, and got the proper output scored from 0 to 1 as a float.\n\n" + f"You are a professor agent who is serving as an expert reward model that can read a research plan, research code, and code output and can determine how well a model followed the plan, built the code, and got the proper output scored from 0 to 1 as a float.\n\n" f"You must structure your score exactly in the following way: ```SCORE\n\n``` where SCORE is just the word score, is a floating point number between 0 and 1 representing how well the model followed the plan, built the code, and got the proper output." ) scoring = query_model( @@ -197,7 +197,7 @@ def code_repair(code, error, ctype, REPAIR_LLM, openai_api_key=None): "============= CODE EDITING TOOL =============\n" "You have access to a code editing tool. \n" "This tool allows you to replace lines indexed n through m (n:m) of the current code with as many lines of new code as you want to add. This removal is inclusive meaning that line n and m and everything between n and m is removed. This will be the primary way that you interact with code. \n" - "You can edit code using the following command: ```EDIT N M\n\n``` EDIT is the word EDIT, N is the first line index you want to replace and M the the last line index you want to replace (everything inbetween will also be removed), and will be the new code that is replacing the old code. Before changing the existing code to be your new code, your new code will be tested and if it returns an error it will not replace the existing code.\n" + "You can edit code using the following command: ```EDIT N M\n\n``` EDIT is the word EDIT, N is the first line index you want to replace and M the last line index you want to replace (everything inbetween will also be removed), and will be the new code that is replacing the old code. Before changing the existing code to be your new code, your new code will be tested and if it returns an error it will not replace the existing code.\n" "Please use the code editing tool to fix this code." "Do not forget the opening ```EDIT N M and the closing ```." "Your output should look like the following\n\n```EDIT N M\n\n```" @@ -332,7 +332,7 @@ def reflect_code(self): code_strs = ("$"*40 + "\n\n").join([self.generate_code_lines(_code[0]) + f"\nCode Return {_code[1]}" for _code in self.best_codes]) code_strs = f"Please reflect on the following sets of code: {code_strs} and come up with generalizable insights that will help you improve your performance on this benchmark." syst = self.system_prompt(commands=False) + code_strs - return query_model(prompt="Please reflect on ideas for how to improve your current code. Examine the provided code and think very specifically (with precise ideas) on how to improve performance, which methods to use, how to improve generalization on the test set with line-by-line examples below:\n", system_prompt=syst, model_str=f"{self.llm_str}", openai_api_key=self.openai_api_key) + return query_model(prompt="Please reflect on ideas for how to improve your current code. Examine the provided code and think very specifically (with precise ideas) on how to improve performance, which methods to use, and how to improve generalization on the test set with line-by-line examples below:\n", system_prompt=syst, model_str=f"{self.llm_str}", openai_api_key=self.openai_api_key) def process_command(self, model_resp): """ @@ -486,7 +486,7 @@ def feedback(self, code_return): grade_return = get_score(self.plan, "\n".join(self.prev_working_code), code_return, openai_api_key=self.openai_api_key)[0] print(f"@@@@ SUBMISSION: model score {grade_return}", REWARD_MODEL_LLM=self.llm_str) f"Your code was properly submitted and you have just received a grade for your model.\nYour score was {grade_return}.\n\n" - reflect_prompt = f"This is your code: {code_str}\n\nYour code successfully returned a submission csv. Consider further improving your technique through advanced learning techniques, data augmentation, or hyperparamter tuning to increase the score. Please provide a detailed reflection on how to improve your performance, which lines in the code could be improved upon, and exactly (line by line) how you hope to improve this in the next update. This step is mostly meant to reflect in order to help your future self." + reflect_prompt = f"This is your code: {code_str}\n\nYour code successfully returned a submission csv. Consider further improving your technique through advanced learning techniques, data augmentation, or hyperparameter tuning to increase the score. Please provide a detailed reflection on how to improve your performance, which lines in the code could be improved upon, and exactly (line by line) how you hope to improve this in the next update. This step is mostly meant to reflect in order to help your future self." for file in os.listdir("."): if file.endswith(".csv"): @@ -526,7 +526,7 @@ def phase_prompt(self,): phase_str = ( "You are an ML engineer and you will be writing the code for a research project.\n" "Your goal is to produce code that obtains final results for a set of research experiments. You should aim for simple code to collect all results, not complex code. You should integrate the provided literature review and the plan to make sure you are implementing everything outlined in the plan. The dataset code will be added to the beginning of your code always, so this does not need to be rewritten. Make sure you do not write functions, only loose code.\n" - "I would recommend writing smaller code so you do not run out of time but make sure to work on all points in the plan in the same code. You code should run every experiment outlined in the plan for a single code.\n", + "I would recommend writing smaller code so you do not run out of time but make sure to work on all points in the plan in the same code. Your code should run every experiment outlined in the plan for a single code.\n", "You cannot pip install new libraries, but many machine learning libraries already work. If you wish to use a language model in your code, please use the following:\nAnything you decide to print inside your code will be provided to you as input, and you will be able to see that part of the code. Using print statements is useful for figuring out what is wrong and understanding your code better." ) return phase_str diff --git a/model_fetcher.py b/model_fetcher.py new file mode 100644 index 0000000..621e47e --- /dev/null +++ b/model_fetcher.py @@ -0,0 +1,247 @@ +import json +import os +import re +from datetime import datetime, timezone, timedelta + +import requests + +from config import GOOGLE_GENERATIVE_API_BASE_URL, DEEPSEEK_API_BASE_URL + +FETCH_TIMEOUT = 15 + +# Pricing aggregator API β€” single source for all providers +PRICING_API_URL = "https://pricepertoken.com/api/pricing" + +# Map pricepertoken.com provider_name to our internal provider names +PROVIDER_NAME_MAP = { + "OpenAI": "openai", + "Anthropic": "anthropic", + "Google": "google", + "Deepseek": "deepseek", +} + +# Model ID prefixes to exclude (embeddings, image gen, etc.) +EXCLUDED_PREFIXES = ( + "text-embedding", "embedding", "dall-e", "tts-", "whisper", + "davinci", "babbage", "curie", "ada", "text-ada", "text-davinci", +) + +# Provider API endpoints for model discovery +PROVIDER_API_ENDPOINTS = { + "openai": "https://api.openai.com/v1/models", + "anthropic": "https://api.anthropic.com/v1/models", + "google": f"{GOOGLE_GENERATIVE_API_BASE_URL}models", + "deepseek": f"{DEEPSEEK_API_BASE_URL}/models", +} + + +def fetch_models_from_api(provider, api_key): + """Fetch available model IDs from a provider's API. Returns list of model ID strings.""" + endpoint = PROVIDER_API_ENDPOINTS.get(provider) + if not endpoint: + return [] + + try: + headers = {} + params = {} + if provider in ("openai", "deepseek"): + headers["Authorization"] = f"Bearer {api_key}" + elif provider == "anthropic": + headers["x-api-key"] = api_key + headers["anthropic-version"] = "2023-06-01" + elif provider == "google": + params["key"] = api_key + + resp = requests.get(endpoint, headers=headers, params=params, timeout=FETCH_TIMEOUT) + if resp.status_code != 200: + print(f"Warning: {provider} API returned status {resp.status_code}") + return [] + + data = resp.json() + + if provider == "google": + raw_ids = [m["name"].replace("models/", "") for m in data.get("models", [])] + else: + raw_ids = [m["id"] for m in data.get("data", [])] + + # Filter out non-chat models + return [mid for mid in raw_ids if not mid.startswith(EXCLUDED_PREFIXES)] + + except Exception as e: + print(f"Warning: Failed to fetch models from {provider}: {e}") + return [] + + +def fetch_pricing_from_aggregator(): + """Fetch pricing for all providers from pricepertoken.com. + + Returns: + dict of {provider: {model_id: {input: float, output: float}}} + where input/output are per-million-token prices in USD. + """ + try: + resp = requests.get(PRICING_API_URL, timeout=FETCH_TIMEOUT, headers={ + "User-Agent": "Mozilla/5.0 (compatible; AgentLaboratory/1.0)", + }) + if resp.status_code != 200: + print(f"Warning: Pricing API returned status {resp.status_code}") + return {} + + data = resp.json() + results = data.get("results", []) + print(f" Pricing API returned {len(results)} models") + + pricing = {} + for entry in results: + provider_name = entry.get("provider_name", "") + provider = PROVIDER_NAME_MAP.get(provider_name) + if provider is None: + continue + + model_id = entry.get("model", "") + input_price = entry.get("input_price_per_1m_tokens") + output_price = entry.get("output_price_per_1m_tokens") + + if not model_id or input_price is None or output_price is None: + continue + + if provider not in pricing: + pricing[provider] = {} + + pricing[provider][model_id] = { + "input": float(input_price), + "output": float(output_price), + } + + for p, models in pricing.items(): + print(f" {p}: {len(models)} models with pricing") + + return pricing + + except Exception as e: + print(f"Warning: Failed to fetch pricing from aggregator: {e}") + return {} + + +def merge_discovered_models(existing_data, discovered_by_provider, pricing_by_provider=None): + """Merge newly discovered models into existing config data. + + Args: + existing_data: Current JSON data dict + discovered_by_provider: dict of {provider: [model_id, ...]} + pricing_by_provider: dict of {provider: {model_name: {input: float, output: float}}} + + Returns: + Updated data dict + """ + if pricing_by_provider is None: + pricing_by_provider = {} + + models = existing_data.get("models", {}) + + for provider, model_ids in discovered_by_provider.items(): + provider_pricing = pricing_by_provider.get(provider, {}) + for model_id in model_ids: + if model_id not in models: + price_info = provider_pricing.get(model_id, {}) + models[model_id] = { + "provider": provider, + "api_model_name": model_id, + "aliases": [], + "cost_per_million_input": price_info.get("input"), + "cost_per_million_output": price_info.get("output"), + } + if price_info: + print(f" Added new model: {model_id} (with pricing)") + else: + print(f" Added new model: {model_id} (no pricing available)") + else: + price_info = provider_pricing.get(model_id, {}) + if price_info: + models[model_id]["cost_per_million_input"] = price_info.get("input", models[model_id].get("cost_per_million_input")) + models[model_id]["cost_per_million_output"] = price_info.get("output", models[model_id].get("cost_per_million_output")) + + # Also update pricing for existing models that weren't in discovered_by_provider + # (e.g. models already in our JSON that have updated prices) + for provider, provider_pricing in pricing_by_provider.items(): + for model_id, price_info in provider_pricing.items(): + if model_id in models and price_info: + models[model_id]["cost_per_million_input"] = price_info.get("input", models[model_id].get("cost_per_million_input")) + models[model_id]["cost_per_million_output"] = price_info.get("output", models[model_id].get("cost_per_million_output")) + + existing_data["models"] = models + existing_data["last_updated"] = datetime.now(timezone.utc).isoformat() + return existing_data + + +def update_models_pricing(config_path, force=False): + """Main entry point: refresh models_pricing.json. + + Args: + config_path: Path to models_pricing.json + force: If True, skip staleness check + + Returns: + True if update was performed, False if skipped + """ + # Load existing data + try: + with open(config_path) as f: + existing_data = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + existing_data = {"last_updated": None, "models": {}} + + # Check staleness + if not force: + last = existing_data.get("last_updated") + if last: + try: + last_dt = datetime.fromisoformat(last) + if last_dt.tzinfo is None: + last_dt = last_dt.replace(tzinfo=timezone.utc) + if datetime.now(timezone.utc) - last_dt < timedelta(days=7): + print("Model pricing data is fresh, skipping update.") + return False + except (ValueError, TypeError): + pass + + print("Updating model pricing data...") + + # Step 1: Fetch pricing from aggregator (no API keys needed!) + print(" Fetching pricing from pricepertoken.com...") + pricing = fetch_pricing_from_aggregator() + + # Step 2: Discover models via provider APIs (requires API keys) + api_keys = { + "openai": os.getenv("OPENAI_API_KEY"), + "anthropic": os.getenv("ANTHROPIC_API_KEY"), + "google": os.getenv("GOOGLE_API_KEY"), + "deepseek": os.getenv("DEEPSEEK_API_KEY"), + } + + discovered = {} + for provider, key in api_keys.items(): + if key and key != "ollama": + print(f" Fetching models from {provider} API...") + models = fetch_models_from_api(provider, key) + if models: + discovered[provider] = models + print(f" Found {len(models)} models from {provider}") + else: + print(f" Skipping {provider} API (no API key set)") + + if not discovered and not pricing: + print("Warning: No data fetched. Check your network connection.") + # Still update timestamp to avoid repeated failures + existing_data["last_updated"] = datetime.now(timezone.utc).isoformat() + + # Step 3: Merge + updated_data = merge_discovered_models(existing_data, discovered, pricing) + + # Write + os.makedirs(os.path.dirname(config_path), exist_ok=True) + with open(config_path, "w") as f: + json.dump(updated_data, f, indent=2) + + print(f"Model pricing data updated at {config_path}") + return True diff --git a/model_registry.py b/model_registry.py new file mode 100644 index 0000000..b8d0fd6 --- /dev/null +++ b/model_registry.py @@ -0,0 +1,200 @@ +import json +import os +from datetime import datetime, timezone, timedelta + +from config import GOOGLE_GENERATIVE_API_BASE_URL, DEEPSEEK_API_BASE_URL, OLLAMA_API_BASE_URL + +STALENESS_DAYS = 7 +DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configs", "models_pricing.json") + +# Base URLs per provider (non-OpenAI providers that need a custom base_url) +PROVIDER_BASE_URLS = { + "deepseek": DEEPSEEK_API_BASE_URL, + "google": GOOGLE_GENERATIVE_API_BASE_URL, +} + + +class ModelNotFoundError(Exception): + def __init__(self, model_name, available_models): + self.model_name = model_name + self.available_models = available_models + super().__init__( + f"Model '{model_name}' not found. Available models: {', '.join(sorted(available_models))}" + ) + + +class ModelRegistry: + def __init__(self, config_path=None, auto_refresh=True): + self.config_path = config_path or DEFAULT_CONFIG_PATH + self.models = {} + self.last_updated = None + self.tokens_in = {} + self.tokens_out = {} + self._alias_map = {} + self._load() + if auto_refresh and self.is_stale(): + self._try_refresh() + + def _load(self): + try: + with open(self.config_path, "r") as f: + data = json.load(f) + self.models = data.get("models", {}) + self.last_updated = data.get("last_updated") + except (FileNotFoundError, json.JSONDecodeError, OSError): + print(f"Warning: Could not load {self.config_path}, using default models.") + self.models = DEFAULT_MODELS + self.last_updated = None + self._save() + self._build_alias_map() + + def _build_alias_map(self): + self._alias_map = {} + for name, info in self.models.items(): + self._alias_map[name] = name + for alias in info.get("aliases", []): + self._alias_map[alias] = name + + def _save(self): + data = { + "last_updated": self.last_updated or datetime.now(timezone.utc).isoformat(), + "models": self.models, + } + os.makedirs(os.path.dirname(self.config_path), exist_ok=True) + with open(self.config_path, "w") as f: + json.dump(data, f, indent=2) + + def _try_refresh(self): + try: + from model_fetcher import update_models_pricing + updated = update_models_pricing(self.config_path) + if updated: + self._load() + except Exception as e: + print(f"Warning: Auto-refresh failed ({e}), using cached data.") + + def is_stale(self): + if self.last_updated is None: + return True + try: + last = datetime.fromisoformat(self.last_updated) + if last.tzinfo is None: + last = last.replace(tzinfo=timezone.utc) + return datetime.now(timezone.utc) - last > timedelta(days=STALENESS_DAYS) + except (ValueError, TypeError): + return True + + def resolve_alias(self, name): + if name in self._alias_map: + return self._alias_map[name] + # Anthropic models: match by startswith (e.g. claude-3-5-sonnet-20241022) + for canonical, info in self.models.items(): + if info.get("provider") == "anthropic" and name.startswith(canonical): + return canonical + raise ModelNotFoundError(name, list(self.models.keys())) + + def get_model(self, name_or_alias): + canonical = self.resolve_alias(name_or_alias) + return self.models[canonical] + + def get_provider(self, name_or_alias): + return self.get_model(name_or_alias)["provider"] + + def get_api_model_name(self, name_or_alias): + model = self.get_model(name_or_alias) + # For anthropic, if user passed a full versioned name, use it directly + if model["provider"] == "anthropic" and name_or_alias != self.resolve_alias(name_or_alias): + if name_or_alias.startswith(self.resolve_alias(name_or_alias)): + return name_or_alias + return model["api_model_name"] + + def get_base_url(self, name_or_alias): + provider = self.get_provider(name_or_alias) + return PROVIDER_BASE_URLS.get(provider) + + def get_cost_input(self, name_or_alias): + cost = self.get_model(name_or_alias).get("cost_per_million_input") + if cost is None: + return None + return cost / 1_000_000 + + def get_cost_output(self, name_or_alias): + cost = self.get_model(name_or_alias).get("cost_per_million_output") + if cost is None: + return None + return cost / 1_000_000 + + def list_models(self, provider=None): + if provider: + return [name for name, info in self.models.items() if info["provider"] == provider] + return list(self.models.keys()) + + def get_canonical_for_cost(self, name_or_alias): + """Return canonical name used as key in tokens_in/tokens_out dicts.""" + return self.resolve_alias(name_or_alias) + + def curr_cost_est(self): + total = 0.0 + for model_name, count in self.tokens_in.items(): + cost = self.get_cost_input(model_name) + if cost is not None: + total += cost * count + for model_name, count in self.tokens_out.items(): + cost = self.get_cost_output(model_name) + if cost is not None: + total += cost * count + return total + + +# Hardcoded fallback β€” used only when the JSON file is missing or corrupt +DEFAULT_MODELS = { + "gpt-4o": { + "provider": "openai", + "api_model_name": "gpt-4o-2024-08-06", + "aliases": ["gpt4o"], + "cost_per_million_input": 2.50, + "cost_per_million_output": 10.00, + }, + "gpt-4o-mini": { + "provider": "openai", + "api_model_name": "gpt-4o-mini-2024-07-18", + "aliases": ["gpt4omini", "gpt-4omini", "gpt4o-mini"], + "cost_per_million_input": 0.150, + "cost_per_million_output": 0.60, + }, + "o1-mini": { + "provider": "openai", + "api_model_name": "o1-mini-2024-09-12", + "aliases": [], + "cost_per_million_input": 1.10, + "cost_per_million_output": 4.40, + }, + "o3-mini": { + "provider": "openai", + "api_model_name": "o3-mini-2025-01-31", + "aliases": [], + "cost_per_million_input": 1.10, + "cost_per_million_output": 4.40, + }, + "claude-3-5-sonnet": { + "provider": "anthropic", + "api_model_name": "claude-3-5-sonnet", + "aliases": [], + "cost_per_million_input": 3.00, + "cost_per_million_output": 15.00, + }, + "gemini-2.5-flash": { + "provider": "google", + "api_model_name": "gemini-2.5-flash", + "aliases": [], + "cost_per_million_input": 0.30, + "cost_per_million_output": 2.50, + }, + "deepseek-chat": { + "provider": "deepseek", + "api_model_name": "deepseek-chat", + "aliases": [], + "cost_per_million_input": 0.27, + "cost_per_million_output": 1.10, + }, +} diff --git a/papersolver.py b/papersolver.py index 39222f0..885b032 100755 --- a/papersolver.py +++ b/papersolver.py @@ -335,17 +335,17 @@ def clean_text(text): return text def gen_initial_report(self): - num_attempts = 0 arx = ArxivSearch() section_scaffold = str() # 1. Abstract 2. Introduction, 3. Background, 4. Methods, 5. Experimental Setup 6. Results, and 7. Discussion for _section in ["scaffold", "abstract", "introduction", "related work", "background", "methods", "experimental setup", "results", "discussion"]: section_complete = False + num_attempts = 0 # reset per-section so previous section's context never bleeds in if _section in ["introduction", "related work", "background", "methods", "discussion"]: attempts = 0 papers = str() first_attempt = True - while len(papers) == 0: + while not papers: att_str = str() if attempts > 5: break @@ -356,7 +356,7 @@ def gen_initial_report(self): papers = arx.find_papers_by_str(query=search_query, N=10) first_attempt = False attempts += 1 - if len(papers) != 0: + if papers: self.section_related_work[_section] = papers while not section_complete: section_scaffold_temp = copy(section_scaffold) @@ -378,19 +378,26 @@ def gen_initial_report(self): model_resp = self.clean_text(model_resp) if _section == "scaffold": # minimal scaffold (some other sections can be combined) + scaffold_valid = True for _sect in ["[ABSTRACT HERE]", "[INTRODUCTION HERE]", "[METHODS HERE]", "[RESULTS HERE]", "[DISCUSSION HERE]"]: if _sect not in model_resp: - cmd_str = "Error: scaffold section placeholders were not present (e.g. [ABSTRACT HERE])." + cmd_str = f"Error: scaffold missing placeholder {_sect}." print("@@@ INIT ATTEMPT:", cmd_str) - continue + scaffold_valid = False + break + if not scaffold_valid: + num_attempts += 1 + continue elif _section != "scaffold": new_text = extract_prompt(model_resp, "REPLACE") - section_scaffold_temp = section_scaffold_temp.replace(f"[{_section.upper()} HERE]", new_text) - model_resp = '```REPLACE\n' + copy(section_scaffold_temp) + '\n```' if "documentclass{article}" in new_text or "usepackage{" in new_text: - cmd_str = "Error: You must not include packages or documentclass in the text! Your latex must only include the section text, equations, and tables." - print("@@@ INIT ATTEMPT:", cmd_str) - continue + # Model returned a full LaTeX document (correct for REPLACE command). + # Use it directly instead of trying to inject only section text. + model_resp = '```REPLACE\n' + new_text + '\n```' + else: + # Model returned only section text β€” inject into scaffold at placeholder. + section_scaffold_temp = section_scaffold_temp.replace(f"[{_section.upper()} HERE]", new_text) + model_resp = '```REPLACE\n' + copy(section_scaffold_temp) + '\n```' cmd_str, latex_lines, prev_latex_ret, score = self.process_command(model_resp, scoring=False) print(f"@@@ INIT ATTEMPT: Command Exec // Attempt {num_attempts}: ", str(cmd_str).replace("\n", " | ")) #print(f"$$$ Score: {score}") @@ -582,6 +589,3 @@ def phase_prompt(self,): "You are a PhD student who has submitted a paper to an ML conference called ICLR. Your goal was to write a research paper and get high scores from the reviewers so that it get accepted to the conference.\n" ) return phase_str - - - diff --git a/provider.py b/provider.py new file mode 100644 index 0000000..7f8896c --- /dev/null +++ b/provider.py @@ -0,0 +1,149 @@ +import os + +import anthropic +import openai +from openai import OpenAI + + +class OpenaiProvider: + @staticmethod + def get_response( + api_key: str, + model_name: str, + user_prompt: str, + system_prompt: str, + temperature: float = None, + base_url: str | None = None, + ) -> str: + openai.api_key = api_key + client_config = { + "api_key": api_key, + } + + if base_url: + openai.base_url = base_url + client_config["base_url"] = base_url + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ] + + version = openai.__version__ + + if api_key == "ollama": + ollama_max_tokens = int(os.getenv("OLLAMA_MAX_TOKENS", 2048)) + if version == "0.28": + if temperature is None: + completion = openai.ChatCompletion.create( + model=model_name, + messages=messages, + max_tokens=ollama_max_tokens, + ) + else: + completion = openai.ChatCompletion.create( + model=model_name, + messages=messages, + temperature=temperature, + max_tokens=ollama_max_tokens, + ) + else: + client = OpenAI(**client_config) + if temperature is None: + completion = client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=ollama_max_tokens, + ) + else: + completion = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=temperature, + max_tokens=ollama_max_tokens, + ) + else: + if version == "0.28": + if temperature is None: + completion = openai.ChatCompletion.create( + model=model_name, + messages=messages, + ) + else: + completion = openai.ChatCompletion.create( + model=model_name, + messages=messages, + temperature=temperature, + ) + else: + client = OpenAI(**client_config) + if temperature is None: + completion = client.chat.completions.create( + model=model_name, + messages=messages, + ) + else: + completion = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=temperature, + ) + + return completion.choices[0].message.content + +class AnthropicProvider: + @staticmethod + def get_response( + api_key: str, + model_name: str, + user_prompt: str, + system_prompt: str, + temperature: float = None, + ) -> str: + client = anthropic.Anthropic(api_key=api_key) + # Set max_tokens based on model type + if 'opus' in model_name: + max_tokens = 16384 + elif 'sonnet' in model_name: + max_tokens = 8192 + else: # haiku or other + max_tokens = 4096 + + if temperature is None: + message = client.messages.create( + model=model_name, + system=system_prompt, + messages=[{"role": "user", "content": user_prompt}], + max_tokens=max_tokens, + ) + else: + message = client.messages.create( + model=model_name, + system=system_prompt, + messages=[{"role": "user", "content": user_prompt}], + max_tokens=max_tokens, + temperature=temperature, + ) + return message.content[0].text + + +def get_provider_response(provider, api_key, model_name, user_prompt, system_prompt, temperature=None, base_url=None): + """Dispatch to the correct provider based on provider string.""" + if provider == "anthropic": + return AnthropicProvider.get_response( + api_key=api_key, + model_name=model_name, + user_prompt=user_prompt, + system_prompt=system_prompt, + temperature=temperature, + ) + else: + # openai, google, deepseek, ollama all use OpenAI-compatible API + return OpenaiProvider.get_response( + api_key=api_key, + model_name=model_name, + user_prompt=user_prompt, + system_prompt=system_prompt, + temperature=temperature, + base_url=base_url, + ) diff --git a/requirements.txt b/requirements.txt index e08992b..6d84e37 100755 --- a/requirements.txt +++ b/requirements.txt @@ -28,11 +28,14 @@ exceptiongroup==1.2.2 feedparser==6.0.11 filelock==3.16.1 flatbuffers==24.3.25 +flask==3.1.0 +flask-cors==5.0.0 fonttools==4.55.0 frozenlist==1.5.0 fsspec==2024.9.0 gast==0.6.0 google-pasta==0.2.0 +gradio grpcio==1.68.0 h11==0.14.0 h5py==3.12.1 @@ -68,7 +71,7 @@ nest-asyncio==1.6.0 networkx==3.2.1 nltk==3.9.1 numpy==2.0.2 -openai==1.55.1 +openai==1.55.3 opt_einsum==3.4.0 optree==0.13.1 packaging==24.2 @@ -122,6 +125,8 @@ tifffile==2024.8.30 tiktoken==0.8.0 tokenizers==0.20.4 torch==2.5.1 +torchaudio==2.5.1 +torchvision==0.20.1 tqdm==4.67.1 transformers==4.46.3 typer==0.13.1 diff --git a/settings_manager.py b/settings_manager.py new file mode 100644 index 0000000..7bddc64 --- /dev/null +++ b/settings_manager.py @@ -0,0 +1,40 @@ +import json +from pathlib import Path +import logging + +class SettingsManager: + def __init__(self): + self.settings_dir = Path("settings") + self.settings_file = self.settings_dir / "user_settings.json" + self._ensure_settings_dir() + + def _ensure_settings_dir(self): + """Ensure the settings directory exists""" + try: + self.settings_dir.mkdir(exist_ok=True) + except Exception as e: + logging.error(f"Failed to create settings directory: {e}") + + def save_settings(self, settings: dict): + """Save settings to JSON file""" + try: + # Filter out empty API keys before saving + filtered_settings = { + k: v for k, v in settings.items() + if not (k.endswith('_api_key') and not v) + } + + with open(self.settings_file, 'w', encoding='utf-8') as f: + json.dump(filtered_settings, f, indent=2) + except Exception as e: + logging.error(f"Failed to save settings: {e}") + + def load_settings(self) -> dict: + """Load settings from JSON file""" + try: + if self.settings_file.exists(): + with open(self.settings_file, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + logging.error(f"Failed to load settings: {e}") + return {} diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_model_fetcher.py b/tests/test_model_fetcher.py new file mode 100644 index 0000000..83cb0bd --- /dev/null +++ b/tests/test_model_fetcher.py @@ -0,0 +1,219 @@ +import json +import pytest +from unittest.mock import patch, MagicMock +from datetime import datetime, timezone + + +SAMPLE_OPENAI_MODELS_RESPONSE = { + "data": [ + {"id": "gpt-4o", "object": "model"}, + {"id": "gpt-4o-mini", "object": "model"}, + {"id": "text-embedding-ada-002", "object": "model"}, + {"id": "dall-e-3", "object": "model"}, + ] +} + + +SAMPLE_GOOGLE_MODELS_RESPONSE = { + "models": [ + {"name": "models/gemini-2.5-flash", "displayName": "Gemini 2.5 Flash"}, + {"name": "models/embedding-001", "displayName": "Embedding 001"}, + ] +} + + +SAMPLE_AGGREGATOR_RESPONSE = { + "results": [ + { + "model": "gpt-4o", + "provider_name": "OpenAI", + "input_price_per_1m_tokens": 2.50, + "output_price_per_1m_tokens": 10.00, + }, + { + "model": "claude-3-5-sonnet", + "provider_name": "Anthropic", + "input_price_per_1m_tokens": 3.00, + "output_price_per_1m_tokens": 15.00, + }, + { + "model": "gemini-2.5-flash", + "provider_name": "Google", + "input_price_per_1m_tokens": 0.30, + "output_price_per_1m_tokens": 2.50, + }, + { + "model": "some-unknown-provider-model", + "provider_name": "SomeOtherProvider", + "input_price_per_1m_tokens": 1.00, + "output_price_per_1m_tokens": 2.00, + }, + ] +} + + +class TestFetchModelsFromAPI: + @patch("model_fetcher.requests.get") + def test_fetch_openai_models(self, mock_get): + from model_fetcher import fetch_models_from_api + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = SAMPLE_OPENAI_MODELS_RESPONSE + mock_get.return_value = mock_resp + + models = fetch_models_from_api("openai", api_key="test-key") + assert "gpt-4o" in models + assert "gpt-4o-mini" in models + assert "text-embedding-ada-002" not in models + assert "dall-e-3" not in models + + @patch("model_fetcher.requests.get") + def test_fetch_returns_empty_on_failure(self, mock_get): + from model_fetcher import fetch_models_from_api + mock_get.side_effect = Exception("Network error") + models = fetch_models_from_api("openai", api_key="test-key") + assert models == [] + + @patch("model_fetcher.requests.get") + def test_fetch_google_models(self, mock_get): + from model_fetcher import fetch_models_from_api + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = SAMPLE_GOOGLE_MODELS_RESPONSE + mock_get.return_value = mock_resp + + models = fetch_models_from_api("google", api_key="test-key") + assert "gemini-2.5-flash" in models + assert "embedding-001" not in models + + +class TestFetchPricingFromAggregator: + @patch("model_fetcher.requests.get") + def test_fetch_pricing(self, mock_get): + from model_fetcher import fetch_pricing_from_aggregator + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = SAMPLE_AGGREGATOR_RESPONSE + mock_get.return_value = mock_resp + + pricing = fetch_pricing_from_aggregator() + assert "openai" in pricing + assert "anthropic" in pricing + assert "google" in pricing + # Unknown providers should be filtered out + assert "someotherprovider" not in pricing + + assert pricing["openai"]["gpt-4o"]["input"] == 2.50 + assert pricing["openai"]["gpt-4o"]["output"] == 10.00 + assert pricing["anthropic"]["claude-3-5-sonnet"]["input"] == 3.00 + + @patch("model_fetcher.requests.get") + def test_fetch_pricing_returns_empty_on_failure(self, mock_get): + from model_fetcher import fetch_pricing_from_aggregator + mock_get.side_effect = Exception("Network error") + pricing = fetch_pricing_from_aggregator() + assert pricing == {} + + @patch("model_fetcher.requests.get") + def test_fetch_pricing_handles_bad_status(self, mock_get): + from model_fetcher import fetch_pricing_from_aggregator + mock_resp = MagicMock() + mock_resp.status_code = 500 + mock_get.return_value = mock_resp + pricing = fetch_pricing_from_aggregator() + assert pricing == {} + + +class TestMergeModels: + def test_merge_adds_new_model(self): + from model_fetcher import merge_discovered_models + existing = { + "last_updated": "2026-03-01T00:00:00Z", + "models": {} + } + discovered = {"openai": ["gpt-4o-new"]} + result = merge_discovered_models(existing, discovered) + assert "gpt-4o-new" in result["models"] + assert result["models"]["gpt-4o-new"]["provider"] == "openai" + assert result["models"]["gpt-4o-new"]["cost_per_million_input"] is None + + def test_merge_keeps_existing(self): + from model_fetcher import merge_discovered_models + existing = { + "last_updated": "2026-03-01T00:00:00Z", + "models": { + "gpt-4o": { + "provider": "openai", + "api_model_name": "gpt-4o", + "aliases": [], + "cost_per_million_input": 2.50, + "cost_per_million_output": 10.00 + } + } + } + discovered = {"openai": ["gpt-4o"]} + result = merge_discovered_models(existing, discovered) + assert result["models"]["gpt-4o"]["cost_per_million_input"] == 2.50 + + def test_merge_does_not_remove_existing(self): + from model_fetcher import merge_discovered_models + existing = { + "last_updated": "2026-03-01T00:00:00Z", + "models": { + "gpt-4o": { + "provider": "openai", + "api_model_name": "gpt-4o", + "aliases": [], + "cost_per_million_input": 2.50, + "cost_per_million_output": 10.00 + } + } + } + discovered = {"openai": ["gpt-4o-new"]} + result = merge_discovered_models(existing, discovered) + assert "gpt-4o" in result["models"] + + def test_merge_updates_pricing_for_existing_models(self): + from model_fetcher import merge_discovered_models + existing = { + "last_updated": "2026-03-01T00:00:00Z", + "models": { + "gpt-4o": { + "provider": "openai", + "api_model_name": "gpt-4o", + "aliases": [], + "cost_per_million_input": 2.50, + "cost_per_million_output": 10.00 + } + } + } + pricing = {"openai": {"gpt-4o": {"input": 3.00, "output": 12.00}}} + result = merge_discovered_models(existing, {}, pricing) + assert result["models"]["gpt-4o"]["cost_per_million_input"] == 3.00 + assert result["models"]["gpt-4o"]["cost_per_million_output"] == 12.00 + + +class TestUpdateModelsPricing: + @patch("model_fetcher.fetch_pricing_from_aggregator") + @patch("model_fetcher.fetch_models_from_api") + def test_update_writes_file(self, mock_fetch_models, mock_fetch_pricing, tmp_path): + from model_fetcher import update_models_pricing + config_path = str(tmp_path / "models_pricing.json") + existing = { + "last_updated": "2026-01-01T00:00:00Z", + "models": {} + } + with open(config_path, "w") as f: + json.dump(existing, f) + + mock_fetch_models.return_value = ["gpt-4o"] + mock_fetch_pricing.return_value = { + "openai": {"gpt-4o": {"input": 2.50, "output": 10.00}} + } + + result = update_models_pricing(config_path, force=True) + assert result is True + + with open(config_path) as f: + data = json.load(f) + assert data["last_updated"] != "2026-01-01T00:00:00Z" diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py new file mode 100644 index 0000000..ca7fdb6 --- /dev/null +++ b/tests/test_model_registry.py @@ -0,0 +1,174 @@ +import json +import os +import pytest +from unittest.mock import patch +from datetime import datetime, timezone + + +SAMPLE_CONFIG = { + "last_updated": "2026-03-28T00:00:00Z", + "models": { + "gpt-4o": { + "provider": "openai", + "api_model_name": "gpt-4o-2024-08-06", + "aliases": ["gpt4o"], + "cost_per_million_input": 2.50, + "cost_per_million_output": 10.00 + }, + "claude-3-5-sonnet": { + "provider": "anthropic", + "api_model_name": "claude-3-5-sonnet-20241022", + "aliases": ["claude-3-5-sonnet-latest"], + "cost_per_million_input": 3.00, + "cost_per_million_output": 15.00 + }, + "deepseek-chat": { + "provider": "deepseek", + "api_model_name": "deepseek-chat", + "aliases": [], + "cost_per_million_input": 0.27, + "cost_per_million_output": 1.10 + }, + "gemini-2.5-flash": { + "provider": "google", + "api_model_name": "gemini-2.5-flash", + "aliases": [], + "cost_per_million_input": 0.30, + "cost_per_million_output": 2.50 + }, + "null-cost-model": { + "provider": "openai", + "api_model_name": "null-cost-model", + "aliases": [], + "cost_per_million_input": None, + "cost_per_million_output": None + } + } +} + + +@pytest.fixture +def sample_config_path(tmp_path): + config_path = tmp_path / "configs" / "models_pricing.json" + config_path.parent.mkdir(parents=True) + config_path.write_text(json.dumps(SAMPLE_CONFIG)) + return str(config_path) + + +@pytest.fixture +def registry(sample_config_path): + from model_registry import ModelRegistry + return ModelRegistry(config_path=sample_config_path, auto_refresh=False) + + +class TestModelResolution: + def test_resolve_canonical_name(self, registry): + model = registry.get_model("gpt-4o") + assert model["provider"] == "openai" + assert model["api_model_name"] == "gpt-4o-2024-08-06" + + def test_resolve_alias(self, registry): + model = registry.get_model("gpt4o") + assert model["api_model_name"] == "gpt-4o-2024-08-06" + + def test_unknown_model_raises(self, registry): + from model_registry import ModelNotFoundError + with pytest.raises(ModelNotFoundError): + registry.get_model("nonexistent-model") + + def test_resolve_alias_method(self, registry): + assert registry.resolve_alias("gpt4o") == "gpt-4o" + + def test_resolve_canonical_returns_same(self, registry): + assert registry.resolve_alias("gpt-4o") == "gpt-4o" + + def test_anthropic_startswith_matching(self, registry): + model = registry.get_model("claude-3-5-sonnet-20241022") + assert model["provider"] == "anthropic" + + +class TestProviderRouting: + def test_get_provider(self, registry): + assert registry.get_provider("gpt-4o") == "openai" + assert registry.get_provider("claude-3-5-sonnet") == "anthropic" + assert registry.get_provider("deepseek-chat") == "deepseek" + assert registry.get_provider("gemini-2.5-flash") == "google" + + def test_get_api_model_name(self, registry): + assert registry.get_api_model_name("gpt-4o") == "gpt-4o-2024-08-06" + + def test_get_base_url(self, registry): + assert registry.get_base_url("gpt-4o") is None + assert "deepseek" in registry.get_base_url("deepseek-chat") + assert "generativelanguage" in registry.get_base_url("gemini-2.5-flash") + + +class TestCostEstimation: + def test_get_cost_input(self, registry): + assert registry.get_cost_input("gpt-4o") == 2.50 / 1_000_000 + + def test_get_cost_output(self, registry): + assert registry.get_cost_output("gpt-4o") == 10.00 / 1_000_000 + + def test_null_cost_returns_none(self, registry): + assert registry.get_cost_input("null-cost-model") is None + assert registry.get_cost_output("null-cost-model") is None + + def test_curr_cost_est_empty(self, registry): + assert registry.curr_cost_est() == 0.0 + + def test_curr_cost_est_with_tokens(self, registry): + registry.tokens_in["gpt-4o"] = 1000 + registry.tokens_out["gpt-4o"] = 500 + cost = registry.curr_cost_est() + expected = 1000 * (2.50 / 1_000_000) + 500 * (10.00 / 1_000_000) + assert abs(cost - expected) < 1e-10 + + +class TestListModels: + def test_list_all(self, registry): + models = registry.list_models() + assert "gpt-4o" in models + assert "claude-3-5-sonnet" in models + + def test_list_by_provider(self, registry): + models = registry.list_models(provider="openai") + assert "gpt-4o" in models + assert "claude-3-5-sonnet" not in models + + +class TestFallback: + def test_missing_file_uses_defaults(self, tmp_path): + from model_registry import ModelRegistry + bad_path = str(tmp_path / "nonexistent" / "models_pricing.json") + reg = ModelRegistry(config_path=bad_path, auto_refresh=False) + assert len(reg.list_models()) > 0 + + def test_corrupt_file_uses_defaults(self, tmp_path): + from model_registry import ModelRegistry + bad_file = tmp_path / "bad.json" + bad_file.write_text("not json{{{") + reg = ModelRegistry(config_path=str(bad_file), auto_refresh=False) + assert len(reg.list_models()) > 0 + + +class TestStaleness: + def test_is_stale_when_old(self, sample_config_path): + from model_registry import ModelRegistry + with open(sample_config_path) as f: + data = json.load(f) + data["last_updated"] = "2026-02-01T00:00:00Z" + with open(sample_config_path, "w") as f: + json.dump(data, f) + reg = ModelRegistry(config_path=sample_config_path, auto_refresh=False) + assert reg.is_stale() is True + + def test_is_not_stale_when_fresh(self, sample_config_path): + from model_registry import ModelRegistry + with open(sample_config_path) as f: + data = json.load(f) + data["last_updated"] = datetime.now(timezone.utc).isoformat() + with open(sample_config_path, "w") as f: + json.dump(data, f) + reg = ModelRegistry(config_path=sample_config_path, auto_refresh=False) + assert reg.is_stale() is False diff --git a/tools.py b/tools.py index 5d0d4a9..e189171 100755 --- a/tools.py +++ b/tools.py @@ -263,27 +263,58 @@ def find_papers_by_str(self, query, N=20): def retrieve_full_paper_text(self, query): pdf_text = str() - paper = next(arxiv.Client().results(arxiv.Search(id_list=[query]))) - # Download the PDF to the PWD with a custom filename. - paper.download_pdf(filename="downloaded-paper.pdf") - # creating a pdf reader object - reader = PdfReader('downloaded-paper.pdf') - # Iterate over all the pages - for page_number, page in enumerate(reader.pages, start=1): - # Extract text from the page + pdf_filename = "downloaded-paper.pdf" # Temporary PDF filename + + try: + # Fetch the paper try: - text = page.extract_text() + paper = next(arxiv.Client().results(arxiv.Search(id_list=[query]))) + except StopIteration: + print(f"No results found for query: {query}") + return "NO RESULTS FOUND" + except arxiv.HTTPError as e: + print(f"Failed to fetch paper {query}: {e}") + return f"HTTP ERROR: {e}" except Exception as e: - os.remove("downloaded-paper.pdf") - time.sleep(2.0) - return "EXTRACTION FAILED" - - # Do something with the text (e.g., print it) - pdf_text += f"--- Page {page_number} ---" - pdf_text += text - pdf_text += "\n" - os.remove("downloaded-paper.pdf") - time.sleep(2.0) + print(f"An unexpected error occurred while fetching paper: {e}") + return f"UNEXPECTED ERROR: {e}" + + # Download the PDF file + try: + paper.download_pdf(filename=pdf_filename) + except Exception as e: + print(f"Failed to download PDF for paper {query}: {e}") + return f"DOWNLOAD ERROR: {e}" + + # Read the PDF file and extract text + try: + reader = PdfReader(pdf_filename) + for page_number, page in enumerate(reader.pages, start=1): + try: + text = page.extract_text() + pdf_text += f"--- Page {page_number} ---\n" + pdf_text += text + pdf_text += "\n" + except Exception as e: + print(f"Failed to extract text from page {page_number}: {e}") + pdf_text += f"--- Page {page_number} ---\n" + pdf_text += "EXTRACTION FAILED\n" + except Exception as e: + print(f"Failed to read PDF file: {e}") + return f"PDF READ ERROR: {e}" + + except Exception as e: + print(f"An unexpected error occurred: {e}") + return f"UNEXPECTED ERROR: {e}" + finally: + # Clean up temporary files + if os.path.exists(pdf_filename): + try: + os.remove(pdf_filename) + except Exception as e: + print(f"Failed to delete temporary file {pdf_filename}: {e}") + time.sleep(2.0) # Avoid frequent requests + return pdf_text """ diff --git a/update_models.py b/update_models.py new file mode 100644 index 0000000..3688d74 --- /dev/null +++ b/update_models.py @@ -0,0 +1,32 @@ +"""CLI tool to refresh model pricing data. + +Usage: + python update_models.py # Update if stale (>7 days) + python update_models.py --force # Force update regardless of freshness +""" +import argparse +import sys + +from model_fetcher import update_models_pricing +from model_registry import DEFAULT_CONFIG_PATH + + +def main(): + parser = argparse.ArgumentParser(description="Update model pricing data") + parser.add_argument("--force", action="store_true", help="Force update regardless of freshness") + parser.add_argument("--config", default=DEFAULT_CONFIG_PATH, help="Path to models_pricing.json") + args = parser.parse_args() + + try: + updated = update_models_pricing(args.config, force=args.force) + if updated: + print("Done.") + else: + print("No update needed. Use --force to override.") + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/utils.py b/utils.py index a163273..8e103d8 100755 --- a/utils.py +++ b/utils.py @@ -118,4 +118,77 @@ def extract_prompt(text, word): extracted_code = "\n".join(code_blocks).strip() return extracted_code - +def build_task_note(task_note, **kwargs): + # Replace the `{{variable}}` placeholders in the task note with the provided values + for note in task_note: + for key, value in kwargs.items(): + placeholder = f"{{{{{key}}}}}" + note["note"] = note["note"].replace(placeholder, str(value)) + return task_note + +def remove_thinking_process(text): + """ + Remove the first occurrence of a substring enclosed in ... or ..., + even if it spans multiple lines. + """ + pattern = r'<(?:thinking|think)>.*?' + # Using re.DOTALL allows '.' to match newline characters. + return re.sub(pattern, '', text, count=1, flags=re.DOTALL) + +# Define allowed phases and variables according to your guide +ALLOWED_PHASES = [ + "literature review", "plan formulation", + "data preparation", "running experiments", + "results interpretation", "report writing", + "report refinement" +] + +ALLOWED_VARIABLES = { + "research_topic", "api_key", "deepseek_api_key", + "google_api_key", "anthropic_api_key", "language", + "llm_backend" +} + +def validate_task_note_config(task_note_config): + """ + Validate the task note configuration based on the allowed phases and variables. + """ + # Ensure the configuration is a list + if not isinstance(task_note_config, list): + raise ValueError("Configuration must be a list.") + + for idx, note in enumerate(task_note_config): + # Each note should be a dictionary + if not isinstance(note, dict): + raise ValueError(f"Entry {idx} must be a dictionary.") + + # Must contain both 'phases' and 'note' + if "phases" not in note or "note" not in note: + raise ValueError(f"Entry {idx} must have both 'phases' and 'note' keys.") + + # Validate phases: it must be a list and contain only allowed values + phases = note["phases"] + if not isinstance(phases, list): + raise ValueError(f"'phases' in entry {idx} must be a list.") + for phase in phases: + if phase not in ALLOWED_PHASES: + raise ValueError( + f"Invalid phase '{phase}' in entry {idx}. " + f"Allowed phases are: {ALLOWED_PHASES}" + ) + + # Validate note: it must be a string + text = note["note"] + if not isinstance(text, str): + raise ValueError(f"'note' in entry {idx} must be a string.") + + # Validate the variables inside the note using a regex that matches double curly braces + variables_found = re.findall(r"{{\s*(\w+)\s*}}", text) + for var in variables_found: + if var not in ALLOWED_VARIABLES: + raise ValueError( + f"Invalid variable '{var}' in note in entry {idx}. " + f"Allowed variables are: {ALLOWED_VARIABLES}" + ) + + return True