Skip to content

feat(go):Support custom configs for all primitives #2883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,16 @@ type embedder core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}]
func DefineEmbedder(
r *registry.Registry,
provider, name string,
options *EmbedderOptions,
embed func(context.Context, *EmbedRequest) (*EmbedResponse, error),
) Embedder {
return (*embedder)(core.DefineAction(r, provider, name, atype.Embedder, nil, embed))
metadata := map[string]any{}
metadata["type"] = "embedder"
metadata["info"] = options.Info
if options.ConfigSchema != nil {
metadata["embedder"] = map[string]any{"customOptions": options.ConfigSchema}
}
return (*embedder)(core.DefineAction(r, provider, name, atype.Embedder, metadata, embed))
}

// LookupEmbedder looks up an [Embedder] registered by [DefineEmbedder].
Expand Down
25 changes: 25 additions & 0 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,28 @@ type TraceMetadata struct {
Paths []*PathMetadata `json:"paths,omitempty"`
Timestamp float64 `json:"timestamp,omitempty"`
}

// EmbedderInfo represents the structure of the embedder information object.
// It mirrors the Zod schema EmbedderInfoSchema.
type EmbedderInfo struct {
// Friendly label for this model (e.g. "Google AI - Gemini Pro")
Label string `json:"label,omitempty"`

Supports *EmbedderSupports `json:"supports,omitempty"`

Dimensions int `json:"dimensions,omitempty"`
}

// EmbedderSupports represents the supported capabilities of the embedder model.
type EmbedderSupports struct {
// Model can input this type of data.
// Expected values could be "text", "image", "video", but the struct
Input []string `json:"input,omitempty"`

Multilingual bool `json:"multilingual,omitempty"`
}

type EmbedderOptions struct {
ConfigSchema map[string]any `json:"configSchema,omitempty"`
Info *EmbedderInfo `json:"info,omitempty"`
}
6 changes: 4 additions & 2 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ type noStream = func(context.Context, struct{}) error
// DefineAction creates a new non-streaming Action and registers it.
func DefineAction[In, Out any](
r *registry.Registry,
provider, name string,
provider,
name string,
atype atype.ActionType,
metadata map[string]any,
fn Func[In, Out],
Expand Down Expand Up @@ -116,7 +117,8 @@ func DefineActionWithInputSchema[Out any](
// defineAction creates an action and registers it with the given Registry.
func defineAction[In, Out, Stream any](
r *registry.Registry,
provider, name string,
provider,
name string,
atype atype.ActionType,
metadata map[string]any,
inputSchema *jsonschema.Schema,
Expand Down
4 changes: 2 additions & 2 deletions go/genkit/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -703,8 +703,8 @@ func LookupIndexer(g *Genkit, provider, name string) ai.Indexer {
// The `provider` and `name` form the unique identifier. The `embed` function
// contains the logic to process an [ai.EmbedRequest] (containing documents or a query)
// and return an [ai.EmbedResponse] (containing the corresponding embeddings).
func DefineEmbedder(g *Genkit, provider, name string, embed func(context.Context, *ai.EmbedRequest) (*ai.EmbedResponse, error)) ai.Embedder {
return ai.DefineEmbedder(g.reg, provider, name, embed)
func DefineEmbedder(g *Genkit, provider string, name string, options *ai.EmbedderOptions, embed func(context.Context, *ai.EmbedRequest) (*ai.EmbedResponse, error)) ai.Embedder {
return ai.DefineEmbedder(g.reg, provider, name, options, embed)
}

// LookupEmbedder retrieves a registered [ai.Embedder] by its provider and name.
Expand Down
32 changes: 30 additions & 2 deletions go/plugins/googlegenai/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,29 @@ type EmbedOptions struct {
TaskType string `json:"task_type,omitempty"`
}

type GeminiEmbeddingConfigSchema struct {
// Override the API key provided at plugin initialization.
APIKey string `json:"apiKey,omitempty"`

// The `task_type` parameter is defined as the intended downstream application
// to help the model produce better quality embeddings.
// NOTE: Assuming TaskTypeSchema resolves to a string. If it's a different
// complex type, this field would need to be adjusted accordingly (e.g., a struct).
TaskType *string `json:"taskType,omitempty"`

Title string `json:"title,omitempty"`

Version string `json:"version,omitempty"`

// The `outputDimensionality` parameter allows you to specify the dimensionality
// of the embedding output. By default, the model generates embeddings
// with 768 dimensions. Models such as `text-embedding-004`, `text-embedding-005`,
// and `text-multilingual-embedding-002` allow the output dimensionality
// to be adjusted between 1 and 768.
// NOTE: The min(1) and max(768) constraints
OutputDimensionality int `json:"outputDimensionality,omitempty"`
}

// configToMap converts a config struct to a map[string]any.
func configToMap(config any) map[string]any {
r := jsonschema.Reflector{
Expand Down Expand Up @@ -272,13 +295,18 @@ func defineModel(g *genkit.Genkit, client *genai.Client, name string, info ai.Mo

// DefineEmbedder defines embeddings for the provided contents and embedder
// model
func defineEmbedder(g *genkit.Genkit, client *genai.Client, name string) ai.Embedder {
func defineEmbedder(g *genkit.Genkit, client *genai.Client, name string, embedOptions ai.EmbedderOptions) ai.Embedder {
provider := googleAIProvider
if client.ClientConfig().Backend == genai.BackendVertexAI {
provider = vertexAIProvider
}

return genkit.DefineEmbedder(g, provider, name, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
emdOpts := &ai.EmbedderOptions{
Info: embedOptions.Info,
ConfigSchema: configToMap(&GeminiEmbeddingConfigSchema{}),
}

return genkit.DefineEmbedder(g, provider, name, emdOpts, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
var content []*genai.Content
var embedConfig *genai.EmbedContentConfig

Expand Down
16 changes: 8 additions & 8 deletions go/plugins/googlegenai/googlegenai.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ func (ga *GoogleAI) Init(ctx context.Context, g *genkit.Genkit) (err error) {
if err != nil {
return err
}
for _, e := range embedders {
defineEmbedder(g, ga.gclient, e)
for e, eOpts := range embedders {
defineEmbedder(g, ga.gclient, e, eOpts)
}

return nil
Expand Down Expand Up @@ -182,8 +182,8 @@ func (v *VertexAI) Init(ctx context.Context, g *genkit.Genkit) (err error) {
if err != nil {
return err
}
for _, e := range embedders {
defineEmbedder(g, v.gclient, e)
for e, eOpts := range embedders {
defineEmbedder(g, v.gclient, e, eOpts)
}

return nil
Expand Down Expand Up @@ -248,23 +248,23 @@ func (v *VertexAI) DefineModel(g *genkit.Genkit, name string, info *ai.ModelInfo
}

// DefineEmbedder defines an embedder with a given name.
func (ga *GoogleAI) DefineEmbedder(g *genkit.Genkit, name string) (ai.Embedder, error) {
func (ga *GoogleAI) DefineEmbedder(g *genkit.Genkit, name string, embedOptions ai.EmbedderOptions) (ai.Embedder, error) {
ga.mu.Lock()
defer ga.mu.Unlock()
if !ga.initted {
return nil, errors.New("GoogleAI plugin not initialized")
}
return defineEmbedder(g, ga.gclient, name), nil
return defineEmbedder(g, ga.gclient, name, embedOptions), nil
}

// DefineEmbedder defines an embedder with a given name.
func (v *VertexAI) DefineEmbedder(g *genkit.Genkit, name string) (ai.Embedder, error) {
func (v *VertexAI) DefineEmbedder(g *genkit.Genkit, name string, embedOptions ai.EmbedderOptions) (ai.Embedder, error) {
v.mu.Lock()
defer v.mu.Unlock()
if !v.initted {
return nil, errors.New("VertexAI plugin not initialized")
}
return defineEmbedder(g, v.gclient, name), nil
return defineEmbedder(g, v.gclient, name, embedOptions), nil
}

// IsDefinedEmbedder reports whether the named [Embedder] is defined by this plugin.
Expand Down
124 changes: 108 additions & 16 deletions go/plugins/googlegenai/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,17 @@ const (

gemini25FlashPreview0417 = "gemini-2.5-flash-preview-04-17"

gemini25ProExp0325 = "gemini-2.5-pro-exp-03-25"
gemini25ProPreview0325 = "gemini-2.5-pro-preview-03-25"
gemini25ProPreview0506 = "gemini-2.5-pro-preview-05-06"
gemini25ProExp0325 = "gemini-2.5-pro-exp-03-25"
gemini25ProPreview0325 = "gemini-2.5-pro-preview-03-25"
gemini25ProPreview0506 = "gemini-2.5-pro-preview-05-06"
textembedding004 = "text-embedding-004"
embedding001 = "embedding-001"
textembeddinggecko003 = "textembedding-gecko@003"
textembeddinggecko002 = "textembedding-gecko@002"
textembeddinggecko001 = "textembedding-gecko@001"
textembeddinggeckomultilingual001 = "textembedding-gecko-multilingual@001"
textmultilingualembedding002 = "text-multilingual-embedding-002"
multimodalembedding = "multimodalembedding"
)

var (
Expand Down Expand Up @@ -163,18 +171,97 @@ var (
}

googleAIEmbedders = []string{
"text-embedding-004",
"embedding-001",
textembedding004,
embedding001,
}

googleAIEmbedderConfig = map[string]ai.EmbedderOptions{
textembedding004: {
Info: &ai.EmbedderInfo{
Dimensions: 768,
Label: "Google Gen AI - Text Embedding 001",
Supports: &ai.EmbedderSupports{
Input: []string{"text"},
},
},
},
embedding001: {
Info: &ai.EmbedderInfo{
Dimensions: 768,
Label: "Google Gen AI - Text Embedding Gecko (Legacy)",
Supports: &ai.EmbedderSupports{
Input: []string{"text"},
},
},
},
textembeddinggecko003: {
Info: &ai.EmbedderInfo{
Dimensions: 768,
Label: "Google Gen AI - Text Embedding Gecko 003",
Supports: &ai.EmbedderSupports{
Input: []string{"text"},
},
},
},
textembeddinggecko002: {
Info: &ai.EmbedderInfo{
Dimensions: 768,
Label: "Vertex AI - Text Embedding Gecko 002",
Supports: &ai.EmbedderSupports{
Input: []string{"text"},
},
},
},
textembeddinggecko001: {
Info: &ai.EmbedderInfo{
Dimensions: 768,
Label: "Vertex AI - Text Embedding Gecko 001",
Supports: &ai.EmbedderSupports{
Input: []string{"text"},
},
},
},
textembeddinggeckomultilingual001: {
Info: &ai.EmbedderInfo{
Dimensions: 768,
Label: "Vertex AI - Text Embedding Gecko Multilingual 001",
Supports: &ai.EmbedderSupports{
Input: []string{"text"},
},
},
},
textmultilingualembedding002: {
Info: &ai.EmbedderInfo{
Dimensions: 768,
Label: "Vertex AI - Text Multilingual Embedding 001",
Supports: &ai.EmbedderSupports{
Input: []string{"text"},
},
},
},
multimodalembedding: {
Info: &ai.EmbedderInfo{
Dimensions: 768,
Label: "Google Gen AI - Text Embedding Gecko (Legacy)",
Supports: &ai.EmbedderSupports{ // Supports object is present
Input: []string{
"text",
"image",
"video",
},
},
},
},
}

vertexAIEmbedders = []string{
"textembedding-gecko@003",
"textembedding-gecko@002",
"textembedding-gecko@001",
"text-embedding-004",
"textembedding-gecko-multilingual@001",
"text-multilingual-embedding-002",
"multimodalembedding",
textembeddinggecko003,
textembeddinggecko002,
textembeddinggecko001,
textembedding004,
textembeddinggeckomultilingual001,
textmultilingualembedding002,
multimodalembedding,
}
)

Expand Down Expand Up @@ -212,17 +299,22 @@ func listModels(provider string) (map[string]ai.ModelInfo, error) {

// listEmbedders returns a list of supported embedders based on the
// detected backend
func listEmbedders(backend genai.Backend) ([]string, error) {
embedders := []string{}
func listEmbedders(backend genai.Backend) (map[string]ai.EmbedderOptions, error) {
embeddersNames := []string{}

switch backend {
case genai.BackendGeminiAPI:
embedders = googleAIEmbedders
embeddersNames = googleAIEmbedders
case genai.BackendVertexAI:
embedders = vertexAIEmbedders
embeddersNames = vertexAIEmbedders
default:
return nil, fmt.Errorf("embedders for backend %s not found", backend)
}

embedders := make(map[string]ai.EmbedderOptions, 0)
for _, n := range embeddersNames {
embedders[n] = googleAIEmbedderConfig[n]
}

return embedders, nil
}
Loading