diff --git a/py/packages/genkit/src/genkit/ai/_registry.py b/py/packages/genkit/src/genkit/ai/_registry.py index e186bdaf82..495f0ac935 100644 --- a/py/packages/genkit/src/genkit/ai/_registry.py +++ b/py/packages/genkit/src/genkit/ai/_registry.py @@ -494,6 +494,7 @@ def define_embedder( name: str, fn: EmbedderFn, options: EmbedderOptions | None = None, + metadata: dict[str, Any] | None = None, description: str | None = None, ) -> Action: """Define a custom embedder action. @@ -505,7 +506,10 @@ def define_embedder( metadata: Optional metadata for the model. description: Optional description for the embedder. """ - embedder_meta: dict[str, Any] = {} + embedder_meta: dict[str, Any] = metadata or {} + if 'embedder' not in embedder_meta: + embedder_meta['embedder'] = {} + if options: if options.label: embedder_meta['embedder']['label'] = options.label @@ -516,9 +520,6 @@ def define_embedder( if options.config_schema: embedder_meta['embedder']['customOptions'] = to_json_schema(options.config_schema) - if 'embedder' not in embedder_meta: - embedder_meta['embedder'] = {} - embedder_description = get_func_description(fn, description) return self.registry.register_action( name=name, diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py index 8f1b0ab0ea..437cf9506b 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py @@ -145,10 +145,15 @@ def initialize(self, ai: GenkitRegistry) -> None: for version in GeminiEmbeddingModels: embedder = Embedder(version=version, client=self._client) + embedder_info = default_embedder_info(version) ai.define_embedder( name=googleai_name(version), fn=embedder.generate, - metadata=default_embedder_info(version), + options=EmbedderOptions( + label=embedder_info.get('label'), + dimensions=embedder_info.get('dimensions'), + supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None, + ), ) def resolve_action( @@ -211,10 +216,15 @@ def _resolve_embedder(self, ai: GenkitRegistry, name: str) -> None: _clean_name = name.replace(GOOGLEAI_PLUGIN_NAME + '/', '') if name.startswith(GOOGLEAI_PLUGIN_NAME) else name embedder = Embedder(version=_clean_name, client=self._client) + embedder_info = default_embedder_info(_clean_name) ai.define_embedder( name=googleai_name(_clean_name), fn=embedder.generate, - metadata=default_embedder_info(_clean_name), + options=EmbedderOptions( + label=embedder_info.get('label'), + dimensions=embedder_info.get('dimensions'), + supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None, + ), ) @cached_property @@ -325,11 +335,15 @@ def initialize(self, ai: GenkitRegistry) -> None: for version in VertexEmbeddingModels: embedder = Embedder(version=version, client=self._client) + embedder_info = default_embedder_info(version) ai.define_embedder( name=vertexai_name(version), fn=embedder.generate, - metadata=default_embedder_info(version), - # config_schema=to_json_schema(EmbedContentConfig), + options=EmbedderOptions( + label=embedder_info.get('label'), + dimensions=embedder_info.get('dimensions'), + supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None, + ), ) for version in ImagenVersion: @@ -407,10 +421,15 @@ def _resolve_embedder(self, ai: GenkitRegistry, name: str) -> None: _clean_name = name.replace(VERTEXAI_PLUGIN_NAME + '/', '') if name.startswith(VERTEXAI_PLUGIN_NAME) else name embedder = Embedder(version=_clean_name, client=self._client) + embedder_info = default_embedder_info(_clean_name) ai.define_embedder( name=vertexai_name(_clean_name), fn=embedder.generate, - metadata=default_embedder_info(_clean_name), + options=EmbedderOptions( + label=embedder_info.get('label'), + dimensions=embedder_info.get('dimensions'), + supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None, + ), ) @cached_property diff --git a/py/plugins/google-genai/test/test_google_plugin.py b/py/plugins/google-genai/test/test_google_plugin.py index 57b96d9c73..e589d86eec 100644 --- a/py/plugins/google-genai/test/test_google_plugin.py +++ b/py/plugins/google-genai/test/test_google_plugin.py @@ -153,7 +153,7 @@ def test_googleai_initialize(): ai_mock.define_embedder.assert_any_call( name=googleai_name(version), fn=ANY, - metadata=ANY, + options=ANY, ) @@ -244,10 +244,15 @@ def test_googleai__resolve_embedder( name=model_name, ) - ai_mock.define_embedder.assert_called_once_with( - name=expected_model_name, fn=ANY, metadata=default_embedder_info(clean_name) + info = default_embedder_info(clean_name) + options = EmbedderOptions( + label=info.get('label'), + supports=EmbedderSupports(**info.get('supports', {})), + dimensions=info.get('dimensions'), ) + ai_mock.define_embedder.assert_called_once_with(name=expected_model_name, fn=ANY, options=options) + def test_googleai_list_actions(googleai_plugin_instance): """Unit test for list actions.""" @@ -491,6 +496,7 @@ def test_vertexai_initialize(vertexai_plugin_instance): plugin.initialize(ai_mock) assert ai_mock.define_model.call_count == len(VertexAIGeminiVersion) + len(ImagenVersion) + # The actual call passes EmbedderOptions, so verify we are calling passing options assert ai_mock.define_embedder.call_count == len(VertexEmbeddingModels) for version in VertexAIGeminiVersion: @@ -507,7 +513,7 @@ def test_vertexai_initialize(vertexai_plugin_instance): ai_mock.define_embedder.assert_any_call( name=vertexai_name(version), fn=ANY, - metadata=ANY, + options=ANY, ) @@ -648,10 +654,15 @@ def test_vertexai__resolve_embedder( name=model_name, ) - ai_mock.define_embedder.assert_called_once_with( - name=expected_model_name, fn=ANY, metadata=default_embedder_info(clean_name) + info = default_embedder_info(clean_name) + options = EmbedderOptions( + label=info.get('label'), + supports=EmbedderSupports(**info.get('supports', {})), + dimensions=info.get('dimensions'), ) + ai_mock.define_embedder.assert_called_once_with(name=expected_model_name, fn=ANY, options=options) + def test_vertexai_list_actions(vertexai_plugin_instance): """Unit test for list actions.""" diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py index 8e31ffce5d..c3aa01e889 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py @@ -198,14 +198,12 @@ def _define_ollama_embedder(self, ai: GenkitRegistry, embedder_ref: EmbeddingDef ai.define_embedder( name=ollama_name(embedder_ref.name), fn=embedder.embed, - config_schema=to_json_schema(ollama_api.Options), - metadata={ - 'label': f'Ollama Embedding - {_clean_name}', - 'dimensions': embedder_ref.dimensions, - 'supports': { - 'input': ['text'], - }, - }, + options=EmbedderOptions( + config_schema=to_json_schema(ollama_api.Options), + label=f'Ollama Embedding - {_clean_name}', + dimensions=embedder_ref.dimensions, + supports=EmbedderSupports(input=['text']), + ), ) @cached_property diff --git a/py/plugins/ollama/tests/test_plugin_api.py b/py/plugins/ollama/tests/test_plugin_api.py index 295c39843f..ff0c14d11b 100644 --- a/py/plugins/ollama/tests/test_plugin_api.py +++ b/py/plugins/ollama/tests/test_plugin_api.py @@ -24,6 +24,7 @@ from pydantic import BaseModel from genkit.ai import ActionKind, Genkit +from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports from genkit.core.schema import to_json_schema from genkit.plugins.ollama import Ollama, ollama_name from genkit.plugins.ollama.embedders import EmbeddingDefinition @@ -127,14 +128,14 @@ def test__initialize_embedders(ollama_plugin_instance): ai_mock.define_embedder.assert_called_once_with( name=ollama_name(name), fn=ANY, - config_schema=to_json_schema(ollama_api.Options), - metadata={ - 'label': f'Ollama Embedding - {name}', - 'dimensions': 1024, - 'supports': { - 'input': ['text'], - }, - }, + options=EmbedderOptions( + config_schema=to_json_schema(ollama_api.Options), + label=f'Ollama Embedding - {name}', + dimensions=1024, + supports=EmbedderSupports( + input=['text'], + ), + ), ) @@ -166,14 +167,14 @@ def test_resolve_action(kind, name, ollama_plugin_instance): ai_mock.define_embedder.assert_called_once_with( name=ollama_name(name), fn=ANY, - config_schema=to_json_schema(ollama_api.Options), - metadata={ - 'label': f'Ollama Embedding - {name}', - 'dimensions': None, - 'supports': { - 'input': ['text'], - }, - }, + options=EmbedderOptions( + config_schema=to_json_schema(ollama_api.Options), + label=f'Ollama Embedding - {name}', + dimensions=None, + supports=EmbedderSupports( + input=['text'], + ), + ), ) @@ -219,14 +220,14 @@ def test_define_ollama_embedder(name, expected_name, clean_name, ollama_plugin_i ai_mock.define_embedder.assert_called_once_with( name=expected_name, fn=ANY, - config_schema=to_json_schema(ollama_api.Options), - metadata={ - 'label': f'Ollama Embedding - {clean_name}', - 'dimensions': 1024, - 'supports': { - 'input': ['text'], - }, - }, + options=EmbedderOptions( + config_schema=to_json_schema(ollama_api.Options), + label=f'Ollama Embedding - {clean_name}', + dimensions=1024, + supports=EmbedderSupports( + input=['text'], + ), + ), )