|
| 1 | +import { Effect } from "effect" |
| 2 | +import { Config } from "@/config/config" |
| 3 | +import { Provider } from "@/provider/provider" |
| 4 | +import { ProviderV2 } from "@opencode-ai/core/provider" |
| 5 | +import { ModelV2 } from "@opencode-ai/core/model" |
| 6 | +import type { SessionV1 } from "@opencode-ai/core/v1/session" |
| 7 | +import { isSafeAllowlisted } from "./allowlist" |
| 8 | +import { resolvePolicy } from "./prompt" |
| 9 | +import { ownModelProvider } from "./provider/own-model" |
| 10 | +import { buildTranscript, projectToolInput } from "./transcript" |
| 11 | +import type { ClassifierDecision } from "./types" |
| 12 | + |
| 13 | +const ALLOW: ClassifierDecision = { kind: "allow" } |
| 14 | +const ask = (reason: string): ClassifierDecision => ({ kind: "ask", reason }) |
| 15 | +const block = (reason: string): ClassifierDecision => ({ kind: "block", reason }) |
| 16 | + |
| 17 | +// Escalation backstop: too many denials in one turn → escalate to the human. |
| 18 | +const MAX_CONSECUTIVE_DENIALS = 3 |
| 19 | +const MAX_TOTAL_DENIALS = 20 |
| 20 | + |
| 21 | +/** |
| 22 | + * Per-session denial counters. Reset when the latest user message changes |
| 23 | + * (i.e. on a new user turn). Keyed by sessionID. |
| 24 | + */ |
| 25 | +const counters = new Map<string, { lastUser: string; consecutive: number; total: number }>() |
| 26 | + |
| 27 | +function lastUserId(messages: SessionV1.WithParts[]): string { |
| 28 | + for (let i = messages.length - 1; i >= 0; i--) { |
| 29 | + if (messages[i]!.info.role === "user") return messages[i]!.info.id |
| 30 | + } |
| 31 | + return "" |
| 32 | +} |
| 33 | + |
| 34 | +function parseModel(s: string): [ProviderV2.ID, ModelV2.ID] { |
| 35 | + const i = s.indexOf("/") |
| 36 | + return i === -1 |
| 37 | + ? [ProviderV2.ID.make(s), ModelV2.ID.make(s)] |
| 38 | + : [ProviderV2.ID.make(s.slice(0, i)), ModelV2.ID.make(s.slice(i + 1))] |
| 39 | +} |
| 40 | + |
| 41 | +/** |
| 42 | + * Decide whether a would-auto-approve tool call should proceed, be blocked |
| 43 | + * (deny-and-continue), or be escalated to the human (`ask`). |
| 44 | + * |
| 45 | + * Returns `undefined` when the classifier is disabled or the tool is on the |
| 46 | + * safe allowlist — the caller then proceeds exactly as today (no gating). |
| 47 | + * |
| 48 | + * Fails CLOSED: any backend error / unparseable response → `ask`. |
| 49 | + * |
| 50 | + * Requires `Config` + `Provider`; the call site runs this through the request |
| 51 | + * EffectBridge so the captured context provides them (the thunk stays R=never). |
| 52 | + */ |
| 53 | +export const evaluate = Effect.fn("Classifier.evaluate")(function* (input: { |
| 54 | + tool: string |
| 55 | + toolInput: unknown |
| 56 | + messages: SessionV1.WithParts[] |
| 57 | + fallbackModel: Provider.Model |
| 58 | + sessionID: string |
| 59 | + abort: AbortSignal |
| 60 | +}) { |
| 61 | + const cfg = (yield* (yield* Config.Service).get()).classifier |
| 62 | + if (!cfg?.enabled) return undefined |
| 63 | + if (isSafeAllowlisted(input.tool)) return undefined |
| 64 | + |
| 65 | + const backend = cfg.backend ?? "own" |
| 66 | + if (backend !== "own") { |
| 67 | + // og-local / og-saas land in a later step. Until then, fail closed. |
| 68 | + return ask(`classifier backend '${backend}' is not implemented yet`) |
| 69 | + } |
| 70 | + |
| 71 | + const provider = yield* Provider.Service |
| 72 | + |
| 73 | + // Counter state, reset on a new user turn. |
| 74 | + const sid = input.sessionID |
| 75 | + const lu = lastUserId(input.messages) |
| 76 | + const c = counters.get(sid) ?? { lastUser: lu, consecutive: 0, total: 0 } |
| 77 | + if (c.lastUser !== lu) { |
| 78 | + c.lastUser = lu |
| 79 | + c.consecutive = 0 |
| 80 | + c.total = 0 |
| 81 | + } |
| 82 | + |
| 83 | + const policy = resolvePolicy(cfg) |
| 84 | + const verdict = yield* Effect.gen(function* () { |
| 85 | + let model: Provider.Model |
| 86 | + if (cfg.model) { |
| 87 | + const [providerID, modelID] = parseModel(cfg.model) |
| 88 | + model = yield* provider.getModel(providerID, modelID) |
| 89 | + } else { |
| 90 | + model = input.fallbackModel |
| 91 | + } |
| 92 | + const language = yield* provider.getLanguage(model) |
| 93 | + const classifier = ownModelProvider(language, `${model.providerID}/${model.id}`) |
| 94 | + const action = { tool: input.tool, input: projectToolInput(input.tool, input.toolInput) } |
| 95 | + return yield* Effect.promise(() => |
| 96 | + classifier.classify({ transcript: buildTranscript(input.messages), action, policy }, input.abort), |
| 97 | + ) |
| 98 | + }).pipe( |
| 99 | + Effect.catch((e) => |
| 100 | + Effect.succeed({ |
| 101 | + shouldBlock: true, |
| 102 | + unavailable: true, |
| 103 | + reason: e instanceof Error ? e.message : String(e), |
| 104 | + model: "own", |
| 105 | + }), |
| 106 | + ), |
| 107 | + ) |
| 108 | + |
| 109 | + if (verdict.unavailable) { |
| 110 | + counters.set(sid, c) |
| 111 | + return ask(verdict.reason ?? "classifier unavailable") |
| 112 | + } |
| 113 | + if (verdict.shouldBlock) { |
| 114 | + c.consecutive += 1 |
| 115 | + c.total += 1 |
| 116 | + counters.set(sid, c) |
| 117 | + if (c.consecutive >= MAX_CONSECUTIVE_DENIALS || c.total >= MAX_TOTAL_DENIALS) { |
| 118 | + return ask("Repeated classifier denials this turn — escalating to you for review.") |
| 119 | + } |
| 120 | + return block(verdict.reason ?? "blocked by the command-approval classifier") |
| 121 | + } |
| 122 | + c.consecutive = 0 |
| 123 | + counters.set(sid, c) |
| 124 | + return ALLOW |
| 125 | +}) |
| 126 | + |
| 127 | +export * as Classifier from "./index" |
0 commit comments