diff --git a/src/client.ts b/src/client.ts index 848cc21..6ca7cb1 100644 --- a/src/client.ts +++ b/src/client.ts @@ -135,7 +135,7 @@ export class InferenceGatewayClient { * Creates a chat completion. */ async createChatCompletion( - request: SchemaCreateChatCompletionRequest, + request: Omit, provider?: Provider ): Promise { const query: Record = {}; @@ -146,7 +146,7 @@ export class InferenceGatewayClient { '/chat/completions', { method: 'POST', - body: JSON.stringify(request), + body: JSON.stringify({ ...request, stream: false }), }, query ); @@ -154,9 +154,17 @@ export class InferenceGatewayClient { /** * Creates a streaming chat completion. + * This method always sets stream=true internally, so there's no need to specify it in the request. + * + * @param request - Chat completion request (must include at least model and messages) + * @param callbacks - Callbacks for handling streaming events + * @param provider - Optional provider to use for this request */ async streamChatCompletion( - request: SchemaCreateChatCompletionRequest, + request: Omit< + SchemaCreateChatCompletionRequest, + 'stream' | 'stream_options' + >, callbacks: ChatCompletionStreamCallbacks, provider?: Provider ): Promise { @@ -195,6 +203,9 @@ export class InferenceGatewayClient { body: JSON.stringify({ ...request, stream: true, + stream_options: { + include_usage: true, + }, }), signal: controller.signal, }); diff --git a/tests/client.test.ts b/tests/client.test.ts index 863d357..01081f8 100644 --- a/tests/client.test.ts +++ b/tests/client.test.ts @@ -116,7 +116,6 @@ describe('InferenceGatewayClient', () => { { role: MessageRole.system, content: 'You are a helpful assistant' }, { role: MessageRole.user, content: 'Hello' }, ], - stream: false, }; const mockResponse: SchemaCreateChatCompletionResponse = { @@ -152,7 +151,7 @@ describe('InferenceGatewayClient', () => { 'http://localhost:8080/v1/chat/completions', expect.objectContaining({ method: 'POST', - body: JSON.stringify(mockRequest), + body: JSON.stringify({ ...mockRequest, stream: false }), }) ); }); @@ -161,7 +160,6 @@ describe('InferenceGatewayClient', () => { const mockRequest = { model: 'claude-3-opus-20240229', messages: [{ role: MessageRole.user, content: 'Hello' }], - stream: false, }; const mockResponse: SchemaCreateChatCompletionResponse = { @@ -200,7 +198,7 @@ describe('InferenceGatewayClient', () => { 'http://localhost:8080/v1/chat/completions?provider=anthropic', expect.objectContaining({ method: 'POST', - body: JSON.stringify(mockRequest), + body: JSON.stringify({ ...mockRequest, stream: false }), }) ); }); @@ -211,7 +209,6 @@ describe('InferenceGatewayClient', () => { const mockRequest = { model: 'gpt-4o', messages: [{ role: MessageRole.user, content: 'Hello' }], - stream: true, }; const mockStream = new TransformStream(); @@ -258,6 +255,9 @@ describe('InferenceGatewayClient', () => { body: JSON.stringify({ ...mockRequest, stream: true, + stream_options: { + include_usage: true, + }, }), }) ); @@ -267,7 +267,6 @@ describe('InferenceGatewayClient', () => { const mockRequest = { model: 'gpt-4o', messages: [{ role: MessageRole.user, content: 'Hello' }], - stream: true, }; const mockStream = new TransformStream(); const writer = mockStream.writable.getWriter(); @@ -318,6 +317,9 @@ describe('InferenceGatewayClient', () => { body: JSON.stringify({ ...mockRequest, stream: true, + stream_options: { + include_usage: true, + }, }), }) ); @@ -341,7 +343,6 @@ describe('InferenceGatewayClient', () => { }, }, ], - stream: true, }; const mockStream = new TransformStream(); @@ -390,13 +391,25 @@ describe('InferenceGatewayClient', () => { }, }); expect(callbacks.onFinish).toHaveBeenCalledTimes(1); + expect(mockFetch).toHaveBeenCalledWith( + 'http://localhost:8080/v1/chat/completions', + expect.objectContaining({ + method: 'POST', + body: JSON.stringify({ + ...mockRequest, + stream: true, + stream_options: { + include_usage: true, + }, + }), + }) + ); }); it('should handle errors in streaming chat completions', async () => { const mockRequest = { model: 'gpt-4o', messages: [{ role: MessageRole.user, content: 'Hello' }], - stream: true, }; mockFetch.mockResolvedValueOnce({ @@ -420,10 +433,6 @@ describe('InferenceGatewayClient', () => { const mockRequest = { model: 'gpt-4o', messages: [{ role: MessageRole.user, content: 'Hello' }], - stream: true, - stream_options: { - include_usage: true, - }, }; const mockStream = new TransformStream(); @@ -478,6 +487,9 @@ describe('InferenceGatewayClient', () => { body: JSON.stringify({ ...mockRequest, stream: true, + stream_options: { + include_usage: true, + }, }), }) );