Skip to content

add a new unified route for count_tokens endpoint #1293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions src/handlers/messagesCountTokensHandler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { RouterError } from '../errors/RouterError';
import {
constructConfigFromRequestHeaders,
tryTargetsRecursively,
} from './handlerUtils';
import { Context } from 'hono';

/**
* Handles the '/messages' API request by selecting the appropriate provider(s) and making the request to them.
*
* @param {Context} c - The Cloudflare Worker context.
* @returns {Promise<Response>} - The response from the provider.
* @throws Will throw an error if no provider options can be determined or if the request to the provider(s) fails.
* @throws Will throw an 500 error if the handler fails due to some reasons
*/
Comment on lines +8 to +15

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 Documentation Gap

Issue: JSDoc comment describes '/messages' API but function handles '/messages/count_tokens' endpoint
Fix: Update documentation to match actual functionality
Impact: Prevents developer confusion about handler purpose

Suggested change
/**
* Handles the '/messages' API request by selecting the appropriate provider(s) and making the request to them.
*
* @param {Context} c - The Cloudflare Worker context.
* @returns {Promise<Response>} - The response from the provider.
* @throws Will throw an error if no provider options can be determined or if the request to the provider(s) fails.
* @throws Will throw an 500 error if the handler fails due to some reasons
*/
/**
* Handles the '/messages/count_tokens' API request by selecting the appropriate provider(s) and making the request to them.
*
* @param {Context} c - The Cloudflare Worker context.
* @returns {Promise<Response>} - The response from the provider.
* @throws Will throw an error if no provider options can be determined or if the request to the provider(s) fails.
* @throws Will throw an 500 error if the handler fails due to some reasons
*/

export async function messagesCountTokensHandler(
c: Context
): Promise<Response> {
try {
let request = await c.req.json();
let requestHeaders = Object.fromEntries(c.req.raw.headers);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️ Performance Improvement

Issue: Object.fromEntries creates unnecessary object copy when headers are already iterable
Fix: Pass headers directly to constructConfigFromRequestHeaders if it accepts Headers object
Impact: Reduces object creation overhead and improves memory efficiency

Suggested change
let requestHeaders = Object.fromEntries(c.req.raw.headers);
let requestHeaders = c.req.raw.headers;

const camelCaseConfig = constructConfigFromRequestHeaders(requestHeaders);
const tryTargetsResponse = await tryTargetsRecursively(
c,
camelCaseConfig ?? {},
request,
requestHeaders,
'messagesCountTokens',
'POST',
'config'
);

return tryTargetsResponse;
} catch (err: any) {
console.log('messagesCountTokens error', err.message);
let statusCode = 500;
let errorMessage = 'Something went wrong';

if (err instanceof RouterError) {
statusCode = 400;
errorMessage = err.message;
}

return new Response(
JSON.stringify({
status: 'failure',
message: errorMessage,
}),
{
status: statusCode,
headers: {
'content-type': 'application/json',
},
}
);
}
}
7 changes: 7 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import { messagesHandler } from './handlers/messagesHandler';
// Config
import conf from '../conf.json';
import modelResponsesHandler from './handlers/modelResponsesHandler';
import { messagesCountTokensHandler } from './handlers/messagesCountTokensHandler';

// Create a new Hono server instance
const app = new Hono();
Expand Down Expand Up @@ -126,6 +127,12 @@ app.onError((err, c) => {
*/
app.post('/v1/messages', requestValidator, messagesHandler);

app.post(
'/v1/messages/count_tokens',
requestValidator,
messagesCountTokensHandler
);

/**
* POST route for '/v1/chat/completions'.
* Handles requests by passing them to the chatCompletionsHandler.
Expand Down
2 changes: 2 additions & 0 deletions src/providers/anthropic/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ const AnthropicAPIConfig: ProviderAPIConfig = {
return '/messages';
case 'messages':
return '/messages';
case 'messagesCountTokens':
return '/messages/count_tokens';
default:
return '';
}
Expand Down
1 change: 1 addition & 0 deletions src/providers/anthropic/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const AnthropicConfig: ProviderConfigs = {
complete: AnthropicCompleteConfig,
chatComplete: AnthropicChatCompleteConfig,
messages: AnthropicMessagesConfig,
messagesCountTokens: AnthropicMessagesConfig,
api: AnthropicAPIConfig,
responseTransforms: {
'stream-complete': AnthropicCompleteStreamChunkTransform,
Expand Down
2 changes: 2 additions & 0 deletions src/providers/google-vertex-ai/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ export const GoogleApiConfig: ProviderAPIConfig = {
mappedFn === 'stream-messages'
) {
return `${projectRoute}/publishers/${provider}/models/${model}:streamRawPredict`;
} else if (mappedFn === 'messagesCountTokens') {
return `${projectRoute}/publishers/${provider}/models/count-tokens:rawPredict`;
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/providers/google-vertex-ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import {
VertexAnthropicMessagesConfig,
VertexAnthropicMessagesResponseTransform,
} from './messages';
import { VertexAnthropicMessagesCountTokensConfig } from './messagesCountTokens';

const VertexConfig: ProviderConfigs = {
api: VertexApiConfig,
Expand Down Expand Up @@ -117,6 +118,7 @@ const VertexConfig: ProviderConfigs = {
createBatch: GoogleBatchCreateConfig,
createFinetune: baseConfig.createFinetune,
messages: VertexAnthropicMessagesConfig,
messagesCountTokens: VertexAnthropicMessagesCountTokensConfig,
responseTransforms: {
'stream-chatComplete':
VertexAnthropicChatCompleteStreamChunkTransform,
Expand Down
14 changes: 14 additions & 0 deletions src/providers/google-vertex-ai/messagesCountTokens.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import { MessageCreateParamsBase } from '../../types/MessagesRequest';
import { getMessagesConfig } from '../anthropic-base/messages';

export const VertexAnthropicMessagesCountTokensConfig = {
...getMessagesConfig({}),
model: {
param: 'model',
required: true,
transform: (params: MessageCreateParamsBase) => {
let model = params.model ?? '';
return model.replace('anthropic.', '');
},
},
};
3 changes: 2 additions & 1 deletion src/providers/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ export type endpointStrings =
| 'getModelResponse'
| 'deleteModelResponse'
| 'listResponseInputItems'
| 'messages';
| 'messages'
| 'messagesCountTokens';

/**
* A collection of API configurations for multiple AI providers.
Expand Down