Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions py/packages/genkit/src/genkit/ai/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ def define_embedder(
name: str,
fn: EmbedderFn,
options: EmbedderOptions | None = None,
metadata: dict[str, Any] | None = None,
description: str | None = None,
) -> Action:
"""Define a custom embedder action.
Expand All @@ -505,7 +506,10 @@ def define_embedder(
metadata: Optional metadata for the model.
description: Optional description for the embedder.
"""
embedder_meta: dict[str, Any] = {}
embedder_meta: dict[str, Any] = metadata or {}
if 'embedder' not in embedder_meta:
embedder_meta['embedder'] = {}

if options:
if options.label:
embedder_meta['embedder']['label'] = options.label
Expand All @@ -516,9 +520,6 @@ def define_embedder(
if options.config_schema:
embedder_meta['embedder']['customOptions'] = to_json_schema(options.config_schema)

if 'embedder' not in embedder_meta:
embedder_meta['embedder'] = {}

embedder_description = get_func_description(fn, description)
return self.registry.register_action(
name=name,
Expand Down
29 changes: 24 additions & 5 deletions py/plugins/google-genai/src/genkit/plugins/google_genai/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,15 @@ def initialize(self, ai: GenkitRegistry) -> None:

for version in GeminiEmbeddingModels:
embedder = Embedder(version=version, client=self._client)
embedder_info = default_embedder_info(version)
ai.define_embedder(
name=googleai_name(version),
fn=embedder.generate,
metadata=default_embedder_info(version),
options=EmbedderOptions(
label=embedder_info.get('label'),
dimensions=embedder_info.get('dimensions'),
supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None,
),
)

def resolve_action(
Expand Down Expand Up @@ -211,10 +216,15 @@ def _resolve_embedder(self, ai: GenkitRegistry, name: str) -> None:
_clean_name = name.replace(GOOGLEAI_PLUGIN_NAME + '/', '') if name.startswith(GOOGLEAI_PLUGIN_NAME) else name
embedder = Embedder(version=_clean_name, client=self._client)

embedder_info = default_embedder_info(_clean_name)
ai.define_embedder(
name=googleai_name(_clean_name),
fn=embedder.generate,
metadata=default_embedder_info(_clean_name),
options=EmbedderOptions(
label=embedder_info.get('label'),
dimensions=embedder_info.get('dimensions'),
supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None,
),
)

@cached_property
Expand Down Expand Up @@ -325,11 +335,15 @@ def initialize(self, ai: GenkitRegistry) -> None:

for version in VertexEmbeddingModels:
embedder = Embedder(version=version, client=self._client)
embedder_info = default_embedder_info(version)
ai.define_embedder(
name=vertexai_name(version),
fn=embedder.generate,
metadata=default_embedder_info(version),
# config_schema=to_json_schema(EmbedContentConfig),
options=EmbedderOptions(
label=embedder_info.get('label'),
dimensions=embedder_info.get('dimensions'),
supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None,
),
)

for version in ImagenVersion:
Expand Down Expand Up @@ -407,10 +421,15 @@ def _resolve_embedder(self, ai: GenkitRegistry, name: str) -> None:
_clean_name = name.replace(VERTEXAI_PLUGIN_NAME + '/', '') if name.startswith(VERTEXAI_PLUGIN_NAME) else name
embedder = Embedder(version=_clean_name, client=self._client)

embedder_info = default_embedder_info(_clean_name)
ai.define_embedder(
name=vertexai_name(_clean_name),
fn=embedder.generate,
metadata=default_embedder_info(_clean_name),
options=EmbedderOptions(
label=embedder_info.get('label'),
dimensions=embedder_info.get('dimensions'),
supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None,
),
)

@cached_property
Expand Down
23 changes: 17 additions & 6 deletions py/plugins/google-genai/test/test_google_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_googleai_initialize():
ai_mock.define_embedder.assert_any_call(
name=googleai_name(version),
fn=ANY,
metadata=ANY,
options=ANY,
)


Expand Down Expand Up @@ -244,10 +244,15 @@ def test_googleai__resolve_embedder(
name=model_name,
)

ai_mock.define_embedder.assert_called_once_with(
name=expected_model_name, fn=ANY, metadata=default_embedder_info(clean_name)
info = default_embedder_info(clean_name)
options = EmbedderOptions(
label=info.get('label'),
supports=EmbedderSupports(**info.get('supports', {})),
dimensions=info.get('dimensions'),
)
Comment on lines +248 to 252
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for constructing EmbedderOptions, specifically for the supports field, diverges from the implementation in google.py. The implementation results in supports=None if the supports dictionary is missing or empty, while this test would create an EmbedderSupports object with default values. This discrepancy could lead to the test failing incorrectly if default_embedder_info changes its return value in the future. To ensure consistency and robustness, the test should replicate the implementation's logic.

Suggested change
options = EmbedderOptions(
label=info.get('label'),
supports=EmbedderSupports(**info.get('supports', {})),
dimensions=info.get('dimensions'),
)
options = EmbedderOptions(
label=info.get('label'),
supports=EmbedderSupports(**s) if (s := info.get('supports')) else None,
dimensions=info.get('dimensions'),
)


ai_mock.define_embedder.assert_called_once_with(name=expected_model_name, fn=ANY, options=options)


def test_googleai_list_actions(googleai_plugin_instance):
"""Unit test for list actions."""
Expand Down Expand Up @@ -491,6 +496,7 @@ def test_vertexai_initialize(vertexai_plugin_instance):
plugin.initialize(ai_mock)

assert ai_mock.define_model.call_count == len(VertexAIGeminiVersion) + len(ImagenVersion)
# The actual call passes EmbedderOptions, so verify we are calling passing options
assert ai_mock.define_embedder.call_count == len(VertexEmbeddingModels)

for version in VertexAIGeminiVersion:
Expand All @@ -507,7 +513,7 @@ def test_vertexai_initialize(vertexai_plugin_instance):
ai_mock.define_embedder.assert_any_call(
name=vertexai_name(version),
fn=ANY,
metadata=ANY,
options=ANY,
)


Expand Down Expand Up @@ -648,10 +654,15 @@ def test_vertexai__resolve_embedder(
name=model_name,
)

ai_mock.define_embedder.assert_called_once_with(
name=expected_model_name, fn=ANY, metadata=default_embedder_info(clean_name)
info = default_embedder_info(clean_name)
options = EmbedderOptions(
label=info.get('label'),
supports=EmbedderSupports(**info.get('supports', {})),
dimensions=info.get('dimensions'),
)
Comment on lines +658 to 662
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to a previous comment, the logic for constructing EmbedderOptions here is inconsistent with the implementation in google.py. The implementation sets supports to None if the supports data is missing or empty, but this test creates an EmbedderSupports object with default values. This should be aligned to ensure the test is accurate and robust against future changes.

Suggested change
options = EmbedderOptions(
label=info.get('label'),
supports=EmbedderSupports(**info.get('supports', {})),
dimensions=info.get('dimensions'),
)
options = EmbedderOptions(
label=info.get('label'),
supports=EmbedderSupports(**s) if (s := info.get('supports')) else None,
dimensions=info.get('dimensions'),
)


ai_mock.define_embedder.assert_called_once_with(name=expected_model_name, fn=ANY, options=options)


def test_vertexai_list_actions(vertexai_plugin_instance):
"""Unit test for list actions."""
Expand Down
14 changes: 6 additions & 8 deletions py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,12 @@ def _define_ollama_embedder(self, ai: GenkitRegistry, embedder_ref: EmbeddingDef
ai.define_embedder(
name=ollama_name(embedder_ref.name),
fn=embedder.embed,
config_schema=to_json_schema(ollama_api.Options),
metadata={
'label': f'Ollama Embedding - {_clean_name}',
'dimensions': embedder_ref.dimensions,
'supports': {
'input': ['text'],
},
},
options=EmbedderOptions(
config_schema=to_json_schema(ollama_api.Options),
label=f'Ollama Embedding - {_clean_name}',
dimensions=embedder_ref.dimensions,
supports=EmbedderSupports(input=['text']),
),
)

@cached_property
Expand Down
49 changes: 25 additions & 24 deletions py/plugins/ollama/tests/test_plugin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pydantic import BaseModel

from genkit.ai import ActionKind, Genkit
from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports
from genkit.core.schema import to_json_schema
from genkit.plugins.ollama import Ollama, ollama_name
from genkit.plugins.ollama.embedders import EmbeddingDefinition
Expand Down Expand Up @@ -127,14 +128,14 @@ def test__initialize_embedders(ollama_plugin_instance):
ai_mock.define_embedder.assert_called_once_with(
name=ollama_name(name),
fn=ANY,
config_schema=to_json_schema(ollama_api.Options),
metadata={
'label': f'Ollama Embedding - {name}',
'dimensions': 1024,
'supports': {
'input': ['text'],
},
},
options=EmbedderOptions(
config_schema=to_json_schema(ollama_api.Options),
label=f'Ollama Embedding - {name}',
dimensions=1024,
supports=EmbedderSupports(
input=['text'],
),
),
)


Expand Down Expand Up @@ -166,14 +167,14 @@ def test_resolve_action(kind, name, ollama_plugin_instance):
ai_mock.define_embedder.assert_called_once_with(
name=ollama_name(name),
fn=ANY,
config_schema=to_json_schema(ollama_api.Options),
metadata={
'label': f'Ollama Embedding - {name}',
'dimensions': None,
'supports': {
'input': ['text'],
},
},
options=EmbedderOptions(
config_schema=to_json_schema(ollama_api.Options),
label=f'Ollama Embedding - {name}',
dimensions=None,
supports=EmbedderSupports(
input=['text'],
),
),
)


Expand Down Expand Up @@ -219,14 +220,14 @@ def test_define_ollama_embedder(name, expected_name, clean_name, ollama_plugin_i
ai_mock.define_embedder.assert_called_once_with(
name=expected_name,
fn=ANY,
config_schema=to_json_schema(ollama_api.Options),
metadata={
'label': f'Ollama Embedding - {clean_name}',
'dimensions': 1024,
'supports': {
'input': ['text'],
},
},
options=EmbedderOptions(
config_schema=to_json_schema(ollama_api.Options),
label=f'Ollama Embedding - {clean_name}',
dimensions=1024,
supports=EmbedderSupports(
input=['text'],
),
),
)


Expand Down
Loading