From bde34339ddd538f751d827520f6fbef1240431b6 Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Mon, 18 May 2026 07:48:53 +0000 Subject: [PATCH] fix openai stream truncation handling --- providers/openai/language_model.go | 24 +- providers/openai/openai_test.go | 341 ++++++++++++++++++- providers/openai/responses_language_model.go | 82 ++++- 3 files changed, 430 insertions(+), 17 deletions(-) diff --git a/providers/openai/language_model.go b/providers/openai/language_model.go index 84dd69445..231a77213 100644 --- a/providers/openai/language_model.go +++ b/providers/openai/language_model.go @@ -568,9 +568,13 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S // can't infer a tool-call turn. Surface as a retryable error so // the retry middleware re-runs the step. if finishReason == "" && mappedFinishReason != fantasy.FinishReasonToolCalls { + err := ctx.Err() + if err == nil { + err = fantasy.NewIncompleteStreamError() + } yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, - Error: fantasy.NewIncompleteStreamError(), + Error: err, }) return } @@ -868,8 +872,8 @@ func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantas var lastParsedObject any var usage fantasy.Usage var finishReason fantasy.FinishReason + var sawFinishReason bool var providerMetadata fantasy.ProviderMetadata - var streamErr error for stream.Next() { chunk := stream.Current() @@ -884,6 +888,7 @@ func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantas choice := chunk.Choices[0] if choice.FinishReason != "" { finishReason = o.mapFinishReasonFunc(choice.FinishReason) + sawFinishReason = true } if choice.Delta.Content != "" { @@ -928,10 +933,21 @@ func (o languageModel) streamObjectWithJSONMode(ctx context.Context, call fantas err := stream.Err() if err != nil && !errors.Is(err, io.EOF) { - streamErr = toProviderErr(err) yield(fantasy.ObjectStreamPart{ Type: fantasy.ObjectStreamPartTypeError, - Error: streamErr, + Error: toProviderErr(err), + }) + return + } + + if !sawFinishReason { + err := ctx.Err() + if err == nil { + err = fantasy.NewIncompleteStreamError() + } + yield(fantasy.ObjectStreamPart{ + Type: fantasy.ObjectStreamPartTypeError, + Error: err, }) return } diff --git a/providers/openai/openai_test.go b/providers/openai/openai_test.go index 249c3f99b..24f4a687c 100644 --- a/providers/openai/openai_test.go +++ b/providers/openai/openai_test.go @@ -2260,6 +2260,17 @@ func (sms *streamingMockServer) prepareStreamResponse(opts map[string]any) { sms.chunks = chunks } +// chatCompletionChunksBeforeFinishReason drops every chunk from the final +// finish_reason chunk onward, including any trailing usage-only chunk. +func chatCompletionChunksBeforeFinishReason(chunks []string) []string { + for i, chunk := range chunks { + if strings.Contains(chunk, `"finish_reason":"`) { + return append([]string(nil), chunks[:i]...) + } + } + return append([]string(nil), chunks...) +} + func (sms *streamingMockServer) prepareToolStreamResponse() { chunks := []string{ `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n", @@ -2328,6 +2339,90 @@ func collectStreamParts(stream fantasy.StreamResponse) ([]fantasy.StreamPart, er return parts, nil } +func TestChatCompletionsStreamObject_RequiresFinishReasonBeforeFinish(t *testing.T) { + t.Parallel() + + objectSchema := fantasy.Schema{ + Type: "object", + Properties: map[string]*fantasy.Schema{ + "answer": {Type: "string"}, + }, + Required: []string{"answer"}, + } + + tests := []struct { + name string + truncate bool + wantFinish bool + }{ + { + name: "complete stream finishes", + wantFinish: true, + }, + { + name: "stream closed before finish_reason errors", + truncate: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + server := newStreamingMockServer() + defer server.close() + + server.prepareStreamResponse(map[string]any{ + "content": []string{`{"answer":"hello"}`}, + }) + if tt.truncate { + server.chunks = chatCompletionChunksBeforeFinishReason(server.chunks) + } + + provider, err := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), + ) + require.NoError(t, err) + model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo") + + stream, err := model.StreamObject(context.Background(), fantasy.ObjectCall{ + Prompt: testPrompt, + Schema: objectSchema, + }) + require.NoError(t, err) + + parts := collectObjectStreamParts(stream) + + var objects, finishes, errorParts []fantasy.ObjectStreamPart + for _, part := range parts { + switch part.Type { + case fantasy.ObjectStreamPartTypeObject: + objects = append(objects, part) + case fantasy.ObjectStreamPartTypeFinish: + finishes = append(finishes, part) + case fantasy.ObjectStreamPartTypeError: + errorParts = append(errorParts, part) + } + } + + require.NotEmpty(t, objects) + require.Equal(t, map[string]any{"answer": "hello"}, objects[len(objects)-1].Object) + + if tt.wantFinish { + require.Len(t, finishes, 1) + require.Empty(t, errorParts) + return + } + + require.Empty(t, finishes) + require.Len(t, errorParts, 1) + require.Error(t, errorParts[0].Error) + requireRetryableUnexpectedEOF(t, errorParts[0].Error) + }) + } +} + func TestDoStream(t *testing.T) { t.Parallel() @@ -3883,7 +3978,8 @@ func TestResponsesGenerate_WebSearchResponse(t *testing.T) { // TextContent with the final answer. require.Len(t, texts, 1) - require.Equal(t, + require.Equal( + t, "Based on recent search results, here is the latest AI news.", texts[0].Text, ) @@ -4265,6 +4361,249 @@ func TestResponsesToPrompt_ReasoningWithStore(t *testing.T) { }) } +func TestResponsesStream_RequiresTerminalEventBeforeFinish(t *testing.T) { + t.Parallel() + + textChunks := []string{ + responsesSSEEvent("response.created", `{"type":"response.created","response":{"id":"resp_01","status":"in_progress","output":[]}}`), + responsesSSEEvent("response.output_item.added", `{"type":"response.output_item.added","output_index":0,"item":{"id":"msg_01","type":"message","role":"assistant","status":"in_progress","content":[]}}`), + responsesSSEEvent("response.content_part.added", `{"type":"response.content_part.added","output_index":0,"content_index":0,"item_id":"msg_01","part":{"type":"output_text","text":""}}`), + responsesSSEEvent("response.output_text.delta", `{"type":"response.output_text.delta","output_index":0,"content_index":0,"item_id":"msg_01","delta":"hello"}`), + } + incompleteEvent := responsesSSEEvent("response.incomplete", `{"type":"response.incomplete","response":{"id":"resp_02","status":"incomplete","output":[],"incomplete_details":{"reason":"max_output_tokens"},"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}}`) + failedEvent := responsesSSEEvent("response.failed", `{"type":"response.failed","response":{"id":"resp_03","status":"failed","error":{"code":"server_error","message":"boom"},"output":[]}}`) + errorEvent := responsesSSEEvent("error", `{"type":"error","message":"stream down","code":"server_error","param":"","sequence_number":1}`) + + tests := []struct { + name string + chunks []string + wantFinish bool + wantFinishReason fantasy.FinishReason + wantRetryable bool + wantErrContain string + }{ + { + name: "incomplete terminal event finishes", + chunks: append(append([]string{}, textChunks...), incompleteEvent), + wantFinish: true, + wantFinishReason: fantasy.FinishReasonLength, + }, + { + name: "stream closed before terminal event errors", + chunks: textChunks, + wantRetryable: true, + }, + { + name: "response failed errors", + chunks: []string{failedEvent}, + wantErrContain: "response failed: boom (code: server_error)", + }, + { + name: "provider error event is preserved", + chunks: []string{errorEvent}, + wantErrContain: "response error: stream down (code: server_error)", + }, + { + name: "malformed event error is preserved", + chunks: []string{responsesSSEEvent("response.created", `{`)}, + wantErrContain: "unexpected end of JSON input", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + sms := newStreamingMockServer() + defer sms.close() + sms.chunks = tt.chunks + + model := newResponsesProvider(t, sms.server.URL) + stream, err := model.Stream(context.Background(), fantasy.Call{Prompt: testPrompt}) + require.NoError(t, err) + + parts, err := collectStreamParts(stream) + require.NoError(t, err) + + var finishes, errorParts []fantasy.StreamPart + for _, part := range parts { + switch part.Type { + case fantasy.StreamPartTypeFinish: + finishes = append(finishes, part) + case fantasy.StreamPartTypeError: + errorParts = append(errorParts, part) + } + } + + if tt.wantFinish { + require.Len(t, finishes, 1) + require.Empty(t, errorParts) + require.Equal(t, tt.wantFinishReason, finishes[0].FinishReason) + return + } + + require.Empty(t, finishes) + require.Len(t, errorParts, 1) + require.Error(t, errorParts[0].Error) + if tt.wantErrContain != "" { + require.Contains(t, errorParts[0].Error.Error(), tt.wantErrContain) + } + + if tt.wantRetryable { + requireRetryableUnexpectedEOF(t, errorParts[0].Error) + } else { + requireNotRetryableUnexpectedEOF(t, errorParts[0].Error) + } + }) + } +} + +func TestResponsesStreamObject_RequiresTerminalEventBeforeFinish(t *testing.T) { + t.Parallel() + + objectSchema := fantasy.Schema{ + Type: "object", + Properties: map[string]*fantasy.Schema{ + "answer": {Type: "string"}, + }, + Required: []string{"answer"}, + } + + objectChunks := []string{ + responsesSSEEvent("response.created", `{"type":"response.created","response":{"id":"resp_obj","status":"in_progress","output":[]}}`), + responsesSSEEvent("response.output_text.delta", `{"type":"response.output_text.delta","output_index":0,"content_index":0,"item_id":"msg_obj","delta":"{\"answer\":\"hello\"}"}`), + } + completedEvent := responsesSSEEvent("response.completed", `{"type":"response.completed","response":{"id":"resp_obj","status":"completed","output":[],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}}`) + failedEvent := responsesSSEEvent("response.failed", `{"type":"response.failed","response":{"id":"resp_failed","status":"failed","error":{"code":"server_error","message":"boom"},"output":[]}}`) + errorEvent := responsesSSEEvent("error", `{"type":"error","message":"stream down","code":"server_error","param":"","sequence_number":1}`) + + tests := []struct { + name string + chunks []string + wantFinish bool + wantObject bool + wantRetryable bool + wantErrContain string + }{ + { + name: "completed terminal event finishes", + chunks: append(append([]string{}, objectChunks...), completedEvent), + wantFinish: true, + wantObject: true, + }, + { + name: "object stream closed before terminal event errors", + chunks: objectChunks, + wantObject: true, + wantRetryable: true, + }, + { + name: "response failed errors", + chunks: []string{failedEvent}, + wantErrContain: "response failed: boom (code: server_error)", + }, + { + name: "provider error event is preserved", + chunks: []string{errorEvent}, + wantErrContain: "response error: stream down (code: server_error)", + }, + { + name: "malformed event error is preserved", + chunks: []string{responsesSSEEvent("response.output_text.delta", `{`)}, + wantErrContain: "unexpected end of JSON input", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + sms := newStreamingMockServer() + defer sms.close() + sms.chunks = tt.chunks + + model := newResponsesProvider(t, sms.server.URL) + stream, err := model.StreamObject(context.Background(), fantasy.ObjectCall{ + Prompt: testPrompt, + Schema: objectSchema, + }) + require.NoError(t, err) + + parts := collectObjectStreamParts(stream) + + var objects, finishes, errorParts []fantasy.ObjectStreamPart + for _, part := range parts { + switch part.Type { + case fantasy.ObjectStreamPartTypeObject: + objects = append(objects, part) + case fantasy.ObjectStreamPartTypeFinish: + finishes = append(finishes, part) + case fantasy.ObjectStreamPartTypeError: + errorParts = append(errorParts, part) + } + } + + if tt.wantObject { + require.NotEmpty(t, objects) + require.Equal(t, map[string]any{"answer": "hello"}, objects[len(objects)-1].Object) + } else { + require.Empty(t, objects) + } + + if tt.wantFinish { + require.Len(t, finishes, 1) + require.Empty(t, errorParts) + return + } + + require.Empty(t, finishes) + require.Len(t, errorParts, 1) + require.Error(t, errorParts[0].Error) + if tt.wantErrContain != "" { + require.Contains(t, errorParts[0].Error.Error(), tt.wantErrContain) + } + + if tt.wantRetryable { + requireRetryableUnexpectedEOF(t, errorParts[0].Error) + } else { + requireNotRetryableUnexpectedEOF(t, errorParts[0].Error) + } + }) + } +} + +func responsesSSEEvent(event, data string) string { + return "event: " + event + "\n" + "data: " + data + "\n\n" +} + +func collectObjectStreamParts(stream fantasy.ObjectStreamResponse) []fantasy.ObjectStreamPart { + var parts []fantasy.ObjectStreamPart + for part := range stream { + parts = append(parts, part) + } + return parts +} + +func requireNotRetryableUnexpectedEOF(t *testing.T, err error) { + t.Helper() + + require.NotErrorIs(t, err, io.ErrUnexpectedEOF) + var providerErr *fantasy.ProviderError + if errors.As(err, &providerErr) { + require.False(t, providerErr.IsRetryable()) + require.NotErrorIs(t, providerErr.Cause, io.ErrUnexpectedEOF) + } +} + +func requireRetryableUnexpectedEOF(t *testing.T, err error) { + t.Helper() + + var providerErr *fantasy.ProviderError + require.ErrorAs(t, err, &providerErr) + require.True(t, providerErr.IsRetryable()) + require.ErrorIs(t, providerErr.Cause, io.ErrUnexpectedEOF) +} + func TestResponsesStream_WebSearchResponse(t *testing.T) { t.Parallel() diff --git a/providers/openai/responses_language_model.go b/providers/openai/responses_language_model.go index fde8c23a8..9d9771360 100644 --- a/providers/openai/responses_language_model.go +++ b/providers/openai/responses_language_model.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "reflect" "strings" @@ -789,6 +790,7 @@ func (o responsesLanguageModel) Generate(ctx context.Context, call fantasy.Call) if err != nil { return nil, toProviderErr(err) } + if response == nil { return nil, &fantasy.Error{Title: "no response", Message: "provider returned nil response"} } @@ -949,6 +951,7 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) ( // identical; the overwrites ensure we have the final value even if an event // is missed. responseID := "" + sawTerminalEvent := false ongoingToolCalls := make(map[int64]*ongoingToolCall) hasFunctionCall := false activeReasoning := make(map[string]*reasoningState) @@ -1208,22 +1211,34 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) ( } case "response.completed": + sawTerminalEvent = true completed := event.AsResponseCompleted() responseID = completed.Response.ID finishReason = mapResponsesFinishReason(completed.Response.IncompleteDetails.Reason, hasFunctionCall) usage = responsesUsage(completed.Response) case "response.incomplete": + sawTerminalEvent = true incomplete := event.AsResponseIncomplete() responseID = incomplete.Response.ID finishReason = mapResponsesFinishReason(incomplete.Response.IncompleteDetails.Reason, hasFunctionCall) usage = responsesUsage(incomplete.Response) + case "response.failed": + failed := event.AsResponseFailed() + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: responsesFailedStreamError(failed.Response.Error.Message, string(failed.Response.Error.Code)), + }) { + return + } + return + case "error": errorEvent := event.AsError() if !yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, - Error: fmt.Errorf("response error: %s (code: %s)", errorEvent.Message, errorEvent.Code), + Error: responsesErrorStreamError(errorEvent.Message, errorEvent.Code), }) { return } @@ -1232,7 +1247,7 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) ( } err := stream.Err() - if err != nil { + if err != nil && !errors.Is(err, io.EOF) { yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, Error: toProviderErr(err), @@ -1240,12 +1255,14 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) ( return } - // Truncated stream: no response.completed / response.incomplete event - // before close. Surface as a retryable error. - if finishReason == fantasy.FinishReasonUnknown { + if !sawTerminalEvent { + err := ctx.Err() + if err == nil { + err = fantasy.NewIncompleteStreamError() + } yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, - Error: fantasy.NewIncompleteStreamError(), + Error: err, }) return } @@ -1259,6 +1276,24 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) ( }, nil } +// responsesFailedStreamError intentionally returns a provider-declared failure +// instead of a retryable transport error. Only synthetic stream truncation +// errors are wrapped with io.ErrUnexpectedEOF. +func responsesFailedStreamError(message, code string) error { + return responsesStreamFailureError("response failed", message, code) +} + +func responsesErrorStreamError(message, code string) error { + return responsesStreamFailureError("response error", message, code) +} + +func responsesStreamFailureError(title, message, code string) error { + if code != "" { + message = fmt.Sprintf("%s (code: %s)", message, code) + } + return &fantasy.Error{Title: title, Message: message} +} + // toWebSearchToolParam converts a ProviderDefinedTool with ID // "web_search" into the OpenAI SDK's WebSearchToolParam. func toWebSearchToolParam(pt fantasy.ProviderDefinedTool) responses.ToolUnionParam { @@ -1518,7 +1553,7 @@ func (o responsesLanguageModel) streamObjectWithJSONMode(ctx context.Context, ca // identical; the overwrites ensure we have the final value even if an event // is missed. var responseID string - var streamErr error + var sawTerminalEvent bool hasFunctionCall := false for stream.Next() { @@ -1573,23 +1608,34 @@ func (o responsesLanguageModel) streamObjectWithJSONMode(ctx context.Context, ca } case "response.completed": + sawTerminalEvent = true completed := event.AsResponseCompleted() responseID = completed.Response.ID finishReason = mapResponsesFinishReason(completed.Response.IncompleteDetails.Reason, hasFunctionCall) usage = responsesUsage(completed.Response) case "response.incomplete": + sawTerminalEvent = true incomplete := event.AsResponseIncomplete() responseID = incomplete.Response.ID finishReason = mapResponsesFinishReason(incomplete.Response.IncompleteDetails.Reason, hasFunctionCall) usage = responsesUsage(incomplete.Response) + case "response.failed": + failed := event.AsResponseFailed() + if !yield(fantasy.ObjectStreamPart{ + Type: fantasy.ObjectStreamPartTypeError, + Error: responsesFailedStreamError(failed.Response.Error.Message, string(failed.Response.Error.Code)), + }) { + return + } + return + case "error": errorEvent := event.AsError() - streamErr = fmt.Errorf("response error: %s (code: %s)", errorEvent.Message, errorEvent.Code) if !yield(fantasy.ObjectStreamPart{ Type: fantasy.ObjectStreamPartTypeError, - Error: streamErr, + Error: responsesErrorStreamError(errorEvent.Message, errorEvent.Code), }) { return } @@ -1598,7 +1644,7 @@ func (o responsesLanguageModel) streamObjectWithJSONMode(ctx context.Context, ca } err := stream.Err() - if err != nil { + if err != nil && !errors.Is(err, io.EOF) { yield(fantasy.ObjectStreamPart{ Type: fantasy.ObjectStreamPartTypeError, Error: toProviderErr(err), @@ -1606,15 +1652,27 @@ func (o responsesLanguageModel) streamObjectWithJSONMode(ctx context.Context, ca return } + if !sawTerminalEvent { + err := ctx.Err() + if err == nil { + err = fantasy.NewIncompleteStreamError() + } + yield(fantasy.ObjectStreamPart{ + Type: fantasy.ObjectStreamPartTypeError, + Error: err, + }) + return + } + // Final validation and emit - if streamErr == nil && lastParsedObject != nil { + if lastParsedObject != nil { yield(fantasy.ObjectStreamPart{ Type: fantasy.ObjectStreamPartTypeFinish, Usage: usage, FinishReason: finishReason, ProviderMetadata: responsesProviderMetadata(responseID), }) - } else if streamErr == nil && lastParsedObject == nil { + } else { // No object was generated yield(fantasy.ObjectStreamPart{ Type: fantasy.ObjectStreamPartTypeError,