Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/acp-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ type AccumulatedUsage = {
type Session = {
query: Query;
input: Pushable<SDKUserMessage>;
cancelled: boolean;
/** Bumped by cancel(); prompt() captures the value at entry so a later
* prompt() call cannot clobber the cancellation signal. */
cancelGeneration: number;
cwd: string;
/** Serialized snapshot of session-defining params (cwd, mcpServers) used to
* detect when loadSession/resumeSession is called with changed values. */
Expand Down Expand Up @@ -545,7 +547,9 @@ export class ClaudeAcpAgent implements Agent {
throw new Error("Session not found");
}

session.cancelled = false;
const cancelGeneration = session.cancelGeneration;
const isCancelled = () => session.cancelGeneration !== cancelGeneration;

session.accumulatedUsage = {
inputTokens: 0,
outputTokens: 0,
Expand Down Expand Up @@ -591,7 +595,7 @@ export class ClaudeAcpAgent implements Agent {
const { value: message, done } = await session.query.next();

if (done || !message) {
if (session.cancelled) {
if (isCancelled()) {
return { stopReason: "cancelled" };
}
break;
Expand Down Expand Up @@ -716,7 +720,7 @@ export class ClaudeAcpAgent implements Agent {
});
}

if (session.cancelled) {
if (isCancelled()) {
stopReason = "cancelled";
break;
}
Expand Down Expand Up @@ -798,7 +802,7 @@ export class ClaudeAcpAgent implements Agent {
}
case "user":
case "assistant": {
if (session.cancelled) {
if (isCancelled()) {
break;
}

Expand Down Expand Up @@ -970,7 +974,7 @@ export class ClaudeAcpAgent implements Agent {
if (!session) {
throw new Error("Session not found");
}
session.cancelled = true;
session.cancelGeneration++;
for (const [, pending] of session.pendingMessages) {
pending.resolve(true);
}
Expand Down Expand Up @@ -1628,7 +1632,7 @@ export class ClaudeAcpAgent implements Agent {
this.sessions[sessionId] = {
query: q,
input: input,
cancelled: false,
cancelGeneration: 0,
cwd: params.cwd,
sessionFingerprint: computeSessionFingerprint(params),
settingsManager,
Expand Down
146 changes: 141 additions & 5 deletions src/tests/acp-agent.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { describe, it, expect, beforeAll, afterAll, vi } from "vitest";
import { spawn, spawnSync } from "child_process";
import { readFileSync } from "fs";
import { fileURLToPath } from "url";
import {
Agent,
AgentSideConnection,
Expand Down Expand Up @@ -1326,7 +1328,7 @@ describe("stop reason propagation", () => {
agent.sessions["test-session"] = {
query: messageGenerator() as any,
input,
cancelled: false,
cancelGeneration: 0,
cwd: "/test",
sessionFingerprint: JSON.stringify({ cwd: "/test", mcpServers: [] }),
modes: {
Expand Down Expand Up @@ -1469,7 +1471,7 @@ describe("stop reason propagation", () => {
input,
cwd: "/tmp/test",
sessionFingerprint: JSON.stringify({ cwd: "/tmp/test", mcpServers: [] }),
cancelled: false,
cancelGeneration: 0,
modes: {
currentModeId: "default",
availableModes: [],
Expand Down Expand Up @@ -1508,6 +1510,140 @@ describe("stop reason propagation", () => {
);
});

it("does not let a stale idle from a cancelled turn complete the next prompt", async () => {
class RecordingClient {
receivedText = "";

async sessionUpdate(params: SessionNotification): Promise<void> {
if (
params.update.sessionUpdate === "agent_message_chunk" &&
params.update.content.type === "text"
) {
this.receivedText += params.update.content.text;
}
}
}

const fakeCliPath = fileURLToPath(
new URL("./fixtures/fake-claude-cli-cancel-stale-idle.mjs", import.meta.url),
);
const fakeCliLogPath = `/tmp/fake-claude-cli-${randomUUID()}.log`;
const transportLog: string[] = [];
const client = new RecordingClient();
const agent = new ClaudeAcpAgent(client as unknown as AgentSideConnection, {
log: () => {},
error: () => {},
});

function debugContext() {
let fakeCliLog = "";
try {
fakeCliLog = readFileSync(fakeCliLogPath, "utf8");
} catch {
fakeCliLog = "<missing fake cli log>";
}

return `transport log:\n${transportLog.join("\n") || "<empty>"}\n\nfake cli log:\n${fakeCliLog}`;
}

async function withTimeout<T>(
label: string,
promise: Promise<T>,
timeoutMs = 2000,
): Promise<T> {
return await Promise.race([
promise,
new Promise<T>((_, reject) => {
setTimeout(() => {
reject(new Error(`${label} timed out\n\n${debugContext()}`));
}, timeoutMs);
}),
]);
}

await agent.initialize({
protocolVersion: 1,
clientCapabilities: {
fs: {
readTextFile: true,
writeTextFile: true,
},
},
});

const newSessionResponse = await withTimeout(
"newSession",
agent.newSession({
cwd: __dirname,
mcpServers: [],
_meta: {
claudeCode: {
options: {
pathToClaudeCodeExecutable: fakeCliPath,
env: {
FAKE_CLAUDE_CLI_LOG_PATH: fakeCliLogPath,
},
spawnClaudeCodeProcess: (options: any) => {
transportLog.push(
`spawn ${options.command} ${Array.isArray(options.args) ? options.args.join(" ") : ""}`,
);
const child = spawn(options.command, options.args, {
cwd: options.cwd,
env: options.env,
signal: options.signal,
stdio: ["pipe", "pipe", "pipe"],
});
child.stdout.on("data", (chunk) => {
transportLog.push(`stdout ${chunk.toString("utf8").trimEnd()}`);
});
child.stderr.on("data", (chunk) => {
transportLog.push(`stderr ${chunk.toString("utf8").trimEnd()}`);
});

const originalWrite = child.stdin.write.bind(child.stdin);
child.stdin.write = ((chunk: any, ...args: any[]) => {
transportLog.push(`stdin ${Buffer.from(chunk).toString("utf8").trimEnd()}`);
return originalWrite(chunk, ...args);
}) as any;

return child as any;
},
},
},
},
}),
);

try {
const firstPrompt = agent.prompt({
prompt: [{ type: "text", text: "first prompt" }],
sessionId: newSessionResponse.sessionId,
});

await withTimeout("cancel", agent.cancel({ sessionId: newSessionResponse.sessionId }));

const secondResponse = await withTimeout(
"second prompt",
agent.prompt({
prompt: [{ type: "text", text: "second prompt" }],
sessionId: newSessionResponse.sessionId,
}),
);
const firstResponse = await withTimeout("first prompt", firstPrompt);

expect(firstResponse.stopReason).toBe("cancelled");
expect(secondResponse.stopReason).toBe("end_turn");
// The first prompt now exits early (correctly returning "cancelled"
// without consuming the interrupted turn's result message). The second
// prompt's loop picks up that leftover result, so its totalTokens
// includes both the cancelled turn's usage and its own.
expect(secondResponse.usage?.totalTokens).toBe(26);
expect(client.receivedText).toContain("actual second prompt output");
} finally {
await agent.unstable_closeSession({ sessionId: newSessionResponse.sessionId });
}
}, 10000);

it("should throw internal error for success with is_error true and no max_tokens", async () => {
const agent = createMockAgent();
injectSession(agent, [
Expand Down Expand Up @@ -1542,7 +1678,7 @@ describe("session/close", () => {
agent.sessions[sessionId] = {
query: gen as any,
input: new Pushable(),
cancelled: false,
cancelGeneration: 0,
cwd: "/test",
sessionFingerprint: JSON.stringify({ cwd: "/test", mcpServers: [] }),
modes: {
Expand Down Expand Up @@ -1638,7 +1774,7 @@ describe("getOrCreateSession param change detection", () => {
agent.sessions[sessionId] = {
query: gen as any,
input: new Pushable(),
cancelled: false,
cancelGeneration: 0,
cwd,
sessionFingerprint: JSON.stringify({
cwd,
Expand Down Expand Up @@ -1850,7 +1986,7 @@ describe("usage_update computation", () => {
agent.sessions["test-session"] = {
query: messageGenerator() as any,
input,
cancelled: false,
cancelGeneration: 0,
cwd: "/test",
sessionFingerprint: JSON.stringify({ cwd: "/test", mcpServers: [] }),
modes: {
Expand Down
Loading