Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add taskId and checkpointNumber parameters to completePrompt and rela… #43

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@ import { HumanRelayHandler } from "./providers/human-relay"
import { KiloCodeHandler } from "./providers/kilocode"

export interface SingleCompletionHandler {
completePrompt(prompt: string): Promise<string>
completePrompt(prompt: string, taskId?: string, checkpointNumber?: number): Promise<string>
}

export interface ApiHandler {
createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream
createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
taskId?: string,
checkpointNumber?: number,
): ApiStream
getModel(): { id: string; info: ModelInfo }

/**
Expand Down
7 changes: 6 additions & 1 deletion src/api/providers/base-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ const TOKEN_FUDGE_FACTOR = 1.5
export abstract class BaseProvider implements ApiHandler {
// Cache the Tiktoken encoder instance since it's stateless
private encoder: Tiktoken | null = null
abstract createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream
abstract createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
taskId?: string,
checkpointNumber?: number,
): ApiStream
abstract getModel(): { id: string; info: ModelInfo }

/**
Expand Down
119 changes: 75 additions & 44 deletions src/api/providers/kilocode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,62 @@ export class KiloCodeHandler extends BaseProvider implements SingleCompletionHan
})
}

async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
private getIdempotencyKey(taskId: string, checkpointNumber: number): string {
// Create a deterministic idempotency key based on task_id and checkpoint number
return `${taskId}-${checkpointNumber}`
}

async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
taskId?: string,
checkpointNumber?: number,
): ApiStream {
let stream: AnthropicStream<Anthropic.Messages.RawMessageStreamEvent>
const cacheControl: CacheControlEphemeral = { type: "ephemeral" }
let { id: modelId, maxTokens, thinking, temperature, virtualId } = this.getModel()
const { id: modelId, maxTokens, thinking, temperature, virtualId } = this.getModel()

const userMsgIndices = messages.reduce(
(acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc),
[] as number[],
)
// Use a for loop instead of reduce with spread to avoid linting error
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this comment. This could've been a self-comment on this PR instead.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also: If you convert this to a for .. of loop it's modern JS best practice compliant, better readable and removes the need for the comment

const userMsgIndices: number[] = []
for (let i = 0; i < messages.length; i++) {
if (messages[i].role === "user") {
userMsgIndices.push(i)
}
}

const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1

// Prepare request options with headers
const requestOptions: { headers: Record<string, string> } = (() => {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either this typing here or the one down on 64 is redundant, making it harder to maintain.

const betas: string[] = []

// Check for models that support prompt caching
switch (modelId) {
case "claude-3-7-sonnet-20250219":
case "claude-3-5-sonnet-20241022":
case "claude-3-5-haiku-20241022":
case "claude-3-opus-20240229":
case "claude-3-haiku-20240307":
betas.push("prompt-caching-2024-07-31")
break
}

const headers: Record<string, string> = {}

// Add beta features if any
if (betas.length > 0) {
headers["anthropic-beta"] = betas.join(",")
}

// Add idempotency key if task_id and checkpoint number are provided
if (taskId && checkpointNumber !== undefined) {
headers["idempotency-key"] = this.getIdempotencyKey(taskId, checkpointNumber)
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the AI way of doing it, the more efficient (preventing multiple object mutations) and readable way (imho) to do it would be like this:

const headers = {
  'anthropic-beta': betas.length > 0 ? betas.join(",") : undefined,
  'idempotency-key': taskId && checkpointNumber !== undefined ? this.getIdempotencyKey(taskId, checkpointNumber),
}


return { headers }
})()

stream = await this.client.messages.create(
{
model: modelId,
Expand Down Expand Up @@ -62,38 +105,12 @@ export class KiloCodeHandler extends BaseProvider implements SingleCompletionHan
// tools: tools,
stream: true,
},
(() => {
// prompt caching: https://x.com/alexalbert__/status/1823751995901272068
// https://github.com/anthropics/anthropic-sdk-typescript?tab=readme-ov-file#default-headers
// https://github.com/anthropics/anthropic-sdk-typescript/commit/c920b77fc67bd839bfeb6716ceab9d7c9bbe7393
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These comments seem useful. Preserve them?


const betas = []

// // Check for the thinking-128k variant first
// if (virtualId === "claude-3-7-sonnet-20250219:thinking") {
// betas.push("output-128k-2025-02-19")
// }

// Then check for models that support prompt caching
switch (modelId) {
case "claude-3-7-sonnet-20250219":
case "claude-3-5-sonnet-20241022":
case "claude-3-5-haiku-20241022":
case "claude-3-opus-20240229":
case "claude-3-haiku-20240307":
betas.push("prompt-caching-2024-07-31")
return {
headers: { "anthropic-beta": betas.join(",") },
}
default:
return undefined
}
})(),
requestOptions,
)

for await (const chunk of stream) {
switch (chunk.type) {
case "message_start":
case "message_start": {
// Tells us cache reads/writes/input/output.
const usage = chunk.message.usage

Expand All @@ -106,6 +123,7 @@ export class KiloCodeHandler extends BaseProvider implements SingleCompletionHan
}

break
}
case "message_delta":
// Tells us stop_reason, stop_sequence, and output tokens
// along the way and at the end of the message.
Expand Down Expand Up @@ -174,17 +192,30 @@ export class KiloCodeHandler extends BaseProvider implements SingleCompletionHan
}
}

async completePrompt(prompt: string) {
let { id: modelId, temperature } = this.getModel()
async completePrompt(prompt: string, taskId?: string, checkpointNumber?: number) {
const { id: modelId, temperature } = this.getModel()

const message = await this.client.messages.create({
model: modelId,
max_tokens: ANTHROPIC_DEFAULT_MAX_TOKENS,
thinking: undefined,
temperature,
messages: [{ role: "user", content: prompt }],
stream: false,
})
// Prepare request options with headers
const requestOptions: { headers: Record<string, string> } = {
headers: {},
}

// Add idempotency key if task_id and checkpoint number are provided
if (taskId && checkpointNumber !== undefined) {
requestOptions.headers["idempotency-key"] = this.getIdempotencyKey(taskId, checkpointNumber)
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe write this a bit more compactly? It feels like "a lot" in the way it's spread over so many lines, while in reality almost nothing is happening here. Code "weight" should feel proportional to the weight of the content.


const message = await this.client.messages.create(
{
model: modelId,
max_tokens: ANTHROPIC_DEFAULT_MAX_TOKENS,
thinking: undefined,
temperature,
messages: [{ role: "user", content: prompt }],
stream: false,
},
requestOptions,
)

const content = message.content.find(({ type }) => type === "text")
return content?.type === "text" ? content.text : ""
Expand Down
6 changes: 5 additions & 1 deletion src/core/Cline.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,11 @@ export class Cline {
}
return { role, content }
})
const stream = this.api.createMessage(systemPrompt, cleanConversationHistory)
// Get the current checkpoint number for idempotency key generation
const checkpointNumber = this.clineMessages.filter(({ say }) => say === "checkpoint_saved").length

// Pass task_id and checkpoint number to the API for idempotency key generation
const stream = this.api.createMessage(systemPrompt, cleanConversationHistory, this.taskId, checkpointNumber)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same point about weight. Remove those comments.

const iterator = stream[Symbol.asyncIterator]()

try {
Expand Down
4 changes: 4 additions & 0 deletions src/core/webview/ClineProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1507,6 +1507,10 @@ export class ClineProvider implements vscode.WebviewViewProvider {
},
customSupportPrompts,
),
this.getCurrentCline()?.taskId,
this.getCurrentCline()?.clineMessages.filter(
({ say }) => say === "checkpoint_saved",
).length,
)

await this.postMessageToWebview({
Expand Down
14 changes: 9 additions & 5 deletions src/utils/__tests__/enhance-prompt.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ describe("enhancePrompt", () => {

// Mock the API handler with a completePrompt method
;(buildApiHandler as jest.Mock).mockReturnValue({
completePrompt: jest.fn().mockResolvedValue("Enhanced prompt"),
completePrompt: jest.fn((prompt, taskId, checkpointNumber) => Promise.resolve("Enhanced prompt")),
createMessage: jest.fn(),
getModel: jest.fn().mockReturnValue({
id: "test-model",
Expand All @@ -38,7 +38,7 @@ describe("enhancePrompt", () => {

expect(result).toBe("Enhanced prompt")
const handler = buildApiHandler(mockApiConfig)
expect((handler as any).completePrompt).toHaveBeenCalledWith(`Test prompt`)
expect((handler as any).completePrompt).toHaveBeenCalledWith(`Test prompt`, undefined, undefined)
})

it("enhances prompt using custom enhancement prompt when provided", async () => {
Expand All @@ -60,7 +60,11 @@ describe("enhancePrompt", () => {

expect(result).toBe("Enhanced prompt")
const handler = buildApiHandler(mockApiConfig)
expect((handler as any).completePrompt).toHaveBeenCalledWith(`${customEnhancePrompt}\n\nTest prompt`)
expect((handler as any).completePrompt).toHaveBeenCalledWith(
`${customEnhancePrompt}\n\nTest prompt`,
undefined,
undefined,
)
})

it("throws error for empty prompt input", async () => {
Expand Down Expand Up @@ -101,7 +105,7 @@ describe("enhancePrompt", () => {

// Mock successful enhancement
;(buildApiHandler as jest.Mock).mockReturnValue({
completePrompt: jest.fn().mockResolvedValue("Enhanced prompt"),
completePrompt: jest.fn((prompt, taskId, checkpointNumber) => Promise.resolve("Enhanced prompt")),
createMessage: jest.fn(),
getModel: jest.fn().mockReturnValue({
id: "test-model",
Expand All @@ -121,7 +125,7 @@ describe("enhancePrompt", () => {

it("propagates API errors", async () => {
;(buildApiHandler as jest.Mock).mockReturnValue({
completePrompt: jest.fn().mockRejectedValue(new Error("API Error")),
completePrompt: jest.fn((prompt, taskId, checkpointNumber) => Promise.reject(new Error("API Error"))),
createMessage: jest.fn(),
getModel: jest.fn().mockReturnValue({
id: "test-model",
Expand Down
9 changes: 7 additions & 2 deletions src/utils/single-completion-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ import { buildApiHandler, SingleCompletionHandler } from "../api"
* Enhances a prompt using the configured API without creating a full Cline instance or task history.
* This is a lightweight alternative that only uses the API's completion functionality.
*/
export async function singleCompletionHandler(apiConfiguration: ApiConfiguration, promptText: string): Promise<string> {
export async function singleCompletionHandler(
apiConfiguration: ApiConfiguration,
promptText: string,
taskId?: string,
checkpointNumber?: number,
): Promise<string> {
if (!promptText) {
throw new Error("No prompt text provided")
}
Expand All @@ -20,5 +25,5 @@ export async function singleCompletionHandler(apiConfiguration: ApiConfiguration
throw new Error("The selected API provider does not support prompt enhancement")
}

return (handler as SingleCompletionHandler).completePrompt(promptText)
return (handler as SingleCompletionHandler).completePrompt(promptText, taskId, checkpointNumber)
}