diff --git a/src/acp-agent.ts b/src/acp-agent.ts index 86f82e6..8e87062 100644 --- a/src/acp-agent.ts +++ b/src/acp-agent.ts @@ -112,7 +112,9 @@ type AccumulatedUsage = { type Session = { query: Query; input: Pushable; - cancelled: boolean; + /** Bumped by cancel(); prompt() captures the value at entry so a later + * prompt() call cannot clobber the cancellation signal. */ + cancelGeneration: number; cwd: string; /** Serialized snapshot of session-defining params (cwd, mcpServers) used to * detect when loadSession/resumeSession is called with changed values. */ @@ -545,7 +547,9 @@ export class ClaudeAcpAgent implements Agent { throw new Error("Session not found"); } - session.cancelled = false; + const cancelGeneration = session.cancelGeneration; + const isCancelled = () => session.cancelGeneration !== cancelGeneration; + session.accumulatedUsage = { inputTokens: 0, outputTokens: 0, @@ -591,7 +595,7 @@ export class ClaudeAcpAgent implements Agent { const { value: message, done } = await session.query.next(); if (done || !message) { - if (session.cancelled) { + if (isCancelled()) { return { stopReason: "cancelled" }; } break; @@ -716,7 +720,7 @@ export class ClaudeAcpAgent implements Agent { }); } - if (session.cancelled) { + if (isCancelled()) { stopReason = "cancelled"; break; } @@ -798,7 +802,7 @@ export class ClaudeAcpAgent implements Agent { } case "user": case "assistant": { - if (session.cancelled) { + if (isCancelled()) { break; } @@ -970,7 +974,7 @@ export class ClaudeAcpAgent implements Agent { if (!session) { throw new Error("Session not found"); } - session.cancelled = true; + session.cancelGeneration++; for (const [, pending] of session.pendingMessages) { pending.resolve(true); } @@ -1628,7 +1632,7 @@ export class ClaudeAcpAgent implements Agent { this.sessions[sessionId] = { query: q, input: input, - cancelled: false, + cancelGeneration: 0, cwd: params.cwd, sessionFingerprint: computeSessionFingerprint(params), settingsManager, diff --git a/src/tests/acp-agent.test.ts b/src/tests/acp-agent.test.ts index 1f9d43b..b49c05b 100644 --- a/src/tests/acp-agent.test.ts +++ b/src/tests/acp-agent.test.ts @@ -1,5 +1,7 @@ import { describe, it, expect, beforeAll, afterAll, vi } from "vitest"; import { spawn, spawnSync } from "child_process"; +import { readFileSync } from "fs"; +import { fileURLToPath } from "url"; import { Agent, AgentSideConnection, @@ -1326,7 +1328,7 @@ describe("stop reason propagation", () => { agent.sessions["test-session"] = { query: messageGenerator() as any, input, - cancelled: false, + cancelGeneration: 0, cwd: "/test", sessionFingerprint: JSON.stringify({ cwd: "/test", mcpServers: [] }), modes: { @@ -1469,7 +1471,7 @@ describe("stop reason propagation", () => { input, cwd: "/tmp/test", sessionFingerprint: JSON.stringify({ cwd: "/tmp/test", mcpServers: [] }), - cancelled: false, + cancelGeneration: 0, modes: { currentModeId: "default", availableModes: [], @@ -1508,6 +1510,140 @@ describe("stop reason propagation", () => { ); }); + it("does not let a stale idle from a cancelled turn complete the next prompt", async () => { + class RecordingClient { + receivedText = ""; + + async sessionUpdate(params: SessionNotification): Promise { + if ( + params.update.sessionUpdate === "agent_message_chunk" && + params.update.content.type === "text" + ) { + this.receivedText += params.update.content.text; + } + } + } + + const fakeCliPath = fileURLToPath( + new URL("./fixtures/fake-claude-cli-cancel-stale-idle.mjs", import.meta.url), + ); + const fakeCliLogPath = `/tmp/fake-claude-cli-${randomUUID()}.log`; + const transportLog: string[] = []; + const client = new RecordingClient(); + const agent = new ClaudeAcpAgent(client as unknown as AgentSideConnection, { + log: () => {}, + error: () => {}, + }); + + function debugContext() { + let fakeCliLog = ""; + try { + fakeCliLog = readFileSync(fakeCliLogPath, "utf8"); + } catch { + fakeCliLog = ""; + } + + return `transport log:\n${transportLog.join("\n") || ""}\n\nfake cli log:\n${fakeCliLog}`; + } + + async function withTimeout( + label: string, + promise: Promise, + timeoutMs = 2000, + ): Promise { + return await Promise.race([ + promise, + new Promise((_, reject) => { + setTimeout(() => { + reject(new Error(`${label} timed out\n\n${debugContext()}`)); + }, timeoutMs); + }), + ]); + } + + await agent.initialize({ + protocolVersion: 1, + clientCapabilities: { + fs: { + readTextFile: true, + writeTextFile: true, + }, + }, + }); + + const newSessionResponse = await withTimeout( + "newSession", + agent.newSession({ + cwd: __dirname, + mcpServers: [], + _meta: { + claudeCode: { + options: { + pathToClaudeCodeExecutable: fakeCliPath, + env: { + FAKE_CLAUDE_CLI_LOG_PATH: fakeCliLogPath, + }, + spawnClaudeCodeProcess: (options: any) => { + transportLog.push( + `spawn ${options.command} ${Array.isArray(options.args) ? options.args.join(" ") : ""}`, + ); + const child = spawn(options.command, options.args, { + cwd: options.cwd, + env: options.env, + signal: options.signal, + stdio: ["pipe", "pipe", "pipe"], + }); + child.stdout.on("data", (chunk) => { + transportLog.push(`stdout ${chunk.toString("utf8").trimEnd()}`); + }); + child.stderr.on("data", (chunk) => { + transportLog.push(`stderr ${chunk.toString("utf8").trimEnd()}`); + }); + + const originalWrite = child.stdin.write.bind(child.stdin); + child.stdin.write = ((chunk: any, ...args: any[]) => { + transportLog.push(`stdin ${Buffer.from(chunk).toString("utf8").trimEnd()}`); + return originalWrite(chunk, ...args); + }) as any; + + return child as any; + }, + }, + }, + }, + }), + ); + + try { + const firstPrompt = agent.prompt({ + prompt: [{ type: "text", text: "first prompt" }], + sessionId: newSessionResponse.sessionId, + }); + + await withTimeout("cancel", agent.cancel({ sessionId: newSessionResponse.sessionId })); + + const secondResponse = await withTimeout( + "second prompt", + agent.prompt({ + prompt: [{ type: "text", text: "second prompt" }], + sessionId: newSessionResponse.sessionId, + }), + ); + const firstResponse = await withTimeout("first prompt", firstPrompt); + + expect(firstResponse.stopReason).toBe("cancelled"); + expect(secondResponse.stopReason).toBe("end_turn"); + // The first prompt now exits early (correctly returning "cancelled" + // without consuming the interrupted turn's result message). The second + // prompt's loop picks up that leftover result, so its totalTokens + // includes both the cancelled turn's usage and its own. + expect(secondResponse.usage?.totalTokens).toBe(26); + expect(client.receivedText).toContain("actual second prompt output"); + } finally { + await agent.unstable_closeSession({ sessionId: newSessionResponse.sessionId }); + } + }, 10000); + it("should throw internal error for success with is_error true and no max_tokens", async () => { const agent = createMockAgent(); injectSession(agent, [ @@ -1542,7 +1678,7 @@ describe("session/close", () => { agent.sessions[sessionId] = { query: gen as any, input: new Pushable(), - cancelled: false, + cancelGeneration: 0, cwd: "/test", sessionFingerprint: JSON.stringify({ cwd: "/test", mcpServers: [] }), modes: { @@ -1638,7 +1774,7 @@ describe("getOrCreateSession param change detection", () => { agent.sessions[sessionId] = { query: gen as any, input: new Pushable(), - cancelled: false, + cancelGeneration: 0, cwd, sessionFingerprint: JSON.stringify({ cwd, @@ -1850,7 +1986,7 @@ describe("usage_update computation", () => { agent.sessions["test-session"] = { query: messageGenerator() as any, input, - cancelled: false, + cancelGeneration: 0, cwd: "/test", sessionFingerprint: JSON.stringify({ cwd: "/test", mcpServers: [] }), modes: { diff --git a/src/tests/cancel-idle-drain.test.ts b/src/tests/cancel-idle-drain.test.ts new file mode 100644 index 0000000..5f4b92d --- /dev/null +++ b/src/tests/cancel-idle-drain.test.ts @@ -0,0 +1,238 @@ +// Regression: first prompt after cancel returns 0-token end_turn because the +// shared `cancelled` boolean is clobbered by a concurrent prompt() call. + +import { describe, it, expect, vi, beforeEach } from "vitest"; +import type { AgentSideConnection } from "@agentclientprotocol/sdk"; +import type { + Query, + SDKMessage, + SDKResultSuccess, + SDKSessionStateChangedMessage, + SDKUserMessage, +} from "@anthropic-ai/claude-agent-sdk"; +import { ClaudeAcpAgent } from "../acp-agent.js"; +import { Pushable } from "../utils.js"; + +const SESSION_ID = "test-session"; +const ZERO_USAGE = { + input_tokens: 0, + output_tokens: 0, + cache_read_input_tokens: 0, + cache_creation_input_tokens: 0, +}; + +function makeResultMessage(overrides?: Record): SDKResultSuccess { + return { + type: "result", + subtype: "success", + duration_ms: 100, + duration_api_ms: 80, + is_error: false, + num_turns: 1, + result: "", + stop_reason: "end_turn", + total_cost_usd: 0, + usage: ZERO_USAGE, + modelUsage: [], + session_id: SESSION_ID, + ...overrides, + } as SDKResultSuccess; +} + +function makeIdleMessage(): SDKSessionStateChangedMessage { + return { + type: "system", + subtype: "session_state_changed", + state: "idle", + uuid: "idle-uuid", + session_id: SESSION_ID, + }; +} + +function createMockClient() { + return { + sessionUpdate: vi.fn().mockResolvedValue(undefined), + } as unknown as AgentSideConnection; +} + +function injectSession( + agent: ClaudeAcpAgent, + sessionId: string, + query: Query, + input: Pushable, +) { + (agent.sessions as any)[sessionId] = { + query, + input, + cancelGeneration: 0, + cwd: "/tmp", + sessionFingerprint: "test", + settingsManager: { dispose: vi.fn() }, + accumulatedUsage: { + inputTokens: 0, + outputTokens: 0, + cachedReadTokens: 0, + cachedWriteTokens: 0, + }, + modes: { currentModeId: "default", availableModes: [] }, + models: { currentModelId: "default", availableModels: [] }, + configOptions: [], + promptRunning: false, + pendingMessages: new Map(), + nextPendingOrder: 0, + abortController: new AbortController(), + }; +} + +const queryStubs = { + setPermissionMode: vi.fn(), + setModel: vi.fn(), + setMaxThinkingTokens: vi.fn(), + applyFlagSettings: vi.fn(), + initializationResult: vi.fn(), + supportedCommands: vi.fn(), + supportedModels: vi.fn(), + supportedAgents: vi.fn(), + mcpServerStatus: vi.fn(), + getContextUsage: vi.fn(), + reloadPlugins: vi.fn(), + accountInfo: vi.fn(), + rewindFiles: vi.fn(), + seedReadState: vi.fn(), + reconnectMcpServer: vi.fn(), + toggleMcpServer: vi.fn(), + setMcpServers: vi.fn(), + streamInput: vi.fn(), + stopTask: vi.fn(), + close: vi.fn(), +}; + +function createMockQuery(): { + query: Query; + feed: Pushable; +} { + const feed = new Pushable(); + const iterator = feed[Symbol.asyncIterator](); + + const query = { + next: () => iterator.next() as Promise>, + return: (v: any) => Promise.resolve({ value: v, done: true as const }), + throw: (e: any) => Promise.reject(e), + [Symbol.asyncIterator]() { + return this; + }, + interrupt: vi.fn(async () => {}), + ...queryStubs, + } as unknown as Query; + + return { query, feed }; +} + +describe("cancel → prompt sequencing", () => { + let client: AgentSideConnection; + let agent: ClaudeAcpAgent; + + beforeEach(() => { + client = createMockClient(); + agent = new ClaudeAcpAgent(client); + }); + + it("normal cancel: idle consumed by first prompt, second prompt works", async () => { + const { query, feed } = createMockQuery(); + const input = new Pushable(); + injectSession(agent, SESSION_ID, query, input); + + // 1. First prompt starts, enters loop, blocks on next(). + const p1 = agent.prompt({ + sessionId: SESSION_ID, + prompt: [{ type: "text", text: "sleep 30" }], + }); + await new Promise((r) => setTimeout(r, 5)); + + // 2. Cancel. + const cancelP = agent.cancel({ sessionId: SESSION_ID }); + + // SDK yields result then idle (normal post-interrupt sequence). + feed.push(makeResultMessage()); + feed.push(makeIdleMessage()); + + await cancelP; + const r1 = await p1; + expect(r1.stopReason).toBe("cancelled"); + + // 3. Second prompt. Feed a real turn's messages. + const p2 = agent.prompt({ + sessionId: SESSION_ID, + prompt: [{ type: "text", text: "hello?" }], + }); + await new Promise((r) => setTimeout(r, 5)); + + feed.push( + makeResultMessage({ + usage: { + input_tokens: 100, + output_tokens: 50, + cache_read_input_tokens: 0, + cache_creation_input_tokens: 0, + }, + }), + ); + feed.push(makeIdleMessage()); + + const r2 = await p2; + expect(r2.stopReason).toBe("end_turn"); + + const session = (agent.sessions as any)[SESSION_ID]; + expect(session.accumulatedUsage.inputTokens).toBe(100); + }); + + // Race: cancel() returns → second prompt() resets cancelled=false → + // first prompt's loop sees cancelled=false → returns "end_turn" instead of "cancelled". + it("race: second prompt resets cancelled before first prompt checks it", async () => { + const { query, feed } = createMockQuery(); + const input = new Pushable(); + injectSession(agent, SESSION_ID, query, input); + + // 1. First prompt starts. + const p1 = agent.prompt({ + sessionId: SESSION_ID, + prompt: [{ type: "text", text: "sleep 30" }], + }); + await new Promise((r) => setTimeout(r, 5)); + + // 2. Cancel resolves before the feed has any messages. + await agent.cancel({ sessionId: SESSION_ID }); + + // 3. Second prompt arrives before first prompt processes any message. + const p2 = agent.prompt({ + sessionId: SESSION_ID, + prompt: [{ type: "text", text: "hello?" }], + }); + + // Feed the SDK cleanup messages for the cancelled turn. + feed.push(makeResultMessage()); + feed.push(makeIdleMessage()); + + const r1 = await p1; + expect(r1.stopReason).toBe("cancelled"); + + // Second prompt enters the loop after the first prompt's finally block. + await new Promise((r) => setTimeout(r, 5)); + feed.push( + makeResultMessage({ + usage: { + input_tokens: 200, + output_tokens: 100, + cache_read_input_tokens: 0, + cache_creation_input_tokens: 0, + }, + }), + ); + feed.push(makeIdleMessage()); + + const r2 = await p2; + expect(r2.stopReason).toBe("end_turn"); + const session = (agent.sessions as any)[SESSION_ID]; + expect(session.accumulatedUsage.inputTokens).toBe(200); + }); +}); diff --git a/src/tests/fixtures/fake-claude-cli-cancel-stale-idle.mjs b/src/tests/fixtures/fake-claude-cli-cancel-stale-idle.mjs new file mode 100644 index 0000000..42d7e3b --- /dev/null +++ b/src/tests/fixtures/fake-claude-cli-cancel-stale-idle.mjs @@ -0,0 +1,205 @@ +#!/usr/bin/env node + +import { randomUUID } from "node:crypto"; +import { appendFileSync, readSync, writeSync } from "node:fs"; + +const sessionId = "fake-session"; +const logPath = process.env.FAKE_CLAUDE_CLI_LOG_PATH; + +let firstUser = null; +let secondUser = null; +let interrupted = false; +let scriptedTransferSent = false; + +function log(message, payload) { + if (!logPath) { + return; + } + + appendFileSync( + logPath, + `${new Date().toISOString()} ${message}${payload ? ` ${JSON.stringify(payload)}` : ""}\n`, + ); +} + +log("started", { argv: process.argv.slice(2) }); + +function send(message) { + log("send", message); + writeSync(1, `${JSON.stringify(message)}\n`); +} + +function controlSuccess(requestId, response = {}) { + send({ + type: "control_response", + response: { + subtype: "success", + request_id: requestId, + response, + }, + }); +} + +function replayUser(userMessage) { + return { + type: "user", + message: userMessage.message, + parent_tool_use_id: null, + uuid: userMessage.uuid, + session_id: sessionId, + isReplay: true, + }; +} + +function buildResult({ + stopReason = null, + inputTokens = 0, + outputTokens = 0, + cachedReadTokens = 0, + cachedWriteTokens = 0, +}) { + return { + type: "result", + subtype: "success", + is_error: false, + duration_ms: 0, + duration_api_ms: 0, + num_turns: 1, + result: "", + stop_reason: stopReason, + total_cost_usd: 0, + usage: { + input_tokens: inputTokens, + output_tokens: outputTokens, + cache_read_input_tokens: cachedReadTokens, + cache_creation_input_tokens: cachedWriteTokens, + }, + modelUsage: { + default: { + inputTokens, + outputTokens, + cacheReadInputTokens: cachedReadTokens, + cacheCreationInputTokens: cachedWriteTokens, + webSearchRequests: 0, + costUSD: 0, + contextWindow: 200000, + maxOutputTokens: 8192, + }, + }, + permission_denials: [], + uuid: randomUUID(), + session_id: sessionId, + }; +} + +function maybeEmitInterruptedTurnTransfer() { + if (!interrupted || !secondUser || scriptedTransferSent) { + return; + } + + scriptedTransferSent = true; + log("emit-transfer"); + + send(buildResult({ inputTokens: 11, outputTokens: 7 })); + send(replayUser(secondUser)); + send({ + type: "system", + subtype: "session_state_changed", + state: "idle", + uuid: randomUUID(), + session_id: sessionId, + }); + + send({ + type: "system", + subtype: "local_command_output", + content: "actual second prompt output", + uuid: randomUUID(), + session_id: sessionId, + }); + send(buildResult({ inputTokens: 3, outputTokens: 5 })); + send({ + type: "system", + subtype: "session_state_changed", + state: "idle", + uuid: randomUUID(), + session_id: sessionId, + }); +} + +function handleMessage(message) { + log("recv", message); + + if (message.type === "control_request") { + switch (message.request.subtype) { + case "initialize": + controlSuccess(message.request_id, { + commands: [], + agents: [], + output_style: "default", + available_output_styles: ["default"], + models: [{ value: "default", displayName: "Default", description: "Fake model" }], + account: {}, + }); + send({ + type: "system", + subtype: "init", + uuid: randomUUID(), + session_id: sessionId, + }); + return; + case "interrupt": + interrupted = true; + controlSuccess(message.request_id, {}); + maybeEmitInterruptedTurnTransfer(); + return; + default: + controlSuccess(message.request_id, {}); + return; + } + } + + if (message.type !== "user") { + return; + } + + if (!firstUser) { + firstUser = message; + send(replayUser(message)); + return; + } + + if (!secondUser) { + secondUser = message; + maybeEmitInterruptedTurnTransfer(); + } +} + +const chunk = Buffer.alloc(4096); +let buffered = ""; + +while (true) { + const bytesRead = readSync(0, chunk, 0, chunk.length, null); + if (bytesRead === 0) { + log("stdin-ended"); + break; + } + + buffered += chunk.toString("utf8", 0, bytesRead); + + while (true) { + const newlineIndex = buffered.indexOf("\n"); + if (newlineIndex === -1) { + break; + } + + const line = buffered.slice(0, newlineIndex).trim(); + buffered = buffered.slice(newlineIndex + 1); + + if (!line) { + continue; + } + + handleMessage(JSON.parse(line)); + } +} diff --git a/src/tests/session-config-options.test.ts b/src/tests/session-config-options.test.ts index 35f75f9..fdc4fd5 100644 --- a/src/tests/session-config-options.test.ts +++ b/src/tests/session-config-options.test.ts @@ -90,7 +90,7 @@ describe("session config options", () => { supportedCommands: async () => [], }, input: null, - cancelled: false, + cancelGeneration: 0, permissionMode: "default", settingsManager: {}, configOptions: structuredClone(MOCK_CONFIG_OPTIONS),