Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 31 additions & 23 deletions providers/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,8 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S
}
}

sawMessageStop := false

for stream.Next() {
chunk := stream.Current()
_ = acc.Accumulate(chunk)
Expand Down Expand Up @@ -1552,41 +1554,47 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S
}
}
case "message_stop":
sawMessageStop = true
}
}

err := stream.Err()
if err == nil || errors.Is(err, io.EOF) {
// Truncated stream: no terminal message_delta with stop_reason.
// Surface as a retryable error.
if acc.StopReason == "" {
yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeError,
Error: fantasy.NewIncompleteStreamError(),
})
return
}
if err != nil && !errors.Is(err, io.EOF) {
yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeFinish,
ID: acc.ID,
FinishReason: mapFinishReason(string(acc.StopReason)),
Usage: fantasy.Usage{
InputTokens: acc.Usage.InputTokens,
OutputTokens: acc.Usage.OutputTokens,
TotalTokens: acc.Usage.InputTokens + acc.Usage.OutputTokens,
CacheCreationTokens: acc.Usage.CacheCreationInputTokens,
CacheReadTokens: acc.Usage.CacheReadInputTokens,
},
ProviderMetadata: fantasy.ProviderMetadata{},
Type: fantasy.StreamPartTypeError,
Error: toProviderErr(err),
})
return
} else { //nolint: revive
}

// Anthropic's SSE protocol reports the stop_reason in message_delta
// and then terminates the message with message_stop. Require both so
// a socket close after only one of those signals is retried.
if !sawMessageStop || acc.StopReason == "" {
err := ctx.Err()
if err == nil {
err = fantasy.NewIncompleteStreamError()
}
yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeError,
Error: toProviderErr(err),
Error: err,
})
return
}

yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeFinish,
ID: acc.ID,
FinishReason: mapFinishReason(string(acc.StopReason)),
Usage: fantasy.Usage{
InputTokens: acc.Usage.InputTokens,
OutputTokens: acc.Usage.OutputTokens,
TotalTokens: acc.Usage.InputTokens + acc.Usage.OutputTokens,
CacheCreationTokens: acc.Usage.CacheCreationInputTokens,
CacheReadTokens: acc.Usage.CacheReadInputTokens,
},
ProviderMetadata: fantasy.ProviderMetadata{},
})
}, nil
}

Expand Down
116 changes: 113 additions & 3 deletions providers/anthropic/anthropic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,113 @@ func TestStream_SendsOutputConfigEffort(t *testing.T) {
requireAnthropicEffort(t, call.body, EffortHigh)
}

func TestStream_RequiresMessageStopBeforeFinish(t *testing.T) {
t.Parallel()

completeTextStream := []string{
anthropicSSEEvent("message_start", `{"type":"message_start","message":{"id":"msg_complete","type":"message","role":"assistant","model":"claude-sonnet-4-20250514","content":[],"stop_reason":null,"usage":{"input_tokens":1,"output_tokens":0}}}`),
anthropicSSEEvent("content_block_start", `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`),
anthropicSSEEvent("content_block_delta", `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hello"}}`),
anthropicSSEEvent("content_block_stop", `{"type":"content_block_stop","index":0}`),
anthropicSSEEvent("message_delta", `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":1}}`),
anthropicSSEEvent("message_stop", `{"type":"message_stop"}`),
}

missingStopReasonStream := append([]string(nil), completeTextStream[:len(completeTextStream)-2]...)
missingStopReasonStream = append(missingStopReasonStream, completeTextStream[len(completeTextStream)-1])

tests := []struct {
name string
chunks []string
wantFinish bool
wantRetryable bool
wantErrContain string
}{
{
name: "complete stream finishes",
chunks: completeTextStream,
wantFinish: true,
},
{
name: "message_stop without stop_reason errors",
chunks: missingStopReasonStream,
wantRetryable: true,
},
{
name: "text stream closed before message_stop errors",
chunks: completeTextStream[:len(completeTextStream)-1],
wantRetryable: true,
},
{
name: "provider error event is preserved",
chunks: []string{
anthropicSSEEvent("error", `{"type":"error","error":{"type":"api_error","message":"stream down"}}`),
},
wantErrContain: "stream down",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

server, calls := newAnthropicStreamingServer(tt.chunks)
defer server.Close()

provider, err := New(
WithAPIKey("test-api-key"),
WithBaseURL(server.URL),
)
require.NoError(t, err)

model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514")
require.NoError(t, err)

stream, err := model.Stream(context.Background(), fantasy.Call{Prompt: testPrompt()})
require.NoError(t, err)

parts := collectAnthropicStreamParts(stream)
_ = awaitAnthropicCall(t, calls)

var finishes, errorParts []fantasy.StreamPart
for _, part := range parts {
switch part.Type {
case fantasy.StreamPartTypeFinish:
finishes = append(finishes, part)
case fantasy.StreamPartTypeError:
errorParts = append(errorParts, part)
}
}

if tt.wantFinish {
require.Len(t, finishes, 1)
require.Empty(t, errorParts)
return
}

require.Empty(t, finishes)
require.Len(t, errorParts, 1)
require.Error(t, errorParts[0].Error)
if tt.wantErrContain != "" {
require.Contains(t, errorParts[0].Error.Error(), tt.wantErrContain)
}

var providerErr *fantasy.ProviderError
if tt.wantRetryable {
require.ErrorAs(t, errorParts[0].Error, &providerErr)
require.True(t, providerErr.IsRetryable())
require.ErrorIs(t, providerErr.Cause, io.ErrUnexpectedEOF)
} else {
require.NotErrorIs(t, errorParts[0].Error, io.ErrUnexpectedEOF)
if errors.As(errorParts[0].Error, &providerErr) {
require.False(t, providerErr.IsRetryable())
require.NotErrorIs(t, providerErr.Cause, io.ErrUnexpectedEOF)
}
}
})
}
}

type anthropicCall struct {
method string
path string
Expand Down Expand Up @@ -1125,7 +1232,8 @@ func TestGenerate_WebSearchResponse(t *testing.T) {

// TextContent with the final answer.
require.Len(t, texts, 1)
require.Equal(t,
require.Equal(
t,
"Based on recent search results, here is the latest AI news.",
texts[0].Text,
)
Expand Down Expand Up @@ -2587,7 +2695,8 @@ func TestGenerate_ComputerUseTool(t *testing.T) {

// Build the next prompt: append the assistant tool-call turn
// and the user screenshot-result turn.
prompt = append(prompt,
prompt = append(
prompt,
fantasy.Message{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{
Expand Down Expand Up @@ -2767,7 +2876,8 @@ func TestStream_ComputerUseTool(t *testing.T) {
require.NoError(t, err, "turn %d", turn)
gotActions = append(gotActions, parsed.Action)

prompt = append(prompt,
prompt = append(
prompt,
fantasy.Message{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{
Expand Down
Loading