Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 103 additions & 27 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"maps"
"slices"
"sync"
"sync/atomic"
"time"

"charm.land/fantasy/schema"
"github.com/charmbracelet/x/exp/slice"
Expand Down Expand Up @@ -172,9 +174,10 @@ type AgentCall struct {
OnRetry OnRetryCallback
MaxRetries *int

StopWhen []StopCondition
PrepareStep PrepareStepFunction
RepairToolCall RepairToolCallFunction
StopWhen []StopCondition
PrepareStep PrepareStepFunction
RepairToolCall RepairToolCallFunction
StreamIdleTimeout time.Duration // Cancels the stream if no data arrives within this duration.
}

// Agent-level callbacks.
Expand Down Expand Up @@ -263,9 +266,10 @@ type AgentStreamCall struct {
OnRetry OnRetryCallback
MaxRetries *int

StopWhen []StopCondition
PrepareStep PrepareStepFunction
RepairToolCall RepairToolCallFunction
StopWhen []StopCondition
PrepareStep PrepareStepFunction
RepairToolCall RepairToolCallFunction
StreamIdleTimeout time.Duration // Cancels the stream if no data arrives within this duration.

// Agent-level callbacks
OnAgentStart OnAgentStartFunc // Called when agent starts
Expand Down Expand Up @@ -761,22 +765,23 @@ func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentT
func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
// Convert AgentStreamCall to AgentCall for preparation
call := AgentCall{
Prompt: opts.Prompt,
Files: opts.Files,
Messages: opts.Messages,
MaxOutputTokens: opts.MaxOutputTokens,
Temperature: opts.Temperature,
TopP: opts.TopP,
TopK: opts.TopK,
PresencePenalty: opts.PresencePenalty,
FrequencyPenalty: opts.FrequencyPenalty,
ActiveTools: opts.ActiveTools,
ProviderOptions: opts.ProviderOptions,
MaxRetries: opts.MaxRetries,
OnRetry: opts.OnRetry,
StopWhen: opts.StopWhen,
PrepareStep: opts.PrepareStep,
RepairToolCall: opts.RepairToolCall,
Prompt: opts.Prompt,
Files: opts.Files,
Messages: opts.Messages,
MaxOutputTokens: opts.MaxOutputTokens,
Temperature: opts.Temperature,
TopP: opts.TopP,
TopK: opts.TopK,
PresencePenalty: opts.PresencePenalty,
FrequencyPenalty: opts.FrequencyPenalty,
ActiveTools: opts.ActiveTools,
ProviderOptions: opts.ProviderOptions,
MaxRetries: opts.MaxRetries,
OnRetry: opts.OnRetry,
StopWhen: opts.StopWhen,
PrepareStep: opts.PrepareStep,
RepairToolCall: opts.RepairToolCall,
StreamIdleTimeout: opts.StreamIdleTimeout,
}

call = a.prepareCall(call)
Expand Down Expand Up @@ -884,14 +889,28 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)

result, err := retry(ctx, func() (stepExecutionResult, error) {
// Create the stream
stream, err := stepModel.Stream(ctx, streamCall)
streamCtx := ctx
var streamCancel context.CancelFunc
if call.StreamIdleTimeout > 0 {
streamCtx, streamCancel = context.WithCancel(ctx)
}

stream, err := stepModel.Stream(streamCtx, streamCall)
if err != nil {
if streamCancel != nil {
streamCancel()
}
return stepExecutionResult{}, err
}

// Process the stream
if call.StreamIdleTimeout > 0 {
stream = withIdleTimeout(stream, call.StreamIdleTimeout, streamCancel)
}

result, err := a.processStepStream(ctx, stream, opts, steps, stepTools, stepExecProviderTools)
if streamCancel != nil {
streamCancel()
}
if err != nil {
return stepExecutionResult{}, err
}
Expand Down Expand Up @@ -1248,11 +1267,17 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
parallel bool
}
toolChan := make(chan toolExecutionRequest, 10)
var closeToolChan sync.Once
var toolExecutionWg sync.WaitGroup
var toolStateMu sync.Mutex
toolResults := make([]ToolResultContent, 0)
var toolExecutionErr error

defer func() {
closeToolChan.Do(func() { close(toolChan) })
toolExecutionWg.Wait()
}()

// Create a map for quick tool lookup
toolMap := make(map[string]AgentTool)
for _, tool := range stepTools {
Expand Down Expand Up @@ -1534,8 +1559,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
}
}

// Close the tool execution channel and wait for all executions to complete
close(toolChan)
// Ensure the tool channel is closed and all tool executions complete.
// This is also handled by the deferred cleanup, but closing eagerly
// here allows us to inspect tool execution errors before returning.
closeToolChan.Do(func() { close(toolChan) })
toolExecutionWg.Wait()

// Check for tool execution errors
Expand Down Expand Up @@ -1602,3 +1629,52 @@ func WithProviderOptions(providerOptions ProviderOptions) AgentOption {
s.providerOptions = providerOptions
}
}

// withIdleTimeout wraps a StreamResponse so that if no stream part is
// received within the given timeout, cancelFn is called to cancel the
// underlying HTTP request context. Each received part resets the timer.
// The timer goroutine exits when the wrapped iterator returns.
func withIdleTimeout(stream StreamResponse, timeout time.Duration, cancelFn context.CancelFunc) StreamResponse {
return func(yield func(StreamPart) bool) {
timer := time.NewTimer(timeout)
done := make(chan struct{})
var timedOut atomic.Bool

go func() {
select {
case <-timer.C:
timedOut.Store(true)
cancelFn()
case <-done:
}
timer.Stop()
}()

defer close(done)

var stopped bool
stream(func(part StreamPart) bool {
timer.Reset(timeout)
// When the idle timeout fired and cancelled streamCtx, the
// provider will yield a StreamPartTypeError carrying
// context.Canceled. Replace it with a descriptive,
// non-context error so the retry logic treats it as a
// retryable failure rather than a user-initiated abort.
if part.Type == StreamPartTypeError && timedOut.Load() {
part.Error = fmt.Errorf("%w: no data received for %s", errStreamIdleTimeout, timeout)
}
if !yield(part) {
stopped = true
return false
}
return true
})

if timedOut.Load() && !stopped {
yield(StreamPart{
Type: StreamPartTypeError,
Error: fmt.Errorf("%w: no data received for %s", errStreamIdleTimeout, timeout),
})
}
}
}
177 changes: 177 additions & 0 deletions agent_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package fantasy
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -595,3 +598,177 @@ func TestStreamingAgentSources(t *testing.T) {
resultSources := result.Response.Content.Sources()
require.Equal(t, 2, len(resultSources))
}

// TestStreamingAgentIdleTimeout verifies that a hanging stream is cancelled
// after the idle timeout fires and that retries are attempted.
func TestStreamingAgentIdleTimeout(t *testing.T) {
t.Parallel()

var attempts atomic.Int32

mockModel := &mockLanguageModel{
streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
attempts.Add(1)
return func(yield func(StreamPart) bool) {
if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
return
}
if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
return
}
// Simulate a hang: block until context is cancelled.
<-ctx.Done()
}, nil
},
}

agent := NewAgent(mockModel)
ctx := context.Background()

streamCall := AgentStreamCall{
Prompt: "Say hello",
StreamIdleTimeout: 100 * time.Millisecond,
MaxRetries: ptrTo(2),
}

start := time.Now()
_, err := agent.Stream(ctx, streamCall)
elapsed := time.Since(start)

require.Error(t, err)
require.ErrorIs(t, err, errStreamIdleTimeout)
// 2 retries with 2s initial delay and 2x backoff (2+4 = 6s).
require.Less(t, elapsed, 10*time.Second, "should not block for a long time")
// 3 total attempts (1 initial + 2 retries).
require.Equal(t, int32(3), attempts.Load(), "should retry on idle timeout")
}

// TestStreamingAgentIdleTimeoutResetsOnChunks verifies that the idle timer
// resets with each chunk so a slow-but-active stream succeeds.
func TestStreamingAgentIdleTimeoutResetsOnChunks(t *testing.T) {
t.Parallel()

mockModel := &mockLanguageModel{
streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
return func(yield func(StreamPart) bool) {
if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
return
}
// Yield deltas with pauses shorter than the idle timeout.
for _, word := range []string{"Hello", ", ", "world", "!"} {
time.Sleep(30 * time.Millisecond)
if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: word}) {
return
}
}
if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
return
}
yield(StreamPart{
Type: StreamPartTypeFinish,
Usage: Usage{InputTokens: 3, OutputTokens: 4, TotalTokens: 7},
FinishReason: FinishReasonStop,
})
}, nil
},
}

agent := NewAgent(mockModel)
ctx := context.Background()

streamCall := AgentStreamCall{
Prompt: "Say hello",
StreamIdleTimeout: 100 * time.Millisecond,
}

result, err := agent.Stream(ctx, streamCall)
require.NoError(t, err)
require.Equal(t, "Hello, world!", result.Response.Content.Text())
}

// TestStreamingAgentCallbackErrorCleanup verifies that an early return from a
// callback error properly cleans up the tool coordinator goroutine (no leak).
func TestStreamingAgentCallbackErrorCleanup(t *testing.T) {
t.Parallel()

callbackErr := errors.New("callback forced error")

mockModel := &mockLanguageModel{
streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
return func(yield func(StreamPart) bool) {
if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
return
}
if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
return
}
if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
return
}
yield(StreamPart{
Type: StreamPartTypeFinish,
Usage: Usage{InputTokens: 1, OutputTokens: 1, TotalTokens: 2},
FinishReason: FinishReasonStop,
})
}, nil
},
}

agent := NewAgent(mockModel)
ctx := context.Background()

streamCall := AgentStreamCall{
Prompt: "Say hello",
OnTextDelta: func(_, _ string) error {
return callbackErr
},
}

_, err := agent.Stream(ctx, streamCall)
require.ErrorIs(t, err, callbackErr)
}

// TestStreamingAgentIdleTimeoutNoYieldAfterStop verifies that withIdleTimeout
// does not call yield after the consumer has stopped iteration, which would
// panic with "range function continued iteration after loop body returned false".
func TestStreamingAgentIdleTimeoutNoYieldAfterStop(t *testing.T) {
t.Parallel()

mockModel := &mockLanguageModel{
streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
return func(yield func(StreamPart) bool) {
if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
return
}
if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
return
}
// Simulate a hang so the idle timeout fires while we're blocked.
<-ctx.Done()
// After context cancellation, the provider may still yield an error.
yield(StreamPart{
Type: StreamPartTypeError,
Error: ctx.Err(),
})
}, nil
},
}

agent := NewAgent(mockModel)
ctx := context.Background()

streamCall := AgentStreamCall{
Prompt: "Say hello",
StreamIdleTimeout: 50 * time.Millisecond,
OnTextDelta: func(_, _ string) error {
return errors.New("consumer error")
},
}

// This must not panic with "range function continued iteration after
// loop body returned false".
_, err := agent.Stream(ctx, streamCall)
require.Error(t, err)
}

func ptrTo[T any](v T) *T { return &v }
4 changes: 4 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ func ErrorTitleForStatusCode(statusCode int) string {
return strings.ToLower(http.StatusText(statusCode))
}

// errStreamIdleTimeout is a sentinel used to identify stream-idle-timeout
// errors so the retry logic can treat them as retryable.
var errStreamIdleTimeout = errors.New("stream idle timeout")

// NoObjectGeneratedError is returned when object generation fails
// due to parsing errors, validation errors, or model failures.
type NoObjectGeneratedError struct {
Expand Down
Loading
Loading