diff --git a/README.md b/README.md index ea8df1eac..bc79c3346 100644 --- a/README.md +++ b/README.md @@ -195,6 +195,27 @@ api_key = "env:MY_PROVIDER_KEY" channel = "my-provider/my-model" ``` +**Azure OpenAI Service** — configure Azure OpenAI deployments: + +```toml +[llm.provider.azure] +api_type = "azure" +base_url = "https://{resource-name}.openai.azure.com" +api_key = "env:AZURE_API_KEY" +api_version = "2024-06-01" # required +deployment = "gpt-4o" # required — your deployment name + +[defaults.routing] +channel = "azure/gpt-4o" +worker = "azure/gpt-4o-mini" +``` + +Important notes: +- `base_url` must end with `.openai.azure.com` +- `api_version` and `deployment` are required fields +- API key authentication is handled automatically via the `api-key` header +- For Azure AI Foundry (accessing Anthropic, Llama, or other models through Azure's model catalog), use `api_type = "openai_chat_completions"` instead and configure the deployment endpoint accordingly + Additional built-in providers include **Kilo Gateway**, **OpenCode Go**, **NVIDIA**, **MiniMax**, **Moonshot AI (Kimi)**, and **Z.AI Coding Plan** — configure with `kilo_key`, `opencode_go_key`, `nvidia_key`, `minimax_key`, `moonshot_key`, or `zai_coding_plan_key` in `[llm]`. ### Skills diff --git a/docs/content/docs/(configuration)/config.mdx b/docs/content/docs/(configuration)/config.mdx index e85ab5a7d..f0115567d 100644 --- a/docs/content/docs/(configuration)/config.mdx +++ b/docs/content/docs/(configuration)/config.mdx @@ -376,6 +376,11 @@ If you define a custom provider with the same ID as a legacy key, your custom co | `mistral_key` | string | None | Mistral API key (`secret:NAME`, `env:VAR_NAME`, or literal) | | `opencode_zen_key` | string | None | OpenCode Zen API key (`secret:NAME`, `env:VAR_NAME`, or literal) | | `opencode_go_key` | string | None | OpenCode Go API key (`secret:NAME`, `env:VAR_NAME`, or literal) | +| `gemini_key` | string | None | Gemini API key (`secret:NAME`, `env:VAR_NAME`, or literal) | +| `nvidia_key` | string | None | NVIDIA API key (`secret:NAME`, `env:VAR_NAME`, or literal) | +| `minimax_key` | string | None | MiniMax API key (`secret:NAME`, `env:VAR_NAME`, or literal) | +| `moonshot_key` | string | None | Moonshot API key (`secret:NAME`, `env:VAR_NAME`, or literal) | +| `github_copilot_key` | string | None | GitHub Copilot PAT (`secret:NAME`, `env:VAR_NAME`, or literal) | #### Custom Providers @@ -391,10 +396,12 @@ name = "My Provider" # Optional - friendly name for display | Field | Type | Required | Description | |-------|------|----------|-------------| -| `api_type` | string | Yes | API protocol type. One of: `anthropic`, `openai_completions`, `openai_chat_completions`, `openai_responses`, `gemini`, or `kilo_gateway` | -| `base_url` | string | Yes | Base URL of the API endpoint. Must be a valid URL (including protocol) | +| `api_type` | string | Yes | API protocol type. One of: `anthropic`, `openai_completions`, `openai_chat_completions`, `openai_responses`, `gemini`, `kilo_gateway`, or `azure` | +| `base_url` | string | Yes | Base URL of the API endpoint. Must be a valid URL (including protocol). For Azure, must end with `.openai.azure.com` | | `api_key` | string | Yes | API key for authentication. Supports `secret:NAME` and `env:VAR_NAME` syntax | | `name` | string | No | Optional friendly name for the provider (displayed in logs and UI) | +| `api_version` | string | Azure only | Azure API version (format: `YYYY-MM-DD` or `YYYY-MM-DD-preview`) | +| `deployment` | string | Azure only | Azure deployment name (alphanumeric, hyphens, and dots allowed) | > Note: > - For `openai_completions`, `openai_chat_completions`, and `openai_responses`, configure `base_url` as the provider root URL (usually without a trailing `/v1`). @@ -421,15 +428,23 @@ api_key = "env:CUSTOM_ANTHROPIC_KEY" name = "Anthropic EU" ``` -**OpenAI Chat Completions provider:** +**Azure OpenAI provider:** ```toml -[llm.provider.azure_openai] -api_type = "openai_responses" -base_url = "https://my-azure-openai.openai.azure.com" -api_key = "env:AZURE_OPENAI_KEY" -name = "Azure OpenAI GPT-4" +[llm.provider.azure] +api_type = "azure" +base_url = "https://my-resource.openai.azure.com" +api_key = "env:AZURE_API_KEY" +api_version = "2024-02-15" # Required for Azure +deployment = "gpt-4o" # Required for Azure (deployment name) +name = "Azure OpenAI" ``` +> **Azure Requirements:** +> - `base_url` must end with `.openai.azure.com` +> - `api_version` must match format: `YYYY-MM-DD` or `YYYY-MM-DD-preview` +> - `deployment` can contain alphanumeric characters, hyphens, and dots (e.g., `gpt-4o`, `gpt-5.2`) +> - Model names in routing should use the format: `azure/` + **OpenAI Completions provider:** ```toml [llm.provider.local_llm] diff --git a/docs/content/docs/(core)/routing.mdx b/docs/content/docs/(core)/routing.mdx index 28bc5b195..bcced75f7 100644 --- a/docs/content/docs/(core)/routing.mdx +++ b/docs/content/docs/(core)/routing.mdx @@ -26,6 +26,10 @@ branch = "anthropic/claude-sonnet-4-20250514" worker = "anthropic/claude-haiku-4.5-20250514" compactor = "anthropic/claude-haiku-4.5-20250514" cortex = "anthropic/claude-haiku-4.5-20250514" + +# Azure example: +# channel = "azure/gpt-4o" +# worker = "azure/gpt-4o-mini" ``` | Process | Why this model tier | @@ -36,6 +40,11 @@ cortex = "anthropic/claude-haiku-4.5-20250514" | Compactor | Summarization and memory extraction. Fast and cheap. No personality needed. | | Cortex | System-level observation. Small context, simple signal processing. Cheapest tier. | +**Model Format:** +- Standard providers: `/` (e.g., `anthropic/claude-sonnet-4-20250514`, `openai/gpt-4o`) +- Azure: `azure/` (e.g., `azure/gpt-4o`, `azure/gpt-5.2`) +- Custom providers: `/` (e.g., `my_anthropic/claude-3.5-sonnet`) + ### Level 2: Task-Type Overrides Workers and branches are generic. Different tasks benefit from different models. The channel or branch specifies a task type when spawning, and the routing config maps task types to models. diff --git a/interface/src/api/client.ts b/interface/src/api/client.ts index 9493f76db..2bae0744f 100644 --- a/interface/src/api/client.ts +++ b/interface/src/api/client.ts @@ -1593,28 +1593,44 @@ export const api = { // Provider management providers: () => fetchJson("/providers"), - updateProvider: async (provider: string, apiKey: string, model: string) => { + updateProvider: async (provider: string, apiKey: string, model: string, baseUrl?: string, apiVersion?: string, deployment?: string) => { const response = await fetch(`${getApiBase()}/providers`, { - method: "PUT", + method: "POST", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ provider, api_key: apiKey, model }), + body: JSON.stringify({ provider, api_key: apiKey, model, base_url: baseUrl, api_version: apiVersion, deployment }), }); if (!response.ok) { throw new Error(`API error: ${response.status}`); } return response.json() as Promise; }, - testProviderModel: async (provider: string, apiKey: string, model: string) => { - const response = await fetch(`${getApiBase()}/providers/test`, { + testProviderModel: async (provider: string, apiKey: string, model: string, baseUrl?: string, apiVersion?: string, deployment?: string) => { + const response = await fetch(`${getApiBase()}/providers/test-model`, { method: "POST", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ provider, api_key: apiKey, model }), + body: JSON.stringify({ provider, api_key: apiKey, model, base_url: baseUrl, api_version: apiVersion, deployment }), }); if (!response.ok) { throw new Error(`API error: ${response.status}`); } return response.json() as Promise; }, + getProviderConfig: async (provider: string, options?: { signal?: AbortSignal }) => { + const response = await fetch(`${getApiBase()}/providers/${provider}/config`, { + method: "GET", + signal: options?.signal, + }); + if (!response.ok) { + throw new Error(`API error: ${response.status}`); + } + return response.json() as Promise<{ + success: boolean; + message: string; + base_url?: string | null; + api_version?: string | null; + deployment?: string | null; + }>; + }, startOpenAiOAuthBrowser: async (params: {model: string}) => { const response = await fetch(`${getApiBase()}/providers/openai/oauth/browser/start`, { method: "POST", diff --git a/interface/src/lib/providerIcons.tsx b/interface/src/lib/providerIcons.tsx index f0207ede2..07b2d99e2 100644 --- a/interface/src/lib/providerIcons.tsx +++ b/interface/src/lib/providerIcons.tsx @@ -140,6 +140,7 @@ export function ProviderIcon({ provider, className = "text-ink-faint", size = 24 "minimax-cn": Minimax, moonshot: Kimi, // Kimi is Moonshot AI's product brand "github-copilot": GithubCopilot, + azure: OpenAI, }; const IconComponent = iconMap[provider.toLowerCase()]; diff --git a/interface/src/routes/Settings.tsx b/interface/src/routes/Settings.tsx index b7e834190..895cc279c 100644 --- a/interface/src/routes/Settings.tsx +++ b/interface/src/routes/Settings.tsx @@ -251,6 +251,14 @@ const PROVIDERS = [ envVar: "GITHUB_COPILOT_API_KEY", defaultModel: "github-copilot/claude-sonnet-4", }, + { + id: "azure", + name: "Azure OpenAI", + description: "Azure OpenAI Service with custom deployments", + placeholder: "Azure API key (alphanumeric string)", + envVar: "AZURE_API_KEY", + defaultModel: "azure/gpt-4o", + }, { id: "ollama", name: "Ollama", @@ -305,6 +313,11 @@ export function Settings() { type: "success" | "error"; } | null>(null); + const [azureBaseUrl, setAzureBaseUrl] = useState(""); + const [azureApiVersion, setAzureApiVersion] = useState(""); + const [azureDeployment, setAzureDeployment] = useState(""); + const fetchAbortControllerRef = useRef(null); + // Fetch providers data (only when on providers tab) const { data, isLoading } = useQuery({ queryKey: ["providers"], @@ -322,8 +335,8 @@ export function Settings() { }); const updateMutation = useMutation({ - mutationFn: ({ provider, apiKey, model }: { provider: string; apiKey: string; model: string }) => - api.updateProvider(provider, apiKey, model), + mutationFn: ({ provider, apiKey, model, baseUrl, apiVersion, deployment }: { provider: string; apiKey: string; model: string; baseUrl?: string; apiVersion?: string; deployment?: string }) => + api.updateProvider(provider, apiKey, model, baseUrl, apiVersion, deployment), onSuccess: (result) => { if (result.success) { setEditingProvider(null); @@ -348,8 +361,14 @@ export function Settings() { }); const testModelMutation = useMutation({ - mutationFn: ({ provider, apiKey, model }: { provider: string; apiKey: string; model: string }) => - api.testProviderModel(provider, apiKey, model), + mutationFn: ({ provider, apiKey, model, baseUrl, apiVersion, deployment }: { + provider: string; + apiKey: string; + model: string; + baseUrl?: string; + apiVersion?: string; + deployment?: string; + }) => api.testProviderModel(provider, apiKey, model, baseUrl, apiVersion, deployment), }); const startOpenAiBrowserOAuthMutation = useMutation({ mutationFn: (params: { model: string }) => api.startOpenAiOAuthBrowser(params), @@ -372,20 +391,49 @@ export function Settings() { const editingProviderData = PROVIDERS.find((p) => p.id === editingProvider); - const currentSignature = `${editingProvider ?? ""}|${keyInput.trim()}|${modelInput.trim()}`; + const currentSignature = `${editingProvider ?? ""}|${keyInput.trim()}|${editingProvider === "azure" ? azureDeployment.trim() : modelInput.trim()}`; const oauthAutoStartRef = useRef(false); const oauthAbortRef = useRef(null); const handleTestModel = async (): Promise => { - if (!editingProvider || !keyInput.trim() || !modelInput.trim()) return false; + if (!editingProvider || !modelInput.trim()) return false; + + if (editingProvider === "azure") { + if (!keyInput.trim()) { + setTestResult({ success: false, message: "API key is required for Azure OpenAI" }); + return false; + } + if (!azureBaseUrl.trim()) { + setTestResult({ success: false, message: "Base URL is required for Azure OpenAI" }); + return false; + } + if (!azureApiVersion.trim()) { + setTestResult({ success: false, message: "API Version is required for Azure OpenAI" }); + return false; + } + if (!azureDeployment.trim()) { + setTestResult({ success: false, message: "Deployment Name is required for Azure OpenAI" }); + return false; + } + const normalizedBaseUrl = azureBaseUrl.trim().replace(/\/+$/, ''); + if (!normalizedBaseUrl.endsWith(".openai.azure.com")) { + setTestResult({ success: false, message: "Base URL must end with '.openai.azure.com' (e.g., https://{resource-name}.openai.azure.com)" }); + return false; + } + } + setMessage(null); setTestResult(null); try { + const azureModel = editingProvider === "azure" ? `azure/${azureDeployment.trim()}` : modelInput.trim(); const result = await testModelMutation.mutateAsync({ provider: editingProvider, apiKey: keyInput.trim(), - model: modelInput.trim(), + model: azureModel, + baseUrl: editingProvider === "azure" ? azureBaseUrl.trim().replace(/\/+$/, '') : undefined, + apiVersion: editingProvider === "azure" ? azureApiVersion.trim() : undefined, + deployment: editingProvider === "azure" ? azureDeployment.trim() : undefined, }); setTestResult({ success: result.success, message: result.message, sample: result.sample }); if (result.success) { @@ -403,18 +451,54 @@ export function Settings() { }; const handleSave = async () => { - if (!keyInput.trim() || !editingProvider || !modelInput.trim()) return; + if (!editingProvider || !modelInput.trim()) return; + + if (editingProvider === "azure") { + if (!keyInput.trim()) { + setMessage({ text: "API key is required for Azure OpenAI", type: "error" }); + return; + } + if (!azureBaseUrl.trim()) { + setMessage({ text: "Base URL is required for Azure OpenAI", type: "error" }); + return; + } + if (!azureApiVersion.trim()) { + setMessage({ text: "API Version is required for Azure OpenAI", type: "error" }); + return; + } + if (!azureDeployment.trim()) { + setMessage({ text: "Deployment Name is required for Azure OpenAI", type: "error" }); + return; + } + const normalizedBaseUrl = azureBaseUrl.trim().replace(/\/+$/, ''); + if (!normalizedBaseUrl.endsWith(".openai.azure.com")) { + setMessage({ text: "Base URL must end with '.openai.azure.com'", type: "error" }); + return; + } + } if (testedSignature !== currentSignature) { const testPassed = await handleTestModel(); if (!testPassed) return; } - updateMutation.mutate({ - provider: editingProvider, - apiKey: keyInput.trim(), - model: modelInput.trim(), - }); + if (editingProvider === "azure") { + const azureModel = `azure/${azureDeployment.trim()}`; + updateMutation.mutate({ + provider: editingProvider, + apiKey: keyInput.trim(), + model: azureModel, + baseUrl: azureBaseUrl.trim().replace(/\/+$/, ''), + apiVersion: azureApiVersion.trim(), + deployment: azureDeployment.trim(), + }); + } else { + updateMutation.mutate({ + provider: editingProvider, + apiKey: keyInput.trim(), + model: modelInput.trim(), + }); + } }; const monitorOpenAiBrowserOAuth = async (stateToken: string, signal: AbortSignal) => { @@ -560,11 +644,15 @@ export function Settings() { }; const handleClose = () => { + fetchAbortControllerRef.current?.abort(); setEditingProvider(null); setKeyInput(""); setModelInput(""); setTestedSignature(null); setTestResult(null); + setAzureBaseUrl(""); + setAzureApiVersion(""); + setAzureDeployment(""); }; const isConfigured = (providerId: string): boolean => { @@ -653,6 +741,38 @@ export function Settings() { setTestedSignature(null); setTestResult(null); setMessage(null); + if (provider.id === "azure") { + // Reset Azure fields before hydrating + setAzureBaseUrl(""); + setAzureApiVersion(""); + setAzureDeployment(""); + + // Cancel previous request + fetchAbortControllerRef.current?.abort(); + + // Create new abort controller + const abortController = new AbortController(); + fetchAbortControllerRef.current = abortController; + + api.getProviderConfig("azure", { signal: abortController.signal }) + .then((result) => { + // Check if aborted + if (abortController.signal.aborted) return; + if (!result.success) return; + + setAzureBaseUrl(result.base_url ?? ""); + setAzureApiVersion(result.api_version ?? ""); + const deployment = result.deployment ?? ""; + setAzureDeployment(deployment); + if (deployment) { + setModelInput(`azure/${deployment}`); + } + }) + .catch((error) => { + if (error.name === 'AbortError') return; + console.error("Failed to fetch Azure config:", error); + }); + } }} onRemove={() => removeMutation.mutate(provider.id)} removing={removeMutation.isPending} @@ -733,43 +853,131 @@ export function Settings() { {isConfigured(editingProvider ?? "") ? "Update" : "Add"}{" "} - {editingProvider === "ollama" ? "Endpoint" : "API Key"} + {editingProvider === "ollama" || editingProvider === "azure" ? "Endpoint" : "API Key"} {editingProvider === "ollama" ? `Enter your ${editingProviderData?.name} base URL. It will be saved to your instance config.` + : editingProvider === "azure" + ? "Enter your Azure OpenAI configuration. API key, base URL, API version, and deployment name are required." : editingProvider === "openai" ? "Enter an OpenAI API key. The model below will be applied to routing." : `Enter your ${editingProviderData?.name} API key. It will be saved to your instance config.`} - { - setKeyInput(e.target.value); - setTestedSignature(null); - }} - placeholder={editingProviderData?.placeholder} - autoFocus - onKeyDown={(e) => { - if (e.key === "Enter") handleSave(); - }} - /> - { - setModelInput(value); - setTestedSignature(null); - }} - provider={editingProvider ?? undefined} - /> -
+ {editingProvider === "azure" ? ( + <> + { + setKeyInput(e.target.value); + setTestedSignature(null); + }} + placeholder="Azure API key" + autoFocus + onKeyDown={(e) => { + if (e.key === "Enter") handleSave(); + }} + /> +
+ + { + setAzureBaseUrl(e.target.value); + setTestedSignature(null); + }} + placeholder="https://{resource-name}.openai.azure.com" + onKeyDown={(e) => { + if (e.key === "Enter") handleSave(); + }} + /> +

+ Must end with '.openai.azure.com' +

+
+
+ + { + setAzureApiVersion(e.target.value); + setTestedSignature(null); + }} + placeholder="2024-06-01" + onKeyDown={(e) => { + if (e.key === "Enter") handleSave(); + }} + /> +

+ For example: 2024-06-01, 2024-10-01-preview +

+
+
+ + { + setAzureDeployment(e.target.value); + setTestedSignature(null); + }} + placeholder="gpt-4o" + onKeyDown={(e) => { + if (e.key === "Enter") handleSave(); + }} + /> +

+ Your Azure OpenAI deployment name +

+
+
+ + +

+ Model is auto-generated from deployment name +

+
+ + ) : ( + <> + { + setKeyInput(e.target.value); + setTestedSignature(null); + }} + placeholder={editingProviderData?.placeholder} + autoFocus + onKeyDown={(e) => { + if (e.key === "Enter") handleSave(); + }} + /> + { + setModelInput(value); + setTestedSignature(null); + }} + provider={editingProvider ?? undefined} + /> + + )} +
- + {editingProvider === "azure" ? ( + <> + {isConfigured(editingProvider ?? "") && ( + + )} + + + + ) : ( + <> + + + + )} diff --git a/src/api/providers.rs b/src/api/providers.rs index 3b923e9b5..894f8c2e9 100644 --- a/src/api/providers.rs +++ b/src/api/providers.rs @@ -61,6 +61,7 @@ pub(super) struct ProviderStatus { moonshot: bool, zai_coding_plan: bool, github_copilot: bool, + azure: bool, } #[derive(Serialize, utoipa::ToSchema)] @@ -74,6 +75,13 @@ pub(super) struct ProviderUpdateRequest { provider: String, api_key: String, model: String, + // Azure-specific fields (optional, required for Azure) + #[serde(default)] + base_url: Option, + #[serde(default)] + api_version: Option, + #[serde(default)] + deployment: Option, } #[derive(Serialize, utoipa::ToSchema)] @@ -87,6 +95,13 @@ pub(super) struct ProviderModelTestRequest { provider: String, api_key: String, model: String, + // Azure-specific fields (optional, required for Azure) + #[serde(default)] + base_url: Option, + #[serde(default)] + api_version: Option, + #[serde(default)] + deployment: Option, } #[derive(Serialize, utoipa::ToSchema)] @@ -388,6 +403,7 @@ pub(super) async fn get_providers( moonshot, zai_coding_plan, github_copilot, + azure, ) = if config_path.exists() { let content = tokio::fs::read_to_string(&config_path) .await @@ -455,6 +471,12 @@ pub(super) async fn get_providers( has_value("moonshot_key", "MOONSHOT_API_KEY"), has_value("zai_coding_plan_key", "ZAI_CODING_PLAN_API_KEY"), has_value("github_copilot_key", "GITHUB_COPILOT_API_KEY"), + doc.get("llm") + .and_then(|llm| llm.get("provider")) + .and_then(|provider| provider.get("azure")) + .and_then(|azure| azure.get("base_url")) + .and_then(|base_url| base_url.as_str()) + .is_some_and(|url| !url.trim().is_empty()), ) } else { ( @@ -480,6 +502,7 @@ pub(super) async fn get_providers( env_set("MOONSHOT_API_KEY"), env_set("ZAI_CODING_PLAN_API_KEY"), env_set("GITHUB_COPILOT_API_KEY"), + false, ) }; @@ -506,6 +529,7 @@ pub(super) async fn get_providers( moonshot, zai_coding_plan, github_copilot, + azure, }; let has_any = providers.anthropic || providers.openai @@ -528,7 +552,8 @@ pub(super) async fn get_providers( || providers.minimax_cn || providers.moonshot || providers.zai_coding_plan - || providers.github_copilot; + || providers.github_copilot + || providers.azure; Ok(Json(ProvidersResponse { providers, has_any })) } @@ -808,7 +833,20 @@ pub(super) async fn update_provider( Json(request): Json, ) -> Result, StatusCode> { let normalized_provider = request.provider.trim().to_lowercase(); - let normalized_model = request.model.trim(); + let normalized_model = request.model.trim().to_string(); + + if normalized_provider == "azure" { + let azure_request = ProviderUpdateRequest { + provider: request.provider, + api_key: request.api_key, + model: request.model, + base_url: request.base_url, + api_version: request.api_version, + deployment: request.deployment, + }; + return update_azure_provider(state, azure_request, &normalized_model).await; + } + let Some(key_name) = provider_toml_key(&normalized_provider) else { return Ok(Json(ProviderUpdateResponse { success: false, @@ -830,7 +868,7 @@ pub(super) async fn update_provider( })); } - if !model_matches_provider(&normalized_provider, normalized_model) { + if !model_matches_provider(&normalized_provider, &normalized_model) { return Ok(Json(ProviderUpdateResponse { success: false, message: format!( @@ -859,7 +897,7 @@ pub(super) async fn update_provider( } doc["llm"][key_name] = toml_edit::value(request.api_key); - apply_model_routing(&mut doc, normalized_model); + apply_model_routing(&mut doc, &normalized_model); tokio::fs::write(&config_path, doc.to_string()) .await @@ -882,6 +920,284 @@ pub(super) async fn update_provider( })) } +async fn update_azure_provider( + state: Arc, + request: ProviderUpdateRequest, + normalized_model: &str, +) -> Result, StatusCode> { + let base_url = request.base_url.as_ref().ok_or(StatusCode::BAD_REQUEST)?; + let normalized_base_url = base_url.trim().trim_end_matches('/'); + + if !normalized_base_url.ends_with(".openai.azure.com") { + return Ok(Json(ProviderUpdateResponse { + success: false, + message: "Base URL must end with .openai.azure.com".to_string(), + })); + } + + let api_version = request + .api_version + .as_ref() + .ok_or(StatusCode::BAD_REQUEST)?; + let api_version_regex = regex::Regex::new(r"^\d{4}-\d{2}-\d{2}(-preview)?$").map_err(|e| { + tracing::error!(error = %e, "failed to compile api_version regex"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + if !api_version_regex.is_match(api_version.trim()) { + return Ok(Json(ProviderUpdateResponse { + success: false, + message: "API version must match format: YYYY-MM-DD or YYYY-MM-DD-preview".to_string(), + })); + } + + let deployment = request.deployment.as_ref().ok_or(StatusCode::BAD_REQUEST)?; + let deployment_regex = regex::Regex::new(r"^[a-zA-Z0-9.-]+$").map_err(|e| { + tracing::error!(error = %e, "failed to compile deployment regex"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + if !deployment_regex.is_match(deployment.trim()) { + return Ok(Json(ProviderUpdateResponse { + success: false, + message: "Deployment name must contain only alphanumeric characters, hyphens, and dots" + .to_string(), + })); + } + + if normalized_model.is_empty() { + return Ok(Json(ProviderUpdateResponse { + success: false, + message: "Model cannot be empty".into(), + })); + } + + let normalized_deployment = request.deployment.as_ref().map(|s| s.trim()).unwrap_or(""); + let azure_model = format!("azure/{}", normalized_deployment); + if !model_matches_provider("azure", &azure_model) { + return Ok(Json(ProviderUpdateResponse { + success: false, + message: format!( + "Deployment '{}' does not match provider 'azure'.", + normalized_deployment + ), + })); + } + + let config_path = state.config_path.read().await.clone(); + let content = if config_path.exists() { + tokio::fs::read_to_string(&config_path) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + } else { + String::new() + }; + + let mut doc: toml_edit::DocumentMut = content + .parse() + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + // Determine the API key: use incoming if non-empty, otherwise preserve existing + let api_key = if request.api_key.trim().is_empty() { + // Read existing API key from config + match doc + .get("llm") + .and_then(|llm| llm.get("provider")) + .and_then(|provider| provider.get("azure")) + .and_then(|azure| azure.get("api_key")) + .and_then(|v| v.as_str()) + .map(String::from) + { + Some(key) => key, + None => { + return Ok(Json(ProviderUpdateResponse { + success: false, + message: "API key is required but no existing key found".to_string(), + })); + } + } + } else { + request.api_key.trim().to_string() + }; + + if doc.get("llm").is_none() { + doc["llm"] = toml_edit::Item::Table(toml_edit::Table::new()); + } + if doc["llm"].get("provider").is_none() { + doc["llm"]["provider"] = toml_edit::Item::Table(toml_edit::Table::new()); + } + if doc["llm"]["provider"].get("azure").is_none() { + doc["llm"]["provider"]["azure"] = toml_edit::Item::Table(toml_edit::Table::new()); + } + + let azure_table = doc["llm"]["provider"]["azure"].as_table_mut().unwrap(); + azure_table["api_type"] = toml_edit::value("azure"); + azure_table["base_url"] = toml_edit::value(base_url.trim()); + azure_table["api_key"] = toml_edit::value(api_key.trim()); + azure_table["api_version"] = toml_edit::value(api_version.trim()); + azure_table["deployment"] = toml_edit::value(deployment.trim()); + + if doc.get("defaults").is_none() { + doc["defaults"] = toml_edit::Item::Table(toml_edit::Table::new()); + } + + if let Some(defaults) = doc.get_mut("defaults").and_then(|item| item.as_table_mut()) { + if defaults.get("routing").is_none() { + defaults["routing"] = toml_edit::Item::Table(toml_edit::Table::new()); + } + + if let Some(routing_table) = defaults + .get_mut("routing") + .and_then(|item| item.as_table_mut()) + { + if routing_table.get("channel").is_none() { + routing_table["channel"] = + toml_edit::value(format!("azure/{}", normalized_deployment)); + } + if routing_table.get("branch").is_none() { + routing_table["branch"] = + toml_edit::value(format!("azure/{}", normalized_deployment)); + } + if routing_table.get("worker").is_none() { + routing_table["worker"] = + toml_edit::value(format!("azure/{}", normalized_deployment)); + } + if routing_table.get("compactor").is_none() { + routing_table["compactor"] = + toml_edit::value(format!("azure/{}", normalized_deployment)); + } + if routing_table.get("cortex").is_none() { + routing_table["cortex"] = + toml_edit::value(format!("azure/{}", normalized_deployment)); + } + } + } + + tokio::fs::write(&config_path, doc.to_string()) + .await + .map_err(|e| { + tracing::error!(error = %e, "failed to write config.toml for azure provider"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + refresh_defaults_config(&state).await; + + state + .provider_setup_tx + .try_send(crate::ProviderSetupEvent::ProvidersConfigured) + .ok(); + + Ok(Json(ProviderUpdateResponse { + success: true, + message: format!( + "Azure provider configured. Deployment '{}' with model '{}' verified and applied to defaults and the default agent routing.", + deployment, normalized_model + ), + })) +} + +#[derive(Serialize, utoipa::ToSchema)] +pub(super) struct ProviderConfigResponse { + success: bool, + message: String, + #[serde(skip_serializing_if = "Option::is_none")] + base_url: Option, + #[serde(skip_serializing_if = "Option::is_none")] + api_version: Option, + #[serde(skip_serializing_if = "Option::is_none")] + deployment: Option, + // Note: api_key is intentionally excluded for security. + // Credentials should never be returned to the client. +} + +#[utoipa::path( + get, + path = "/providers/{provider}/config", + responses( + (status = 200, body = ProviderConfigResponse), + (status = 404, description = "Provider not found"), + ), + tag = "providers", + params( + ("provider" = String, Path, description = "Provider ID"), + ), +)] +pub(super) async fn get_provider_config( + State(state): State>, + axum::extract::Path(provider): axum::extract::Path, +) -> Result, StatusCode> { + let normalized_provider = provider.trim().to_lowercase(); + + // Only Azure needs special config retrieval + if normalized_provider != "azure" { + return Ok(Json(ProviderConfigResponse { + success: true, + message: "No additional configuration needed for this provider".to_string(), + base_url: None, + api_version: None, + deployment: None, + })); + } + + let config_path = state.config_path.read().await.clone(); + if !config_path.exists() { + return Ok(Json(ProviderConfigResponse { + success: false, + message: "No config file found".to_string(), + base_url: None, + api_version: None, + deployment: None, + })); + } + + let content = tokio::fs::read_to_string(&config_path) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + let doc: toml_edit::DocumentMut = content + .parse() + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + // Get Azure config from [llm.provider.azure] + let azure_config = doc + .get("llm") + .and_then(|llm| llm.get("provider")) + .and_then(|provider| provider.get("azure")); + + if let Some(azure_table) = azure_config.and_then(|item| item.as_table_like()) { + let base_url = azure_table + .get("base_url") + .and_then(|v| v.as_str()) + .map(String::from); + let api_version = azure_table + .get("api_version") + .and_then(|v| v.as_str()) + .map(String::from); + let deployment = azure_table + .get("deployment") + .and_then(|v| v.as_str()) + .map(String::from); + + if base_url.is_some() || api_version.is_some() || deployment.is_some() { + return Ok(Json(ProviderConfigResponse { + success: true, + message: "Azure configuration found".to_string(), + base_url, + api_version, + deployment, + })); + } + } + + Ok(Json(ProviderConfigResponse { + success: false, + message: "Azure configuration not found".to_string(), + base_url: None, + api_version: None, + deployment: None, + })) +} + #[utoipa::path( post, path = "/providers/test-model", @@ -893,11 +1209,16 @@ pub(super) async fn update_provider( tag = "providers", )] pub(super) async fn test_provider_model( + State(state): State>, Json(request): Json, ) -> Result, StatusCode> { let normalized_provider = request.provider.trim().to_lowercase(); let normalized_model = request.model.trim().to_string(); - if provider_toml_key(&normalized_provider).is_none() { + + // Azure is handled specially and doesn't have a TOML key + if normalized_provider == "azure" { + // Azure validation happens later in the function + } else if provider_toml_key(&normalized_provider).is_none() { return Ok(Json(ProviderModelTestResponse { success: false, message: format!("Unknown provider: {}", request.provider), @@ -907,15 +1228,46 @@ pub(super) async fn test_provider_model( })); } - if request.api_key.trim().is_empty() { - return Ok(Json(ProviderModelTestResponse { - success: false, - message: "API key cannot be empty".to_string(), - provider: request.provider, - model: request.model, - sample: None, - })); - } + // Determine the API key to use + let api_key_to_use = if request.api_key.trim().is_empty() { + if normalized_provider == "azure" { + // For Azure, try to use the existing stored key from config + let config_path = state.config_path.read().await.clone(); + if config_path.exists() { + let content = tokio::fs::read_to_string(&config_path).await.ok(); + if let Some(doc) = content.and_then(|c| c.parse::().ok()) { + doc.get("llm") + .and_then(|llm| llm.get("provider")) + .and_then(|provider| provider.get("azure")) + .and_then(|azure| azure.get("api_key")) + .and_then(|v| v.as_str()) + .map(String::from) + } else { + None + } + } else { + None + } + } else { + None + } + } else { + Some(request.api_key.trim().to_string()) + }; + + // If no key found, return error + let api_key = match api_key_to_use { + Some(key) => key, + None => { + return Ok(Json(ProviderModelTestResponse { + success: false, + message: "API key is required but not provided".to_string(), + provider: request.provider, + model: request.model, + sample: None, + })); + } + }; if normalized_model.is_empty() { return Ok(Json(ProviderModelTestResponse { @@ -940,7 +1292,138 @@ pub(super) async fn test_provider_model( })); } - let llm_config = build_test_llm_config(&normalized_provider, request.api_key.trim()); + if normalized_provider == "azure" { + let base_url = request.base_url.as_ref().ok_or(StatusCode::BAD_REQUEST)?; + let normalized_base_url = base_url.trim().trim_end_matches('/'); + + if !normalized_base_url.ends_with(".openai.azure.com") { + return Ok(Json(ProviderModelTestResponse { + success: false, + message: "Base URL must end with .openai.azure.com".to_string(), + provider: request.provider, + model: request.model, + sample: None, + })); + } + + let api_version = request + .api_version + .as_ref() + .ok_or(StatusCode::BAD_REQUEST)?; + let api_version_regex = + regex::Regex::new(r"^\d{4}-\d{2}-\d{2}(-preview)?$").map_err(|e| { + tracing::error!(error = %e, "failed to compile api_version regex"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + if !api_version_regex.is_match(api_version.trim()) { + return Ok(Json(ProviderModelTestResponse { + success: false, + message: "API version must match format: YYYY-MM-DD or YYYY-MM-DD-preview" + .to_string(), + provider: request.provider, + model: request.model, + sample: None, + })); + } + + let deployment = request.deployment.as_ref().ok_or(StatusCode::BAD_REQUEST)?; + let deployment_regex = regex::Regex::new(r"^[a-zA-Z0-9.-]+$").map_err(|e| { + tracing::error!(error = %e, "failed to compile deployment regex"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + if !deployment_regex.is_match(deployment.trim()) { + return Ok(Json(ProviderModelTestResponse { + success: false, + message: + "Deployment name must contain only alphanumeric characters, hyphens, and dots" + .to_string(), + provider: request.provider, + model: request.model, + sample: None, + })); + } + + let llm_config = crate::config::LlmConfig { + anthropic_key: None, + openai_key: None, + openrouter_key: None, + kilo_key: None, + zhipu_key: None, + groq_key: None, + together_key: None, + fireworks_key: None, + deepseek_key: None, + xai_key: None, + mistral_key: None, + gemini_key: None, + ollama_key: None, + ollama_base_url: None, + opencode_zen_key: None, + opencode_go_key: None, + nvidia_key: None, + minimax_key: None, + minimax_cn_key: None, + moonshot_key: None, + zai_coding_plan_key: None, + github_copilot_key: None, + providers: { + let mut providers = HashMap::new(); + providers.insert( + "azure".to_string(), + crate::config::ProviderConfig { + api_type: crate::config::ApiType::Azure, + base_url: base_url.trim().to_string(), + api_key: api_key.trim().to_string(), + name: None, + use_bearer_auth: false, + extra_headers: Vec::new(), + api_version: Some(api_version.trim().to_string()), + deployment: Some(deployment.trim().to_string()), + }, + ); + providers + }, + }; + + let llm_manager = match crate::llm::LlmManager::new(llm_config).await { + Ok(manager) => Arc::new(manager), + Err(error) => { + return Ok(Json(ProviderModelTestResponse { + success: false, + message: format!("Failed to initialize provider: {error}"), + provider: request.provider, + model: request.model, + sample: None, + })); + } + }; + + let model = crate::llm::SpacebotModel::make(&llm_manager, normalized_model); + let agent = AgentBuilder::new(model) + .preamble("You are running a provider connectivity check. Reply with exactly: OK") + .build(); + + return match agent.prompt("Connection test").await { + Ok(sample) => Ok(Json(ProviderModelTestResponse { + success: true, + message: "Model responded successfully".to_string(), + provider: request.provider, + model: request.model, + sample: Some(sample), + })), + Err(error) => Ok(Json(ProviderModelTestResponse { + success: false, + message: format!("Model test failed: {error}"), + provider: request.provider, + model: request.model, + sample: None, + })), + }; + } + + let llm_config = build_test_llm_config(&normalized_provider, api_key.trim()); let llm_manager = match crate::llm::LlmManager::new(llm_config).await { Ok(manager) => Arc::new(manager), Err(error) => { @@ -1029,6 +1512,47 @@ pub(super) async fn delete_provider( } } + if provider == "azure" { + let config_path = state.config_path.read().await.clone(); + if !config_path.exists() { + return Ok(Json(ProviderUpdateResponse { + success: false, + message: "No config file found".into(), + })); + } + + let content = tokio::fs::read_to_string(&config_path) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + let mut doc: toml_edit::DocumentMut = content + .parse() + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + if let Some(llm) = doc.get_mut("llm") + && let Some(llm_table) = llm.as_table_mut() + && let Some(provider_table) = llm_table.get_mut("provider") + && let Some(provider_tbl) = provider_table.as_table_mut() + { + provider_tbl.remove("azure"); + if provider_tbl.is_empty() { + llm_table.remove("provider"); + } + } + + tokio::fs::write(&config_path, doc.to_string()) + .await + .map_err(|e| { + tracing::error!(error = %e, "failed to write config after azure removal"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + return Ok(Json(ProviderUpdateResponse { + success: true, + message: "Provider 'azure' removed".into(), + })); + } + let Some(key_name) = provider_toml_key(&provider) else { return Ok(Json(ProviderUpdateResponse { success: false, diff --git a/src/api/server.rs b/src/api/server.rs index 4e35d4543..4e78f137d 100644 --- a/src/api/server.rs +++ b/src/api/server.rs @@ -185,6 +185,7 @@ pub fn api_router() -> OpenApiRouter> { .routes(routes!(providers::openai_browser_oauth_status)) .routes(routes!(providers::test_provider_model)) .routes(routes!(providers::delete_provider)) + .routes(routes!(providers::get_provider_config)) // Model routes .routes(routes!(models::get_models)) .routes(routes!(models::refresh_models)) diff --git a/src/config/load.rs b/src/config/load.rs index 046585d1d..bf34c991e 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -487,6 +487,8 @@ impl Config { name: None, use_bearer_auth: anthropic_from_auth_token, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -500,6 +502,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: openrouter_extra_headers(), + api_version: None, + deployment: None, }); } @@ -571,6 +575,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -584,6 +590,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -597,6 +605,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -610,6 +620,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: openrouter_extra_headers(), + api_version: None, + deployment: None, }); } @@ -651,6 +663,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -664,6 +678,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -677,6 +693,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -690,6 +708,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -703,6 +723,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -716,6 +738,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -729,6 +753,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -742,6 +768,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -755,6 +783,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -768,6 +798,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -781,6 +813,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -794,6 +828,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -807,6 +843,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -823,6 +861,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -1149,6 +1189,8 @@ impl Config { name: config.name, use_bearer_auth: false, extra_headers, + api_version: config.api_version, + deployment: config.deployment, }, )) }) @@ -1175,6 +1217,8 @@ impl Config { name: None, use_bearer_auth: anthropic_from_auth_token, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -1188,6 +1232,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -1201,6 +1247,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: openrouter_extra_headers(), + api_version: None, + deployment: None, }); } @@ -1272,6 +1320,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -1285,6 +1335,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -1298,6 +1350,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -1311,6 +1365,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -1324,6 +1380,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -1337,6 +1395,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -1350,6 +1410,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -1363,6 +1425,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -1376,6 +1440,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -1389,6 +1455,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -1402,6 +1470,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } @@ -1418,6 +1488,8 @@ impl Config { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }); } diff --git a/src/config/providers.rs b/src/config/providers.rs index 091ca71cc..575b9cb2b 100644 --- a/src/config/providers.rs +++ b/src/config/providers.rs @@ -61,6 +61,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "openai" => ProviderConfig { api_type: ApiType::OpenAiCompletions, @@ -69,6 +71,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "openrouter" => ProviderConfig { api_type: ApiType::OpenAiCompletions, @@ -77,6 +81,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: openrouter_extra_headers(), + api_version: None, + deployment: None, }, "kilo" => ProviderConfig { api_type: ApiType::KiloGateway, @@ -85,6 +91,8 @@ pub(crate) fn default_provider_config( name: Some("Kilo Gateway".to_string()), use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "zhipu" => ProviderConfig { api_type: ApiType::OpenAiChatCompletions, @@ -93,6 +101,8 @@ pub(crate) fn default_provider_config( name: Some("Z.AI (GLM)".to_string()), use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "groq" => ProviderConfig { api_type: ApiType::OpenAiCompletions, @@ -101,6 +111,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "together" => ProviderConfig { api_type: ApiType::OpenAiCompletions, @@ -109,6 +121,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "fireworks" => ProviderConfig { api_type: ApiType::OpenAiCompletions, @@ -117,6 +131,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "deepseek" => ProviderConfig { api_type: ApiType::OpenAiCompletions, @@ -125,6 +141,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "xai" => ProviderConfig { api_type: ApiType::OpenAiCompletions, @@ -133,6 +151,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "mistral" => ProviderConfig { api_type: ApiType::OpenAiCompletions, @@ -141,6 +161,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "gemini" => ProviderConfig { api_type: ApiType::Gemini, @@ -149,6 +171,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "ollama" => ProviderConfig { api_type: ApiType::OpenAiCompletions, @@ -157,6 +181,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "opencode-zen" => ProviderConfig { api_type: ApiType::OpenAiCompletions, @@ -165,6 +191,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "opencode-go" => ProviderConfig { api_type: ApiType::OpenAiCompletions, @@ -173,6 +201,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "nvidia" => ProviderConfig { api_type: ApiType::OpenAiCompletions, @@ -181,6 +211,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "minimax" => ProviderConfig { api_type: ApiType::Anthropic, @@ -189,6 +221,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "minimax-cn" => ProviderConfig { api_type: ApiType::Anthropic, @@ -197,6 +231,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "moonshot" => ProviderConfig { api_type: ApiType::OpenAiCompletions, @@ -205,6 +241,8 @@ pub(crate) fn default_provider_config( name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, "zai-coding-plan" => ProviderConfig { api_type: ApiType::OpenAiChatCompletions, @@ -213,6 +251,8 @@ pub(crate) fn default_provider_config( name: Some("Z.AI Coding Plan".to_string()), use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }, // GitHub Copilot requires token exchange and dynamic base URL derivation. // The test path should use LlmManager::get_github_copilot_provider() instead. @@ -240,6 +280,8 @@ pub(super) fn add_shorthand_provider( name: name.map(str::to_string), use_bearer_auth, extra_headers: vec![], + api_version: None, + deployment: None, }); } } diff --git a/src/config/toml_schema.rs b/src/config/toml_schema.rs index 1ca93fdcf..5e19bb4d1 100644 --- a/src/config/toml_schema.rs +++ b/src/config/toml_schema.rs @@ -152,6 +152,10 @@ pub(super) struct TomlProviderConfig { pub(super) base_url: String, pub(super) api_key: String, pub(super) name: Option, + #[serde(default)] + pub(super) api_version: Option, + #[serde(default)] + pub(super) deployment: Option, } #[derive(Deserialize, Default)] diff --git a/src/config/types.rs b/src/config/types.rs index 887b023a7..e4943bbb7 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -187,6 +187,8 @@ pub enum ApiType { Anthropic, /// Google Gemini API (https://generativelanguage.googleapis.com/v1beta/openai/chat/completions) Gemini, + /// Azure OpenAI API (https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}) + Azure, } impl<'de> serde::Deserialize<'de> for ApiType { @@ -201,16 +203,17 @@ impl<'de> serde::Deserialize<'de> for ApiType { "openai_responses" => Ok(Self::OpenAiResponses), "anthropic" => Ok(Self::Anthropic), "gemini" => Ok(Self::Gemini), + "azure" => Ok(Self::Azure), other => Err(serde::de::Error::invalid_value( serde::de::Unexpected::Str(other), - &"one of \"openai_completions\", \"openai_chat_completions\", \"kilo_gateway\", \"openai_responses\", \"anthropic\", or \"gemini\"", + &"one of \"openai_completions\", \"openai_chat_completions\", \"kilo_gateway\", \"openai_responses\", \"anthropic\", \"gemini\", or \"azure\"", )), } } } /// Configuration for a single LLM provider. -#[derive(Clone)] +#[derive(Clone, serde::Deserialize)] pub struct ProviderConfig { pub api_type: ApiType, pub base_url: String, @@ -219,10 +222,18 @@ pub struct ProviderConfig { /// When true, use `Authorization: Bearer` instead of `x-api-key` for /// Anthropic requests. Set automatically when the key originates from /// `ANTHROPIC_AUTH_TOKEN` (proxy-compatible auth). + #[serde(default)] pub use_bearer_auth: bool, /// Additional HTTP headers included in requests to this provider. /// Currently applied in `call_openai()` (the `OpenAiCompletions` path). + #[serde(default)] pub extra_headers: Vec<(String, String)>, + /// Azure API version (e.g., "2024-12-01-preview"). Required for Azure providers. + #[serde(default)] + pub api_version: Option, + /// Azure deployment name (e.g., "gpt-4o"). Required for Azure providers. + #[serde(default)] + pub deployment: Option, } impl std::fmt::Debug for ProviderConfig { diff --git a/src/llm/manager.rs b/src/llm/manager.rs index d3fe79cb8..1c4d59883 100644 --- a/src/llm/manager.rs +++ b/src/llm/manager.rs @@ -224,6 +224,8 @@ impl LlmManager { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }), (None, None) => Err(LlmError::UnknownProvider("anthropic".to_string()).into()), } @@ -294,6 +296,8 @@ impl LlmManager { name: None, use_bearer_auth: false, extra_headers: vec![], + api_version: None, + deployment: None, }), None => Err(LlmError::UnknownProvider("openai-chatgpt".to_string()).into()), } @@ -406,6 +410,8 @@ impl LlmManager { COPILOT_EDITOR_PLUGIN_VERSION.to_string(), ), ], + api_version: None, + deployment: None, }) } diff --git a/src/llm/model.rs b/src/llm/model.rs index ac418e44a..55f1b2745 100644 --- a/src/llm/model.rs +++ b/src/llm/model.rs @@ -164,6 +164,83 @@ impl SpacebotModel { ) .await } + ApiType::Azure => { + // Azure OpenAI Service requires a specific endpoint structure. + // Supported domain: *.openai.azure.com (HTTPS only) + let base_url = provider_config.base_url.trim_end_matches('/'); + + // Validate HTTPS scheme + if !base_url.starts_with("https://") { + return Err(CompletionError::ProviderError(format!( + "Invalid Azure endpoint. Azure OpenAI Service requires HTTPS.\n\ + \n\ + Detected: {}\n\ + \n\ + The endpoint must use https:// (e.g., https://.openai.azure.com)", + base_url + ))); + } + + // Validate that the endpoint is actually an Azure OpenAI endpoint + if !base_url.ends_with(".openai.azure.com") { + return Err(CompletionError::ProviderError(format!( + "Invalid Azure endpoint. Azure OpenAI Service requires a base_url ending in '.openai.azure.com'.\n\ + \n\ + Detected: {}\n\ + \n\ + If you are using Azure AI Foundry (hosting Anthropic, Llama, Mistral, etc.) or other Azure-hosted models,\n\ + use the standard OpenAI-compatible provider instead:\n\ + - Set api_type = \"openai_chat_completions\"\n\ + - Point base_url directly to your endpoint (e.g., https://.services.ai.azure.com)\n\ + - Omit the 'deployment' and 'api_version' fields\n\ + \n\ + For Azure OpenAI Service, the endpoint must follow this pattern:\n\ + https://.openai.azure.com", + base_url + ))); + } + + let resource = base_url + .trim_start_matches("https://") + .trim_end_matches(".openai.azure.com"); + + let deployment = provider_config + .deployment + .as_ref() + .ok_or_else(|| CompletionError::ProviderError( + "Azure deployment name is required. Example: 'gpt-4o', 'gpt-35-turbo', etc.\n\ + This is the deployment name you created in the Azure Portal for your OpenAI model." + .to_string() + ))?; + + let api_version = provider_config.api_version.as_ref().ok_or_else(|| { + CompletionError::ProviderError( + "Azure API version is required. Example: '2024-12-01-preview'\n\ + Find available API versions in the Azure OpenAI documentation." + .to_string(), + ) + })?; + + let endpoint = format!( + "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}", + resource, deployment, api_version + ); + + let display_name = provider_config.name.as_deref().unwrap_or("Azure OpenAI"); + + // Azure uses "api-key" header instead of Authorization + let headers: Vec<(&str, &str)> = + vec![("api-key", provider_config.api_key.as_str())]; + + self.call_openai_compatible_with_optional_auth( + request, + display_name, + &endpoint, + None, // No Bearer token needed + &headers, + ) + .await + } ApiType::KiloGateway => { let endpoint = format!( "{}/chat/completions", @@ -527,6 +604,82 @@ impl CompletionModel for SpacebotModel { ) .await } + ApiType::Azure => { + // Azure OpenAI Service requires a specific endpoint structure. + // Supported domain: *.openai.azure.com (HTTPS only) + let base_url = provider_config.base_url.trim_end_matches('/'); + + // Validate HTTPS scheme + if !base_url.starts_with("https://") { + return Err(CompletionError::ProviderError(format!( + "Invalid Azure endpoint. Azure OpenAI Service requires HTTPS.\n\ + \n\ + Detected: {}\n\ + \n\ + The endpoint must use https:// (e.g., https://.openai.azure.com)", + base_url + ))); + } + + // Validate that the endpoint is actually an Azure OpenAI endpoint + if !base_url.ends_with(".openai.azure.com") { + return Err(CompletionError::ProviderError(format!( + "Invalid Azure endpoint. Azure OpenAI Service requires a base_url ending in '.openai.azure.com'.\n\ + \n\ + Detected: {}\n\ + \n\ + If you are using Azure AI Foundry (hosting Anthropic, Llama, Mistral, etc.) or other Azure-hosted models,\n\ + use the standard OpenAI-compatible provider instead:\n\ + - Set api_type = \"openai_chat_completions\"\n\ + - Point base_url directly to your endpoint (e.g., https://.services.ai.azure.com)\n\ + - Omit the 'deployment' and 'api_version' fields\n\ + \n\ + For Azure OpenAI Service, the endpoint must follow this pattern:\n\ + https://.openai.azure.com", + base_url + ))); + } + + let resource = base_url + .trim_start_matches("https://") + .trim_end_matches(".openai.azure.com"); + + let deployment = provider_config + .deployment + .as_ref() + .ok_or_else(|| CompletionError::ProviderError( + "Azure deployment name is required. Example: 'gpt-4o', 'gpt-35-turbo', etc.\n\ + This is the deployment name you created in the Azure Portal for your OpenAI model." + .to_string() + ))?; + + let api_version = provider_config.api_version.as_ref().ok_or_else(|| { + CompletionError::ProviderError( + "Azure API version is required. Example: '2024-12-01-preview'\n\ + Find available API versions in the Azure OpenAI documentation." + .to_string(), + ) + })?; + + let endpoint = format!( + "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}", + resource, deployment, api_version + ); + + let display_name = provider_config.name.as_deref().unwrap_or("Azure OpenAI"); + + let headers: Vec<(&str, &str)> = + vec![("api-key", provider_config.api_key.as_str())]; + + self.stream_openai_compatible_with_optional_auth( + request, + display_name, + &endpoint, + None, + &headers, + ) + .await + } ApiType::KiloGateway => { let endpoint = format!( "{}/chat/completions", @@ -1166,6 +1319,14 @@ impl SpacebotModel { let endpoint_path = match provider_config.api_type { ApiType::OpenAiCompletions | ApiType::OpenAiResponses => "/v1/chat/completions", ApiType::OpenAiChatCompletions | ApiType::Gemini => "/chat/completions", + ApiType::Azure => { + // Azure handles its own endpoint construction in the call() match + // This fallback should not be reached for Azure + return Err(CompletionError::ProviderError( + "Azure provider should use the dedicated Azure endpoint construction in call()" + .to_string(), + )); + } ApiType::Anthropic => { return Err(CompletionError::ProviderError(format!( "{provider_display_name} is configured with anthropic API type, but this call expects an OpenAI-compatible API"