diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 0ee8e76519..b17ce2999f 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -61,6 +61,74 @@ ], "additionalProperties": false }, + "EmbedderInfo": { + "type": "object", + "properties": { + "label": { + "type": "string" + }, + "dimensions": { + "type": "number" + }, + "supports": { + "$ref": "#/$defs/EmbedderSupports" + } + }, + "additionalProperties": false + }, + "EmbedderOptions": { + "type": "object", + "properties": { + "label": { + "type": "string" + }, + "dimensions": { + "type": "number" + }, + "supports": { + "$ref": "#/$defs/EmbedderSupports" + }, + "configSchema": { + "type": "object", + "additionalProperties": {} + } + }, + "additionalProperties": false + }, + "EmbedderRef": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "info": { + "$ref": "#/$defs/EmbedderInfo" + }, + "config": {}, + "version": { + "type": "string" + } + }, + "required": [ + "name" + ], + "additionalProperties": false + }, + "EmbedderSupports": { + "type": "object", + "properties": { + "input": { + "type": "array", + "items": { + "type": "string" + } + }, + "multiturn": { + "type": "boolean" + } + }, + "additionalProperties": false + }, "Embedding": { "type": "object", "properties": { diff --git a/go/ai/gen.go b/go/ai/gen.go index 4d3fcfcd43..e2e99958e3 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -80,6 +80,31 @@ type EmbedResponse struct { Embeddings []*Embedding `json:"embeddings,omitempty"` } +type EmbedderInfo struct { + Dimensions float64 `json:"dimensions,omitempty"` + Label string `json:"label,omitempty"` + Supports *EmbedderSupports `json:"supports,omitempty"` +} + +type EmbedderOptions struct { + ConfigSchema map[string]any `json:"configSchema,omitempty"` + Dimensions float64 `json:"dimensions,omitempty"` + Label string `json:"label,omitempty"` + Supports *EmbedderSupports `json:"supports,omitempty"` +} + +type EmbedderRef struct { + Config any `json:"config,omitempty"` + Info *EmbedderInfo `json:"info,omitempty"` + Name string `json:"name,omitempty"` + Version string `json:"version,omitempty"` +} + +type EmbedderSupports struct { + Input []string `json:"input,omitempty"` + Multiturn bool `json:"multiturn,omitempty"` +} + type Embedding struct { Embedding []float32 `json:"embedding,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` diff --git a/py/bin/sanitize_schema_typing.py b/py/bin/sanitize_schema_typing.py index 6138127e9f..e8b5ff7f0d 100644 --- a/py/bin/sanitize_schema_typing.py +++ b/py/bin/sanitize_schema_typing.py @@ -262,7 +262,29 @@ def add_header(content: str) -> str: ] cleaned_content = '\n'.join(filtered_lines) + # Inject EmbedderFn type alias after imports + embedder_fn_injection = """ +# EmbedderFn type alias for embedder functions +from typing import Callable, Awaitable + +EmbedderFn = Callable[['EmbedRequest'], Awaitable['EmbedResponse']] +""" + final_output = header_text + future_import + '\n' + str_enum_block + '\n\n' + cleaned_content + + # Insert EmbedderFn after the typing imports but before class definitions + # Find the position after "from pydantic import" and "from typing import" + lines = final_output.split('\n') + insert_index = -1 + for i, line in enumerate(lines): + if line.startswith('from pydantic import'): + insert_index = i + 1 + break + + if insert_index > 0: + lines.insert(insert_index, embedder_fn_injection) + final_output = '\n'.join(lines) + if not final_output.endswith('\n'): final_output += '\n' return final_output diff --git a/py/packages/genkit/src/genkit/ai/_aio.py b/py/packages/genkit/src/genkit/ai/_aio.py index 7d77d881d8..da7784ca80 100644 --- a/py/packages/genkit/src/genkit/ai/_aio.py +++ b/py/packages/genkit/src/genkit/ai/_aio.py @@ -26,7 +26,6 @@ class while customizing it with any plugins. from genkit.aio import Channel from genkit.blocks.document import Document -from genkit.blocks.embedding import EmbedRequest, EmbedResponse from genkit.blocks.generate import ( StreamingCallback as ModelStreamingCallback, generate_action, @@ -39,6 +38,7 @@ class while customizing it with any plugins. from genkit.blocks.prompt import PromptConfig, to_generate_action_options from genkit.core.action import ActionRunContext from genkit.core.action.types import ActionKind +from genkit.core.typing import EmbedRequest, EmbedResponse, EmbedderRef from genkit.types import ( DocumentData, GenerationCommonConfig, @@ -335,3 +335,30 @@ async def retrieve( retrieve_action = self.registry.lookup_action(ActionKind.RETRIEVER, retriever) return (await retrieve_action.arun(RetrieverRequest(query=query, options=options))).response + + async def embed( + self, + embedder: str | EmbedderRef | None = None, + documents: list[Document] | None = None, + options: dict[str, Any] | None = None, + ) -> EmbedResponse: + embedder_name: str + embedder_config: dict[str, Any] = {} + + if isinstance(embedder, EmbedderRef): + embedder_name = embedder.name + embedder_config = embedder.config or {} + if embedder.version: + embedder_config['version'] = embedder.version # Handle version from ref + elif isinstance(embedder, str): + embedder_name = embedder + else: + # Handle case where embedder is None + raise ValueError('Embedder must be specified as a string name or an EmbedderRef.') + + # Merge options passed to embed() with config from EmbedderRef + final_options = {**(embedder_config or {}), **(options or {})} + + embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, embedder_name) + + return (await embed_action.arun(EmbedRequest(input=documents, options=final_options))).response diff --git a/py/packages/genkit/src/genkit/ai/_registry.py b/py/packages/genkit/src/genkit/ai/_registry.py index 4bb19db11e..6c239643b3 100644 --- a/py/packages/genkit/src/genkit/ai/_registry.py +++ b/py/packages/genkit/src/genkit/ai/_registry.py @@ -47,7 +47,6 @@ import structlog from pydantic import BaseModel -from genkit.blocks.embedding import EmbedderFn from genkit.blocks.evaluator import BatchEvaluatorFn, EvaluatorFn from genkit.blocks.formats.types import FormatDef from genkit.blocks.model import ModelFn, ModelMiddleware @@ -72,6 +71,8 @@ Score, SpanMetadata, ToolChoice, + EmbedderOptions, + EmbedderFn, ) EVALUATOR_METADATA_KEY_DISPLAY_NAME = 'evaluatorDisplayName' @@ -458,8 +459,7 @@ def define_embedder( self, name: str, fn: EmbedderFn, - config_schema: BaseModel | dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, + options: EmbedderOptions | None = None, description: str | None = None, ) -> Action: """Define a custom embedder action. @@ -471,19 +471,28 @@ def define_embedder( metadata: Optional metadata for the model. description: Optional description for the embedder. """ - embedder_meta: dict[str, Any] = metadata if metadata else {} - if 'embedder' not in embedder_meta: - embedder_meta['embedder'] = {} + embedder_metadata: dict[str, Any] = {} + if options: + if options.label: + embedder_metadata['embedder']['label'] = options.label + if options.dimensions: + embedder_metadata['embedder']['dimensions'] = options.dimensions + if options.supports: + embedder_metadata['embedder']['supports'] = options.supports.model_dump( + exclude_none=True, by_alias=True + ) + if options.config_schema: + embedder_metadata['embedder']['customOptions'] = to_json_schema(options.config_schema) - if config_schema: - embedder_meta['embedder']['customOptions'] = to_json_schema(config_schema) + if 'embedder' not in embedder_metadata: + embedder_metadata['embedder'] = {} embedder_description = get_func_description(fn, description) return self.registry.register_action( name=name, kind=ActionKind.EMBEDDER, fn=fn, - metadata=embedder_meta, + metadata=embedder_metadata, description=embedder_description, ) diff --git a/py/packages/genkit/src/genkit/blocks/embedding.py b/py/packages/genkit/src/genkit/blocks/embedding.py index 582ec5f6ac..6efc728c0b 100644 --- a/py/packages/genkit/src/genkit/blocks/embedding.py +++ b/py/packages/genkit/src/genkit/blocks/embedding.py @@ -22,23 +22,40 @@ from genkit.ai import ActionKind from genkit.core.action import ActionMetadata from genkit.core.schema import to_json_schema -from genkit.core.typing import EmbedRequest, EmbedResponse +from genkit.core.typing import EmbedRequest, EmbedResponse, EmbedderOptions, EmbedderRef, EmbedderSupports +from pydantic import BaseModel -# type EmbedderFn = Callable[[EmbedRequest], EmbedResponse] EmbedderFn = Callable[[EmbedRequest], EmbedResponse] def embedder_action_metadata( name: str, - info: dict[str, Any] | None = None, - config_schema: Any | None = None, + options: EmbedderOptions | None = None, ) -> ActionMetadata: - """Generates an ActionMetadata for embedders.""" - info = info if info is not None else {} + options = options if options is not None else EmbedderOptions() + embedder_metadata_dict = {'embedder': {}} + + if options.label: + embedder_metadata_dict['embedder']['label'] = options.label + + embedder_metadata_dict['embedder']['dimensions'] = options.dimensions + + if options.supports: + embedder_metadata_dict['embedder']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True) + + embedder_metadata_dict['embedder']['customOptions'] = ( + to_json_schema(options.config_schema) if options.config_schema else None + ) + return ActionMetadata( kind=ActionKind.EMBEDDER, name=name, input_json_schema=to_json_schema(EmbedRequest), output_json_schema=to_json_schema(EmbedResponse), - metadata={'embedder': {**info, 'customOptions': to_json_schema(config_schema) if config_schema else None}}, + metadata=embedder_metadata_dict, ) + + +def create_embedder_ref(name: str, config: dict[str, Any] | None = None, version: str | None = None) -> EmbedderRef: + """Creates an EmbedderRef instance.""" + return EmbedderRef(name=name, config=config, version=version) diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index 40baffe314..1af28ecdd6 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -33,10 +33,14 @@ from enum import StrEnum # noqa +# EmbedderFn type alias for embedder functions +from collections.abc import Awaitable, Callable from typing import Any from pydantic import BaseModel, ConfigDict, Field, RootModel +EmbedderFn = Callable[['EmbedRequest'], Awaitable['EmbedResponse']] + class Model(RootModel[Any]): """Root model for model.""" @@ -44,6 +48,61 @@ class Model(RootModel[Any]): root: Any +class CustomPart(BaseModel): + """Model for custompart data.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + text: Any | None = None + media: Any | None = None + tool_request: Any | None = Field(None, alias='toolRequest') + tool_response: Any | None = Field(None, alias='toolResponse') + data: Any | None = None + metadata: dict[str, Any] | None = None + custom: dict[str, Any] + reasoning: Any | None = None + resource: Any | None = None + + +class Media(BaseModel): + """Model for media data.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + content_type: str | None = Field(None, alias='contentType') + url: str + + +class Resource1(BaseModel): + """Model for resource1 data.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + uri: str + + +class ToolRequest(BaseModel): + """Model for toolrequest data.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + ref: str | None = None + name: str + input: Any | None = None + + +class ToolResponse(BaseModel): + """Model for toolresponse data.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + ref: str | None = None + name: str + output: Any | None = None + + +class EmbedderSupports(BaseModel): + """Model for embeddersupports data.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + input: list[str] | None = None + multiturn: bool | None = None + class Embedding(BaseModel): """Model for embedding data.""" @@ -747,6 +806,91 @@ class ToolResponsePart(BaseModel): resource: Resource | None = None +class EmbedResponse(BaseModel): + """Model for embedresponse data.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + embeddings: list[Embedding] + + +class EmbedderInfo(BaseModel): + """Model for embedderinfo data.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + label: str | None = None + dimensions: float | None = None + supports: EmbedderSupports | None = None + + +class EmbedderOptions(BaseModel): + """Model for embedderoptions data.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + label: str | None = None + dimensions: float | None = None + supports: EmbedderSupports | None = None + config_schema: dict[str, Any] | None = Field(None, alias='configSchema') + + +class EmbedderRef(BaseModel): + """Model for embedderref data.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + name: str + info: EmbedderInfo | None = None + config: Any | None = None + version: str | None = None + + +class BaseEvalDataPoint(BaseModel): + """Model for baseevaldatapoint data.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + input: Input | None = None + output: Output | None = None + context: Context | None = None + reference: Reference | None = None + test_case_id: str = Field(..., alias='testCaseId') + trace_ids: TraceIds | None = Field(None, alias='traceIds') + + +class EvalFnResponse(BaseModel): + """Model for evalfnresponse data.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + sample_index: float | None = Field(None, alias='sampleIndex') + test_case_id: str = Field(..., alias='testCaseId') + trace_id: str | None = Field(None, alias='traceId') + span_id: str | None = Field(None, alias='spanId') + evaluation: Score | list[Score] + + +class EvalResponse(RootModel[list[EvalFnResponse]]): + """Root model for evalresponse.""" + + root: list[EvalFnResponse] + + +class Resume(BaseModel): + """Model for resume data.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + respond: list[ToolResponsePart] | None = None + restart: list[ToolRequestPart] | None = None + metadata: dict[str, Any] | None = None + + +class Part( + RootModel[ + TextPart | MediaPart | ToolRequestPart | ToolResponsePart | DataPart | CustomPart | ReasoningPart | ResourcePart + ] +): + """Root model for part.""" + + root: ( + TextPart | MediaPart | ToolRequestPart | ToolResponsePart | DataPart | CustomPart | ReasoningPart | ResourcePart + ) + class Link(BaseModel): """Model for link data.""" diff --git a/py/packages/genkit/src/genkit/types/__init__.py b/py/packages/genkit/src/genkit/types/__init__.py index fa2f6b24c6..674e4f8e0e 100644 --- a/py/packages/genkit/src/genkit/types/__init__.py +++ b/py/packages/genkit/src/genkit/types/__init__.py @@ -32,6 +32,9 @@ Embedding, EmbedRequest, EmbedResponse, + EmbedderOptions, + EmbedderRef, + EmbedderSupports, EvalFnResponse, EvalRequest, EvalResponse, @@ -108,4 +111,7 @@ ToolRequestPart.__name__, ToolResponse.__name__, ToolResponsePart.__name__, + EmbedderOptions.__name__, + EmbedderRef.__name__, + EmbedderSupports.__name__, ] diff --git a/py/packages/genkit/tests/genkit/blocks/embedding_test.py b/py/packages/genkit/tests/genkit/blocks/embedding_test.py index 0a54d6aa2e..a29ee4b661 100644 --- a/py/packages/genkit/tests/genkit/blocks/embedding_test.py +++ b/py/packages/genkit/tests/genkit/blocks/embedding_test.py @@ -16,19 +16,234 @@ """Tests for the action module.""" -from genkit.blocks.embedding import embedder_action_metadata -from genkit.core.action import ActionMetadata +from genkit.blocks.embedding import embedder_action_metadata, create_embedder_ref +from genkit.core.action import ActionMetadata, Action +from genkit.core.typing import ( + EmbedderOptions, + EmbedderSupports, + EmbedderRef, + EmbedRequest, + EmbedResponse, + Embedding, + Part, +) +from genkit.blocks.document import Document +from genkit.core.schema import to_json_schema +from genkit.core.action.types import ActionResponse +import pytest +from unittest.mock import AsyncMock, MagicMock +from genkit.ai._aio import Genkit +from pydantic import BaseModel def test_embedder_action_metadata(): - """Test for embedder_action_metadata.""" + """Test for embedder_action_metadata with basic options.""" + options = EmbedderOptions(label='Test Embedder', dimensions=128) action_metadata = embedder_action_metadata( name='test_model', - info={'label': 'test_label'}, - config_schema=None, + options=options, ) assert isinstance(action_metadata, ActionMetadata) assert action_metadata.input_json_schema is not None assert action_metadata.output_json_schema is not None - assert action_metadata.metadata == {'embedder': {'customOptions': None, 'label': 'test_label'}} + assert action_metadata.metadata == { + 'embedder': { + 'label': options.label, + 'dimensions': options.dimensions, + 'customOptions': None, + } + } + + +def test_embedder_action_metadata_with_supports_and_config_schema(): + """Test for embedder_action_metadata with supports and config_schema.""" + + class CustomConfig(BaseModel): + param1: str + param2: int + + options = EmbedderOptions( + label='Advanced Embedder', + dimensions=256, + supports=EmbedderSupports(input=['text', 'image']), + config_schema=to_json_schema(CustomConfig), + ) + action_metadata = embedder_action_metadata( + name='advanced_model', + options=options, + ) + + assert isinstance(action_metadata, ActionMetadata) + assert action_metadata.metadata['embedder']['label'] == 'Advanced Embedder' + assert action_metadata.metadata['embedder']['dimensions'] == options.dimensions + assert action_metadata.metadata['embedder']['supports'] == { + 'input': ['text', 'image'], + } + assert action_metadata.metadata['embedder']['customOptions'] == { + 'title': 'CustomConfig', + 'type': 'object', + 'properties': { + 'param1': {'title': 'Param1', 'type': 'string'}, + 'param2': {'title': 'Param2', 'type': 'integer'}, + }, + 'required': ['param1', 'param2'], + } + + +def test_embedder_action_metadata_no_options(): + """Test embedder_action_metadata when no options are provided.""" + action_metadata = embedder_action_metadata(name='default_model') + assert isinstance(action_metadata, ActionMetadata) + assert action_metadata.metadata == {'embedder': {'customOptions': None, 'dimensions': None}} + + +def test_create_embedder_ref_basic(): + """Test basic creation of EmbedderRef.""" + ref = create_embedder_ref('my-embedder') + assert ref.name == 'my-embedder' + assert ref.config is None + assert ref.version is None + + +def test_create_embedder_ref_with_config(): + """Test creation of EmbedderRef with configuration.""" + config = {'temperature': 0.5, 'max_tokens': 100} + ref = create_embedder_ref('configured-embedder', config=config) + assert ref.name == 'configured-embedder' + assert ref.config == config + assert ref.version is None + + +def test_create_embedder_ref_with_version(): + """Test creation of EmbedderRef with a version.""" + ref = create_embedder_ref('versioned-embedder', version='v1.0') + assert ref.name == 'versioned-embedder' + assert ref.config is None + assert ref.version == 'v1.0' + + +def test_create_embedder_ref_with_config_and_version(): + """Test creation of EmbedderRef with both config and version.""" + config = {'task_type': 'retrieval'} + ref = create_embedder_ref('full-embedder', config=config, version='beta') + assert ref.name == 'full-embedder' + assert ref.config == config + assert ref.version == 'beta' + + +class MockGenkitRegistry: + """A mock registry to simulate action lookup.""" + + def __init__(self): + self.actions = {} + + def register_action(self, name, kind, fn, metadata, description): + mock_action = MagicMock(spec=Action) + mock_action.name = name + mock_action.kind = kind + mock_action.metadata = metadata + mock_action.description = description + + async def mock_arun_side_effect(request, *args, **kwargs): + # Call the actual (fake) embedder function directly + embed_response = await fn(request) + return ActionResponse(response=embed_response, trace_id='mock_trace_id') + + mock_action.arun = AsyncMock(side_effect=mock_arun_side_effect) + self.actions[(kind, name)] = mock_action + return mock_action + + def lookup_action(self, kind, name): + return self.actions.get((kind, name)) + + +@pytest.fixture +def mock_genkit_instance(): + """Fixture for a Genkit instance with a mock registry.""" + registry = MockGenkitRegistry() + genkit_instance = Genkit() + genkit_instance.registry = registry + return genkit_instance, registry + + +@pytest.mark.asyncio +async def test_embed_with_embedder_ref(mock_genkit_instance): + """Test the embed method using EmbedderRef.""" + genkit_instance, registry = mock_genkit_instance + + async def fake_embedder_fn(request: EmbedRequest) -> EmbedResponse: + return EmbedResponse(embeddings=[Embedding(embedding=[1.0, 2.0, 3.0])]) + + embedder_options = EmbedderOptions( + label='Fake Embedder', + dimensions=3, + supports=EmbedderSupports(input=['text']), + config_schema={'type': 'object', 'properties': {'param': {'type': 'string'}}}, + ) + registry.register_action( + name='my-plugin/my-embedder', + kind='embedder', + fn=fake_embedder_fn, + metadata=embedder_action_metadata('my-plugin/my-embedder', options=embedder_options).metadata, + description='A fake embedder for testing', + ) + + embedder_ref = create_embedder_ref('my-plugin/my-embedder', config={'param': 'value'}, version='v1') + + documents = [Document.from_text('hello world')] + + response = await genkit_instance.embed( + embedder=embedder_ref, documents=documents, options={'additional_option': True} + ) + + assert response.embeddings[0].embedding == [1.0, 2.0, 3.0] + + embed_action = registry.lookup_action('embedder', 'my-plugin/my-embedder') + assert embed_action is not None + embed_action.arun.assert_called_once() + + called_request = embed_action.arun.call_args[0][0] + assert isinstance(called_request, EmbedRequest) + assert called_request.input == documents + # Check if config from EmbedderRef and options are merged correctly + assert called_request.options == {'param': 'value', 'additional_option': True, 'version': 'v1'} + + +@pytest.mark.asyncio +async def test_embed_with_string_name_and_options(mock_genkit_instance): + """Test the embed method using a string name for embedder and options.""" + genkit_instance, registry = mock_genkit_instance + + async def fake_embedder_fn(request: EmbedRequest) -> EmbedResponse: + return EmbedResponse(embeddings=[Embedding(embedding=[4.0, 5.0, 6.0])]) + + embedder_options = EmbedderOptions(label='Another Fake', dimensions=3) + registry.register_action( + name='another-embedder', + kind='embedder', + fn=fake_embedder_fn, + metadata=embedder_action_metadata('another-embedder', options=embedder_options).metadata, + description='Another fake embedder', + ) + + documents = [Document.from_text('test text')] + + response = await genkit_instance.embed( + embedder='another-embedder', documents=documents, options={'custom_setting': 'high'} + ) + + assert response.embeddings[0].embedding == [4.0, 5.0, 6.0] + embed_action = registry.lookup_action('embedder', 'another-embedder') + called_request = embed_action.arun.call_args[0][0] + assert called_request.options == {'custom_setting': 'high'} + + +@pytest.mark.asyncio +async def test_embed_missing_embedder_raises_error(mock_genkit_instance): + """Test that embedding with a missing embedder raises an error.""" + genkit_instance, _ = mock_genkit_instance + documents = [Document.from_text('some text')] + + with pytest.raises(ValueError, match='Embedder must be specified as a string name or an EmbedderRef.'): + await genkit_instance.embed(documents=documents) diff --git a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py index f7d9739254..43dfdd5e89 100644 --- a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py +++ b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py @@ -29,7 +29,8 @@ from genkit.blocks.model import model_action_metadata from genkit.core.action import ActionMetadata from genkit.core.action.types import ActionKind -from genkit.core.typing import GenerationCommonConfig +from genkit.core.typing import GenerationCommonConfig, EmbedderOptions, EmbedderSupports +from genkit.core.schema import to_json_schema from genkit.plugins.compat_oai.models import ( SUPPORTED_OPENAI_COMPAT_MODELS, SUPPORTED_OPENAI_MODELS, @@ -191,14 +192,11 @@ def list_actions(self) -> list[ActionMetadata]: actions.append( embedder_action_metadata( name=open_ai_name(_name), - config_schema=Embedding, - info={ - 'label': f'OpenAI Embedding - {_name}', - 'dimensions': None, - 'supports': { - 'input': ['text'], - }, - }, + options=EmbedderOptions( + config_schema=to_json_schema(Embedding), + label=f'OpenAI Embedding - {_name}', + supports=EmbedderSupports(input=['text']), + ), ) ) else: diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py index 36433d9eeb..2e5b9a4541 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py @@ -28,6 +28,8 @@ from genkit.blocks.model import model_action_metadata from genkit.core.action import ActionMetadata from genkit.core.registry import ActionKind +from genkit.core.typing import EmbedderOptions, EmbedderSupports +from genkit.core.schema import to_json_schema from genkit.plugins.google_genai.models.embedder import ( Embedder, GeminiEmbeddingModels, @@ -242,11 +244,17 @@ def list_actions(self) -> list[ActionMetadata]: ) if 'embedContent' in m.supported_actions: + embed_info = default_embedder_info(name) + print(f"DEBUG: Processing embedder '{name}', embed_info: {embed_info}") actions_list.append( embedder_action_metadata( name=googleai_name(name), - info=default_embedder_info(name), - config_schema=EmbedContentConfig, + options=EmbedderOptions( + label=embed_info.get('label'), + supports=EmbedderSupports(input=embed_info.get('supports', {}).get('input')), + dimensions=embed_info.get('dimensions'), + config_schema=to_json_schema(EmbedContentConfig), + ), ) ) @@ -428,11 +436,17 @@ def list_actions(self) -> list[ActionMetadata]: for m in self._client.models.list(): name = m.name.replace('publishers/google/models/', '') if 'embed' in name.lower(): + embed_info = default_embedder_info(name) + print(f"DEBUG: Processing embedder '{name}', embed_info: {embed_info}") actions_list.append( embedder_action_metadata( name=vertexai_name(name), - info=default_embedder_info(name), - config_schema=EmbedContentConfig, + options=EmbedderOptions( + label=embed_info.get('label'), + supports=EmbedderSupports(input=embed_info.get('supports', {}).get('input')), + dimensions=embed_info.get('dimensions'), + config_schema=to_json_schema(EmbedContentConfig), + ), ) ) # List all the vertexai models for generate actions diff --git a/py/plugins/google-genai/test/test_google_plugin.py b/py/plugins/google-genai/test/test_google_plugin.py index d621f422e6..53517a4ac6 100644 --- a/py/plugins/google-genai/test/test_google_plugin.py +++ b/py/plugins/google-genai/test/test_google_plugin.py @@ -31,6 +31,8 @@ from genkit.blocks.embedding import embedder_action_metadata from genkit.blocks.model import model_action_metadata from genkit.core.registry import ActionKind +from genkit.core.typing import EmbedderOptions, EmbedderSupports +from genkit.core.schema import to_json_schema from genkit.plugins.google_genai import ( GoogleAI, VertexAI, @@ -279,8 +281,12 @@ class MockModel(BaseModel): ), embedder_action_metadata( name=googleai_name('model2'), - info=default_embedder_info('model2'), - config_schema=EmbedContentConfig, + options=EmbedderOptions( + label=default_embedder_info('model2').get('label'), + supports=EmbedderSupports(input=default_embedder_info('model2').get('supports', {}).get('input')), + dimensions=default_embedder_info('model2').get('dimensions'), + config_schema=to_json_schema(EmbedContentConfig), + ), ), model_action_metadata( name=googleai_name('model3'), @@ -289,8 +295,12 @@ class MockModel(BaseModel): ), embedder_action_metadata( name=googleai_name('model3'), - info=default_embedder_info('model3'), - config_schema=EmbedContentConfig, + options=EmbedderOptions( + label=default_embedder_info('model3').get('label'), + supports=EmbedderSupports(input=default_embedder_info('model3').get('supports', {}).get('input')), + dimensions=default_embedder_info('model3').get('dimensions'), + config_schema=to_json_schema(EmbedContentConfig), + ), ), ] @@ -684,8 +694,14 @@ class MockModel(BaseModel): ), embedder_action_metadata( name=vertexai_name('model2_embeddings'), - info=default_embedder_info('model2_embeddings'), - config_schema=EmbedContentConfig, + options=EmbedderOptions( + label=default_embedder_info('model2_embeddings').get('label'), + supports=EmbedderSupports( + input=default_embedder_info('model2_embeddings').get('supports', {}).get('input') + ), + dimensions=default_embedder_info('model2_embeddings').get('dimensions'), + config_schema=to_json_schema(EmbedContentConfig), + ), ), model_action_metadata( name=vertexai_name('model2_embeddings'), @@ -694,9 +710,18 @@ class MockModel(BaseModel): ), embedder_action_metadata( name=vertexai_name('model3_embedder'), - info=default_embedder_info('model3_embedder'), - config_schema=EmbedContentConfig, - ), + options=EmbedderOptions( + label=default_embedder_info('model3_embedder').get('label'), + supports=EmbedderSupports( + input=default_embedder_info('model3_embedder').get('supports', {}).get('input') + ), + dimensions=default_embedder_info('model3_embedder').get('dimensions'), + config_schema=to_json_schema(EmbedContentConfig), + ), + ), + # info=default_embedder_info('model3_embedder'), + # config_schema=EmbedContentConfig, + # ), model_action_metadata( name=vertexai_name('model3_embedder'), info=google_model_info('model3_embedder').model_dump(), diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py index 472d1c0a7b..35bda84b5e 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py @@ -25,6 +25,8 @@ from genkit.ai import GenkitRegistry, Plugin from genkit.blocks.embedding import embedder_action_metadata from genkit.blocks.model import model_action_metadata +from genkit.core.schema import to_json_schema +from genkit.core.typing import EmbedderOptions, EmbedderSupports from genkit.core.registry import ActionKind from genkit.plugins.ollama.constants import ( DEFAULT_OLLAMA_SERVER_URL, @@ -234,14 +236,11 @@ def list_actions(self) -> list[dict[str, str]]: actions.append( embedder_action_metadata( name=ollama_name(_name), - config_schema=ollama_api.Options, - info={ - 'label': f'Ollama Embedding - {_name}', - 'dimensions': None, - 'supports': { - 'input': ['text'], - }, - }, + options=EmbedderOptions( + config_schema=to_json_schema(ollama_api.Options), + label=f'Ollama Embedding - {_name}', + supports=EmbedderSupports(input=['text']), + ), ) ) else: