Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ on:

jobs:
lint:
permissions:
checks: write
contents: read
pull-requests: read
uses: charmbracelet/meta/.github/workflows/lint.yml@main
with:
golangci_path: .golangci.yml
Expand Down
83 changes: 70 additions & 13 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ type (

// RepairToolCallFunction defines a function that repairs a tool call.
RepairToolCallFunction = func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error)

// PrepareCallFunction defines a function that prepares the call before the first model invocation.
PrepareCallFunction = func(ctx context.Context, call *AgentCall) (context.Context, error)
)

type agentSettings struct {
Expand All @@ -152,13 +155,16 @@ type agentSettings struct {
model LanguageModel

stopWhen []StopCondition
prepareCall PrepareCallFunction
prepareStep PrepareStepFunction
repairToolCall RepairToolCallFunction
onRetry OnRetryCallback
}

// AgentCall represents a call to an agent.
type AgentCall struct {
// SystemPrompt overrides the agent's system prompt for this call when non-nil.
SystemPrompt *string `json:"system_prompt"`
Prompt string `json:"prompt"`
Files []FilePart `json:"files"`
Messages []Message `json:"messages"`
Expand All @@ -174,7 +180,11 @@ type AgentCall struct {
OnRetry OnRetryCallback
MaxRetries *int

// CallOptions carries application-defined data that PrepareCall can read.
CallOptions any

StopWhen []StopCondition
PrepareCall PrepareCallFunction
PrepareStep PrepareStepFunction
RepairToolCall RepairToolCallFunction
}
Expand Down Expand Up @@ -250,6 +260,8 @@ type (

// AgentStreamCall represents a streaming call to an agent.
type AgentStreamCall struct {
// SystemPrompt overrides the agent's system prompt for this call when non-nil.
SystemPrompt *string `json:"system_prompt"`
Prompt string `json:"prompt"`
Files []FilePart `json:"files"`
Messages []Message `json:"messages"`
Expand All @@ -266,7 +278,11 @@ type AgentStreamCall struct {
OnRetry OnRetryCallback
MaxRetries *int

// CallOptions carries application-defined data that PrepareCall can read.
CallOptions any

StopWhen []StopCondition
PrepareCall PrepareCallFunction
PrepareStep PrepareStepFunction
RepairToolCall RepairToolCallFunction

Expand Down Expand Up @@ -330,7 +346,7 @@ func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
}
}

func (a *agent) prepareCall(call AgentCall) AgentCall {
func (a *agent) prepareCall(ctx context.Context, call AgentCall) (context.Context, AgentCall, error) {
call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
call.TopP = cmp.Or(call.TopP, a.settings.topP)
Expand All @@ -340,6 +356,10 @@ func (a *agent) prepareCall(call AgentCall) AgentCall {
call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)
call.ToolChoice = cmp.Or(call.ToolChoice, a.settings.toolChoice)

if call.SystemPrompt == nil {
sp := a.settings.systemPrompt
call.SystemPrompt = &sp
}
if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
call.StopWhen = a.settings.stopWhen
}
Expand Down Expand Up @@ -368,13 +388,35 @@ func (a *agent) prepareCall(call AgentCall) AgentCall {
maps.Copy(headers, a.settings.headers)
}

return call
prepareFn := call.PrepareCall
if prepareFn == nil {
prepareFn = a.settings.prepareCall
}
if prepareFn != nil {
var err error
if ctx, err = prepareFn(ctx, &call); err != nil {
return ctx, call, err
}
}

// Re-resolve in case the hook cleared SystemPrompt to opt back into the agent default.
if call.SystemPrompt == nil {
sp := a.settings.systemPrompt
call.SystemPrompt = &sp
}

return ctx, call, nil
}

// Generate implements Agent.
func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
opts = a.prepareCall(opts)
initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
ctx, opts, err := a.prepareCall(ctx, opts)
if err != nil {
return nil, err
}
// prepareCall guarantees SystemPrompt is non-nil at this point.
systemPrompt := *opts.SystemPrompt
initialPrompt, err := a.createPrompt(systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
if err != nil {
return nil, err
}
Expand All @@ -384,7 +426,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
for {
stepInputMessages := append(initialPrompt, responseMessages...)
stepModel := a.settings.model
stepSystemPrompt := a.settings.systemPrompt
stepSystemPrompt := systemPrompt
stepActiveTools := opts.ActiveTools
stepToolChoice := ToolChoiceAuto
if opts.ToolChoice != nil {
Expand Down Expand Up @@ -428,7 +470,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
}

// Recreate prompt with potentially modified system prompt
if stepSystemPrompt != a.settings.systemPrompt {
if stepSystemPrompt != systemPrompt {
stepPrompt, err := a.createPrompt(stepSystemPrompt, opts.Prompt, opts.Messages, opts.Files...)
if err != nil {
return nil, err
Expand Down Expand Up @@ -784,6 +826,7 @@ func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentT
func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
// Convert AgentStreamCall to AgentCall for preparation
call := AgentCall{
SystemPrompt: opts.SystemPrompt,
Prompt: opts.Prompt,
Files: opts.Files,
Messages: opts.Messages,
Expand All @@ -798,14 +841,21 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
ProviderOptions: opts.ProviderOptions,
MaxRetries: opts.MaxRetries,
OnRetry: opts.OnRetry,
CallOptions: opts.CallOptions,
StopWhen: opts.StopWhen,
PrepareCall: opts.PrepareCall,
PrepareStep: opts.PrepareStep,
RepairToolCall: opts.RepairToolCall,
}

call = a.prepareCall(call)
ctx, call, err := a.prepareCall(ctx, call)
if err != nil {
return nil, err
}

initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
// prepareCall guarantees SystemPrompt is non-nil at this point.
systemPrompt := *call.SystemPrompt
initialPrompt, err := a.createPrompt(systemPrompt, call.Prompt, call.Messages, call.Files...)
if err != nil {
return nil, err
}
Expand All @@ -822,7 +872,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
for stepNumber := 0; ; stepNumber++ {
stepInputMessages := append(initialPrompt, responseMessages...)
stepModel := a.settings.model
stepSystemPrompt := a.settings.systemPrompt
stepSystemPrompt := systemPrompt
stepActiveTools := call.ActiveTools
stepToolChoice := ToolChoiceAuto
if call.ToolChoice != nil {
Expand Down Expand Up @@ -866,7 +916,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
}

// Recreate prompt with potentially modified system prompt
if stepSystemPrompt != a.settings.systemPrompt {
if stepSystemPrompt != systemPrompt {
stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
if err != nil {
return nil, err
Expand Down Expand Up @@ -918,7 +968,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
}

// Process the stream
result, err := a.processStepStream(ctx, stream, opts, steps, stepTools, stepExecProviderTools)
result, err := a.processStepStream(ctx, stream, opts, steps, stepSystemPrompt, stepTools, stepExecProviderTools)
if err != nil {
return stepExecutionResult{}, err
}
Expand Down Expand Up @@ -1246,6 +1296,13 @@ func WithRepairToolCall(fn RepairToolCallFunction) AgentOption {
}
}

// WithPrepareCall sets the prepare call function for the agent.
func WithPrepareCall(fn PrepareCallFunction) AgentOption {
return func(s *agentSettings) {
s.prepareCall = fn
}
}

// WithMaxRetries sets the maximum number of retries for the agent.
func WithMaxRetries(maxRetries int) AgentOption {
return func(s *agentSettings) {
Expand All @@ -1261,7 +1318,7 @@ func WithOnRetry(callback OnRetryCallback) AgentOption {
}

// processStepStream processes a single step's stream and returns the step result.
func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool, execProviderTools []ExecutableProviderTool) (stepExecutionResult, error) {
func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepSystemPrompt string, stepTools []AgentTool, execProviderTools []ExecutableProviderTool) (stepExecutionResult, error) {
var stepContent []Content
var stepToolCalls []ToolCallContent
var stepUsage Usage
Expand Down Expand Up @@ -1452,7 +1509,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
delete(activeToolCalls, part.ID)
} else {
// Validate and potentially repair the tool call
validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, execProviderTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, execProviderTools, stepSystemPrompt, nil, opts.RepairToolCall)
stepToolCalls = append(stepToolCalls, validatedToolCall)
stepContent = append(stepContent, validatedToolCall)

Expand Down
Loading
Loading