diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 0f40c17..78ab342 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -16,7 +16,8 @@ "eamodio.gitlens", "VisualStudioExptTeam.vscodeintellicode", "christian-kohler.path-intellisense", - "christian-kohler.npm-intellisense" + "christian-kohler.npm-intellisense", + "orta.vscode-jest" ], "settings": { "terminal.integrated.defaultProfile.linux": "zsh", diff --git a/README.md b/README.md index ebe5f15..4d2d907 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,7 @@ try { onOpen: () => console.log('Stream opened'), onContent: (content) => process.stdout.write(content), onChunk: (chunk) => console.log('Received chunk:', chunk.id), + onUsageMetrics: (metrics) => console.log('Usage metrics:', metrics), onFinish: () => console.log('\nStream completed'), onError: (error) => console.error('Stream error:', error), }, diff --git a/src/client.ts b/src/client.ts index 650a9a1..848cc21 100644 --- a/src/client.ts +++ b/src/client.ts @@ -1,6 +1,7 @@ import type { Provider, SchemaChatCompletionMessageToolCall, + SchemaCompletionUsage, SchemaCreateChatCompletionRequest, SchemaCreateChatCompletionResponse, SchemaCreateChatCompletionStreamResponse, @@ -15,6 +16,7 @@ interface ChatCompletionStreamCallbacks { onReasoning?: (reasoningContent: string) => void; onContent?: (content: string) => void; onTool?: (toolCall: SchemaChatCompletionMessageToolCall) => void; + onUsageMetrics?: (usage: SchemaCompletionUsage) => void; onFinish?: ( response: SchemaCreateChatCompletionStreamResponse | null ) => void; @@ -258,6 +260,10 @@ export class InferenceGatewayClient { JSON.parse(data); callbacks.onChunk?.(chunk); + if (chunk.usage && callbacks.onUsageMetrics) { + callbacks.onUsageMetrics(chunk.usage); + } + const reasoning_content = chunk.choices[0]?.delta?.reasoning_content; if (reasoning_content !== undefined) { diff --git a/tests/client.test.ts b/tests/client.test.ts index eabc7a8..863d357 100644 --- a/tests/client.test.ts +++ b/tests/client.test.ts @@ -415,6 +415,73 @@ describe('InferenceGatewayClient', () => { expect(callbacks.onError).toHaveBeenCalledTimes(1); }); + + it('should handle streaming chat completions with usage metrics', async () => { + const mockRequest = { + model: 'gpt-4o', + messages: [{ role: MessageRole.user, content: 'Hello' }], + stream: true, + stream_options: { + include_usage: true, + }, + }; + + const mockStream = new TransformStream(); + const writer = mockStream.writable.getWriter(); + const encoder = new TextEncoder(); + + mockFetch.mockResolvedValueOnce({ + ok: true, + body: mockStream.readable, + }); + + const callbacks = { + onOpen: jest.fn(), + onChunk: jest.fn(), + onContent: jest.fn(), + onUsageMetrics: jest.fn(), + onFinish: jest.fn(), + onError: jest.fn(), + }; + + const streamPromise = client.streamChatCompletion(mockRequest, callbacks); + + await writer.write( + encoder.encode( + 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}\n\n' + + 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}\n\n' + + 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}\n\n' + + 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}\n\n' + + 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[],"usage":{"prompt_tokens":10,"completion_tokens":8,"total_tokens":18}}\n\n' + + 'data: [DONE]\n\n' + ) + ); + + await writer.close(); + await streamPromise; + + expect(callbacks.onOpen).toHaveBeenCalledTimes(1); + expect(callbacks.onChunk).toHaveBeenCalledTimes(5); + expect(callbacks.onContent).toHaveBeenCalledWith('Hello'); + expect(callbacks.onContent).toHaveBeenCalledWith('!'); + expect(callbacks.onUsageMetrics).toHaveBeenCalledTimes(1); + expect(callbacks.onUsageMetrics).toHaveBeenCalledWith({ + prompt_tokens: 10, + completion_tokens: 8, + total_tokens: 18, + }); + expect(callbacks.onFinish).toHaveBeenCalledTimes(1); + expect(mockFetch).toHaveBeenCalledWith( + 'http://localhost:8080/v1/chat/completions', + expect.objectContaining({ + method: 'POST', + body: JSON.stringify({ + ...mockRequest, + stream: true, + }), + }) + ); + }); }); describe('proxy', () => {