diff --git a/py/bin/sanitize_schema_typing.py b/py/bin/sanitize_schema_typing.py index 2ec483c654..6138127e9f 100644 --- a/py/bin/sanitize_schema_typing.py +++ b/py/bin/sanitize_schema_typing.py @@ -129,7 +129,7 @@ def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef: # noqa: N802 """ # First apply base class transformations recursively node = super().generic_visit(_node) - new_body: list[ ast.stmt | ast.Constant | ast.Assign ] = [] + new_body: list[ast.stmt | ast.Constant | ast.Assign] = [] # Handle Docstrings if not node.body or not isinstance(node.body[0], ast.Expr) or not isinstance(node.body[0].value, ast.Constant): diff --git a/py/packages/genkit/src/genkit/ai/_base.py b/py/packages/genkit/src/genkit/ai/_base.py index da6ddbc2e0..4433597b47 100644 --- a/py/packages/genkit/src/genkit/ai/_base.py +++ b/py/packages/genkit/src/genkit/ai/_base.py @@ -117,8 +117,14 @@ def _initialize_registry(self, model: str | None, plugins: list[Plugin] | None) def resolver(kind, name, plugin=plugin): return plugin.resolve_action(self, kind, name) + def action_resolver(plugin=plugin): + if isinstance(plugin.list_actions, list): + return plugin.list_actions + else: + return plugin.list_actions() + self.registry.register_action_resolver(plugin.plugin_name(), resolver) - self.registry.register_list_actions_resolver(plugin.plugin_name(), plugin.list_actions) + self.registry.register_list_actions_resolver(plugin.plugin_name(), action_resolver) else: raise ValueError(f'Invalid {plugin=} provided to Genkit: must be of type `genkit.ai.Plugin`') diff --git a/py/packages/genkit/src/genkit/core/registry.py b/py/packages/genkit/src/genkit/core/registry.py index 09ca6ac766..a821e191d0 100644 --- a/py/packages/genkit/src/genkit/core/registry.py +++ b/py/packages/genkit/src/genkit/core/registry.py @@ -252,13 +252,7 @@ def list_actions( actions = {} for plugin_name in self._list_actions_resolvers: - actions_lister = self._list_actions_resolvers[plugin_name] - - # TODO: Set all the list_actions plugins' methods as cached_properties. - if isinstance(actions_lister, list): - actions_list = actions_lister - else: - actions_list = actions_lister() + actions_list = self._list_actions_resolvers[plugin_name]() for _action in actions_list: kind = _action.kind diff --git a/py/plugins/google-genai/test/test_google_plugin.py b/py/plugins/google-genai/test/test_google_plugin.py index b1d52a14a6..d621f422e6 100644 --- a/py/plugins/google-genai/test/test_google_plugin.py +++ b/py/plugins/google-genai/test/test_google_plugin.py @@ -111,7 +111,7 @@ def test_init_with_credentials(self, mock_genai_client): plugin = GoogleAI(credentials=mock_credentials) mock_genai_client.assert_called_once_with( vertexai=False, - api_key=None, + api_key=ANY, credentials=mock_credentials, debug_config=None, http_options=_inject_attribution_headers(), @@ -122,11 +122,12 @@ def test_init_with_credentials(self, mock_genai_client): def test_init_raises_value_error_no_api_key(self): """Test using credentials parameter.""" - with self.assertRaisesRegex( - ValueError, - 'Gemini api key should be passed in plugin params or as a GEMINI_API_KEY environment variable', - ): - GoogleAI() + with patch.dict(os.environ, {'GEMINI_API_KEY': ''}, clear=True): + with self.assertRaisesRegex( + ValueError, + 'Gemini api key should be passed in plugin params or as a GEMINI_API_KEY environment variable', + ): + GoogleAI() def test_googleai_initialize(): diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/embedders.py b/py/plugins/ollama/src/genkit/plugins/ollama/embedders.py index c89a328ec6..995336649d 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/embedders.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/embedders.py @@ -14,9 +14,10 @@ # # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Callable + from pydantic import BaseModel -import ollama as ollama_api from genkit.blocks.embedding import EmbedRequest, EmbedResponse from genkit.types import Embedding @@ -30,10 +31,10 @@ class EmbeddingDefinition(BaseModel): class OllamaEmbedder: def __init__( self, - client: ollama_api.AsyncClient, + client: Callable, embedding_definition: EmbeddingDefinition, ): - self.client = client + self.client = client() self.embedding_definition = embedding_definition async def embed(self, request: EmbedRequest) -> EmbedResponse: diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/models.py b/py/plugins/ollama/src/genkit/plugins/ollama/models.py index 0d9cd600d8..4c69bab33e 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/models.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/models.py @@ -14,6 +14,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Callable from typing import Any, Literal import structlog @@ -56,8 +57,8 @@ class ModelDefinition(BaseModel): class OllamaModel: - def __init__(self, client: ollama_api.AsyncClient, model_definition: ModelDefinition): - self.client = client + def __init__(self, client: Callable, model_definition: ModelDefinition): + self.client = client() self.model_definition = model_definition async def generate(self, request: GenerateRequest, ctx: ActionRunContext | None = None) -> GenerateResponse: diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py index 9264b3f036..947271885e 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py @@ -16,8 +16,15 @@ """Ollama Plugin for Genkit.""" +import asyncio +from functools import cached_property, partial + +import structlog + import ollama as ollama_api from genkit.ai import GenkitRegistry, Plugin +from genkit.blocks.embedding import embedder_action_metadata +from genkit.blocks.model import model_action_metadata from genkit.core.registry import ActionKind from genkit.plugins.ollama.constants import ( DEFAULT_OLLAMA_SERVER_URL, @@ -34,6 +41,7 @@ from genkit.types import GenerationCommonConfig OLLAMA_PLUGIN_NAME = 'ollama' +logger = structlog.get_logger(__name__) def ollama_name(name: str) -> str: @@ -80,7 +88,7 @@ def __init__( self.server_address = server_address or DEFAULT_OLLAMA_SERVER_URL self.request_headers = request_headers or {} - self.client = ollama_api.AsyncClient(host=self.server_address) + self.client = partial(ollama_api.AsyncClient, host=self.server_address) def initialize(self, ai: GenkitRegistry) -> None: """Initialize the Ollama plugin. @@ -198,3 +206,47 @@ def _define_ollama_embedder(self, ai: GenkitRegistry, embedder_ref: EmbeddingDef }, }, ) + + @cached_property + def list_actions(self) -> list[dict[str, str]]: + """.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + _client = self.client() + response = loop.run_until_complete(_client.list()) + + actions = [] + for model in response.models: + _name = model.model + if 'embed' in _name: + actions.append( + embedder_action_metadata( + name=ollama_name(_name), + config_schema=ollama_api.Options, + info={ + 'label': f'Ollama Embedding - {_name}', + 'dimensions': None, + 'supports': { + 'input': ['text'], + }, + }, + ) + ) + else: + actions.append( + model_action_metadata( + name=ollama_name(_name), + config_schema=GenerationCommonConfig, + info={ + 'label': f'Ollama - {_name}', + 'multiturn': True, + 'system_role': True, + 'tools': False, + }, + ) + ) + return actions diff --git a/py/plugins/ollama/tests/test_plugin_api.py b/py/plugins/ollama/tests/test_plugin_api.py index 2d8aec98ec..5d20c0e83e 100644 --- a/py/plugins/ollama/tests/test_plugin_api.py +++ b/py/plugins/ollama/tests/test_plugin_api.py @@ -17,10 +17,11 @@ """Unit tests for Ollama Plugin.""" import unittest -from unittest.mock import ANY, MagicMock, patch +from unittest.mock import ANY, AsyncMock, MagicMock import ollama as ollama_api import pytest +from pydantic import BaseModel from genkit.ai import ActionKind, Genkit from genkit.plugins.ollama import Ollama, ollama_name @@ -33,30 +34,21 @@ class TestOllamaInit(unittest.TestCase): """Test cases for Ollama.__init__ plugin.""" - @patch('ollama.AsyncClient') - def test_init_with_models(self, ollama_aclient): + def test_init_with_models(self): """Test correct propagation of models param.""" model_ref = ModelDefinition(name='test_model') plugin = Ollama(models=[model_ref]) assert plugin.models[0] == model_ref - ollama_aclient.assert_called_once_with( - host=DEFAULT_OLLAMA_SERVER_URL, - ) - @patch('ollama.AsyncClient') - def test_init_with_embedders(self, ollama_aclient): + def test_init_with_embedders(self): """Test correct propagation of embedders param.""" embedder_ref = EmbeddingDefinition(name='test_embedder') plugin = Ollama(embedders=[embedder_ref]) assert plugin.embedders[0] == embedder_ref - ollama_aclient.assert_called_once_with( - host=DEFAULT_OLLAMA_SERVER_URL, - ) - @patch('ollama.AsyncClient') - def test_init_with_options(self, ollama_aclient): + def test_init_with_options(self): """Test correct propagation of other options param.""" model_ref = ModelDefinition(name='test_model') embedder_ref = EmbeddingDefinition(name='test_embedder') @@ -75,10 +67,6 @@ def test_init_with_options(self, ollama_aclient): assert plugin.server_address == server_address assert plugin.request_headers == headers - ollama_aclient.assert_called_once_with( - host=server_address, - ) - def test_initialize(ollama_plugin_instance): """Test initialize method of Ollama plugin.""" @@ -240,3 +228,49 @@ def test_define_ollama_embedder(name, expected_name, clean_name, ollama_plugin_i }, }, ) + + +def test_list_actions(ollama_plugin_instance): + """Unit tests for list_actions method.""" + + class MockModelResponse(BaseModel): + model: str + + class MockListResponse(BaseModel): + models: list[MockModelResponse] + + _client_mock = MagicMock() + list_method_mock = AsyncMock() + _client_mock.list = list_method_mock + + list_method_mock.return_value = MockListResponse( + models=[ + MockModelResponse(model='test_model'), + MockModelResponse(model='test_embedder'), + ] + ) + + def mock_client(): + return _client_mock + + ollama_plugin_instance.client = mock_client + + actions = ollama_plugin_instance.list_actions + + assert len(actions) == 2 + + has_model = False + for action in actions: + if action.kind == ActionKind.MODEL: + has_model = True + break + + assert has_model + + has_embedder = False + for action in actions: + if action.kind == ActionKind.EMBEDDER: + has_embedder = True + break + + assert has_embedder