diff --git a/providers/anthropic/anthropic.go b/providers/anthropic/anthropic.go index 2f05b2f12..86486b83d 100644 --- a/providers/anthropic/anthropic.go +++ b/providers/anthropic/anthropic.go @@ -1244,6 +1244,7 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S stream := a.client.Messages.NewStreaming(ctx, *params, reqOpts...) acc := anthropic.Message{} + var sawMessageStop bool return func(yield func(fantasy.StreamPart) bool) { if len(warnings) > 0 { if !yield(fantasy.StreamPart{ @@ -1448,11 +1449,22 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S } } case "message_stop": + sawMessageStop = true } } err := stream.Err() if err == nil || errors.Is(err, io.EOF) { + if !sawMessageStop { + if err == nil { + err = io.EOF + } + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: fmt.Errorf("anthropic stream closed before message_stop: %w", err), + }) + return + } yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeFinish, ID: acc.ID, diff --git a/providers/anthropic/anthropic_test.go b/providers/anthropic/anthropic_test.go index 4387a34fa..0f91262a2 100644 --- a/providers/anthropic/anthropic_test.go +++ b/providers/anthropic/anthropic_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "math" "net/http" "net/http/httptest" @@ -504,6 +505,97 @@ func TestStream_SendsOutputConfigEffort(t *testing.T) { requireAnthropicEffort(t, call.body, EffortHigh) } +func TestStream_RequiresMessageStopBeforeFinish(t *testing.T) { + t.Parallel() + + completeTextStream := []string{ + anthropicSSEEvent("message_start", `{"type":"message_start","message":{"id":"msg_complete","type":"message","role":"assistant","model":"claude-sonnet-4-20250514","content":[],"stop_reason":null,"usage":{"input_tokens":1,"output_tokens":0}}}`), + anthropicSSEEvent("content_block_start", `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`), + anthropicSSEEvent("content_block_delta", `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hello"}}`), + anthropicSSEEvent("content_block_stop", `{"type":"content_block_stop","index":0}`), + anthropicSSEEvent("message_delta", `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":1,"output_tokens":1}}`), + anthropicSSEEvent("message_stop", `{"type":"message_stop"}`), + } + truncatedTextStream := completeTextStream[:len(completeTextStream)-1] + + tests := []struct { + name string + chunks []string + wantFinish bool + wantEOF bool + wantError string + }{ + { + name: "complete stream finishes", + chunks: completeTextStream, + wantFinish: true, + }, + { + name: "eof before message_stop returns EOF error", + chunks: truncatedTextStream, + wantEOF: true, + }, + { + name: "empty stream returns EOF error", + wantEOF: true, + }, + { + name: "error event keeps existing error path", + chunks: []string{ + anthropicSSEEvent("error", `{"type":"error","error":{"type":"overloaded_error","message":"stream down"}}`), + }, + wantError: "stream down", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + server, calls := newAnthropicStreamingServer(tt.chunks) + defer server.Close() + + provider, err := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.URL), + ) + require.NoError(t, err) + + model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514") + require.NoError(t, err) + + stream, err := model.Stream(context.Background(), fantasy.Call{ + Prompt: testPrompt(), + }) + require.NoError(t, err) + + parts := collectAnthropicStreamParts(stream) + _ = awaitAnthropicCall(t, calls) + + finishParts := streamPartsByType(parts, fantasy.StreamPartTypeFinish) + errorParts := streamPartsByType(parts, fantasy.StreamPartTypeError) + + if tt.wantFinish { + require.Len(t, finishParts, 1) + require.Empty(t, errorParts) + require.Equal(t, fantasy.FinishReasonStop, finishParts[0].FinishReason) + return + } + + require.Empty(t, finishParts) + require.Len(t, errorParts, 1) + require.Error(t, errorParts[0].Error) + if tt.wantEOF { + require.ErrorIs(t, errorParts[0].Error, io.EOF) + require.Contains(t, errorParts[0].Error.Error(), "message_stop") + } else { + require.NotContains(t, errorParts[0].Error.Error(), "message_stop") + require.Contains(t, errorParts[0].Error.Error(), tt.wantError) + } + }) + } +} + type anthropicCall struct { method string path string @@ -563,6 +655,29 @@ func newAnthropicStreamingServer(chunks []string) (*httptest.Server, <-chan anth return server, calls } +func anthropicSSEEvent(event, data string) string { + return fmt.Sprintf("event: %s\ndata: %s\n\n", event, data) +} + +func collectAnthropicStreamParts(stream fantasy.StreamResponse) []fantasy.StreamPart { + var parts []fantasy.StreamPart + stream(func(part fantasy.StreamPart) bool { + parts = append(parts, part) + return true + }) + return parts +} + +func streamPartsByType(parts []fantasy.StreamPart, typ fantasy.StreamPartType) []fantasy.StreamPart { + var matches []fantasy.StreamPart + for _, part := range parts { + if part.Type == typ { + matches = append(matches, part) + } + } + return matches +} + func awaitAnthropicCall(t *testing.T, calls <-chan anthropicCall) anthropicCall { t.Helper() @@ -1574,7 +1689,8 @@ func TestComputerUseToolJSON(t *testing.T) { } _, err := computerUseToolJSON(pdt) require.Error(t, err) - require.Contains(t, err.Error(), "tool_version arg is missing") }) + require.Contains(t, err.Error(), "tool_version arg is missing") + }) t.Run("returns error for unsupported version", func(t *testing.T) { t.Parallel() diff --git a/providers/openai/openai_test.go b/providers/openai/openai_test.go index db37452a7..c344e6fad 100644 --- a/providers/openai/openai_test.go +++ b/providers/openai/openai_test.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/json" "errors" + "io" "net/http" "net/http/httptest" "strings" @@ -3643,6 +3644,21 @@ func newResponsesProvider(t *testing.T, serverURL string) fantasy.LanguageModel return model } +func responsesSSEEvent(event, data string) string { + return "event: " + event + "\ndata: " + data + "\n\n" +} + +func collectObjectStreamParts(stream fantasy.ObjectStreamResponse) []fantasy.ObjectStreamPart { + var parts []fantasy.ObjectStreamPart + for part := range stream { + parts = append(parts, part) + if part.Type == fantasy.ObjectStreamPartTypeError || part.Type == fantasy.ObjectStreamPartTypeFinish { + break + } + } + return parts +} + func TestResponsesGenerate_WebSearchResponse(t *testing.T) { t.Parallel() @@ -4333,6 +4349,199 @@ func TestResponsesToPrompt_ReasoningWithFunctionCallCombined(t *testing.T) { require.Equal(t, functionCallID, input[3].OfFunctionCallOutput.CallID) } +func TestResponsesStream_RequiresTerminalEventBeforeFinish(t *testing.T) { + t.Parallel() + + textChunks := []string{ + responsesSSEEvent("response.output_item.added", `{"type":"response.output_item.added","output_index":0,"item":{"type":"message","id":"msg_01","role":"assistant","status":"in_progress","content":[]}}`), + responsesSSEEvent("response.output_text.delta", `{"type":"response.output_text.delta","output_index":0,"content_index":0,"item_id":"msg_01","delta":"hello"}`), + responsesSSEEvent("response.output_item.done", `{"type":"response.output_item.done","output_index":0,"item":{"type":"message","id":"msg_01","role":"assistant","status":"completed","content":[{"type":"output_text","text":"hello","annotations":[]}]}}`), + } + completedEvent := responsesSSEEvent("response.completed", `{"type":"response.completed","response":{"id":"resp_01","status":"completed","output":[],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}}`) + incompleteEvent := responsesSSEEvent("response.incomplete", `{"type":"response.incomplete","response":{"id":"resp_02","status":"incomplete","output":[],"incomplete_details":{"reason":"max_output_tokens"},"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}}`) + failedEvent := responsesSSEEvent("response.failed", `{"type":"response.failed","response":{"id":"resp_03","status":"failed","error":{"code":"server_error","message":"boom"},"output":[]}}`) + errorEvent := responsesSSEEvent("error", `{"type":"error","message":"stream down","code":"server_error"}`) + + tests := []struct { + name string + chunks []string + wantFinishReason fantasy.FinishReason + wantEOF bool + wantError string + }{ + { + name: "completed stream finishes", + chunks: append(append([]string{}, textChunks...), completedEvent), + wantFinishReason: fantasy.FinishReasonStop, + }, + { + name: "incomplete stream is terminal", + chunks: append(append([]string{}, textChunks...), incompleteEvent), + wantFinishReason: fantasy.FinishReasonLength, + }, + { + name: "eof before terminal event returns EOF error", + chunks: textChunks, + wantEOF: true, + }, + { + name: "empty stream returns EOF error", + wantEOF: true, + }, + { + name: "failed event returns provider error", + chunks: []string{failedEvent}, + wantError: "boom", + }, + { + name: "error event keeps existing error path", + chunks: []string{errorEvent}, + wantError: "stream down", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + sms := newStreamingMockServer() + defer sms.close() + sms.chunks = tt.chunks + + model := newResponsesProvider(t, sms.server.URL) + + stream, err := model.Stream(context.Background(), fantasy.Call{ + Prompt: testPrompt, + }) + require.NoError(t, err) + + parts, err := collectStreamParts(stream) + require.NoError(t, err) + + var finishes []fantasy.StreamPart + var errors []fantasy.StreamPart + for _, part := range parts { + switch part.Type { + case fantasy.StreamPartTypeFinish: + finishes = append(finishes, part) + case fantasy.StreamPartTypeError: + errors = append(errors, part) + } + } + + if tt.wantFinishReason != "" { + require.Len(t, finishes, 1) + require.Empty(t, errors) + require.Equal(t, tt.wantFinishReason, finishes[0].FinishReason) + return + } + + require.Empty(t, finishes) + require.Len(t, errors, 1) + require.Error(t, errors[0].Error) + if tt.wantEOF { + require.ErrorIs(t, errors[0].Error, io.EOF) + require.Contains(t, errors[0].Error.Error(), "terminal event") + } else { + require.NotContains(t, errors[0].Error.Error(), "terminal event") + require.Contains(t, errors[0].Error.Error(), tt.wantError) + } + }) + } +} + +func TestResponsesStreamObject_RequiresTerminalEventBeforeFinish(t *testing.T) { + t.Parallel() + + objectChunks := []string{ + responsesSSEEvent("response.output_text.delta", `{"type":"response.output_text.delta","output_index":0,"content_index":0,"item_id":"msg_01","delta":"{\"name\":\"Alice\"}"}`), + } + completedEvent := responsesSSEEvent("response.completed", `{"type":"response.completed","response":{"id":"resp_01","status":"completed","output":[],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}}`) + failedEvent := responsesSSEEvent("response.failed", `{"type":"response.failed","response":{"id":"resp_02","status":"failed","error":{"code":"server_error","message":"boom"},"output":[]}}`) + errorEvent := responsesSSEEvent("error", `{"type":"error","message":"stream down","code":"server_error"}`) + + tests := []struct { + name string + chunks []string + wantFinish bool + wantError string + }{ + { + name: "completed stream finishes", + chunks: append(append([]string{}, objectChunks...), completedEvent), + wantFinish: true, + }, + { + name: "eof before terminal event returns EOF error", + chunks: objectChunks, + }, + { + name: "failed event returns provider error", + chunks: []string{failedEvent}, + wantError: "boom", + }, + { + name: "error event keeps existing error path", + chunks: []string{errorEvent}, + wantError: "stream down", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + sms := newStreamingMockServer() + defer sms.close() + sms.chunks = tt.chunks + + model := newResponsesProvider(t, sms.server.URL) + stream, err := model.StreamObject(context.Background(), fantasy.ObjectCall{ + Prompt: fantasy.Prompt{fantasy.NewUserMessage("Generate a person.")}, + Schema: fantasy.Schema{ + Type: "object", + Properties: map[string]*fantasy.Schema{ + "name": {Type: "string"}, + }, + Required: []string{"name"}, + }, + SchemaName: "Person", + }) + require.NoError(t, err) + + parts := collectObjectStreamParts(stream) + + var finishes []fantasy.ObjectStreamPart + var errors []fantasy.ObjectStreamPart + for _, part := range parts { + switch part.Type { + case fantasy.ObjectStreamPartTypeFinish: + finishes = append(finishes, part) + case fantasy.ObjectStreamPartTypeError: + errors = append(errors, part) + } + } + + if tt.wantFinish { + require.Len(t, finishes, 1) + require.Empty(t, errors) + require.Equal(t, fantasy.FinishReasonStop, finishes[0].FinishReason) + return + } + + require.Empty(t, finishes) + require.Len(t, errors, 1) + if tt.wantError != "" { + require.NotContains(t, errors[0].Error.Error(), "terminal event") + require.Contains(t, errors[0].Error.Error(), tt.wantError) + return + } + require.ErrorIs(t, errors[0].Error, io.EOF) + require.Contains(t, errors[0].Error.Error(), "terminal event") + }) + } +} + func TestResponsesStream_WebSearchResponse(t *testing.T) { t.Parallel() @@ -4696,7 +4905,8 @@ func TestComputerUseGenerateRoundTrip_NonImageResult(t *testing.T) { }, } - input, warnings := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false) + require.NoError(t, err) // Should warn about non-image result. var foundWarning bool diff --git a/providers/openai/responses_language_model.go b/providers/openai/responses_language_model.go index d6d2db4b3..077c27325 100644 --- a/providers/openai/responses_language_model.go +++ b/providers/openai/responses_language_model.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "reflect" "slices" "strings" @@ -1130,6 +1131,17 @@ func mapResponsesFinishReason(reason string, hasFunctionCall bool) fantasy.Finis } } +func responsesStreamClosedBeforeTerminalEventError(err error) error { + if err == nil { + err = io.EOF + } + return fmt.Errorf("openai responses stream closed before terminal event: %w", err) +} + +func responsesFailedStreamError(response responses.Response) error { + return fmt.Errorf("response failed: %s (code: %s)", response.Error.Message, response.Error.Code) +} + func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { params, warnings, err := o.prepareParams(call) if err != nil { @@ -1148,6 +1160,7 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) ( responseID := "" ongoingToolCalls := make(map[int64]*ongoingToolCall) hasFunctionCall := false + sawTerminalEvent := false activeReasoning := make(map[string]*reasoningState) return func(yield func(fantasy.StreamPart) bool) { @@ -1449,17 +1462,29 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) ( } case "response.completed": + sawTerminalEvent = true completed := event.AsResponseCompleted() responseID = completed.Response.ID finishReason = mapResponsesFinishReason(completed.Response.IncompleteDetails.Reason, hasFunctionCall) usage = responsesUsage(completed.Response) case "response.incomplete": + sawTerminalEvent = true incomplete := event.AsResponseIncomplete() responseID = incomplete.Response.ID finishReason = mapResponsesFinishReason(incomplete.Response.IncompleteDetails.Reason, hasFunctionCall) usage = responsesUsage(incomplete.Response) + case "response.failed": + failed := event.AsResponseFailed() + if !yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: responsesFailedStreamError(failed.Response), + }) { + return + } + return + case "error": errorEvent := event.AsError() if !yield(fantasy.StreamPart{ @@ -1473,13 +1498,20 @@ func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) ( } err := stream.Err() - if err != nil { + if err != nil && !errors.Is(err, io.EOF) { yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, Error: toProviderErr(err), }) return } + if !sawTerminalEvent { + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: responsesStreamClosedBeforeTerminalEventError(err), + }) + return + } yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeFinish, @@ -1757,6 +1789,7 @@ func (o responsesLanguageModel) streamObjectWithJSONMode(ctx context.Context, ca var responseID string var streamErr error hasFunctionCall := false + sawTerminalEvent := false for stream.Next() { event := stream.Current() @@ -1810,17 +1843,30 @@ func (o responsesLanguageModel) streamObjectWithJSONMode(ctx context.Context, ca } case "response.completed": + sawTerminalEvent = true completed := event.AsResponseCompleted() responseID = completed.Response.ID finishReason = mapResponsesFinishReason(completed.Response.IncompleteDetails.Reason, hasFunctionCall) usage = responsesUsage(completed.Response) case "response.incomplete": + sawTerminalEvent = true incomplete := event.AsResponseIncomplete() responseID = incomplete.Response.ID finishReason = mapResponsesFinishReason(incomplete.Response.IncompleteDetails.Reason, hasFunctionCall) usage = responsesUsage(incomplete.Response) + case "response.failed": + failed := event.AsResponseFailed() + streamErr = responsesFailedStreamError(failed.Response) + if !yield(fantasy.ObjectStreamPart{ + Type: fantasy.ObjectStreamPartTypeError, + Error: streamErr, + }) { + return + } + return + case "error": errorEvent := event.AsError() streamErr = fmt.Errorf("response error: %s (code: %s)", errorEvent.Message, errorEvent.Code) @@ -1835,13 +1881,20 @@ func (o responsesLanguageModel) streamObjectWithJSONMode(ctx context.Context, ca } err := stream.Err() - if err != nil { + if err != nil && !errors.Is(err, io.EOF) { yield(fantasy.ObjectStreamPart{ Type: fantasy.ObjectStreamPartTypeError, Error: toProviderErr(err), }) return } + if !sawTerminalEvent { + yield(fantasy.ObjectStreamPart{ + Type: fantasy.ObjectStreamPartTypeError, + Error: responsesStreamClosedBeforeTerminalEventError(err), + }) + return + } // Final validation and emit if streamErr == nil && lastParsedObject != nil {