Skip to content

feat(go): add custom configs for all primitives #2883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/base"
"github.com/firebase/genkit/go/internal/registry"
)

Expand All @@ -32,17 +33,60 @@ type Embedder interface {
Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error)
}

// EmbedderInfo represents the structure of the embedder information object.
type EmbedderInfo struct {
// Friendly label for this model (e.g. "Google AI - Gemini Pro")
Label string `json:"label,omitempty"`

Supports *EmbedderSupports `json:"supports,omitempty"`

Dimensions int `json:"dimensions,omitempty"`
}

// EmbedderSupports represents the supported capabilities of the embedder model.
type EmbedderSupports struct {
// Model can input this type of data.
// Expected values could be "text", "image", "video", but the struct
Input []string `json:"input,omitempty"`

Multilingual bool `json:"multilingual,omitempty"`
}

type EmbedderOptions struct {
ConfigSchema any `json:"configSchema,omitempty"`
Info *EmbedderInfo `json:"info,omitempty"`
}

// An embedder is used to convert a document to a multidimensional vector.
type embedder core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}]

func configToMap(config any) map[string]any {
schema := base.InferJSONSchema(config)
result := base.SchemaAsMap(schema)
return result
}

// DefineEmbedder registers the given embed function as an action, and returns an
// [Embedder] that runs it.
func DefineEmbedder(
r *registry.Registry,
provider, name string,
options *EmbedderOptions,
embed func(context.Context, *EmbedRequest) (*EmbedResponse, error),
) Embedder {
return (*embedder)(core.DefineAction(r, provider, name, core.ActionTypeEmbedder, nil, embed))
metadata := map[string]any{}
metadata["type"] = "embedder"
metadata["info"] = options.Info
if options.ConfigSchema != nil {
metadata["embedder"] = map[string]any{"customOptions": configToMap(options.ConfigSchema)}
}
inputSchema := base.InferJSONSchema(EmbedRequest{})
if inputSchema.Properties != nil && options.ConfigSchema != nil {
if _, ok := inputSchema.Properties.Get("options"); ok {
inputSchema.Properties.Set("options", base.InferJSONSchema(options.ConfigSchema))
}
}
return (*embedder)(core.DefineActionWithInputSchema(r, provider, name, core.ActionTypeEmbedder, metadata, inputSchema, embed))
}

// LookupEmbedder looks up an [Embedder] registered by [DefineEmbedder].
Expand Down
24 changes: 22 additions & 2 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,32 @@ type Retriever interface {
Retrieve(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error)
}

type RetrieverInfo struct {
Label string `json:"label,omitempty"`
Supports *MediaSupports `json:"supports,omitempty"`
}

type MediaSupports struct {
Media bool `json:"media,omitempty"`
}

type RetrieverOptions struct {
ConfigSchema any
Info *RetrieverInfo
}

type retriever core.ActionDef[*RetrieverRequest, *RetrieverResponse, struct{}]

// DefineRetriever registers the given retrieve function as an action, and returns a
// [Retriever] that runs it.
func DefineRetriever(r *registry.Registry, provider, name string, fn RetrieverFunc) Retriever {
return (*retriever)(core.DefineAction(r, provider, name, core.ActionTypeRetriever, nil, fn))
func DefineRetriever(r *registry.Registry, provider, name string, options *RetrieverOptions, fn RetrieverFunc) Retriever {
metadata := map[string]any{}
metadata["type"] = "retriever"
metadata["info"] = options.Info
if options.ConfigSchema != nil {
metadata["retriever"] = map[string]any{"customOptions": configToMap(options.ConfigSchema)}
}
return (*retriever)(core.DefineAction(r, provider, name, core.ActionTypeRetriever, metadata, fn))
}

// LookupRetriever looks up a [Retriever] registered by [DefineRetriever].
Expand Down
14 changes: 8 additions & 6 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ type noStream = func(context.Context, struct{}) error
// DefineAction creates a new non-streaming Action and registers it.
func DefineAction[In, Out any](
r *registry.Registry,
provider, name string,
provider,
name string,
atype ActionType,
metadata map[string]any,
fn Func[In, Out],
Expand All @@ -121,24 +122,25 @@ func DefineStreamingAction[In, Out, Stream any](
// This differs from DefineAction in that the input schema is
// defined dynamically; the static input type is "any".
// This is used for prompts.
func DefineActionWithInputSchema[Out any](
func DefineActionWithInputSchema[In, Out any](
r *registry.Registry,
provider, name string,
atype ActionType,
metadata map[string]any,
inputSchema *jsonschema.Schema,
fn Func[any, Out],
) *ActionDef[any, Out, struct{}] {
fn Func[In, Out],
) *ActionDef[In, Out, struct{}] {
return defineAction(r, provider, name, atype, metadata, inputSchema,
func(ctx context.Context, in any, _ noStream) (Out, error) {
func(ctx context.Context, in In, _ noStream) (Out, error) {
return fn(ctx, in)
})
}

// defineAction creates an action and registers it with the given Registry.
func defineAction[In, Out, Stream any](
r *registry.Registry,
provider, name string,
provider,
name string,
atype ActionType,
metadata map[string]any,
inputSchema *jsonschema.Schema,
Expand Down
8 changes: 4 additions & 4 deletions go/genkit/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -651,8 +651,8 @@ func GenerateData[Out any](ctx context.Context, g *Genkit, opts ...ai.GenerateOp
// The `provider` and `name` form the unique identifier. The `ret` function
// contains the logic to process an [ai.RetrieverRequest] (containing the query)
// and return an [ai.RetrieverResponse] (containing the relevant documents).
func DefineRetriever(g *Genkit, provider, name string, ret func(context.Context, *ai.RetrieverRequest) (*ai.RetrieverResponse, error)) ai.Retriever {
return ai.DefineRetriever(g.reg, provider, name, ret)
func DefineRetriever(g *Genkit, provider, name string, options *ai.RetrieverOptions, ret func(context.Context, *ai.RetrieverRequest) (*ai.RetrieverResponse, error)) ai.Retriever {
return ai.DefineRetriever(g.reg, provider, name, options, ret)
}

// LookupRetriever retrieves a registered [ai.Retriever] by its provider and name.
Expand All @@ -669,8 +669,8 @@ func LookupRetriever(g *Genkit, provider, name string) ai.Retriever {
// The `provider` and `name` form the unique identifier. The `embed` function
// contains the logic to process an [ai.EmbedRequest] (containing documents or a query)
// and return an [ai.EmbedResponse] (containing the corresponding embeddings).
func DefineEmbedder(g *Genkit, provider, name string, embed func(context.Context, *ai.EmbedRequest) (*ai.EmbedResponse, error)) ai.Embedder {
return ai.DefineEmbedder(g.reg, provider, name, embed)
func DefineEmbedder(g *Genkit, provider string, name string, options *ai.EmbedderOptions, embed func(context.Context, *ai.EmbedRequest) (*ai.EmbedResponse, error)) ai.Embedder {
return ai.DefineEmbedder(g.reg, provider, name, options, embed)
}

// LookupEmbedder retrieves a registered [ai.Embedder] by its provider and name.
Expand Down
14 changes: 13 additions & 1 deletion go/internal/doc-snippets/pinecone.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,22 @@ func pineconeEx(ctx context.Context) error {
var docChunks []*ai.Document

// [START defineretriever]
retOpts := &ai.RetrieverOptions{
ConfigSchema: pinecone.PineconeRetrieverOptions{
K: 1,
Namespace: "Test",
},
Info: &ai.RetrieverInfo{
Label: "Pinecone",
Supports: &ai.MediaSupports{
Media: false,
},
},
}
ds, menuRetriever, err := pinecone.DefineRetriever(ctx, g, pinecone.Config{
IndexID: "menu_data", // Your Pinecone index
Embedder: googlegenai.GoogleAIEmbedder(g, "text-embedding-004"), // Embedding model of your choice
})
}, retOpts)
if err != nil {
return err
}
Expand Down
48 changes: 48 additions & 0 deletions go/internal/doc-snippets/rag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,25 @@ func main() {
if err != nil {
log.Fatal(err)
}
retOpts := &ai.RetrieverOptions{
ConfigSchema: localvec.RetrieverOptions{
K: 3,
},
Info: &ai.RetrieverInfo{
Label: "menuQA",
Supports: &ai.MediaSupports{
Media: false,
},
},
}

docStore, _, err := localvec.DefineRetriever(
g,
"menuQA",
localvec.Config{
Embedder: googlegenai.VertexAIEmbedder(g, "text-embedding-004"),
},
retOpts,
)
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -155,12 +167,25 @@ func menuQA() {

model := googlegenai.VertexAIModel(g, "gemini-1.5-flash")

retOpts := &ai.RetrieverOptions{
ConfigSchema: localvec.RetrieverOptions{
K: 3,
},
Info: &ai.RetrieverInfo{
Label: "menuQA",
Supports: &ai.MediaSupports{
Media: false,
},
},
}

_, menuPdfRetriever, err := localvec.DefineRetriever(
g,
"menuQA",
localvec.Config{
Embedder: googlegenai.VertexAIEmbedder(g, "text-embedding-004"),
},
retOpts,
)
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -207,23 +232,46 @@ func customret() {
log.Fatal(err)
}

retOpts := &ai.RetrieverOptions{
ConfigSchema: localvec.RetrieverOptions{
K: 3,
},
Info: &ai.RetrieverInfo{
Label: "menuQA",
Supports: &ai.MediaSupports{
Media: false,
},
},
}

_, menuPDFRetriever, _ := localvec.DefineRetriever(
g,
"menuQA",
localvec.Config{
Embedder: googlegenai.VertexAIEmbedder(g, "text-embedding-004"),
},
retOpts,
)

// [START customret]
type CustomMenuRetrieverOptions struct {
K int
PreRerankK int
}
genRetOpts := &ai.RetrieverOptions{
ConfigSchema: CustomMenuRetrieverOptions{},
Info: &ai.RetrieverInfo{
Label: "advancedMenuRetriever",
Supports: &ai.MediaSupports{
Media: false,
},
},
}
advancedMenuRetriever := genkit.DefineRetriever(
g,
"custom",
"advancedMenuRetriever",
genRetOpts,
func(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
// Handle options passed using our custom type.
opts, _ := req.Options.(CustomMenuRetrieverOptions)
Expand Down
12 changes: 11 additions & 1 deletion go/internal/fakeembedder/fakeembedder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,17 @@ func TestFakeEmbedder(t *testing.T) {
}

embed := New()
emb := ai.DefineEmbedder(r, "fake", "embed", embed.Embed)
emdOpts := &ai.EmbedderOptions{
Info: &ai.EmbedderInfo{
Dimensions: 32,
Label: "embed",
Supports: &ai.EmbedderSupports{
Input: []string{"text"},
},
},
ConfigSchema: nil,
}
emb := ai.DefineEmbedder(r, "fake", "embed", emdOpts, embed.Embed)
d := ai.DocumentFromText("fakeembedder test", nil)

vals := []float32{1, 2}
Expand Down
4 changes: 2 additions & 2 deletions go/plugins/compat_oai/compat_oai.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ func (o *OpenAICompatible) DefineModel(g *genkit.Genkit, provider, name string,
}

// DefineEmbedder defines an embedder with a given name.
func (o *OpenAICompatible) DefineEmbedder(g *genkit.Genkit, provider, name string) (ai.Embedder, error) {
func (o *OpenAICompatible) DefineEmbedder(g *genkit.Genkit, provider, name string, embedOptions ai.EmbedderOptions) (ai.Embedder, error) {
o.mu.Lock()
defer o.mu.Unlock()
if !o.initted {
return nil, errors.New("OpenAICompatible.Init not called")
}

return genkit.DefineEmbedder(g, provider, name, func(ctx context.Context, input *ai.EmbedRequest) (*ai.EmbedResponse, error) {
return genkit.DefineEmbedder(g, provider, name, &embedOptions, func(ctx context.Context, input *ai.EmbedRequest) (*ai.EmbedResponse, error) {
var data openaiGo.EmbeddingNewParamsInputArrayOfStrings
for _, doc := range input.Input {
for _, p := range doc.Content {
Expand Down
Loading
Loading