Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions genkit-tools/genkit-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,74 @@
],
"additionalProperties": false
},
"EmbedderInfo": {
"type": "object",
"properties": {
"label": {
"type": "string"
},
"dimensions": {
"type": "number"
},
"supports": {
"$ref": "#/$defs/EmbedderSupports"
}
},
"additionalProperties": false
},
"EmbedderOptions": {
"type": "object",
"properties": {
"label": {
"type": "string"
},
"dimensions": {
"type": "number"
},
"supports": {
"$ref": "#/$defs/EmbedderSupports"
},
"configSchema": {
"type": "object",
"additionalProperties": {}
}
},
"additionalProperties": false
},
"EmbedderRef": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"info": {
"$ref": "#/$defs/EmbedderInfo"
},
"config": {},
"version": {
"type": "string"
}
},
"required": [
"name"
],
"additionalProperties": false
},
"EmbedderSupports": {
"type": "object",
"properties": {
"input": {
"type": "array",
"items": {
"type": "string"
}
},
"multiturn": {
"type": "boolean"
}
},
"additionalProperties": false
},
"Embedding": {
"type": "object",
"properties": {
Expand Down
25 changes: 25 additions & 0 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,31 @@ type EmbedResponse struct {
Embeddings []*Embedding `json:"embeddings,omitempty"`
}

type EmbedderInfo struct {
Dimensions float64 `json:"dimensions,omitempty"`
Label string `json:"label,omitempty"`
Supports *EmbedderSupports `json:"supports,omitempty"`
}

type EmbedderOptions struct {
ConfigSchema map[string]any `json:"configSchema,omitempty"`
Dimensions float64 `json:"dimensions,omitempty"`
Label string `json:"label,omitempty"`
Supports *EmbedderSupports `json:"supports,omitempty"`
}

type EmbedderRef struct {
Config any `json:"config,omitempty"`
Info *EmbedderInfo `json:"info,omitempty"`
Name string `json:"name,omitempty"`
Version string `json:"version,omitempty"`
}

type EmbedderSupports struct {
Input []string `json:"input,omitempty"`
Multiturn bool `json:"multiturn,omitempty"`
}

type Embedding struct {
Embedding []float32 `json:"embedding,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Expand Down
22 changes: 22 additions & 0 deletions py/bin/sanitize_schema_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,29 @@ def add_header(content: str) -> str:
]
cleaned_content = '\n'.join(filtered_lines)

# Inject EmbedderFn type alias after imports
embedder_fn_injection = """
# EmbedderFn type alias for embedder functions
from typing import Callable, Awaitable

EmbedderFn = Callable[['EmbedRequest'], Awaitable['EmbedResponse']]
"""

final_output = header_text + future_import + '\n' + str_enum_block + '\n\n' + cleaned_content

# Insert EmbedderFn after the typing imports but before class definitions
# Find the position after "from pydantic import" and "from typing import"
lines = final_output.split('\n')
insert_index = -1
for i, line in enumerate(lines):
if line.startswith('from pydantic import'):
insert_index = i + 1
break

if insert_index > 0:
lines.insert(insert_index, embedder_fn_injection)
final_output = '\n'.join(lines)

if not final_output.endswith('\n'):
final_output += '\n'
return final_output
Expand Down
29 changes: 28 additions & 1 deletion py/packages/genkit/src/genkit/ai/_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class while customizing it with any plugins.

from genkit.aio import Channel
from genkit.blocks.document import Document
from genkit.blocks.embedding import EmbedRequest, EmbedResponse
from genkit.blocks.generate import (
StreamingCallback as ModelStreamingCallback,
generate_action,
Expand All @@ -39,6 +38,7 @@ class while customizing it with any plugins.
from genkit.blocks.prompt import PromptConfig, to_generate_action_options
from genkit.core.action import ActionRunContext
from genkit.core.action.types import ActionKind
from genkit.core.typing import EmbedRequest, EmbedResponse, EmbedderRef
from genkit.types import (
DocumentData,
GenerationCommonConfig,
Expand Down Expand Up @@ -335,3 +335,30 @@ async def retrieve(
retrieve_action = self.registry.lookup_action(ActionKind.RETRIEVER, retriever)

return (await retrieve_action.arun(RetrieverRequest(query=query, options=options))).response

async def embed(
self,
embedder: str | EmbedderRef | None = None,
documents: list[Document] | None = None,
options: dict[str, Any] | None = None,
) -> EmbedResponse:
embedder_name: str
embedder_config: dict[str, Any] = {}

if isinstance(embedder, EmbedderRef):
embedder_name = embedder.name
embedder_config = embedder.config or {}
if embedder.version:
embedder_config['version'] = embedder.version # Handle version from ref
elif isinstance(embedder, str):
embedder_name = embedder
else:
# Handle case where embedder is None
raise ValueError('Embedder must be specified as a string name or an EmbedderRef.')

# Merge options passed to embed() with config from EmbedderRef
final_options = {**(embedder_config or {}), **(options or {})}

embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, embedder_name)

return (await embed_action.arun(EmbedRequest(input=documents, options=final_options))).response
27 changes: 18 additions & 9 deletions py/packages/genkit/src/genkit/ai/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import structlog
from pydantic import BaseModel

from genkit.blocks.embedding import EmbedderFn
from genkit.blocks.evaluator import BatchEvaluatorFn, EvaluatorFn
from genkit.blocks.formats.types import FormatDef
from genkit.blocks.model import ModelFn, ModelMiddleware
Expand All @@ -72,6 +71,8 @@
Score,
SpanMetadata,
ToolChoice,
EmbedderOptions,
EmbedderFn,
)

EVALUATOR_METADATA_KEY_DISPLAY_NAME = 'evaluatorDisplayName'
Expand Down Expand Up @@ -458,8 +459,7 @@ def define_embedder(
self,
name: str,
fn: EmbedderFn,
config_schema: BaseModel | dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
options: EmbedderOptions | None = None,
description: str | None = None,
) -> Action:
"""Define a custom embedder action.
Expand All @@ -471,19 +471,28 @@ def define_embedder(
metadata: Optional metadata for the model.
description: Optional description for the embedder.
"""
embedder_meta: dict[str, Any] = metadata if metadata else {}
if 'embedder' not in embedder_meta:
embedder_meta['embedder'] = {}
embedder_metadata: dict[str, Any] = {}
if options:
if options.label:
embedder_metadata['embedder']['label'] = options.label
if options.dimensions:
embedder_metadata['embedder']['dimensions'] = options.dimensions
if options.supports:
embedder_metadata['embedder']['supports'] = options.supports.model_dump(
exclude_none=True, by_alias=True
)
if options.config_schema:
embedder_metadata['embedder']['customOptions'] = to_json_schema(options.config_schema)

if config_schema:
embedder_meta['embedder']['customOptions'] = to_json_schema(config_schema)
if 'embedder' not in embedder_metadata:
embedder_metadata['embedder'] = {}

embedder_description = get_func_description(fn, description)
return self.registry.register_action(
name=name,
kind=ActionKind.EMBEDDER,
fn=fn,
metadata=embedder_meta,
metadata=embedder_metadata,
description=embedder_description,
)

Expand Down
31 changes: 24 additions & 7 deletions py/packages/genkit/src/genkit/blocks/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,40 @@
from genkit.ai import ActionKind
from genkit.core.action import ActionMetadata
from genkit.core.schema import to_json_schema
from genkit.core.typing import EmbedRequest, EmbedResponse
from genkit.core.typing import EmbedRequest, EmbedResponse, EmbedderOptions, EmbedderRef, EmbedderSupports
from pydantic import BaseModel

# type EmbedderFn = Callable[[EmbedRequest], EmbedResponse]
EmbedderFn = Callable[[EmbedRequest], EmbedResponse]


def embedder_action_metadata(
name: str,
info: dict[str, Any] | None = None,
config_schema: Any | None = None,
options: EmbedderOptions | None = None,
) -> ActionMetadata:
"""Generates an ActionMetadata for embedders."""
info = info if info is not None else {}
options = options if options is not None else EmbedderOptions()
embedder_metadata_dict = {'embedder': {}}

if options.label:
embedder_metadata_dict['embedder']['label'] = options.label

embedder_metadata_dict['embedder']['dimensions'] = options.dimensions

if options.supports:
embedder_metadata_dict['embedder']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True)

embedder_metadata_dict['embedder']['customOptions'] = (
to_json_schema(options.config_schema) if options.config_schema else None
)

return ActionMetadata(
kind=ActionKind.EMBEDDER,
name=name,
input_json_schema=to_json_schema(EmbedRequest),
output_json_schema=to_json_schema(EmbedResponse),
metadata={'embedder': {**info, 'customOptions': to_json_schema(config_schema) if config_schema else None}},
metadata=embedder_metadata_dict,
)


def create_embedder_ref(name: str, config: dict[str, Any] | None = None, version: str | None = None) -> EmbedderRef:
"""Creates an EmbedderRef instance."""
return EmbedderRef(name=name, config=config, version=version)
Loading
Loading