Skip to content

Commit 7e5db41

Browse files
committed
Merge branch 'main' into pj/js-googleai-imagen-veo
2 parents 843a8a2 + 00407cf commit 7e5db41

File tree

27 files changed

+765
-298
lines changed

27 files changed

+765
-298
lines changed

genkit-tools/common/src/types/eval.ts

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import type {
2323
ListEvalKeysResponse,
2424
UpdateDatasetRequest,
2525
} from './apis';
26-
import { GenerateRequestSchema } from './model';
26+
import { GenerateActionOptionsSchema, GenerateRequestSchema } from './model';
2727

2828
/**
2929
* This file defines schema and types that are used by the Eval store.
@@ -56,6 +56,17 @@ export const GenerateRequestJSONSchema = zodToJsonSchema(
5656
}
5757
) as JSONSchema7;
5858

59+
/**
60+
* Combined GenerateInput JSON schema to support eval-inference using models
61+
*/
62+
export const GenerateInputJSONSchema = zodToJsonSchema(
63+
z.union([GenerateRequestSchema, GenerateActionOptionsSchema]),
64+
{
65+
$refStrategy: 'none',
66+
removeAdditionalStrategy: 'strict',
67+
}
68+
) as JSONSchema7;
69+
5970
/**
6071
* A single sample to be used for inference.
6172
**/

go/ai/generate.go

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,29 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
372372
modelName = genOpts.ModelName
373373
}
374374

375+
var dynamicTools []Tool
375376
tools := make([]string, len(genOpts.Tools))
376-
for i, tool := range genOpts.Tools {
377-
tools[i] = tool.Name()
377+
toolNames := make(map[string]bool)
378+
for i, toolRef := range genOpts.Tools {
379+
name := toolRef.Name()
380+
// Redundant duplicate tool check with GenerateWithRequest otherwise we will panic when we register the dynamic tools.
381+
if toolNames[name] {
382+
return nil, core.NewError(core.INVALID_ARGUMENT, "ai.Generate: duplicate tool %q", name)
383+
}
384+
toolNames[name] = true
385+
tools[i] = name
386+
// Dynamic tools wouldn't have been registered by this point.
387+
if LookupTool(r, name) == nil {
388+
if tool, ok := toolRef.(Tool); ok {
389+
dynamicTools = append(dynamicTools, tool)
390+
}
391+
}
392+
}
393+
if len(dynamicTools) > 0 {
394+
r = r.NewChild()
395+
for _, tool := range dynamicTools {
396+
tool.Register(r)
397+
}
378398
}
379399

380400
messages := []*Message{}
@@ -596,7 +616,7 @@ func handleToolRequests(ctx context.Context, r *registry.Registry, req *ModelReq
596616

597617
output, err := tool.RunRaw(ctx, toolReq.Input)
598618
if err != nil {
599-
var tie *ToolInterruptError
619+
var tie *toolInterruptError
600620
if errors.As(err, &tie) {
601621
logger.FromContext(ctx).Debug("tool %q triggered an interrupt: %v", toolReq.Name, tie.Metadata)
602622

@@ -636,7 +656,7 @@ func handleToolRequests(ctx context.Context, r *registry.Registry, req *ModelReq
636656
for range toolCount {
637657
res := <-resultChan
638658
if res.err != nil {
639-
var tie *ToolInterruptError
659+
var tie *toolInterruptError
640660
if errors.As(res.err, &tie) {
641661
hasInterrupts = true
642662
continue
@@ -878,7 +898,7 @@ func handleResumedToolRequest(ctx context.Context, r *registry.Registry, genOpts
878898

879899
output, err := tool.RunRaw(resumedCtx, restartPart.ToolRequest.Input)
880900
if err != nil {
881-
var tie *ToolInterruptError
901+
var tie *toolInterruptError
882902
if errors.As(err, &tie) {
883903
logger.FromContext(ctx).Debug("tool %q triggered an interrupt: %v", restartPart.ToolRequest.Name, tie.Metadata)
884904

go/ai/generate_test.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,124 @@ func TestGenerate(t *testing.T) {
621621
t.Errorf("got text %q, want %q", res.Text(), expectedText)
622622
}
623623
})
624+
625+
t.Run("registers dynamic tools", func(t *testing.T) {
626+
// Create a tool that is NOT registered in the global registry
627+
dynamicTool := NewTool("dynamicTestTool", "a tool that is dynamically registered",
628+
func(ctx *ToolContext, input struct {
629+
Message string
630+
}) (string, error) {
631+
return "Dynamic: " + input.Message, nil
632+
},
633+
)
634+
635+
// Verify the tool is not in the global registry
636+
if LookupTool(r, "dynamicTestTool") != nil {
637+
t.Fatal("dynamicTestTool should not be registered in global registry")
638+
}
639+
640+
// Create a model that will call the dynamic tool then provide a final response
641+
roundCount := 0
642+
info := &ModelInfo{
643+
Supports: &ModelSupports{
644+
Multiturn: true,
645+
Tools: true,
646+
},
647+
}
648+
toolCallModel := DefineModel(r, "test", "toolcall", info,
649+
func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
650+
roundCount++
651+
if roundCount == 1 {
652+
// First response: call the dynamic tool
653+
return &ModelResponse{
654+
Request: gr,
655+
Message: &Message{
656+
Role: RoleModel,
657+
Content: []*Part{
658+
NewToolRequestPart(&ToolRequest{
659+
Name: "dynamicTestTool",
660+
Input: map[string]any{"Message": "Hello from dynamic tool"},
661+
}),
662+
},
663+
},
664+
}, nil
665+
}
666+
// Second response: provide final answer based on tool response
667+
var toolResult string
668+
for _, msg := range gr.Messages {
669+
if msg.Role == RoleTool {
670+
for _, part := range msg.Content {
671+
if part.ToolResponse != nil {
672+
toolResult = part.ToolResponse.Output.(string)
673+
}
674+
}
675+
}
676+
}
677+
return &ModelResponse{
678+
Request: gr,
679+
Message: &Message{
680+
Role: RoleModel,
681+
Content: []*Part{
682+
NewTextPart(toolResult),
683+
},
684+
},
685+
}, nil
686+
})
687+
688+
// Use Generate with the dynamic tool - this should trigger the dynamic registration
689+
res, err := Generate(context.Background(), r,
690+
WithModel(toolCallModel),
691+
WithPrompt("call the dynamic tool"),
692+
WithTools(dynamicTool),
693+
)
694+
if err != nil {
695+
t.Fatal(err)
696+
}
697+
698+
// The tool should have been called and returned a response
699+
expectedText := "Dynamic: Hello from dynamic tool"
700+
if res.Text() != expectedText {
701+
t.Errorf("expected text %q, got %q", expectedText, res.Text())
702+
}
703+
704+
// Verify two rounds were executed: tool call + final response
705+
if roundCount != 2 {
706+
t.Errorf("expected 2 rounds, got %d", roundCount)
707+
}
708+
709+
// Verify the tool is still not in the global registry (it was registered in a child)
710+
if LookupTool(r, "dynamicTestTool") != nil {
711+
t.Error("dynamicTestTool should not be registered in global registry after generation")
712+
}
713+
})
714+
715+
t.Run("handles duplicate dynamic tools", func(t *testing.T) {
716+
// Create two tools with the same name
717+
dynamicTool1 := NewTool("duplicateTool", "first tool",
718+
func(ctx *ToolContext, input any) (string, error) {
719+
return "tool1", nil
720+
},
721+
)
722+
dynamicTool2 := NewTool("duplicateTool", "second tool",
723+
func(ctx *ToolContext, input any) (string, error) {
724+
return "tool2", nil
725+
},
726+
)
727+
728+
// Using both tools should result in an error
729+
_, err := Generate(context.Background(), r,
730+
WithModel(echoModel),
731+
WithPrompt("test duplicate tools"),
732+
WithTools(dynamicTool1, dynamicTool2),
733+
)
734+
735+
if err == nil {
736+
t.Fatal("expected error for duplicate tool names")
737+
}
738+
if !strings.Contains(err.Error(), "duplicate tool \"duplicateTool\"") {
739+
t.Errorf("unexpected error message: %v", err)
740+
}
741+
})
624742
}
625743

626744
func TestModelVersion(t *testing.T) {

go/ai/option.go

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ type ConfigOption interface {
4747
applyEmbedder(*embedderOptions) error
4848
applyRetriever(*retrieverOptions) error
4949
applyEvaluator(*evaluatorOptions) error
50-
applyIndexer(*indexerOptions) error
5150
}
5251

5352
// applyConfig applies the option to the config options.
@@ -96,11 +95,6 @@ func (o *configOptions) applyEvaluator(opts *evaluatorOptions) error {
9695
return o.applyConfig(&opts.configOptions)
9796
}
9897

99-
// applyIndexer applies the option to the indexer options.
100-
func (o *configOptions) applyIndexer(opts *indexerOptions) error {
101-
return o.applyConfig(&opts.configOptions)
102-
}
103-
10498
// WithConfig sets the configuration.
10599
func WithConfig(config any) ConfigOption {
106100
return &configOptions{Config: config}
@@ -580,7 +574,6 @@ type DocumentOption interface {
580574
applyPromptExecute(*promptExecutionOptions) error
581575
applyEmbedder(*embedderOptions) error
582576
applyRetriever(*retrieverOptions) error
583-
applyIndexer(*indexerOptions) error
584577
}
585578

586579
// applyDocument applies the option to the context options.
@@ -615,11 +608,6 @@ func (o *documentOptions) applyRetriever(retOpts *retrieverOptions) error {
615608
return o.applyDocument(&retOpts.documentOptions)
616609
}
617610

618-
// applyIndexer applies the option to the indexer options.
619-
func (o *documentOptions) applyIndexer(idxOpts *indexerOptions) error {
620-
return o.applyDocument(&idxOpts.documentOptions)
621-
}
622-
623611
// WithTextDocs sets the text to be used as context documents for generation or as input to an embedder.
624612
func WithTextDocs(text ...string) DocumentOption {
625613
docs := make([]*Document, len(text))
@@ -730,31 +718,6 @@ func (o *retrieverOptions) applyRetriever(retOpts *retrieverOptions) error {
730718
return nil
731719
}
732720

733-
// indexerOptions holds configuration and input for an embedder request.
734-
type indexerOptions struct {
735-
configOptions
736-
documentOptions
737-
}
738-
739-
// IndexerOption is an option for configuring an embedder request.
740-
// It applies only to [Index].
741-
type IndexerOption interface {
742-
applyIndexer(*indexerOptions) error
743-
}
744-
745-
// applyIndexer applies the option to the indexer options.
746-
func (o *indexerOptions) applyIndexer(idxOpts *indexerOptions) error {
747-
if err := o.applyConfig(&idxOpts.configOptions); err != nil {
748-
return err
749-
}
750-
751-
if err := o.applyDocument(&idxOpts.documentOptions); err != nil {
752-
return err
753-
}
754-
755-
return nil
756-
}
757-
758721
// generateOptions are options for generating a model response by calling a model directly.
759722
type generateOptions struct {
760723
commonGenOptions

0 commit comments

Comments
 (0)