diff --git a/src/acp-agent.ts b/src/acp-agent.ts index 1ad57477..b6396bb2 100644 --- a/src/acp-agent.ts +++ b/src/acp-agent.ts @@ -109,6 +109,20 @@ type AccumulatedUsage = { cachedWriteTokens: number; }; +type UsageSnapshot = { + input_tokens: number; + output_tokens: number; + cache_read_input_tokens: number; + cache_creation_input_tokens: number; +}; + +const ZERO_USAGE: UsageSnapshot = { + input_tokens: 0, + output_tokens: 0, + cache_read_input_tokens: 0, + cache_creation_input_tokens: 0, +}; + type Session = { query: Query; input: Pushable; @@ -488,6 +502,7 @@ export class ClaudeAcpAgent implements Agent { }; let lastAssistantTotalUsage: number | null = null; + let lastAssistantUsage: UsageSnapshot | null = null; let lastAssistantModel: string | null = null; let lastContextWindowSize: number = 200000; @@ -704,6 +719,51 @@ export class ClaudeAcpAgent implements Agent { break; } case "stream_event": { + if (message.parent_tool_use_id === null) { + if (message.event.type === "message_start") { + const usage = message.event.message.usage; + lastAssistantUsage = usage + ? { + input_tokens: usage.input_tokens ?? 0, + output_tokens: usage.output_tokens ?? 0, + cache_read_input_tokens: usage.cache_read_input_tokens ?? 0, + cache_creation_input_tokens: usage.cache_creation_input_tokens ?? 0, + } + : null; + if ( + message.event.message.model && + message.event.message.model !== "" + ) { + lastAssistantModel = message.event.message.model; + } + } else if (message.event.type === "message_delta") { + const usage = message.event.usage as Partial | undefined; + if (usage) { + const prev: UsageSnapshot = lastAssistantUsage ?? ZERO_USAGE; + lastAssistantUsage = { + input_tokens: usage.input_tokens ?? prev.input_tokens, + output_tokens: usage.output_tokens ?? prev.output_tokens, + cache_read_input_tokens: + usage.cache_read_input_tokens ?? prev.cache_read_input_tokens, + cache_creation_input_tokens: + usage.cache_creation_input_tokens ?? prev.cache_creation_input_tokens, + }; + } + } + + const nextUsage = lastAssistantUsage ? totalTokens(lastAssistantUsage) : null; + if (nextUsage !== null && nextUsage !== lastAssistantTotalUsage) { + lastAssistantTotalUsage = nextUsage; + await this.client.sessionUpdate({ + sessionId: params.sessionId, + update: { + sessionUpdate: "usage_update", + used: nextUsage, + size: lastContextWindowSize, + }, + }); + } + } for (const notification of streamEventToAcpNotifications( message, params.sessionId, @@ -754,11 +814,13 @@ export class ClaudeAcpAgent implements Agent { // all four fields is not double-counting. if ((message.message as any).usage && message.parent_tool_use_id === null) { const messageWithUsage = message.message as unknown as SDKResultMessage; - lastAssistantTotalUsage = - messageWithUsage.usage.input_tokens + - messageWithUsage.usage.output_tokens + - messageWithUsage.usage.cache_read_input_tokens + - messageWithUsage.usage.cache_creation_input_tokens; + lastAssistantUsage = { + input_tokens: messageWithUsage.usage.input_tokens, + output_tokens: messageWithUsage.usage.output_tokens, + cache_read_input_tokens: messageWithUsage.usage.cache_read_input_tokens, + cache_creation_input_tokens: messageWithUsage.usage.cache_creation_input_tokens, + }; + lastAssistantTotalUsage = totalTokens(lastAssistantUsage); } // Track the current top-level model for context window size lookup // (exclude subagent messages to stay in sync with lastAssistantTotalUsage) @@ -1569,6 +1631,15 @@ function sessionUsage(session: Session) { }; } +function totalTokens(usage: UsageSnapshot): number { + return ( + usage.input_tokens + + usage.output_tokens + + usage.cache_read_input_tokens + + usage.cache_creation_input_tokens + ); +} + function createEnvForGateway(gatewayMeta?: GatewayAuthMeta) { if (!gatewayMeta) { return {}; diff --git a/src/tests/acp-agent.test.ts b/src/tests/acp-agent.test.ts index 6d77aad9..28d3110c 100644 --- a/src/tests/acp-agent.test.ts +++ b/src/tests/acp-agent.test.ts @@ -1676,6 +1676,23 @@ describe("usage_update computation", () => { }; } + function createStreamEvent( + eventType: "message_start" | "message_delta", + payload: Record, + parentToolUseId: string | null = null, + ) { + return { + type: "stream_event" as const, + parent_tool_use_id: parentToolUseId, + uuid: randomUUID(), + session_id: "test-session", + event: + eventType === "message_start" + ? { type: "message_start" as const, message: payload } + : { type: "message_delta" as const, ...payload }, + }; + } + function createMockAgentWithCapture() { const updates: any[] = []; const mockClient = { @@ -1770,6 +1787,128 @@ describe("usage_update computation", () => { expect(usageUpdate.update.used).toBe(1800); }); + it("stream_event message_start emits usage_update before result", async () => { + const { agent, updates } = createMockAgentWithCapture(); + injectSession(agent, [ + createStreamEvent("message_start", { + model: "claude-opus-4-20250514", + usage: { + input_tokens: 1000, + output_tokens: 500, + cache_read_input_tokens: 200, + cache_creation_input_tokens: 100, + }, + }), + createResultMessageWithModel({ + modelUsage: { + "claude-opus-4-20250514": { + inputTokens: 1000, + outputTokens: 500, + cacheReadInputTokens: 200, + cacheCreationInputTokens: 100, + webSearchRequests: 0, + costUSD: 0.01, + contextWindow: 1000000, + maxOutputTokens: 16384, + }, + }, + }), + { type: "system", subtype: "session_state_changed", state: "idle" }, + ]); + + await agent.prompt({ sessionId: "test-session", prompt: [{ type: "text", text: "test" }] }); + + const usageUpdates = updates.filter((u: any) => u.update?.sessionUpdate === "usage_update"); + expect(usageUpdates).toHaveLength(2); + expect(usageUpdates[0].update.used).toBe(1800); + expect(usageUpdates[0].update.cost).toBeUndefined(); + expect(usageUpdates[1].update.used).toBe(1800); + expect(usageUpdates[1].update.cost).toBeDefined(); + }); + + it("stream_event message_delta patches previous snapshot", async () => { + const { agent, updates } = createMockAgentWithCapture(); + injectSession(agent, [ + createStreamEvent("message_start", { + model: "claude-opus-4-20250514", + usage: { + input_tokens: 1000, + output_tokens: 0, + cache_read_input_tokens: 200, + cache_creation_input_tokens: 100, + }, + }), + createStreamEvent("message_delta", { + usage: { output_tokens: 500 }, + }), + createResultMessageWithModel({ + modelUsage: { + "claude-opus-4-20250514": { + inputTokens: 1000, + outputTokens: 500, + cacheReadInputTokens: 200, + cacheCreationInputTokens: 100, + webSearchRequests: 0, + costUSD: 0.01, + contextWindow: 1000000, + maxOutputTokens: 16384, + }, + }, + }), + { type: "system", subtype: "session_state_changed", state: "idle" }, + ]); + + await agent.prompt({ sessionId: "test-session", prompt: [{ type: "text", text: "test" }] }); + + const usageUpdates = updates.filter((u: any) => u.update?.sessionUpdate === "usage_update"); + expect(usageUpdates).toHaveLength(3); + expect(usageUpdates[0].update.used).toBe(1300); + expect(usageUpdates[0].update.cost).toBeUndefined(); + expect(usageUpdates[1].update.used).toBe(1800); + expect(usageUpdates[1].update.cost).toBeUndefined(); + expect(usageUpdates[2].update.used).toBe(1800); + expect(usageUpdates[2].update.cost).toBeDefined(); + }); + + it("subagent stream_event does not emit usage_update", async () => { + const { agent, updates } = createMockAgentWithCapture(); + injectSession(agent, [ + createStreamEvent( + "message_start", + { + model: "claude-haiku-4-5-20251001", + usage: { + input_tokens: 500, + output_tokens: 100, + cache_read_input_tokens: 0, + cache_creation_input_tokens: 0, + }, + }, + "tool_use_123", + ), + createResultMessageWithModel({ + modelUsage: { + "claude-haiku-4-5-20251001": { + inputTokens: 500, + outputTokens: 100, + cacheReadInputTokens: 0, + cacheCreationInputTokens: 0, + webSearchRequests: 0, + costUSD: 0.001, + contextWindow: 200000, + maxOutputTokens: 8192, + }, + }, + }), + { type: "system", subtype: "session_state_changed", state: "idle" }, + ]); + + await agent.prompt({ sessionId: "test-session", prompt: [{ type: "text", text: "test" }] }); + + const usageUpdates = updates.filter((u: any) => u.update?.sessionUpdate === "usage_update"); + expect(usageUpdates).toHaveLength(0); + }); + it("size reflects the current model's context window, not min across all", async () => { const { agent, updates } = createMockAgentWithCapture(); injectSession(agent, [