diff --git a/providers/anthropic/anthropic.go b/providers/anthropic/anthropic.go index f3aa6e829..eb5b735e7 100644 --- a/providers/anthropic/anthropic.go +++ b/providers/anthropic/anthropic.go @@ -1347,6 +1347,8 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S } } + sawMessageStop := false + for stream.Next() { chunk := stream.Current() _ = acc.Accumulate(chunk) @@ -1552,41 +1554,47 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S } } case "message_stop": + sawMessageStop = true } } err := stream.Err() - if err == nil || errors.Is(err, io.EOF) { - // Truncated stream: no terminal message_delta with stop_reason. - // Surface as a retryable error. - if acc.StopReason == "" { - yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeError, - Error: fantasy.NewIncompleteStreamError(), - }) - return - } + if err != nil && !errors.Is(err, io.EOF) { yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeFinish, - ID: acc.ID, - FinishReason: mapFinishReason(string(acc.StopReason)), - Usage: fantasy.Usage{ - InputTokens: acc.Usage.InputTokens, - OutputTokens: acc.Usage.OutputTokens, - TotalTokens: acc.Usage.InputTokens + acc.Usage.OutputTokens, - CacheCreationTokens: acc.Usage.CacheCreationInputTokens, - CacheReadTokens: acc.Usage.CacheReadInputTokens, - }, - ProviderMetadata: fantasy.ProviderMetadata{}, + Type: fantasy.StreamPartTypeError, + Error: toProviderErr(err), }) return - } else { //nolint: revive + } + + // Anthropic's SSE protocol reports the stop_reason in message_delta + // and then terminates the message with message_stop. Require both so + // a socket close after only one of those signals is retried. + if !sawMessageStop || acc.StopReason == "" { + err := ctx.Err() + if err == nil { + err = fantasy.NewIncompleteStreamError() + } yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, - Error: toProviderErr(err), + Error: err, }) return } + + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeFinish, + ID: acc.ID, + FinishReason: mapFinishReason(string(acc.StopReason)), + Usage: fantasy.Usage{ + InputTokens: acc.Usage.InputTokens, + OutputTokens: acc.Usage.OutputTokens, + TotalTokens: acc.Usage.InputTokens + acc.Usage.OutputTokens, + CacheCreationTokens: acc.Usage.CacheCreationInputTokens, + CacheReadTokens: acc.Usage.CacheReadInputTokens, + }, + ProviderMetadata: fantasy.ProviderMetadata{}, + }) }, nil } diff --git a/providers/anthropic/anthropic_test.go b/providers/anthropic/anthropic_test.go index 880bd528b..9d8af299d 100644 --- a/providers/anthropic/anthropic_test.go +++ b/providers/anthropic/anthropic_test.go @@ -600,6 +600,113 @@ func TestStream_SendsOutputConfigEffort(t *testing.T) { requireAnthropicEffort(t, call.body, EffortHigh) } +func TestStream_RequiresMessageStopBeforeFinish(t *testing.T) { + t.Parallel() + + completeTextStream := []string{ + anthropicSSEEvent("message_start", `{"type":"message_start","message":{"id":"msg_complete","type":"message","role":"assistant","model":"claude-sonnet-4-20250514","content":[],"stop_reason":null,"usage":{"input_tokens":1,"output_tokens":0}}}`), + anthropicSSEEvent("content_block_start", `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`), + anthropicSSEEvent("content_block_delta", `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hello"}}`), + anthropicSSEEvent("content_block_stop", `{"type":"content_block_stop","index":0}`), + anthropicSSEEvent("message_delta", `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":1}}`), + anthropicSSEEvent("message_stop", `{"type":"message_stop"}`), + } + + missingStopReasonStream := append([]string(nil), completeTextStream[:len(completeTextStream)-2]...) + missingStopReasonStream = append(missingStopReasonStream, completeTextStream[len(completeTextStream)-1]) + + tests := []struct { + name string + chunks []string + wantFinish bool + wantRetryable bool + wantErrContain string + }{ + { + name: "complete stream finishes", + chunks: completeTextStream, + wantFinish: true, + }, + { + name: "message_stop without stop_reason errors", + chunks: missingStopReasonStream, + wantRetryable: true, + }, + { + name: "text stream closed before message_stop errors", + chunks: completeTextStream[:len(completeTextStream)-1], + wantRetryable: true, + }, + { + name: "provider error event is preserved", + chunks: []string{ + anthropicSSEEvent("error", `{"type":"error","error":{"type":"api_error","message":"stream down"}}`), + }, + wantErrContain: "stream down", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + server, calls := newAnthropicStreamingServer(tt.chunks) + defer server.Close() + + provider, err := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.URL), + ) + require.NoError(t, err) + + model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514") + require.NoError(t, err) + + stream, err := model.Stream(context.Background(), fantasy.Call{Prompt: testPrompt()}) + require.NoError(t, err) + + parts := collectAnthropicStreamParts(stream) + _ = awaitAnthropicCall(t, calls) + + var finishes, errorParts []fantasy.StreamPart + for _, part := range parts { + switch part.Type { + case fantasy.StreamPartTypeFinish: + finishes = append(finishes, part) + case fantasy.StreamPartTypeError: + errorParts = append(errorParts, part) + } + } + + if tt.wantFinish { + require.Len(t, finishes, 1) + require.Empty(t, errorParts) + return + } + + require.Empty(t, finishes) + require.Len(t, errorParts, 1) + require.Error(t, errorParts[0].Error) + if tt.wantErrContain != "" { + require.Contains(t, errorParts[0].Error.Error(), tt.wantErrContain) + } + + var providerErr *fantasy.ProviderError + if tt.wantRetryable { + require.ErrorAs(t, errorParts[0].Error, &providerErr) + require.True(t, providerErr.IsRetryable()) + require.ErrorIs(t, providerErr.Cause, io.ErrUnexpectedEOF) + } else { + require.NotErrorIs(t, errorParts[0].Error, io.ErrUnexpectedEOF) + if errors.As(errorParts[0].Error, &providerErr) { + require.False(t, providerErr.IsRetryable()) + require.NotErrorIs(t, providerErr.Cause, io.ErrUnexpectedEOF) + } + } + }) + } +} + type anthropicCall struct { method string path string @@ -1125,7 +1232,8 @@ func TestGenerate_WebSearchResponse(t *testing.T) { // TextContent with the final answer. require.Len(t, texts, 1) - require.Equal(t, + require.Equal( + t, "Based on recent search results, here is the latest AI news.", texts[0].Text, ) @@ -2587,7 +2695,8 @@ func TestGenerate_ComputerUseTool(t *testing.T) { // Build the next prompt: append the assistant tool-call turn // and the user screenshot-result turn. - prompt = append(prompt, + prompt = append( + prompt, fantasy.Message{ Role: fantasy.MessageRoleAssistant, Content: []fantasy.MessagePart{ @@ -2767,7 +2876,8 @@ func TestStream_ComputerUseTool(t *testing.T) { require.NoError(t, err, "turn %d", turn) gotActions = append(gotActions, parsed.Action) - prompt = append(prompt, + prompt = append( + prompt, fantasy.Message{ Role: fantasy.MessageRoleAssistant, Content: []fantasy.MessagePart{