Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion desktop/frontend/src/__tests__/provider-model-refresh.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Run: tsx src/__tests__/provider-model-refresh.test.ts

import { mergedFetchedProviderModels, providerDefaultModel } from "../lib/providerModels";
import { isLikelyChatModel, mergedFetchedProviderModels, providerDefaultModel, providerModelCandidates } from "../lib/providerModels";

let passed = 0;
let failed = 0;
Expand Down Expand Up @@ -41,6 +41,29 @@ eq(
"manual access refresh preserves selected MiMo model instead of importing provider catalog",
);

eq(
providerModelCandidates(["mimo-v2.5-pro"], ["mimo-v2-flash", "mimo-v2-omni", "mimo-v2.5-pro"]),
["mimo-v2.5-pro", "mimo-v2-flash", "mimo-v2-omni"],
"manual access refresh can show provider catalog as unsaved candidates",
);

eq(
providerModelCandidates(["mimo-v2.5-pro"], ["mimo-v2.5-asr", "mimo-v2.5-tts", "mimo-v2.5", "mimo-v2.5-pro"]),
["mimo-v2.5-pro", "mimo-v2.5"],
"manual access refresh filters non-chat candidates before saving",
);

eq(
[
isLikelyChatModel("mimo-v2.5-pro"),
isLikelyChatModel("mimo-v2.5-asr"),
isLikelyChatModel("mimo-v2.5-tts"),
isLikelyChatModel("text-embedding-3-small"),
],
[true, false, false, false],
"matches backend non-chat model heuristic",
);

eq(
mergedFetchedProviderModels([], ["coding-pro", "chat"], { preserveCurated: true }),
["coding-pro", "chat"],
Expand Down
236 changes: 205 additions & 31 deletions desktop/frontend/src/components/SettingsPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { asArray } from "../lib/array";
import { useDeferredClose } from "../lib/useMountTransition";
import { app } from "../lib/bridge";
import { normalizeLangPref, useI18n, useT, type DictKey, type LangPref } from "../lib/i18n";
import { mergedFetchedProviderModels, providerDefaultModel } from "../lib/providerModels";
import { mergedFetchedProviderModels, providerDefaultModel, providerModelCandidates } from "../lib/providerModels";
import { useUpdater } from "../lib/useUpdater";
import {
THEME_STYLES,
Expand Down Expand Up @@ -1913,6 +1913,7 @@ function ProvidersSection({ s, busy, apply }: SectionProps) {
const [adding, setAdding] = useState<AddProviderMode>(null);
const [fetchingProvider, setFetchingProvider] = useState<string | null>(null);
const [fetchResults, setFetchResults] = useState<Record<string, ProviderFetchResult>>({});
const [modelDrafts, setModelDrafts] = useState<Record<string, ProviderModelDraft>>({});
const groups = providerAccessGroups(s.providers.filter((p) => p.added), t);

const setGroupFetchResult = (groupID: string, result: ProviderFetchResult | null) => {
Expand All @@ -1924,35 +1925,67 @@ function ProvidersSection({ s, busy, apply }: SectionProps) {
});
};

const setGroupModelDraft = (groupID: string, draft: ProviderModelDraft | null) => {
setModelDrafts((prev) => {
const next = { ...prev };
if (draft) next[groupID] = draft;
else delete next[groupID];
return next;
});
};

const modelDraftForFetch = (p: ProviderView, fetched: string[]): ProviderModelDraft => {
const candidates = providerModelCandidates(p.models, fetched);
const selected = mergedFetchedProviderModels(p.models, fetched, { preserveCurated: true });
return {
providerName: p.name,
candidates,
selected: candidates.filter((model) => selected.includes(model)),
};
};

const updateModelDraftSelection = (groupID: string, nextSelected: (draft: ProviderModelDraft) => string[]) => {
setModelDrafts((prev) => {
const draft = prev[groupID];
if (!draft) return prev;
const selectedSet = new Set(nextSelected(draft));
return {
...prev,
[groupID]: {
...draft,
selected: draft.candidates.filter((model) => selectedSet.has(model)),
},
};
});
};

const refreshModels = async (group: ProviderAccessGroup, p: ProviderView) => {
setFetchingProvider(group.id);
setGroupFetchResult(group.id, null);
setGroupModelDraft(group.id, null);
try {
await apply(async () => {
let fetched: string[];
try {
fetched = await app.FetchProviderModels(p);
} catch (e) {
setGroupFetchResult(group.id, {
kind: "warn",
text: t("settings.fetchModelsFailedForProvider", { provider: group.label, err: String((e as Error)?.message ?? e) }),
});
return;
}
if (fetched.length === 0) {
setGroupFetchResult(group.id, {
kind: "warn",
text: t("settings.fetchModelsEmptyForProvider", { provider: group.label }),
});
return;
}
const models = mergedFetchedProviderModels(p.models, fetched, { preserveCurated: true });
const currentDefault = providerDefaultModel(p.default, models);
await app.SaveProvider({ ...p, models, default: currentDefault });
let fetched: string[];
try {
fetched = await app.FetchProviderModels(p);
} catch (e) {
setGroupFetchResult(group.id, {
kind: "warn",
text: t("settings.fetchModelsFailedForProvider", { provider: group.label, err: String((e as Error)?.message ?? e) }),
});
return;
}
if (fetched.length === 0) {
setGroupFetchResult(group.id, {
kind: "ok",
text: t("settings.fetchModelsUpdatedForProvider", { provider: group.label, n: models.length }),
kind: "warn",
text: t("settings.fetchModelsEmptyForProvider", { provider: group.label }),
});
return;
}
const draft = modelDraftForFetch(p, fetched);
setGroupModelDraft(group.id, draft);
setGroupFetchResult(group.id, {
kind: "ok",
text: t("settings.fetchModelsReadyForProvider", { provider: group.label, n: draft.candidates.length }),
});
} finally {
setFetchingProvider(null);
Expand All @@ -1970,18 +2003,18 @@ function ProvidersSection({ s, busy, apply }: SectionProps) {
if (!probe || !apiKeyEnv) return;
setFetchingProvider(group.id);
setGroupFetchResult(group.id, null);
setGroupModelDraft(group.id, null);
try {
await apply(async () => {
await app.SetProviderKey(apiKeyEnv, value);
try {
const fetched = await app.FetchProviderModels({ ...probe, apiKeyEnv });
if (fetched.length > 0) {
const models = mergedFetchedProviderModels(probe.models, fetched, { preserveCurated: true });
const currentDefault = providerDefaultModel(probe.default, models);
await app.SaveProvider({ ...probe, apiKeyEnv, models, default: currentDefault });
const draft = modelDraftForFetch({ ...probe, apiKeyEnv }, fetched);
setGroupModelDraft(group.id, draft);
setGroupFetchResult(group.id, {
kind: "ok",
text: t("settings.fetchModelsUpdatedForProvider", { provider: group.label, n: models.length }),
text: t("settings.fetchModelsReadyForProvider", { provider: group.label, n: draft.candidates.length }),
});
return;
}
Expand All @@ -2004,6 +2037,7 @@ function ProvidersSection({ s, busy, apply }: SectionProps) {
const saveProviderKey = async (group: ProviderAccessGroup, apiKeyEnv: string, value: string) => {
if (!apiKeyEnv) return;
setGroupFetchResult(group.id, null);
setGroupModelDraft(group.id, null);
await apply(() => app.SetProviderKey(apiKeyEnv, value));
};

Expand All @@ -2012,6 +2046,24 @@ function ProvidersSection({ s, busy, apply }: SectionProps) {
await apply(() => app.ClearProviderKey(apiKeyEnv));
};

const saveModelDraft = async (group: ProviderAccessGroup) => {
const draft = modelDrafts[group.id];
const provider = draft ? group.providers.find((p) => p.name === draft.providerName) : null;
const models = uniqueStrings(draft?.selected ?? []);
if (!draft || !provider || models.length === 0) return;
let saved = false;
await apply(async () => {
await app.SaveProvider({ ...provider, models, default: providerDefaultModel(provider.default, models) });
saved = true;
});
if (!saved) return;
setGroupModelDraft(group.id, null);
setGroupFetchResult(group.id, {
kind: "ok",
text: t("settings.enabledModelsSavedForProvider", { provider: group.label, n: models.length }),
});
};

return (
<SettingsSection
title={t("settings.providerAccess")}
Expand Down Expand Up @@ -2055,13 +2107,26 @@ function ProvidersSection({ s, busy, apply }: SectionProps) {
busy={busy}
fetching={fetchingProvider === group.id || group.providers.some((p) => fetchingProvider === p.name)}
fetchResult={fetchResults[group.id]}
modelDraft={modelDrafts[group.id]}
defaultProvider={defaultProvider}
editing={editing}
kinds={s.providerKinds}
onEdit={setEditing}
onCancelEdit={() => setEditing(null)}
onSave={(pv) => apply(() => app.SaveProvider(pv)).then(() => setEditing(null))}
onSave={(pv) => apply(() => app.SaveProvider(pv)).then(() => {
setEditing(null);
setGroupModelDraft(group.id, null);
})}
onRefresh={() => void refreshGroup(group)}
onToggleDraftModel={(model) => updateModelDraftSelection(group.id, (draft) => (
draft.selected.includes(model)
? draft.selected.filter((candidate) => candidate !== model)
: [...draft.selected, model]
))}
onSelectAllDraftModels={() => updateModelDraftSelection(group.id, (draft) => draft.candidates)}
onClearDraftModels={() => updateModelDraftSelection(group.id, () => [])}
onCancelDraftModels={() => setGroupModelDraft(group.id, null)}
onSaveDraftModels={() => void saveModelDraft(group)}
onSaveEditorKey={(env, value) => group.builtIn ? saveProviderKey(group, env, value) : saveKeyEnvAndAutoRefresh(group, env, value)}
onClearEditorKey={clearProviderKey}
onDelete={(p) => apply(() => app.RemoveProviderAccess(p.name))}
Expand Down Expand Up @@ -2090,6 +2155,12 @@ type ProviderFetchResult = {
text: string;
};

type ProviderModelDraft = {
providerName: string;
candidates: string[];
selected: string[];
};

type AddProviderMode = null | "official" | "custom";
type OfficialProviderKind = "deepseek" | "mimo-api" | "mimo-token-plan";

Expand Down Expand Up @@ -2226,13 +2297,19 @@ function ProviderAccessCard({
busy,
fetching,
fetchResult,
modelDraft,
defaultProvider,
editing,
kinds,
onEdit,
onCancelEdit,
onSave,
onRefresh,
onToggleDraftModel,
onSelectAllDraftModels,
onClearDraftModels,
onCancelDraftModels,
onSaveDraftModels,
onSaveEditorKey,
onClearEditorKey,
onDelete,
Expand All @@ -2241,13 +2318,19 @@ function ProviderAccessCard({
busy: boolean;
fetching: boolean;
fetchResult?: ProviderFetchResult;
modelDraft?: ProviderModelDraft;
defaultProvider: string;
editing: string | null;
kinds: string[];
onEdit: (name: string) => void;
onCancelEdit: () => void;
onSave: (p: ProviderView) => void | Promise<void>;
onRefresh: () => void;
onToggleDraftModel: (model: string) => void;
onSelectAllDraftModels: () => void;
onClearDraftModels: () => void;
onCancelDraftModels: () => void;
onSaveDraftModels: () => void;
onSaveEditorKey: (apiKeyEnv: string, value: string) => Promise<void>;
onClearEditorKey?: (apiKeyEnv: string) => Promise<void>;
onDelete?: (p: ProviderView) => Promise<void>;
Expand Down Expand Up @@ -2318,8 +2401,8 @@ function ProviderAccessCard({
</div>

<div className="provider-card-block">
<div className="provider-card-block__label">{t(group.keySet ? "settings.availableModels" : "settings.modelList")}</div>
<div className="provider-model-chips" aria-label={t(group.keySet ? "settings.availableModels" : "settings.modelList")}>
<div className="provider-card-block__label">{t(group.keySet ? "settings.enabledModels" : "settings.modelList")}</div>
<div className="provider-model-chips" aria-label={t(group.keySet ? "settings.enabledModels" : "settings.modelList")}>
{visibleModels.length > 0 ? visibleModels.map((model) => (
<span className="provider-model-chip" key={model}>
{model}
Expand All @@ -2343,6 +2426,19 @@ function ProviderAccessCard({
)}
</div>

{modelDraft && (
<ProviderModelDraftPicker
draft={modelDraft}
busy={busy}
fetching={fetching}
onToggle={onToggleDraftModel}
onSelectAll={onSelectAllDraftModels}
onClear={onClearDraftModels}
onCancel={onCancelDraftModels}
onSave={onSaveDraftModels}
/>
)}

{group.providers.length > 1 && (
<div className="provider-profiles">
{group.providers.map((p) => {
Expand Down Expand Up @@ -2380,6 +2476,84 @@ function ProviderAccessCard({
);
}

function ProviderModelDraftPicker({
draft,
busy,
fetching,
onToggle,
onSelectAll,
onClear,
onCancel,
onSave,
}: {
draft: ProviderModelDraft;
busy: boolean;
fetching: boolean;
onToggle: (model: string) => void;
onSelectAll: () => void;
onClear: () => void;
onCancel: () => void;
onSave: () => void;
}) {
const t = useT();
const [query, setQuery] = useState("");
const selected = new Set(draft.selected);
const q = query.trim().toLowerCase();
const visibleCandidates = q
? draft.candidates.filter((model) => model.toLowerCase().includes(q))
: draft.candidates;
const disabled = busy || fetching;

return (
<div className="provider-model-draft">
<div className="provider-model-draft__head">
<div>
<div className="provider-card-block__label">{t("settings.modelCandidates")}</div>
<span>{t("settings.modelCandidatesSelected", { n: draft.selected.length })}</span>
</div>
<div className="provider-model-draft__tools">
<button type="button" className="btn btn--small" disabled={disabled || draft.selected.length === draft.candidates.length} onClick={onSelectAll}>
{t("settings.selectAllModels")}
</button>
<button type="button" className="btn btn--small" disabled={disabled || draft.selected.length === 0} onClick={onClear}>
{t("settings.clearModelSelection")}
</button>
</div>
</div>
<input
className="mem-input provider-model-draft__search"
placeholder={t("settings.modelCandidateSearch")}
value={query}
disabled={disabled}
onChange={(e) => setQuery(e.target.value)}
/>
<div className="provider-model-draft__list" role="list" aria-label={t("settings.modelCandidates")}>
{visibleCandidates.length > 0 ? visibleCandidates.map((model) => (
<label className="provider-model-draft__option" key={model}>
<input
type="checkbox"
checked={selected.has(model)}
disabled={disabled}
onChange={() => onToggle(model)}
/>
<span>{model}</span>
</label>
)) : (
<div className="provider-model-draft__empty">{t("settings.noMatchingCandidateModels")}</div>
)}
</div>
<div className="provider-model-draft__actions">
<button type="button" className="btn btn--small" disabled={disabled} onClick={onCancel}>
{t("common.cancel")}
</button>
<button type="button" className="btn btn--primary btn--small" disabled={disabled || draft.selected.length === 0} onClick={onSave}>
{t("settings.saveEnabledModels")}
</button>
</div>
</div>
);
}

function providerAccessGroups(providers: ProviderView[], t: ReturnType<typeof useT>): ProviderAccessGroup[] {
const groups = new Map<string, ProviderAccessGroup>();
for (const p of providers) {
Expand Down
Loading
Loading