diff --git a/interface/src/components/ModelSelect.tsx b/interface/src/components/ModelSelect.tsx index c44b5173c..693b246f7 100644 --- a/interface/src/components/ModelSelect.tsx +++ b/interface/src/components/ModelSelect.tsx @@ -1,7 +1,9 @@ import { api, type ModelInfo } from "@/api/client"; import { Input } from "@/ui"; import { useQuery } from "@tanstack/react-query"; -import { useEffect, useMemo, useRef, useState } from "react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { ArrowDown01Icon, Search01Icon } from "@hugeicons/core-free-icons"; +import { HugeiconsIcon } from "@hugeicons/react"; interface ModelSelectProps { label: string; @@ -40,6 +42,28 @@ function formatContextWindow(tokens: number | null): string { return `${Math.round(tokens / 1000)}K`; } +const providerOrder = [ + "openrouter", + "kilo", + "anthropic", + "openai", + "openai-chatgpt", + "github-copilot", + "ollama", + "deepseek", + "xai", + "mistral", + "gemini", + "groq", + "together", + "fireworks", + "zhipu", + "opencode-zen", + "opencode-go", + "minimax", + "minimax-cn", +]; + export function ModelSelect({ label, description, @@ -50,10 +74,12 @@ export function ModelSelect({ }: ModelSelectProps) { const [open, setOpen] = useState(false); const [filter, setFilter] = useState(""); + const [highlightIndex, setHighlightIndex] = useState(-1); const containerRef = useRef(null); const inputRef = useRef(null); + const listRef = useRef(null); - const { data } = useQuery({ + const { data, isLoading, isError } = useQuery({ queryKey: ["models", provider ?? "configured", capability ?? "all"], queryFn: () => api.models(provider, capability), staleTime: 60_000, @@ -69,10 +95,16 @@ export function ModelSelect({ (m) => m.id.toLowerCase().includes(query) || m.name.toLowerCase().includes(query) || - m.provider.toLowerCase().includes(query), + m.provider.toLowerCase().includes(query) || + (PROVIDER_LABELS[m.provider] ?? "").toLowerCase().includes(query), ); }, [models, filter]); + const providerRank = (p: string) => { + const index = providerOrder.indexOf(p); + return index === -1 ? Number.MAX_SAFE_INTEGER : index; + }; + const grouped = useMemo(() => { const groups: Record = {}; for (const model of filtered) { @@ -86,6 +118,28 @@ export function ModelSelect({ return groups; }, [filtered]); + const sortedProviders = useMemo( + () => Object.keys(grouped).sort((a, b) => providerRank(a) - providerRank(b)), + [grouped], + ); + + // Flat list for keyboard navigation + const flatList = useMemo(() => { + const items: ModelInfo[] = []; + for (const p of sortedProviders) { + for (const m of grouped[p]) { + items.push(m); + } + } + return items; + }, [sortedProviders, grouped]); + + // Find display name for current value + const selectedModel = useMemo( + () => models.find((m) => m.id === value), + [models, value], + ); + // Close on outside click useEffect(() => { const handler = (e: MouseEvent) => { @@ -95,69 +149,98 @@ export function ModelSelect({ ) { setOpen(false); setFilter(""); + setHighlightIndex(-1); } }; document.addEventListener("mousedown", handler); return () => document.removeEventListener("mousedown", handler); }, []); - const handleSelect = (modelId: string) => { - onChange(modelId); - setOpen(false); - setFilter(""); - }; + // Scroll highlighted item into view + useEffect(() => { + if (highlightIndex < 0 || !listRef.current) return; + const items = listRef.current.querySelectorAll("[data-model-item]"); + items[highlightIndex]?.scrollIntoView({ block: "nearest" }); + }, [highlightIndex]); + + const handleSelect = useCallback( + (modelId: string) => { + onChange(modelId); + setOpen(false); + setFilter(""); + setHighlightIndex(-1); + }, + [onChange], + ); + + // Track whether blur should skip committing (e.g. after Escape) + const suppressBlurCommitRef = useRef(false); const handleInputChange = (e: React.ChangeEvent) => { const val = e.target.value; setFilter(val); + setHighlightIndex(-1); if (!open) setOpen(true); - // Allow free-form input for custom model IDs - onChange(val); }; const handleFocus = () => { setOpen(true); - // Start filtering from current value setFilter(value); + setHighlightIndex(-1); + }; + + const handleBlur = () => { + if (suppressBlurCommitRef.current) { + suppressBlurCommitRef.current = false; + return; + } + // Commit typed custom model ID when input loses focus (e.g. clicking Save/Test) + const trimmed = filter.trim(); + if (trimmed && trimmed !== value) { + onChange(trimmed); + } }; const handleKeyDown = (e: React.KeyboardEvent) => { if (e.key === "Escape") { + suppressBlurCommitRef.current = true; setOpen(false); setFilter(""); + setHighlightIndex(-1); + inputRef.current?.blur(); + return; + } + if (e.key === "ArrowDown") { + e.preventDefault(); + if (!open) { + setOpen(true); + setFilter(value); + } + setHighlightIndex((prev) => + prev < flatList.length - 1 ? prev + 1 : 0, + ); + return; + } + if (e.key === "ArrowUp") { + e.preventDefault(); + setHighlightIndex((prev) => + prev > 0 ? prev - 1 : flatList.length - 1, + ); + return; + } + if (e.key === "Enter") { + e.preventDefault(); + suppressBlurCommitRef.current = true; + if (highlightIndex >= 0 && highlightIndex < flatList.length) { + handleSelect(flatList[highlightIndex].id); + } else if (filter.trim()) { + // Commit raw model ID typed by the user + handleSelect(filter.trim()); + } inputRef.current?.blur(); } }; - const providerOrder = [ - "openrouter", - "kilo", - "anthropic", - "openai", - "openai-chatgpt", - "github-copilot", - "ollama", - "deepseek", - "xai", - "mistral", - "gemini", - "groq", - "together", - "fireworks", - "zhipu", - "opencode-zen", - "opencode-go", - "minimax", - "minimax-cn", - ]; - const providerRank = (provider: string) => { - const index = providerOrder.indexOf(provider); - return index === -1 ? Number.MAX_SAFE_INTEGER : index; - }; - const sortedProviders = Object.keys(grouped).sort( - (a, b) => providerRank(a) - providerRank(b), - ); - return (
@@ -169,64 +252,149 @@ export function ModelSelect({ value={open ? filter : value} onChange={handleInputChange} onFocus={handleFocus} + onBlur={handleBlur} onKeyDown={handleKeyDown} - placeholder="Type to search models..." + placeholder="Search models..." className="border-app-line/50 bg-app-darkBox/30" + icon={ + open ? ( + + ) : undefined + } + right={ + + } /> - {open && filtered.length > 0 && ( -
- {sortedProviders.map((provider) => ( -
-
- {PROVIDER_LABELS[provider] ?? provider} + {/* Selected model badge (shown when closed and a known model is selected) */} + {!open && selectedModel && selectedModel.id === value && ( +
+ + {PROVIDER_LABELS[selectedModel.provider] ?? selectedModel.provider} + + / + {selectedModel.name} + {selectedModel.context_window && ( + <> + · + + {formatContextWindow(selectedModel.context_window)} + + + )} +
+ )} + {open && ( +
+ {isLoading ? ( +
+
+
+ Loading models...
- {grouped[provider].map((model) => ( - - ))}
- ))} + ) : filtered.length === 0 ? ( +
+ {isError + ? "Failed to load models — check your connection" + : models.length === 0 + ? "No models available — configure a provider first" + : "No models match your search"} +
+ ) : ( + sortedProviders.map((prov) => ( +
+
+ {PROVIDER_LABELS[prov] ?? prov} +
+ {grouped[prov].map((model) => { + const flatIdx = flatList.indexOf(model); + const isHighlighted = flatIdx === highlightIndex; + const isSelected = model.id === value; + return ( + + ); + })} +
+ )) + )}
)}
diff --git a/src/api/models.rs b/src/api/models.rs index f988bf6b6..0e7c155ad 100644 --- a/src/api/models.rs +++ b/src/api/models.rs @@ -118,7 +118,11 @@ fn direct_provider_mapping(models_dev_id: &str) -> Option<&'static str> { "opencode-go" => Some("opencode-go"), "zai-coding-plan" => Some("zai-coding-plan"), "minimax" => Some("minimax"), - "moonshotai" => Some("moonshot"), + "minimax-cn" => Some("minimax-cn"), + "moonshotai" | "moonshotai-cn" => Some("moonshot"), + "nvidia" => Some("nvidia"), + "ollama" | "ollama-cloud" => Some("ollama"), + "github-copilot" => Some("github-copilot"), _ => None, } } @@ -144,32 +148,6 @@ fn as_openai_chatgpt_model(model: &ModelInfo) -> Option { }) } -/// Models from providers not in models.dev (private/custom endpoints). -fn extra_models() -> Vec { - vec![ - // MiniMax CN - China-specific endpoint, not on models.dev - ModelInfo { - id: "minimax-cn/MiniMax-M2.5".into(), - name: "MiniMax M2.5".into(), - provider: "minimax-cn".into(), - context_window: Some(200000), - tool_call: true, - reasoning: true, - input_audio: false, - }, - // Moonshot AI (Kimi) - moonshot-v1-8k not on models.dev - ModelInfo { - id: "moonshot/moonshot-v1-8k".into(), - name: "Moonshot V1 8K".into(), - provider: "moonshot".into(), - context_window: Some(8000), - tool_call: false, - reasoning: false, - input_audio: false, - }, - ] -} - /// Fetch the full model catalog from models.dev and transform into ModelInfo entries. async fn fetch_models_dev() -> anyhow::Result> { let client = reqwest::Client::new(); @@ -290,6 +268,14 @@ pub(super) async fn configured_providers(config_path: &std::path::Path) -> Vec<& if has_key("anthropic_key", "ANTHROPIC_API_KEY") { providers.push("anthropic"); } + // Anthropic OAuth stores credentials as a separate JSON file + if !providers.contains(&"anthropic") + && config_path + .parent() + .is_some_and(|instance_dir| crate::auth::credentials_path(instance_dir).exists()) + { + providers.push("anthropic"); + } if has_key("openai_key", "OPENAI_API_KEY") { providers.push("openai"); } @@ -329,12 +315,18 @@ pub(super) async fn configured_providers(config_path: &std::path::Path) -> Vec<& if has_key("gemini_key", "GEMINI_API_KEY") { providers.push("gemini"); } + if has_key("ollama_base_url", "OLLAMA_BASE_URL") || has_key("ollama_key", "OLLAMA_API_KEY") { + providers.push("ollama"); + } if has_key("opencode_zen_key", "OPENCODE_ZEN_API_KEY") { providers.push("opencode-zen"); } if has_key("opencode_go_key", "OPENCODE_GO_API_KEY") { providers.push("opencode-go"); } + if has_key("nvidia_key", "NVIDIA_API_KEY") { + providers.push("nvidia"); + } if has_key("minimax_key", "MINIMAX_API_KEY") { providers.push("minimax"); } @@ -347,6 +339,9 @@ pub(super) async fn configured_providers(config_path: &std::path::Path) -> Vec<& if has_key("zai_coding_plan_key", "ZAI_CODING_PLAN_API_KEY") { providers.push("zai-coding-plan"); } + if has_key("github_copilot_key", "GITHUB_COPILOT_API_KEY") { + providers.push("github-copilot"); + } providers } @@ -418,26 +413,6 @@ pub(super) async fn get_models( models.extend(chatgpt_models); } - for model in extra_models() { - if let Some(capability) = requested_capability { - if capability == "input_audio" && !model.input_audio { - continue; - } - if capability == "voice_transcription" - && (!model.input_audio || !is_known_voice_transcription_model(&model.id)) - { - continue; - } - } - if let Some(provider) = requested_provider { - if model.provider == provider { - models.push(model); - } - } else if configured.contains(&model.provider.as_str()) { - models.push(model); - } - } - Ok(Json(ModelsResponse { models })) } diff --git a/src/api/providers.rs b/src/api/providers.rs index f3c00b00d..0d21d66ef 100644 --- a/src/api/providers.rs +++ b/src/api/providers.rs @@ -349,6 +349,7 @@ pub(super) async fn get_providers( let config_path = state.config_path.read().await.clone(); let instance_dir = (**state.instance_dir.load()).clone(); let secrets_store = state.secrets_store.load(); + let anthropic_oauth_configured = crate::auth::credentials_path(&instance_dir).exists(); let openai_oauth_configured = crate::openai_auth::credentials_path(&instance_dir).exists(); let env_set = |name: &str| { std::env::var(name) @@ -423,7 +424,7 @@ pub(super) async fn get_providers( }; ( - has_value("anthropic_key", "ANTHROPIC_API_KEY"), + has_value("anthropic_key", "ANTHROPIC_API_KEY") || anthropic_oauth_configured, has_value("openai_key", "OPENAI_API_KEY"), openai_oauth_configured, has_value("openrouter_key", "OPENROUTER_API_KEY"), @@ -449,7 +450,7 @@ pub(super) async fn get_providers( ) } else { ( - env_set("ANTHROPIC_API_KEY"), + env_set("ANTHROPIC_API_KEY") || anthropic_oauth_configured, env_set("OPENAI_API_KEY"), openai_oauth_configured, env_set("OPENROUTER_API_KEY"), diff --git a/src/llm/routing.rs b/src/llm/routing.rs index eb677a8c8..b091e982b 100644 --- a/src/llm/routing.rs +++ b/src/llm/routing.rs @@ -391,6 +391,7 @@ pub fn defaults_for_provider(provider: &str) -> RoutingConfig { ..RoutingConfig::default() } } + "ollama" => RoutingConfig::for_model("ollama/llama3".into()), "nvidia" => RoutingConfig::for_model("nvidia/meta/llama-3.1-405b-instruct".into()), "minimax" => RoutingConfig::for_model("minimax/MiniMax-M2.5".into()), "minimax-cn" => RoutingConfig::for_model("minimax-cn/MiniMax-M2.5".into()), @@ -433,6 +434,7 @@ pub fn provider_to_prefix(provider: &str) -> &str { "xai" => "xai/", "mistral" => "mistral/", "gemini" => "gemini/", + "ollama" => "ollama/", "nvidia" => "nvidia/", "opencode-zen" => "opencode-zen/", "opencode-go" => "opencode-go/",