diff --git a/go/ai/generate.go b/go/ai/generate.go index 7ae2eba241..9a0d37aac5 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -372,9 +372,29 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption) modelName = genOpts.ModelName } + var dynamicTools []Tool tools := make([]string, len(genOpts.Tools)) - for i, tool := range genOpts.Tools { - tools[i] = tool.Name() + toolNames := make(map[string]bool) + for i, toolRef := range genOpts.Tools { + name := toolRef.Name() + // Redundant duplicate tool check with GenerateWithRequest otherwise we will panic when we register the dynamic tools. + if toolNames[name] { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.Generate: duplicate tool %q", name) + } + toolNames[name] = true + tools[i] = name + // Dynamic tools wouldn't have been registered by this point. + if LookupTool(r, name) == nil { + if tool, ok := toolRef.(Tool); ok { + dynamicTools = append(dynamicTools, tool) + } + } + } + if len(dynamicTools) > 0 { + r = r.NewChild() + for _, tool := range dynamicTools { + tool.Register(r) + } } messages := []*Message{} @@ -596,7 +616,7 @@ func handleToolRequests(ctx context.Context, r *registry.Registry, req *ModelReq output, err := tool.RunRaw(ctx, toolReq.Input) if err != nil { - var tie *ToolInterruptError + var tie *toolInterruptError if errors.As(err, &tie) { logger.FromContext(ctx).Debug("tool %q triggered an interrupt: %v", toolReq.Name, tie.Metadata) @@ -636,7 +656,7 @@ func handleToolRequests(ctx context.Context, r *registry.Registry, req *ModelReq for range toolCount { res := <-resultChan if res.err != nil { - var tie *ToolInterruptError + var tie *toolInterruptError if errors.As(res.err, &tie) { hasInterrupts = true continue @@ -878,7 +898,7 @@ func handleResumedToolRequest(ctx context.Context, r *registry.Registry, genOpts output, err := tool.RunRaw(resumedCtx, restartPart.ToolRequest.Input) if err != nil { - var tie *ToolInterruptError + var tie *toolInterruptError if errors.As(err, &tie) { logger.FromContext(ctx).Debug("tool %q triggered an interrupt: %v", restartPart.ToolRequest.Name, tie.Metadata) diff --git a/go/ai/generate_test.go b/go/ai/generate_test.go index c4c37fb5a3..fa67e5ccce 100644 --- a/go/ai/generate_test.go +++ b/go/ai/generate_test.go @@ -621,6 +621,124 @@ func TestGenerate(t *testing.T) { t.Errorf("got text %q, want %q", res.Text(), expectedText) } }) + + t.Run("registers dynamic tools", func(t *testing.T) { + // Create a tool that is NOT registered in the global registry + dynamicTool := NewTool("dynamicTestTool", "a tool that is dynamically registered", + func(ctx *ToolContext, input struct { + Message string + }) (string, error) { + return "Dynamic: " + input.Message, nil + }, + ) + + // Verify the tool is not in the global registry + if LookupTool(r, "dynamicTestTool") != nil { + t.Fatal("dynamicTestTool should not be registered in global registry") + } + + // Create a model that will call the dynamic tool then provide a final response + roundCount := 0 + info := &ModelInfo{ + Supports: &ModelSupports{ + Multiturn: true, + Tools: true, + }, + } + toolCallModel := DefineModel(r, "test", "toolcall", info, + func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) { + roundCount++ + if roundCount == 1 { + // First response: call the dynamic tool + return &ModelResponse{ + Request: gr, + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewToolRequestPart(&ToolRequest{ + Name: "dynamicTestTool", + Input: map[string]any{"Message": "Hello from dynamic tool"}, + }), + }, + }, + }, nil + } + // Second response: provide final answer based on tool response + var toolResult string + for _, msg := range gr.Messages { + if msg.Role == RoleTool { + for _, part := range msg.Content { + if part.ToolResponse != nil { + toolResult = part.ToolResponse.Output.(string) + } + } + } + } + return &ModelResponse{ + Request: gr, + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewTextPart(toolResult), + }, + }, + }, nil + }) + + // Use Generate with the dynamic tool - this should trigger the dynamic registration + res, err := Generate(context.Background(), r, + WithModel(toolCallModel), + WithPrompt("call the dynamic tool"), + WithTools(dynamicTool), + ) + if err != nil { + t.Fatal(err) + } + + // The tool should have been called and returned a response + expectedText := "Dynamic: Hello from dynamic tool" + if res.Text() != expectedText { + t.Errorf("expected text %q, got %q", expectedText, res.Text()) + } + + // Verify two rounds were executed: tool call + final response + if roundCount != 2 { + t.Errorf("expected 2 rounds, got %d", roundCount) + } + + // Verify the tool is still not in the global registry (it was registered in a child) + if LookupTool(r, "dynamicTestTool") != nil { + t.Error("dynamicTestTool should not be registered in global registry after generation") + } + }) + + t.Run("handles duplicate dynamic tools", func(t *testing.T) { + // Create two tools with the same name + dynamicTool1 := NewTool("duplicateTool", "first tool", + func(ctx *ToolContext, input any) (string, error) { + return "tool1", nil + }, + ) + dynamicTool2 := NewTool("duplicateTool", "second tool", + func(ctx *ToolContext, input any) (string, error) { + return "tool2", nil + }, + ) + + // Using both tools should result in an error + _, err := Generate(context.Background(), r, + WithModel(echoModel), + WithPrompt("test duplicate tools"), + WithTools(dynamicTool1, dynamicTool2), + ) + + if err == nil { + t.Fatal("expected error for duplicate tool names") + } + if !strings.Contains(err.Error(), "duplicate tool \"duplicateTool\"") { + t.Errorf("unexpected error message: %v", err) + } + }) } func TestModelVersion(t *testing.T) { diff --git a/go/ai/tools.go b/go/ai/tools.go index 5ee31ffe0d..ff4350ce17 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -60,19 +60,21 @@ type Tool interface { Definition() *ToolDefinition // RunRaw runs this tool using the provided raw input. RunRaw(ctx context.Context, input any) (any, error) + // Register sets the tracing state on the action and registers it with the registry. + Register(r *registry.Registry) // Respond constructs a *Part with a ToolResponse for a given interrupted tool request. Respond(toolReq *Part, outputData any, opts *RespondOptions) *Part // Restart constructs a *Part with a new ToolRequest to re-trigger a tool, - // potentially with new input and resumedMetadata. + // potentially with new input and metadata. Restart(toolReq *Part, opts *RestartOptions) *Part } -// ToolInterruptError represents an intentional interruption of tool execution. -type ToolInterruptError struct { +// toolInterruptError represents an intentional interruption of tool execution. +type toolInterruptError struct { Metadata map[string]any } -func (e *ToolInterruptError) Error() string { +func (e *toolInterruptError) Error() string { return "tool execution interrupted" } @@ -112,58 +114,54 @@ type ToolContext struct { OriginalInput any } -// DefineTool defines a tool function with interrupt capability +// DefineTool defines a tool. func DefineTool[In, Out any](r *registry.Registry, name, description string, fn func(ctx *ToolContext, input In) (Out, error)) Tool { - wrappedFn := func(ctx context.Context, input In) (Out, error) { - toolCtx := &ToolContext{ - Context: ctx, - Interrupt: func(opts *InterruptOptions) error { - return &ToolInterruptError{ - Metadata: opts.Metadata, - } - }, - Resumed: resumedCtxKey.FromContext(ctx), - OriginalInput: origInputCtxKey.FromContext(ctx), - } - return fn(toolCtx, input) - } - - metadata := map[string]any{ - "type": "tool", - "name": name, - "description": description, - } + metadata, wrappedFn := implementTool(name, description, fn) toolAction := core.DefineAction(r, "", name, core.ActionTypeTool, metadata, wrappedFn) - return &tool{Action: toolAction} } -// DefineToolWithInputSchema defines a tool function with a custom input schema and interrupt capability. -// The input schema allows specifying a JSON Schema for validating tool inputs. +// DefineToolWithInputSchema defines a tool function with a custom input schema. func DefineToolWithInputSchema[Out any](r *registry.Registry, name, description string, inputSchema *jsonschema.Schema, fn func(ctx *ToolContext, input any) (Out, error)) Tool { - metadata := make(map[string]any) - metadata["type"] = "tool" - metadata["name"] = name - metadata["description"] = description + metadata, wrappedFn := implementTool(name, description, fn) + toolAction := core.DefineActionWithInputSchema(r, "", name, core.ActionTypeTool, metadata, inputSchema, wrappedFn) + return &tool{Action: toolAction} +} - wrappedFn := func(ctx context.Context, input any) (Out, error) { +// NewTool creates a tool but does not register it in the registry. It can be passed directly to [Generate]. +func NewTool[In, Out any](name, description string, + fn func(ctx *ToolContext, input In) (Out, error)) Tool { + metadata, wrappedFn := implementTool(name, description, fn) + metadata["dynamic"] = true + toolAction := core.NewAction("", name, core.ActionTypeTool, metadata, wrappedFn) + return &tool{Action: toolAction} +} + +// implementTool creates the metadata and wrapped function common to both DefineTool and NewTool. +func implementTool[In, Out any](name, description string, fn func(ctx *ToolContext, input In) (Out, error)) (map[string]any, func(context.Context, In) (Out, error)) { + metadata := map[string]any{ + "type": core.ActionTypeTool, + "name": name, + "description": description, + } + wrappedFn := func(ctx context.Context, input In) (Out, error) { toolCtx := &ToolContext{ Context: ctx, Interrupt: func(opts *InterruptOptions) error { - return &ToolInterruptError{ + return &toolInterruptError{ Metadata: opts.Metadata, } }, + Resumed: resumedCtxKey.FromContext(ctx), + OriginalInput: origInputCtxKey.FromContext(ctx), } return fn(toolCtx, input) } - toolAction := core.DefineActionWithInputSchema(r, "", name, core.ActionTypeTool, metadata, inputSchema, wrappedFn) - - return &tool{Action: toolAction} + return metadata, wrappedFn } // Name returns the name of the tool. @@ -193,6 +191,12 @@ func (t *tool) RunRaw(ctx context.Context, input any) (any, error) { return runAction(ctx, t.Definition(), t.Action, input) } +// Register sets the tracing state on the action and registers it with the registry. +func (t *tool) Register(r *registry.Registry) { + t.Action.SetTracingState(r.TracingState()) + r.RegisterAction(fmt.Sprintf("/%s/%s", core.ActionTypeTool, t.Action.Name()), t.Action) +} + // runAction runs the given action with the provided raw input and returns the output in raw format. func runAction(ctx context.Context, def *ToolDefinition, action core.Action, input any) (any, error) { mi, err := json.Marshal(input) diff --git a/go/core/action.go b/go/core/action.go index 4636864907..0d66fdb270 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -48,6 +48,8 @@ type Action interface { RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) // Desc returns a descriptor of the action. Desc() ActionDesc + // SetTracingState sets the tracing state on the action. + SetTracingState(tstate *tracing.State) } // An ActionType is the kind of an action. @@ -106,6 +108,23 @@ func DefineAction[In, Out any]( }) } +// NewAction creates a new non-streaming Action without registering it. +func NewAction[In, Out any]( + provider, name string, + atype ActionType, + metadata map[string]any, + fn Func[In, Out], +) *ActionDef[In, Out, struct{}] { + fullName := name + if provider != "" { + fullName = provider + "/" + name + } + return newAction(nil, fullName, atype, metadata, nil, + func(ctx context.Context, in In, cb noStream) (Out, error) { + return fn(ctx, in) + }) +} + // DefineStreamingAction creates a new streaming action and registers it. func DefineStreamingAction[In, Out, Stream any]( r *registry.Registry, @@ -155,6 +174,7 @@ func defineAction[In, Out, Stream any]( } // newAction creates a new Action with the given name and arguments. +// If registry is nil, tracing state is left nil to be set later. // If inputSchema is nil, it is inferred from In. func newAction[In, Out, Stream any]( r *registry.Registry, @@ -164,23 +184,31 @@ func newAction[In, Out, Stream any]( inputSchema *jsonschema.Schema, fn StreamingFunc[In, Out, Stream], ) *ActionDef[In, Out, Stream] { - var i In - var o Out if inputSchema == nil { + var i In if reflect.ValueOf(i).Kind() != reflect.Invalid { inputSchema = base.InferJSONSchema(i) } } + + var o Out var outputSchema *jsonschema.Schema if reflect.ValueOf(o).Kind() != reflect.Invalid { outputSchema = base.InferJSONSchema(o) } + var description string if desc, ok := metadata["description"].(string); ok { description = desc } + + var tstate *tracing.State + if r != nil { + tstate = r.TracingState() + } + return &ActionDef[In, Out, Stream]{ - tstate: r.TracingState(), + tstate: tstate, fn: func(ctx context.Context, input In, cb StreamCallback[Stream]) (Out, error) { tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype)) return fn(ctx, input, cb) @@ -200,6 +228,12 @@ func newAction[In, Out, Stream any]( // Name returns the Action's Name. func (a *ActionDef[In, Out, Stream]) Name() string { return a.desc.Name } +// SetTracingState sets the tracing state on the action. This is used when an action +// created without a registry needs to have its tracing state set later. +func (a *ActionDef[In, Out, Stream]) SetTracingState(tstate *tracing.State) { + a.tstate = tstate +} + // Run executes the Action's function in a new trace span. func (a *ActionDef[In, Out, Stream]) Run(ctx context.Context, input In, cb StreamCallback[Stream]) (output Out, err error) { logger.FromContext(ctx).Debug("Action.Run", diff --git a/go/core/flow.go b/go/core/flow.go index 017f6fdd51..f8be665804 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -115,6 +115,11 @@ func (f *Flow[In, Out, Stream]) Run(ctx context.Context, input In) (Out, error) return (*ActionDef[In, Out, Stream])(f).Run(ctx, input, nil) } +// SetTracingState sets the tracing state on the flow. +func (f *Flow[In, Out, Stream]) SetTracingState(tstate *tracing.State) { + (*ActionDef[In, Out, Stream])(f).SetTracingState(tstate) +} + // Stream runs the flow in the context of another flow and streams the output. // It returns a function whose argument function (the "yield function") will be repeatedly // called with the results. diff --git a/go/internal/registry/registry.go b/go/internal/registry/registry.go index 69132ccbfa..c543267025 100644 --- a/go/internal/registry/registry.go +++ b/go/internal/registry/registry.go @@ -19,6 +19,7 @@ package registry import ( "fmt" "log/slog" + "maps" "os" "sync" @@ -37,12 +38,15 @@ const ( type Registry struct { tstate *tracing.State mu sync.Mutex + frozen bool // when true, no more additions + parent *Registry // parent registry for hierarchical lookups actions map[string]any // Values follow interface core.Action but we can't reference it here. plugins map[string]any // Values follow interface genkit.Plugin but we can't reference it here. values map[string]any // Values can truly be anything. Dotprompt *dotprompt.Dotprompt } +// New creates a new root registry. func New() (*Registry, error) { r := &Registry{ actions: map[string]any{}, @@ -60,6 +64,21 @@ func New() (*Registry, error) { return r, nil } +// NewChild creates a new child registry that inherits from this registry. +// Child registries are cheap to create and will fall back to the parent +// for lookups if a value is not found in the child. +func (r *Registry) NewChild() *Registry { + child := &Registry{ + parent: r, + tstate: r.tstate, + actions: map[string]any{}, + plugins: map[string]any{}, + values: map[string]any{}, + Dotprompt: r.Dotprompt, + } + return child +} + func (r *Registry) TracingState() *tracing.State { return r.tstate } // RegisterPlugin records the plugin in the registry. @@ -88,11 +107,22 @@ func (r *Registry) RegisterAction(key string, action any) { slog.Debug("RegisterAction", "key", key) } -// LookupPlugin returns the plugin for the given name, or nil if there is none. +// LookupPlugin returns the plugin for the given name. +// It first checks the current registry, then falls back to the parent if not found. +// Returns nil if the plugin is not found in the registry hierarchy. func (r *Registry) LookupPlugin(name string) any { r.mu.Lock() defer r.mu.Unlock() - return r.plugins[name] + + if plugin, ok := r.plugins[name]; ok { + return plugin + } + + if r.parent != nil { + return r.parent.LookupPlugin(name) + } + + return nil } // RegisterValue records an arbitrary value in the registry. @@ -107,21 +137,45 @@ func (r *Registry) RegisterValue(name string, value any) { slog.Debug("RegisterValue", "name", name) } -// LookupValue returns the value for the given name, or nil if there is none. +// LookupValue returns the value for the given name. +// It first checks the current registry, then falls back to the parent if not found. +// Returns nil if the value is not found in the registry hierarchy. func (r *Registry) LookupValue(name string) any { r.mu.Lock() defer r.mu.Unlock() - return r.values[name] + + if value, ok := r.values[name]; ok { + return value + } + + if r.parent != nil { + return r.parent.LookupValue(name) + } + + return nil } -// LookupAction returns the action for the given key, or nil if there is none. +// LookupAction returns the action for the given key. +// It first checks the current registry, then falls back to the parent if not found. +// Returns nil if the action is not found in the registry hierarchy. func (r *Registry) LookupAction(key string) any { r.mu.Lock() defer r.mu.Unlock() - return r.actions[key] + + if action, ok := r.actions[key]; ok { + return action + } + + if r.parent != nil { + return r.parent.LookupAction(key) + } + + return nil } // ListActions returns a list of all registered actions. +// This includes actions from both the current registry and its parent hierarchy. +// Child registry actions take precedence over parent actions with the same key. func (r *Registry) ListActions() []any { r.mu.Lock() defer r.mu.Unlock() @@ -148,10 +202,24 @@ func (r *Registry) RegisterSpanProcessor(sp sdktrace.SpanProcessor) { } // ListValues returns a list of values of all registered values. +// This includes values from both the current registry and its parent hierarchy. +// Child registry values take precedence over parent values with the same key. func (r *Registry) ListValues() map[string]any { r.mu.Lock() defer r.mu.Unlock() - return r.values + + allValues := make(map[string]any) + + if r.parent != nil { + parentValues := r.parent.ListValues() + for key, value := range parentValues { + allValues[key] = value + } + } + + maps.Copy(allValues, r.values) + + return allValues } // An Environment is the execution context in which the program is running.