Skip to content

perf(llm): Optimize pruneLines functions in countTokens #5310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions core/llm/countTokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -167,27 +167,55 @@ function pruneLinesFromTop(
maxTokens: number,
modelName: string,
): string {
let totalTokens = countTokens(prompt, modelName);
const lines = prompt.split("\n");
while (totalTokens > maxTokens && lines.length > 0) {
totalTokens -= countTokens(lines.shift()!, modelName);
// Preprocess tokens for all lines and cache them.
const lineTokens = lines.map((line) => countTokens(line, modelName));
let totalTokens = lineTokens.reduce((sum, tokens) => sum + tokens, 0);
let start = 0;
let currentLines = lines.length;

// Calculate initial token count including newlines
totalTokens += Math.max(0, currentLines - 1); // Add tokens for joining newlines

// Using indexes instead of array modifications.
// Remove lines from the top until the token count is within the limit.
while (totalTokens > maxTokens && start < currentLines) {
totalTokens -= lineTokens[start];
// Decrement token count for the removed line and its preceding/joining newline (if not the last line)
if (currentLines - start > 1) {
totalTokens--;
}
start++;
}

return lines.join("\n");
return lines.slice(start).join("\n");
}

function pruneLinesFromBottom(
prompt: string,
maxTokens: number,
modelName: string,
): string {
let totalTokens = countTokens(prompt, modelName);
const lines = prompt.split("\n");
while (totalTokens > maxTokens && lines.length > 0) {
totalTokens -= countTokens(lines.pop()!, modelName);
const lineTokens = lines.map((line) => countTokens(line, modelName));
let totalTokens = lineTokens.reduce((sum, tokens) => sum + tokens, 0);
let end = lines.length;

// Calculate initial token count including newlines
totalTokens += Math.max(0, end - 1); // Add tokens for joining newlines

// Reverse traversal to avoid array modification
// Remove lines from the bottom until the token count is within the limit.
while (totalTokens > maxTokens && end > 0) {
end--;
totalTokens -= lineTokens[end];
// Decrement token count for the removed line and its following/joining newline (if not the first line)
if (end > 0) {
totalTokens--;
}
}

return lines.join("\n");
return lines.slice(0, end).join("\n");
}

function pruneStringFromBottom(
Expand Down Expand Up @@ -606,5 +634,5 @@ export {
pruneLinesFromTop,
pruneRawPromptFromTop,
pruneStringFromBottom,
pruneStringFromTop,
pruneStringFromTop
};
Loading