diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d71849463..de3a45a2a 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -5,6 +5,10 @@ on: jobs: lint: + permissions: + checks: write + contents: read + pull-requests: read uses: charmbracelet/meta/.github/workflows/lint.yml@main with: golangci_path: .golangci.yml diff --git a/agent.go b/agent.go index 62e219457..fa09475f6 100644 --- a/agent.go +++ b/agent.go @@ -129,6 +129,9 @@ type ( // RepairToolCallFunction defines a function that repairs a tool call. RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) + + // PrepareCallFunction defines a function that prepares the call before the first model invocation. + PrepareCallFunction = func(ctx context.Context, call *AgentCall) (context.Context, error) ) type agentSettings struct { @@ -152,6 +155,7 @@ type agentSettings struct { model LanguageModel stopWhen []StopCondition + prepareCall PrepareCallFunction prepareStep PrepareStepFunction repairToolCall RepairToolCallFunction onRetry OnRetryCallback @@ -159,6 +163,8 @@ type agentSettings struct { // AgentCall represents a call to an agent. type AgentCall struct { + // SystemPrompt overrides the agent's system prompt for this call when non-nil. + SystemPrompt *string `json:"system_prompt"` Prompt string `json:"prompt"` Files []FilePart `json:"files"` Messages []Message `json:"messages"` @@ -174,7 +180,11 @@ type AgentCall struct { OnRetry OnRetryCallback MaxRetries *int + // CallOptions carries application-defined data that PrepareCall can read. + CallOptions any + StopWhen []StopCondition + PrepareCall PrepareCallFunction PrepareStep PrepareStepFunction RepairToolCall RepairToolCallFunction } @@ -250,6 +260,8 @@ type ( // AgentStreamCall represents a streaming call to an agent. type AgentStreamCall struct { + // SystemPrompt overrides the agent's system prompt for this call when non-nil. + SystemPrompt *string `json:"system_prompt"` Prompt string `json:"prompt"` Files []FilePart `json:"files"` Messages []Message `json:"messages"` @@ -266,7 +278,11 @@ type AgentStreamCall struct { OnRetry OnRetryCallback MaxRetries *int + // CallOptions carries application-defined data that PrepareCall can read. + CallOptions any + StopWhen []StopCondition + PrepareCall PrepareCallFunction PrepareStep PrepareStepFunction RepairToolCall RepairToolCallFunction @@ -330,7 +346,7 @@ func NewAgent(model LanguageModel, opts ...AgentOption) Agent { } } -func (a *agent) prepareCall(call AgentCall) AgentCall { +func (a *agent) prepareCall(ctx context.Context, call AgentCall) (context.Context, AgentCall, error) { call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens) call.Temperature = cmp.Or(call.Temperature, a.settings.temperature) call.TopP = cmp.Or(call.TopP, a.settings.topP) @@ -340,6 +356,10 @@ func (a *agent) prepareCall(call AgentCall) AgentCall { call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries) call.ToolChoice = cmp.Or(call.ToolChoice, a.settings.toolChoice) + if call.SystemPrompt == nil { + sp := a.settings.systemPrompt + call.SystemPrompt = &sp + } if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 { call.StopWhen = a.settings.stopWhen } @@ -368,13 +388,35 @@ func (a *agent) prepareCall(call AgentCall) AgentCall { maps.Copy(headers, a.settings.headers) } - return call + prepareFn := call.PrepareCall + if prepareFn == nil { + prepareFn = a.settings.prepareCall + } + if prepareFn != nil { + var err error + if ctx, err = prepareFn(ctx, &call); err != nil { + return ctx, call, err + } + } + + // Re-resolve in case the hook cleared SystemPrompt to opt back into the agent default. + if call.SystemPrompt == nil { + sp := a.settings.systemPrompt + call.SystemPrompt = &sp + } + + return ctx, call, nil } // Generate implements Agent. func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) { - opts = a.prepareCall(opts) - initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...) + ctx, opts, err := a.prepareCall(ctx, opts) + if err != nil { + return nil, err + } + // prepareCall guarantees SystemPrompt is non-nil at this point. + systemPrompt := *opts.SystemPrompt + initialPrompt, err := a.createPrompt(systemPrompt, opts.Prompt, opts.Messages, opts.Files...) if err != nil { return nil, err } @@ -384,7 +426,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err for { stepInputMessages := append(initialPrompt, responseMessages...) stepModel := a.settings.model - stepSystemPrompt := a.settings.systemPrompt + stepSystemPrompt := systemPrompt stepActiveTools := opts.ActiveTools stepToolChoice := ToolChoiceAuto if opts.ToolChoice != nil { @@ -428,7 +470,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err } // Recreate prompt with potentially modified system prompt - if stepSystemPrompt != a.settings.systemPrompt { + if stepSystemPrompt != systemPrompt { stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...) if err != nil { return nil, err @@ -784,6 +826,7 @@ func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentT func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) { // Convert AgentStreamCall to AgentCall for preparation call := AgentCall{ + SystemPrompt: opts.SystemPrompt, Prompt: opts.Prompt, Files: opts.Files, Messages: opts.Messages, @@ -798,14 +841,21 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, ProviderOptions: opts.ProviderOptions, MaxRetries: opts.MaxRetries, OnRetry: opts.OnRetry, + CallOptions: opts.CallOptions, StopWhen: opts.StopWhen, + PrepareCall: opts.PrepareCall, PrepareStep: opts.PrepareStep, RepairToolCall: opts.RepairToolCall, } - call = a.prepareCall(call) + ctx, call, err := a.prepareCall(ctx, call) + if err != nil { + return nil, err + } - initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...) + // prepareCall guarantees SystemPrompt is non-nil at this point. + systemPrompt := *call.SystemPrompt + initialPrompt, err := a.createPrompt(systemPrompt, call.Prompt, call.Messages, call.Files...) if err != nil { return nil, err } @@ -822,7 +872,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, for stepNumber := 0; ; stepNumber++ { stepInputMessages := append(initialPrompt, responseMessages...) stepModel := a.settings.model - stepSystemPrompt := a.settings.systemPrompt + stepSystemPrompt := systemPrompt stepActiveTools := call.ActiveTools stepToolChoice := ToolChoiceAuto if call.ToolChoice != nil { @@ -866,7 +916,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, } // Recreate prompt with potentially modified system prompt - if stepSystemPrompt != a.settings.systemPrompt { + if stepSystemPrompt != systemPrompt { stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...) if err != nil { return nil, err @@ -918,7 +968,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, } // Process the stream - result, err := a.processStepStream(ctx, stream, opts, steps, stepTools, stepExecProviderTools) + result, err := a.processStepStream(ctx, stream, opts, steps, stepSystemPrompt, stepTools, stepExecProviderTools) if err != nil { return stepExecutionResult{}, err } @@ -1246,6 +1296,13 @@ func WithRepairToolCall(fn RepairToolCallFunction) AgentOption { } } +// WithPrepareCall sets the prepare call function for the agent. +func WithPrepareCall(fn PrepareCallFunction) AgentOption { + return func(s *agentSettings) { + s.prepareCall = fn + } +} + // WithMaxRetries sets the maximum number of retries for the agent. func WithMaxRetries(maxRetries int) AgentOption { return func(s *agentSettings) { @@ -1261,7 +1318,7 @@ func WithOnRetry(callback OnRetryCallback) AgentOption { } // processStepStream processes a single step's stream and returns the step result. -func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool, execProviderTools []ExecutableProviderTool) (stepExecutionResult, error) { +func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepSystemPrompt string, stepTools []AgentTool, execProviderTools []ExecutableProviderTool) (stepExecutionResult, error) { var stepContent []Content var stepToolCalls []ToolCallContent var stepUsage Usage @@ -1452,7 +1509,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op delete(activeToolCalls, part.ID) } else { // Validate and potentially repair the tool call - validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, execProviderTools, a.settings.systemPrompt, nil, opts.RepairToolCall) + validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, execProviderTools, stepSystemPrompt, nil, opts.RepairToolCall) stepToolCalls = append(stepToolCalls, validatedToolCall) stepContent = append(stepContent, validatedToolCall) diff --git a/agent_test.go b/agent_test.go index da10747cd..152e1f083 100644 --- a/agent_test.go +++ b/agent_test.go @@ -2558,3 +2558,176 @@ func TestAgent_Generate_StopTurn_NotSet(t *testing.T) { require.Len(t, toolResults, 1) require.False(t, toolResults[0].StopTurn) } + +func TestPrepareCall(t *testing.T) { + t.Parallel() + + systemFromCall := func(call Call) string { + if len(call.Prompt) == 0 || call.Prompt[0].Role != MessageRoleSystem { + return "" + } + if len(call.Prompt[0].Content) == 0 { + return "" + } + text, ok := AsContentType[TextPart](call.Prompt[0].Content[0]) + if !ok { + return "" + } + return text.Text + } + + t.Run("hook can override system prompt via CallOptions", func(t *testing.T) { + t.Parallel() + + var captured string + model := &mockLanguageModel{ + generateFunc: func(_ context.Context, call Call) (*Response, error) { + captured = systemFromCall(call) + return &Response{ + Content: ResponseContent{TextContent{Text: "ok"}}, + Usage: Usage{TotalTokens: 1}, + FinishReason: FinishReasonStop, + }, nil + }, + } + + type promptRef struct{ Name string } + hook := func(ctx context.Context, c *AgentCall) (context.Context, error) { + ref, ok := c.CallOptions.(promptRef) + require.True(t, ok) + require.NotNil(t, c.SystemPrompt, "prepareCall should resolve SystemPrompt before the hook") + require.Equal(t, "agent default", *c.SystemPrompt) + s := "fetched: " + ref.Name + c.SystemPrompt = &s + return ctx, nil + } + + agent := NewAgent(model, WithSystemPrompt("agent default"), WithPrepareCall(hook)) + + _, err := agent.Generate(context.Background(), AgentCall{ + Prompt: "user input", + CallOptions: promptRef{Name: "support-bot"}, + }) + + require.NoError(t, err) + require.Equal(t, "fetched: support-bot", captured) + }) + + t.Run("explicit SystemPrompt without hook bypasses agent default", func(t *testing.T) { + t.Parallel() + + var captured string + model := &mockLanguageModel{ + generateFunc: func(_ context.Context, call Call) (*Response, error) { + captured = systemFromCall(call) + return &Response{ + Content: ResponseContent{TextContent{Text: "ok"}}, + Usage: Usage{TotalTokens: 1}, + FinishReason: FinishReasonStop, + }, nil + }, + } + + agent := NewAgent(model, WithSystemPrompt("agent default")) + + override := "explicit per-call system" + _, err := agent.Generate(context.Background(), AgentCall{ + SystemPrompt: &override, + Prompt: "user input", + }) + + require.NoError(t, err) + require.Equal(t, "explicit per-call system", captured) + }) + + t.Run("PrepareStep can still override the prepared system", func(t *testing.T) { + t.Parallel() + + var captured string + model := &mockLanguageModel{ + generateFunc: func(_ context.Context, call Call) (*Response, error) { + captured = systemFromCall(call) + return &Response{ + Content: ResponseContent{TextContent{Text: "ok"}}, + Usage: Usage{TotalTokens: 1}, + FinishReason: FinishReasonStop, + }, nil + }, + } + + hook := func(ctx context.Context, c *AgentCall) (context.Context, error) { + s := "from prepare call" + c.SystemPrompt = &s + return ctx, nil + } + prepareStep := func(ctx context.Context, _ PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) { + s := "from prepare step" + return ctx, PrepareStepResult{System: &s}, nil + } + + agent := NewAgent(model, WithSystemPrompt("agent default"), WithPrepareCall(hook)) + + _, err := agent.Generate(context.Background(), AgentCall{ + Prompt: "user input", + PrepareStep: prepareStep, + }) + + require.NoError(t, err) + require.Equal(t, "from prepare step", captured) + }) + + t.Run("hook clearing SystemPrompt falls back to agent default", func(t *testing.T) { + t.Parallel() + + var captured string + model := &mockLanguageModel{ + generateFunc: func(_ context.Context, call Call) (*Response, error) { + captured = systemFromCall(call) + return &Response{ + Content: ResponseContent{TextContent{Text: "ok"}}, + Usage: Usage{TotalTokens: 1}, + FinishReason: FinishReasonStop, + }, nil + }, + } + + hook := func(ctx context.Context, c *AgentCall) (context.Context, error) { + c.SystemPrompt = nil + return ctx, nil + } + + agent := NewAgent(model, WithSystemPrompt("agent default"), WithPrepareCall(hook)) + + _, err := agent.Generate(context.Background(), AgentCall{Prompt: "user input"}) + require.NoError(t, err) + require.Equal(t, "agent default", captured) + }) + + t.Run("Stream forwards explicit SystemPrompt through AgentCall conversion", func(t *testing.T) { + t.Parallel() + + var captured string + model := &mockLanguageModel{ + streamFunc: func(_ context.Context, call Call) (StreamResponse, error) { + captured = systemFromCall(call) + return func(yield func(StreamPart) bool) { + yield(StreamPart{ + Type: StreamPartTypeFinish, + Usage: Usage{TotalTokens: 1}, + FinishReason: FinishReasonStop, + }) + }, nil + }, + } + + agent := NewAgent(model, WithSystemPrompt("agent default")) + + override := "explicit per-call system" + _, err := agent.Stream(context.Background(), AgentStreamCall{ + SystemPrompt: &override, + Prompt: "user input", + }) + require.NoError(t, err) + require.Equal(t, "explicit per-call system", captured) + }) +}