diff --git a/py/packages/genkit/src/genkit/ai/_base.py b/py/packages/genkit/src/genkit/ai/_base.py index c00de0833a..da6ddbc2e0 100644 --- a/py/packages/genkit/src/genkit/ai/_base.py +++ b/py/packages/genkit/src/genkit/ai/_base.py @@ -118,6 +118,7 @@ def resolver(kind, name, plugin=plugin): return plugin.resolve_action(self, kind, name) self.registry.register_action_resolver(plugin.plugin_name(), resolver) + self.registry.register_list_actions_resolver(plugin.plugin_name(), plugin.list_actions) else: raise ValueError(f'Invalid {plugin=} provided to Genkit: must be of type `genkit.ai.Plugin`') diff --git a/py/packages/genkit/src/genkit/ai/_plugin.py b/py/packages/genkit/src/genkit/ai/_plugin.py index e1165e7915..d90a0f30e1 100644 --- a/py/packages/genkit/src/genkit/ai/_plugin.py +++ b/py/packages/genkit/src/genkit/ai/_plugin.py @@ -74,3 +74,15 @@ def initialize(self, ai: GenkitRegistry) -> None: None, initialization is done by side-effect on the registry. """ pass + + def list_actions(self) -> list[dict[str, str]]: + """Generate a list of available actions or models. + + Returns: + list of actions dicts with the following shape: + { + 'name': str, + 'kind': ActionKind, + } + """ + return [] diff --git a/py/packages/genkit/src/genkit/ai/_registry.py b/py/packages/genkit/src/genkit/ai/_registry.py index 82f377351e..cc56994fda 100644 --- a/py/packages/genkit/src/genkit/ai/_registry.py +++ b/py/packages/genkit/src/genkit/ai/_registry.py @@ -456,6 +456,7 @@ def define_embedder( self, name: str, fn: EmbedderFn, + config_schema: BaseModel | dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, description: str | None = None, ) -> Action: @@ -464,15 +465,23 @@ def define_embedder( Args: name: Name of the model. fn: Function implementing the embedder behavior. + config_schema: Optional schema for embedder configuration. metadata: Optional metadata for the model. description: Optional description for the embedder. """ + embedder_meta: dict[str, Any] = metadata if metadata else {} + if 'embedder' not in embedder_meta: + embedder_meta['embedder'] = {} + + if config_schema: + embedder_meta['embedder']['customOptions'] = to_json_schema(config_schema) + embedder_description = get_func_description(fn, description) return self.registry.register_action( name=name, kind=ActionKind.EMBEDDER, fn=fn, - metadata=metadata, + metadata=embedder_meta, description=embedder_description, ) diff --git a/py/packages/genkit/src/genkit/blocks/embedding.py b/py/packages/genkit/src/genkit/blocks/embedding.py index 50f5d3307c..582ec5f6ac 100644 --- a/py/packages/genkit/src/genkit/blocks/embedding.py +++ b/py/packages/genkit/src/genkit/blocks/embedding.py @@ -17,8 +17,28 @@ """Embedding actions.""" from collections.abc import Callable +from typing import Any +from genkit.ai import ActionKind +from genkit.core.action import ActionMetadata +from genkit.core.schema import to_json_schema from genkit.core.typing import EmbedRequest, EmbedResponse # type EmbedderFn = Callable[[EmbedRequest], EmbedResponse] EmbedderFn = Callable[[EmbedRequest], EmbedResponse] + + +def embedder_action_metadata( + name: str, + info: dict[str, Any] | None = None, + config_schema: Any | None = None, +) -> ActionMetadata: + """Generates an ActionMetadata for embedders.""" + info = info if info is not None else {} + return ActionMetadata( + kind=ActionKind.EMBEDDER, + name=name, + input_json_schema=to_json_schema(EmbedRequest), + output_json_schema=to_json_schema(EmbedResponse), + metadata={'embedder': {**info, 'customOptions': to_json_schema(config_schema) if config_schema else None}}, + ) diff --git a/py/packages/genkit/src/genkit/blocks/model.py b/py/packages/genkit/src/genkit/blocks/model.py index 9227aba8c4..5f3d800fa4 100644 --- a/py/packages/genkit/src/genkit/blocks/model.py +++ b/py/packages/genkit/src/genkit/blocks/model.py @@ -36,8 +36,10 @@ def my_model(request: GenerateRequest) -> GenerateResponse: from pydantic import BaseModel, Field -from genkit.core.action import ActionRunContext +from genkit.ai import ActionKind +from genkit.core.action import ActionMetadata, ActionRunContext from genkit.core.extract import extract_json +from genkit.core.schema import to_json_schema from genkit.core.typing import ( Candidate, DocumentPart, @@ -423,3 +425,19 @@ def get_part_counts(parts: list[Part]) -> PartCounts: part_counts.audio += 1 if is_audio else 0 return part_counts + + +def model_action_metadata( + name: str, + info: dict[str, Any] | None = None, + config_schema: Any | None = None, +) -> ActionMetadata: + """Generates an ActionMetadata for models.""" + info = info if info is not None else {} + return ActionMetadata( + kind=ActionKind.MODEL, + name=name, + input_json_schema=to_json_schema(GenerateRequest), + output_json_schema=to_json_schema(GenerateResponse), + metadata={'model': {**info, 'customOptions': to_json_schema(config_schema) if config_schema else None}}, + ) diff --git a/py/packages/genkit/src/genkit/core/action/__init__.py b/py/packages/genkit/src/genkit/core/action/__init__.py index cedcf13b49..e5fbc94737 100644 --- a/py/packages/genkit/src/genkit/core/action/__init__.py +++ b/py/packages/genkit/src/genkit/core/action/__init__.py @@ -18,6 +18,7 @@ from ._action import ( Action, + ActionMetadata, ActionRunContext, ) from ._key import ( @@ -28,6 +29,7 @@ __all__ = [ Action.__name__, + ActionMetadata.__name__, ActionRunContext.__name__, create_action_key.__name__, parse_action_key.__name__, diff --git a/py/packages/genkit/src/genkit/core/action/_action.py b/py/packages/genkit/src/genkit/core/action/_action.py index 9fafeca780..370562c736 100644 --- a/py/packages/genkit/src/genkit/core/action/_action.py +++ b/py/packages/genkit/src/genkit/core/action/_action.py @@ -91,7 +91,7 @@ from functools import cached_property from typing import Any -from pydantic import TypeAdapter +from pydantic import BaseModel, TypeAdapter from genkit.aio import Channel, ensure_async from genkit.core.error import GenkitError @@ -457,6 +457,20 @@ def _initialize_io_schemas( self._metadata[ActionMetadataKey.OUTPUT_KEY] = self._output_schema +class ActionMetadata(BaseModel): + """Metadata for actions.""" + + kind: ActionKind + name: str + description: str | None = None + input_schema: Any | None = None + input_json_schema: dict[str, Any] | None = None + output_schema: Any | None = None + output_json_schema: dict[str, Any] | None = None + stream_schema: Any | None = None + metadata: dict[str, Any] | None = None + + _SyncTracingWrapper = Callable[[Any | None, ActionRunContext], ActionResponse] _AsyncTracingWrapper = Callable[[Any | None, ActionRunContext], ActionResponse] diff --git a/py/packages/genkit/src/genkit/core/reflection.py b/py/packages/genkit/src/genkit/core/reflection.py index c93fb12da4..e5789e1749 100644 --- a/py/packages/genkit/src/genkit/core/reflection.py +++ b/py/packages/genkit/src/genkit/core/reflection.py @@ -122,6 +122,7 @@ def do_GET(self) -> None: # noqa: N802 self.send_header('content-type', 'application/json') self.end_headers() actions = registry.list_serializable_actions() + actions = registry.list_actions(actions) self.wfile.write(bytes(json.dumps(actions), encoding)) else: self.send_response(404) diff --git a/py/packages/genkit/src/genkit/core/registry.py b/py/packages/genkit/src/genkit/core/registry.py index e4c0d87a13..09ca6ac766 100644 --- a/py/packages/genkit/src/genkit/core/registry.py +++ b/py/packages/genkit/src/genkit/core/registry.py @@ -31,14 +31,19 @@ from collections.abc import Callable from typing import Any +import structlog + from genkit.core.action import ( Action, + ActionMetadata, create_action_key, parse_action_key, parse_plugin_name_from_action_name, ) from genkit.core.action.types import ActionKind, ActionName, ActionResolver +logger = structlog.get_logger(__name__) + # An action store is a nested dictionary mapping ActionKind to a dictionary of # action names and their corresponding Action instances. # @@ -75,6 +80,7 @@ class Registry: def __init__(self): """Initialize an empty Registry instance.""" self._action_resolvers: dict[str, ActionResolver] = {} + self._list_actions_resolvers: dict[str, Callable] = {} self._entries: ActionStore = {} self._value_by_kind_and_name: dict[str, dict[str, Any]] = {} self._lock = threading.RLock() @@ -82,7 +88,7 @@ def __init__(self): # TODO: Figure out how to set this. self.api_stability: str = 'stable' - def register_action_resolver(self, plugin_name: str, resolver: ActionResolver): + def register_action_resolver(self, plugin_name: str, resolver: ActionResolver) -> None: """Registers an ActionResolver function for a given plugin. Args: @@ -97,6 +103,21 @@ def register_action_resolver(self, plugin_name: str, resolver: ActionResolver): raise ValueError(f'Plugin {plugin_name} already registered') self._action_resolvers[plugin_name] = resolver + def register_list_actions_resolver(self, plugin_name: str, resolver: Callable) -> None: + """Registers an Callable function to list available actions or models. + + Args: + plugin_name: The name of the plugin. + resolver: The Callable function to list models. + + Raises: + ValueError: If a resolver is already registered for the plugin. + """ + with self._lock: + if plugin_name in self._list_actions_resolvers: + raise ValueError(f'Plugin {plugin_name} already registered') + self._list_actions_resolvers[plugin_name] = resolver + def register_action( self, kind: ActionKind, @@ -212,6 +233,49 @@ def list_serializable_actions(self, allowed_kinds: set[ActionKind] | None = None } return actions + def list_actions( + self, + actions: dict[str, Action] | None = None, + allowed_kinds: set[ActionKind] | None = None, + ) -> dict[str, Action] | None: + """Add actions or models. + + Args: + actions: dictionary of serializable actions. + allowed_kinds: The types of actions to list. If None, all actions + are listed. + + Returns: + A dictionary of serializable Actions updated. + """ + if actions is None: + actions = {} + + for plugin_name in self._list_actions_resolvers: + actions_lister = self._list_actions_resolvers[plugin_name] + + # TODO: Set all the list_actions plugins' methods as cached_properties. + if isinstance(actions_lister, list): + actions_list = actions_lister + else: + actions_list = actions_lister() + + for _action in actions_list: + kind = _action.kind + if allowed_kinds is not None and kind not in allowed_kinds: + continue + key = create_action_key(kind, _action.name) + + if key not in actions: + actions[key] = { + 'key': key, + 'name': _action.name, + 'inputSchema': _action.input_json_schema, + 'outputSchema': _action.output_json_schema, + 'metadata': _action.metadata, + } + return actions + def register_value(self, kind: str, name: str, value: Any): """Registers a value with a given kind and name. diff --git a/py/packages/genkit/tests/genkit/blocks/embedding_test.py b/py/packages/genkit/tests/genkit/blocks/embedding_test.py new file mode 100644 index 0000000000..0a54d6aa2e --- /dev/null +++ b/py/packages/genkit/tests/genkit/blocks/embedding_test.py @@ -0,0 +1,34 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the action module.""" + +from genkit.blocks.embedding import embedder_action_metadata +from genkit.core.action import ActionMetadata + + +def test_embedder_action_metadata(): + """Test for embedder_action_metadata.""" + action_metadata = embedder_action_metadata( + name='test_model', + info={'label': 'test_label'}, + config_schema=None, + ) + + assert isinstance(action_metadata, ActionMetadata) + assert action_metadata.input_json_schema is not None + assert action_metadata.output_json_schema is not None + assert action_metadata.metadata == {'embedder': {'customOptions': None, 'label': 'test_label'}} diff --git a/py/packages/genkit/tests/genkit/blocks/model_test.py b/py/packages/genkit/tests/genkit/blocks/model_test.py index 52925faae4..eca64fee58 100644 --- a/py/packages/genkit/tests/genkit/blocks/model_test.py +++ b/py/packages/genkit/tests/genkit/blocks/model_test.py @@ -14,7 +14,9 @@ PartCounts, get_basic_usage_stats, get_part_counts, + model_action_metadata, ) +from genkit.core.action import ActionMetadata from genkit.core.typing import ( Candidate, GenerateRequest, @@ -453,3 +455,17 @@ def test_response_wrapper_interrupts() -> None: tool_request=ToolRequest(name='tool2', input={'bcd': 4}), metadata={'interrupt': {'banana': 'yes'}} ) ] + + +def test_model_action_metadata(): + """Test for model_action_metadata.""" + action_metadata = model_action_metadata( + name='test_model', + info={'label': 'test_label'}, + config_schema=None, + ) + + assert isinstance(action_metadata, ActionMetadata) + assert action_metadata.input_json_schema is not None + assert action_metadata.output_json_schema is not None + assert action_metadata.metadata == {'model': {'customOptions': None, 'label': 'test_label'}} diff --git a/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py b/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py index c009b150d0..173c6a8037 100644 --- a/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py +++ b/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py @@ -82,6 +82,7 @@ async def test_health_check(asgi_client): async def test_list_actions(asgi_client, mock_registry): """Test that the actions list endpoint returns registered actions.""" mock_registry.list_serializable_actions.return_value = {'action1': {'name': 'Action 1'}} + mock_registry.list_actions.return_value = {'action1': {'name': 'Action 1'}} response = await asgi_client.get('/api/actions') assert response.status_code == 200 assert response.json() == {'action1': {'name': 'Action 1'}} diff --git a/py/packages/genkit/tests/genkit/core/registry_test.py b/py/packages/genkit/tests/genkit/core/registry_test.py index de38e37aa0..6c5b2f52f2 100644 --- a/py/packages/genkit/tests/genkit/core/registry_test.py +++ b/py/packages/genkit/tests/genkit/core/registry_test.py @@ -12,10 +12,36 @@ import pytest from genkit.ai import Genkit, GenkitRegistry, Plugin +from genkit.core.action import ActionMetadata from genkit.core.action.types import ActionKind, ActionMetadataKey from genkit.core.registry import Registry +def test_register_list_actions_resolver(): + """Test for register list actions resolver.""" + registry = Registry() + + def list_actions_mock(): + return [] + + registry.register_list_actions_resolver('test_plugin', list_actions_mock) + + assert 'test_plugin' in registry._list_actions_resolvers + + +def test_register_list_actions_resolver_raises_exception(): + """Test when ValueError is raised.""" + registry = Registry() + + def list_actions_mock(): + return [] + + registry._list_actions_resolvers['test_plugin'] = list_actions_mock + + with pytest.raises(ValueError, match=r'Plugin .* already registered'): + registry.register_list_actions_resolver('test_plugin', list_actions_mock) + + def test_register_action_with_name_and_kind() -> None: """Ensure we can register an action with a name and kind.""" registry = Registry() @@ -65,7 +91,84 @@ def test_list_serializable_actions() -> None: } +@pytest.mark.parametrize( + 'allowed_kind, expected', + [ + ( + set([ActionKind.CUSTOM]), + { + '/custom/test_action': { + 'key': '/custom/test_action', + 'name': 'test_action', + 'inputSchema': None, + 'outputSchema': None, + 'metadata': None, + }, + }, + ), + ( + None, + { + '/custom/test_action': { + 'key': '/custom/test_action', + 'name': 'test_action', + 'inputSchema': None, + 'outputSchema': None, + 'metadata': None, + }, + '/tool/test_tool': { + 'key': '/tool/test_tool', + 'name': 'test_tool', + 'inputSchema': None, + 'outputSchema': None, + 'metadata': None, + }, + }, + ), + ( + set([ActionKind.CUSTOM, ActionKind.TOOL]), + { + '/custom/test_action': { + 'key': '/custom/test_action', + 'name': 'test_action', + 'inputSchema': None, + 'outputSchema': None, + 'metadata': None, + }, + '/tool/test_tool': { + 'key': '/tool/test_tool', + 'name': 'test_tool', + 'inputSchema': None, + 'outputSchema': None, + 'metadata': None, + }, + }, + ), + ], +) +def test_list_actions(allowed_kind, expected) -> None: + """Ensure we can list actions.""" + + def list_actions_mock(): + return [ + ActionMetadata( + kind=ActionKind.CUSTOM, + name='test_action', + ), + ActionMetadata(kind=ActionKind.TOOL, name='test_tool'), + ] + + registry = Registry() + registry._list_actions_resolvers['test_plugin'] = list_actions_mock + registry._entries[ActionKind.CUSTOM] = {} + registry._entries[ActionKind.TOOL] = {} + + got = registry.list_actions({}, allowed_kind) + assert got == expected + + def test_resolve_action_from_plugin(): + """Resolve action from plugin test.""" resolver_calls = [] class MyPlugin(Plugin): @@ -98,6 +201,7 @@ def initialize(self, ai: GenkitRegistry) -> None: def test_register_value(): + """Register a value and lookup test.""" registry = Registry() registry.register_value('format', 'json', [1, 2, 3]) diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py index f31b40c707..bf2b74644b 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py @@ -15,19 +15,23 @@ # SPDX-License-Identifier: Apache-2.0 import os +from functools import cached_property from google import genai from google.auth.credentials import Credentials from google.genai.client import DebugConfig -from google.genai.types import GenerateImagesConfigOrDict, HttpOptions, HttpOptionsDict +from google.genai.types import EmbedContentConfig, GenerateImagesConfigOrDict, HttpOptions, HttpOptionsDict import genkit.plugins.google_genai.constants as const from genkit.ai import GENKIT_CLIENT_HEADER, GenkitRegistry, Plugin +from genkit.blocks.embedding import embedder_action_metadata +from genkit.blocks.model import model_action_metadata from genkit.core.registry import ActionKind from genkit.plugins.google_genai.models.embedder import ( Embedder, GeminiEmbeddingModels, VertexEmbeddingModels, + default_embedder_info, ) from genkit.plugins.google_genai.models.gemini import ( SUPPORTED_MODELS, @@ -139,24 +143,29 @@ def initialize(self, ai: GenkitRegistry) -> None: for version in GeminiEmbeddingModels: embedder = Embedder(version=version, client=self._client) - ai.define_embedder(name=googleai_name(version), fn=embedder.generate) + ai.define_embedder( + name=googleai_name(version), + fn=embedder.generate, + metadata=default_embedder_info(version), + config_schema=EmbedContentConfig, + ) def resolve_action( self, ai: GenkitRegistry, - type: ActionKind, + kind: ActionKind, name: str, ) -> None: """Resolves and action. Args: ai: The Genkit registry. - type: The kind of action to resolve. + kind: The kind of action to resolve. name: The name of the action to resolve. """ - if type == ActionKind.MODEL: + if kind == ActionKind.MODEL: self._resolve_model(ai, name) - elif type == ActionKind.EMBEDDER: + elif kind == ActionKind.EMBEDDER: self._resolve_embedder(ai, name) def _resolve_model(self, ai: GenkitRegistry, name: str) -> None: @@ -204,8 +213,44 @@ def _resolve_embedder(self, ai: GenkitRegistry, name: str) -> None: ai.define_embedder( name=googleai_name(_clean_name), fn=embedder.generate, + metadata=default_embedder_info(_clean_name), + config_schema=EmbedContentConfig, ) + @cached_property + def list_actions(self) -> list[dict[str, str]]: + """Generate a list of available actions or models. + + Returns: + list of actions dicts with the following shape: + { + 'name': str, + 'kind': ActionKind, + } + """ + actions_list = list() + for m in self._client.models.list(): + name = m.name.replace('models/', '') + if 'generateContent' in m.supported_actions: + actions_list.append( + model_action_metadata( + name=googleai_name(name), + info=google_model_info(name).model_dump(), + config_schema=GeminiConfigSchema, + ), + ) + + if 'embedContent' in m.supported_actions: + actions_list.append( + embedder_action_metadata( + name=googleai_name(name), + info=default_embedder_info(name), + config_schema=EmbedContentConfig, + ) + ) + + return actions_list + class VertexAI(Plugin): """Vertex AI plugin for Genkit. @@ -280,6 +325,8 @@ def initialize(self, ai: GenkitRegistry) -> None: ai.define_embedder( name=vertexai_name(version), fn=embedder.generate, + metadata=default_embedder_info(version), + config_schema=EmbedContentConfig, ) for version in ImagenVersion: @@ -294,19 +341,19 @@ def initialize(self, ai: GenkitRegistry) -> None: def resolve_action( self, ai: GenkitRegistry, - type: ActionKind, + kind: ActionKind, name: str, ) -> None: """Resolves and action. Args: ai: The Genkit registry. - type: The kind of action to resolve. + kind: The kind of action to resolve. name: The name of the action to resolve. """ - if type == ActionKind.MODEL: + if kind == ActionKind.MODEL: self._resolve_model(ai, name) - elif type == ActionKind.EMBEDDER: + elif kind == ActionKind.EMBEDDER: self._resolve_embedder(ai, name) def _resolve_model(self, ai: GenkitRegistry, name: str) -> None: @@ -361,8 +408,43 @@ def _resolve_embedder(self, ai: GenkitRegistry, name: str) -> None: ai.define_embedder( name=vertexai_name(_clean_name), fn=embedder.generate, + metadata=default_embedder_info(_clean_name), + config_schema=EmbedContentConfig, ) + @cached_property + def list_actions(self) -> list[dict[str, str]]: + """Generate a list of available actions or models. + + Returns: + list of actions dicts with the following shape: + { + 'name': str, + 'kind': ActionKind, + } + """ + actions_list = list() + for m in self._client.models.list(): + name = m.name.replace('publishers/google/models/', '') + if 'embed' in name.lower(): + actions_list.append( + embedder_action_metadata( + name=vertexai_name(name), + info=default_embedder_info(name), + config_schema=EmbedContentConfig, + ) + ) + # List all the vertexai models for generate actions + actions_list.append( + model_action_metadata( + name=vertexai_name(name), + info=google_model_info(name).model_dump(), + config_schema=GeminiConfigSchema, + ), + ) + + return actions_list + def _inject_attribution_headers(http_options: HttpOptions | dict | None = None): """Adds genkit client info to the appropriate http headers.""" diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/embedder.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/embedder.py index 80c9ab142e..eeaa64bd6f 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/embedder.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/embedder.py @@ -16,7 +16,8 @@ """Google-Genai embedder model.""" -import sys # noqa +import sys +from typing import Any # noqa if sys.version_info < (3, 11): # noqa from strenum import StrEnum # noqa @@ -59,6 +60,11 @@ class EmbeddingTaskType(StrEnum): FACT_VERIFICATION = 'FACT_VERIFICATION' +def default_embedder_info(name: str) -> dict[str, Any]: + """Returns default info for embedders given a name.""" + return {'dimensions': 768, 'label': f'Google AI - {name}', 'supports': {'input': ['text']}} + + class Embedder: """Embedder for Google-Genai.""" diff --git a/py/plugins/google-genai/test/test_google_plugin.py b/py/plugins/google-genai/test/test_google_plugin.py index ba0dbaa308..b1d52a14a6 100644 --- a/py/plugins/google-genai/test/test_google_plugin.py +++ b/py/plugins/google-genai/test/test_google_plugin.py @@ -23,10 +23,13 @@ from unittest.mock import MagicMock, patch, ANY from google.auth.credentials import Credentials -from google.genai.types import GenerateImagesConfigOrDict, HttpOptions +from pydantic import BaseModel +from google.genai.types import EmbedContentConfig, GenerateImagesConfigOrDict, HttpOptions import pytest from genkit.ai import Genkit, GENKIT_CLIENT_HEADER +from genkit.blocks.embedding import embedder_action_metadata +from genkit.blocks.model import model_action_metadata from genkit.core.registry import ActionKind from genkit.plugins.google_genai import ( GoogleAI, @@ -38,6 +41,7 @@ from genkit.plugins.google_genai.models.embedder import ( GeminiEmbeddingModels, VertexEmbeddingModels, + default_embedder_info, ) from genkit.plugins.google_genai.models.gemini import ( DEFAULT_SUPPORTS_MODEL, @@ -45,11 +49,11 @@ SUPPORTED_MODELS, GoogleAIGeminiVersion, VertexAIGeminiVersion, + google_model_info, ) from genkit.plugins.google_genai.models.imagen import ( SUPPORTED_MODELS as IMAGE_SUPPORTED_MODELS, DEFAULT_IMAGE_SUPPORT, - ImagenModel, ImagenVersion, ) from genkit.types import ( @@ -57,6 +61,14 @@ ) +@pytest.fixture +@patch('google.genai.client.Client') +def googleai_plugin_instance(client): + """GoogleAI fixture.""" + api_key = 'test_api_key' + return GoogleAI(api_key=api_key) + + class TestGoogleAIInit(unittest.TestCase): """Test cases for __init__ plugin.""" @@ -140,28 +152,28 @@ def test_googleai_initialize(): ai_mock.define_embedder.assert_any_call( name=googleai_name(version), fn=ANY, + metadata=ANY, + config_schema=EmbedContentConfig, ) @patch('genkit.plugins.google_genai.GoogleAI._resolve_model') -def test_googleai_resolve_action_model(mock_resolve_action): +def test_googleai_resolve_action_model(mock_resolve_action, googleai_plugin_instance): """Test resolve action for model.""" - api_key = 'test_api_key' - plugin = GoogleAI(api_key=api_key) + plugin = googleai_plugin_instance ai_mock = MagicMock(spec=Genkit) - plugin.resolve_action(ai=ai_mock, type=ActionKind.MODEL, name='lazaro-model') + plugin.resolve_action(ai=ai_mock, kind=ActionKind.MODEL, name='lazaro-model') mock_resolve_action.assert_called_once_with(ai_mock, 'lazaro-model') @patch('genkit.plugins.google_genai.GoogleAI._resolve_embedder') -def test_googleai_resolve_action_embedder(mock_resolve_action): +def test_googleai_resolve_action_embedder(mock_resolve_action, googleai_plugin_instance): """Test resolve action for embedder.""" - api_key = 'test_api_key' - plugin = GoogleAI(api_key=api_key) + plugin = googleai_plugin_instance ai_mock = MagicMock(spec=Genkit) - plugin.resolve_action(ai=ai_mock, type=ActionKind.EMBEDDER, name='lazaro-model') + plugin.resolve_action(ai=ai_mock, kind=ActionKind.EMBEDDER, name='lazaro-model') mock_resolve_action.assert_called_once_with(ai_mock, 'lazaro-model') @@ -186,10 +198,10 @@ def test_googleai__resolve_model( model_name, expected_model_name, key, + googleai_plugin_instance, ): """Tests for GoogleAI._resolve_model method.""" - api_key = 'test_api_key' - plugin = GoogleAI(api_key=api_key) + plugin = googleai_plugin_instance ai_mock = MagicMock(spec=Genkit) mock_google_model_info.return_value = ModelInfo( @@ -212,25 +224,20 @@ def test_googleai__resolve_model( @pytest.mark.parametrize( - 'model_name, expected_model_name', + 'model_name, expected_model_name, clean_name', [ - ( - 'gemini-pro-deluxe-max', - 'googleai/gemini-pro-deluxe-max', - ), - ( - 'googleai/gemini-pro-deluxe-max', - 'googleai/gemini-pro-deluxe-max', - ), + ('gemini-pro-deluxe-max', 'googleai/gemini-pro-deluxe-max', 'gemini-pro-deluxe-max'), + ('googleai/gemini-pro-deluxe-max', 'googleai/gemini-pro-deluxe-max', 'gemini-pro-deluxe-max'), ], ) def test_googleai__resolve_embedder( model_name, expected_model_name, + clean_name, + googleai_plugin_instance, ): """Tests for GoogleAI._resolve_embedder method.""" - api_key = 'test_api_key' - plugin = GoogleAI(api_key=api_key) + plugin = googleai_plugin_instance ai_mock = MagicMock(spec=Genkit) plugin._resolve_embedder( @@ -239,11 +246,54 @@ def test_googleai__resolve_embedder( ) ai_mock.define_embedder.assert_called_once_with( - name=expected_model_name, - fn=ANY, + name=expected_model_name, fn=ANY, config_schema=EmbedContentConfig, metadata=default_embedder_info(clean_name) ) +def test_googleai_list_actions(googleai_plugin_instance): + """Unit test for list actions.""" + + class MockModel(BaseModel): + """mock.""" + + supported_actions: list[str] + name: str + + models_return_value = [ + MockModel(supported_actions=['generateContent'], name='models/model1'), + MockModel(supported_actions=['embedContent'], name='models/model2'), + MockModel(supported_actions=['generateContent', 'embedContent'], name='models/model3'), + ] + + mock_client = MagicMock() + mock_client.models.list.return_value = models_return_value + googleai_plugin_instance._client = mock_client + + result = googleai_plugin_instance.list_actions + assert result == [ + model_action_metadata( + name=googleai_name('model1'), + info=google_model_info('model1').model_dump(), + config_schema=GeminiConfigSchema, + ), + embedder_action_metadata( + name=googleai_name('model2'), + info=default_embedder_info('model2'), + config_schema=EmbedContentConfig, + ), + model_action_metadata( + name=googleai_name('model3'), + info=google_model_info('model3').model_dump(), + config_schema=GeminiConfigSchema, + ), + embedder_action_metadata( + name=googleai_name('model3'), + info=default_embedder_info('model3'), + config_schema=EmbedContentConfig, + ), + ] + + @pytest.mark.parametrize( 'input_options, expected_headers', [ @@ -457,6 +507,8 @@ def test_vertexai_initialize(vertexai_plugin_instance): ai_mock.define_embedder.assert_any_call( name=vertexai_name(version), fn=ANY, + metadata=ANY, + config_schema=EmbedContentConfig, ) @@ -466,7 +518,7 @@ def test_vertexai_resolve_action_model(mock_resolve_action, vertexai_plugin_inst plugin = vertexai_plugin_instance ai_mock = MagicMock(spec=Genkit) - plugin.resolve_action(ai=ai_mock, type=ActionKind.MODEL, name='lazaro-model') + plugin.resolve_action(ai=ai_mock, kind=ActionKind.MODEL, name='lazaro-model') mock_resolve_action.assert_called_once_with(ai_mock, 'lazaro-model') @@ -476,7 +528,7 @@ def test_vertexai_resolve_action_embedder(mock_resolve_action, vertexai_plugin_i plugin = vertexai_plugin_instance ai_mock = MagicMock(spec=Genkit) - plugin.resolve_action(ai=ai_mock, type=ActionKind.EMBEDDER, name='lazaro-model') + plugin.resolve_action(ai=ai_mock, kind=ActionKind.EMBEDDER, name='lazaro-model') mock_resolve_action.assert_called_once_with(ai_mock, 'lazaro-model') @@ -570,21 +622,24 @@ def test_vertexai__resolve_model( @pytest.mark.parametrize( - 'model_name, expected_model_name', + 'model_name, expected_model_name, clean_name', [ ( 'gemini-pro-deluxe-max', 'vertexai/gemini-pro-deluxe-max', + 'gemini-pro-deluxe-max', ), ( 'vertexai/gemini-pro-deluxe-max', 'vertexai/gemini-pro-deluxe-max', + 'gemini-pro-deluxe-max', ), ], ) def test_vertexai__resolve_embedder( model_name, expected_model_name, + clean_name, vertexai_plugin_instance, ): """Tests for VertexAI._resolve_embedder method.""" @@ -597,6 +652,53 @@ def test_vertexai__resolve_embedder( ) ai_mock.define_embedder.assert_called_once_with( - name=expected_model_name, - fn=ANY, + name=expected_model_name, fn=ANY, config_schema=EmbedContentConfig, metadata=default_embedder_info(clean_name) ) + + +def test_vertexai_list_actions(vertexai_plugin_instance): + """Unit test for list actions.""" + + class MockModel(BaseModel): + """mock.""" + + name: str + + models_return_value = [ + MockModel(name='publishers/google/models/model1'), + MockModel(name='publishers/google/models/model2_embeddings'), + MockModel(name='publishers/google/models/model3_embedder'), + ] + + mock_client = MagicMock() + mock_client.models.list.return_value = models_return_value + vertexai_plugin_instance._client = mock_client + + result = vertexai_plugin_instance.list_actions + assert result == [ + model_action_metadata( + name=vertexai_name('model1'), + info=google_model_info('model1').model_dump(), + config_schema=GeminiConfigSchema, + ), + embedder_action_metadata( + name=vertexai_name('model2_embeddings'), + info=default_embedder_info('model2_embeddings'), + config_schema=EmbedContentConfig, + ), + model_action_metadata( + name=vertexai_name('model2_embeddings'), + info=google_model_info('model2_embeddings').model_dump(), + config_schema=GeminiConfigSchema, + ), + embedder_action_metadata( + name=vertexai_name('model3_embedder'), + info=default_embedder_info('model3_embedder'), + config_schema=EmbedContentConfig, + ), + model_action_metadata( + name=vertexai_name('model3_embedder'), + info=google_model_info('model3_embedder').model_dump(), + config_schema=GeminiConfigSchema, + ), + ]