diff --git a/ai_scientist/llm.py b/ai_scientist/llm.py index b79c6c67..e2594c19 100644 --- a/ai_scientist/llm.py +++ b/ai_scientist/llm.py @@ -47,6 +47,8 @@ "vertex_ai/claude-3-5-sonnet@20241022", "vertex_ai/claude-3-sonnet@20240229", "vertex_ai/claude-3-haiku@20240307", + # Anthropic Claude models via Azure AI Foundry + "azure_foundry/claude-opus-4-5", # Google Gemini models "gemini-2.0-flash", "gemini-2.5-flash-preview-04-17", @@ -489,6 +491,21 @@ def create_client(model) -> tuple[Any, str]: client_model = model.split("/")[-1] print(f"Using Vertex AI with model {client_model}.") return anthropic.AnthropicVertex(), client_model + elif model.startswith("azure_foundry/") and "claude" in model: + client_model = model.split("/", 1)[-1] + azure_endpoint = os.environ.get("AZURE_FOUNDRY_ENDPOINT") + azure_api_key = os.environ.get("AZURE_FOUNDRY_API_KEY") + + if not azure_endpoint: + raise ValueError("AZURE_FOUNDRY_ENDPOINT environment variable not set") + if not azure_api_key: + raise ValueError("AZURE_FOUNDRY_API_KEY environment variable not set") + + print(f"Using Azure AI Foundry with model {client_model}.") + return anthropic.Anthropic( + base_url=azure_endpoint, + api_key=azure_api_key, + ), client_model elif model.startswith("ollama/"): print(f"Using Ollama with model {model}.") return openai.OpenAI( diff --git a/ai_scientist/treesearch/backend/backend_anthropic.py b/ai_scientist/treesearch/backend/backend_anthropic.py index effa6573..d074ebe3 100644 --- a/ai_scientist/treesearch/backend/backend_anthropic.py +++ b/ai_scientist/treesearch/backend/backend_anthropic.py @@ -14,9 +14,25 @@ anthropic.APIStatusError, ) -def get_ai_client(model : str, max_retries=2) -> anthropic.AnthropicBedrock: - client = anthropic.AnthropicBedrock(max_retries=max_retries) - return client +def get_ai_client(model: str, max_retries=2) -> anthropic.Anthropic: + """Get appropriate Anthropic client based on model prefix.""" + if model.startswith("azure_foundry/"): + azure_endpoint = os.environ.get("AZURE_FOUNDRY_ENDPOINT") + azure_api_key = os.environ.get("AZURE_FOUNDRY_API_KEY") + + if not azure_endpoint: + raise ValueError("AZURE_FOUNDRY_ENDPOINT environment variable not set") + if not azure_api_key: + raise ValueError("AZURE_FOUNDRY_API_KEY environment variable not set") + + return anthropic.Anthropic( + base_url=azure_endpoint, + api_key=azure_api_key, + max_retries=max_retries, + ) + else: + # Default to Bedrock for backward compatibility + return anthropic.AnthropicBedrock(max_retries=max_retries) def query( system_message: str | None, @@ -24,7 +40,12 @@ def query( func_spec: FunctionSpec | None = None, **model_kwargs, ) -> tuple[OutputType, float, int, int, dict]: - client = get_ai_client(model_kwargs.get("model"), max_retries=0) + model = model_kwargs.get("model") + client = get_ai_client(model, max_retries=0) + + # Strip provider prefix from model name for API call + if model and model.startswith("azure_foundry/"): + model_kwargs["model"] = model.split("/", 1)[-1] filtered_kwargs: dict = select_values(notnone, model_kwargs) # type: ignore if "max_tokens" not in filtered_kwargs: