Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/bin/sanitize_schema_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef: # noqa: N802
"""
# First apply base class transformations recursively
node = super().generic_visit(_node)
new_body: list[ ast.stmt | ast.Constant | ast.Assign ] = []
new_body: list[ast.stmt | ast.Constant | ast.Assign] = []

# Handle Docstrings
if not node.body or not isinstance(node.body[0], ast.Expr) or not isinstance(node.body[0].value, ast.Constant):
Expand Down
8 changes: 7 additions & 1 deletion py/packages/genkit/src/genkit/ai/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,14 @@ def _initialize_registry(self, model: str | None, plugins: list[Plugin] | None)
def resolver(kind, name, plugin=plugin):
return plugin.resolve_action(self, kind, name)

def action_resolver(plugin=plugin):
if isinstance(plugin.list_actions, list):
return plugin.list_actions
else:
return plugin.list_actions()

self.registry.register_action_resolver(plugin.plugin_name(), resolver)
self.registry.register_list_actions_resolver(plugin.plugin_name(), plugin.list_actions)
self.registry.register_list_actions_resolver(plugin.plugin_name(), action_resolver)
else:
raise ValueError(f'Invalid {plugin=} provided to Genkit: must be of type `genkit.ai.Plugin`')

Expand Down
8 changes: 1 addition & 7 deletions py/packages/genkit/src/genkit/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,7 @@ def list_actions(
actions = {}

for plugin_name in self._list_actions_resolvers:
actions_lister = self._list_actions_resolvers[plugin_name]

# TODO: Set all the list_actions plugins' methods as cached_properties.
if isinstance(actions_lister, list):
actions_list = actions_lister
else:
actions_list = actions_lister()
actions_list = self._list_actions_resolvers[plugin_name]()

for _action in actions_list:
kind = _action.kind
Expand Down
13 changes: 7 additions & 6 deletions py/plugins/google-genai/test/test_google_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_init_with_credentials(self, mock_genai_client):
plugin = GoogleAI(credentials=mock_credentials)
mock_genai_client.assert_called_once_with(
vertexai=False,
api_key=None,
api_key=ANY,
credentials=mock_credentials,
debug_config=None,
http_options=_inject_attribution_headers(),
Expand All @@ -122,11 +122,12 @@ def test_init_with_credentials(self, mock_genai_client):

def test_init_raises_value_error_no_api_key(self):
"""Test using credentials parameter."""
with self.assertRaisesRegex(
ValueError,
'Gemini api key should be passed in plugin params or as a GEMINI_API_KEY environment variable',
):
GoogleAI()
with patch.dict(os.environ, {'GEMINI_API_KEY': ''}, clear=True):
with self.assertRaisesRegex(
ValueError,
'Gemini api key should be passed in plugin params or as a GEMINI_API_KEY environment variable',
):
GoogleAI()


def test_googleai_initialize():
Expand Down
7 changes: 4 additions & 3 deletions py/plugins/ollama/src/genkit/plugins/ollama/embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
#
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Callable

from pydantic import BaseModel

import ollama as ollama_api
from genkit.blocks.embedding import EmbedRequest, EmbedResponse
from genkit.types import Embedding

Expand All @@ -30,10 +31,10 @@ class EmbeddingDefinition(BaseModel):
class OllamaEmbedder:
def __init__(
self,
client: ollama_api.AsyncClient,
client: Callable,
embedding_definition: EmbeddingDefinition,
):
self.client = client
self.client = client()
self.embedding_definition = embedding_definition

async def embed(self, request: EmbedRequest) -> EmbedResponse:
Expand Down
5 changes: 3 additions & 2 deletions py/plugins/ollama/src/genkit/plugins/ollama/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Callable
from typing import Any, Literal

import structlog
Expand Down Expand Up @@ -56,8 +57,8 @@ class ModelDefinition(BaseModel):


class OllamaModel:
def __init__(self, client: ollama_api.AsyncClient, model_definition: ModelDefinition):
self.client = client
def __init__(self, client: Callable, model_definition: ModelDefinition):
self.client = client()
self.model_definition = model_definition

async def generate(self, request: GenerateRequest, ctx: ActionRunContext | None = None) -> GenerateResponse:
Expand Down
54 changes: 53 additions & 1 deletion py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@

"""Ollama Plugin for Genkit."""

import asyncio
from functools import cached_property, partial

import structlog

import ollama as ollama_api
from genkit.ai import GenkitRegistry, Plugin
from genkit.blocks.embedding import embedder_action_metadata
from genkit.blocks.model import model_action_metadata
from genkit.core.registry import ActionKind
from genkit.plugins.ollama.constants import (
DEFAULT_OLLAMA_SERVER_URL,
Expand All @@ -34,6 +41,7 @@
from genkit.types import GenerationCommonConfig

OLLAMA_PLUGIN_NAME = 'ollama'
logger = structlog.get_logger(__name__)


def ollama_name(name: str) -> str:
Expand Down Expand Up @@ -80,7 +88,7 @@ def __init__(
self.server_address = server_address or DEFAULT_OLLAMA_SERVER_URL
self.request_headers = request_headers or {}

self.client = ollama_api.AsyncClient(host=self.server_address)
self.client = partial(ollama_api.AsyncClient, host=self.server_address)

def initialize(self, ai: GenkitRegistry) -> None:
"""Initialize the Ollama plugin.
Expand Down Expand Up @@ -198,3 +206,47 @@ def _define_ollama_embedder(self, ai: GenkitRegistry, embedder_ref: EmbeddingDef
},
},
)

@cached_property
def list_actions(self) -> list[dict[str, str]]:
"""."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

_client = self.client()
response = loop.run_until_complete(_client.list())

actions = []
for model in response.models:
_name = model.model
if 'embed' in _name:
actions.append(
embedder_action_metadata(
name=ollama_name(_name),
config_schema=ollama_api.Options,
info={
'label': f'Ollama Embedding - {_name}',
'dimensions': None,
'supports': {
'input': ['text'],
},
},
)
)
else:
actions.append(
model_action_metadata(
name=ollama_name(_name),
config_schema=GenerationCommonConfig,
info={
'label': f'Ollama - {_name}',
'multiturn': True,
'system_role': True,
'tools': False,
},
)
)
return actions
68 changes: 51 additions & 17 deletions py/plugins/ollama/tests/test_plugin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
"""Unit tests for Ollama Plugin."""

import unittest
from unittest.mock import ANY, MagicMock, patch
from unittest.mock import ANY, AsyncMock, MagicMock

import ollama as ollama_api
import pytest
from pydantic import BaseModel

from genkit.ai import ActionKind, Genkit
from genkit.plugins.ollama import Ollama, ollama_name
Expand All @@ -33,30 +34,21 @@
class TestOllamaInit(unittest.TestCase):
"""Test cases for Ollama.__init__ plugin."""

@patch('ollama.AsyncClient')
def test_init_with_models(self, ollama_aclient):
def test_init_with_models(self):
"""Test correct propagation of models param."""
model_ref = ModelDefinition(name='test_model')
plugin = Ollama(models=[model_ref])

assert plugin.models[0] == model_ref
ollama_aclient.assert_called_once_with(
host=DEFAULT_OLLAMA_SERVER_URL,
)

@patch('ollama.AsyncClient')
def test_init_with_embedders(self, ollama_aclient):
def test_init_with_embedders(self):
"""Test correct propagation of embedders param."""
embedder_ref = EmbeddingDefinition(name='test_embedder')
plugin = Ollama(embedders=[embedder_ref])

assert plugin.embedders[0] == embedder_ref
ollama_aclient.assert_called_once_with(
host=DEFAULT_OLLAMA_SERVER_URL,
)

@patch('ollama.AsyncClient')
def test_init_with_options(self, ollama_aclient):
def test_init_with_options(self):
"""Test correct propagation of other options param."""
model_ref = ModelDefinition(name='test_model')
embedder_ref = EmbeddingDefinition(name='test_embedder')
Expand All @@ -75,10 +67,6 @@ def test_init_with_options(self, ollama_aclient):
assert plugin.server_address == server_address
assert plugin.request_headers == headers

ollama_aclient.assert_called_once_with(
host=server_address,
)


def test_initialize(ollama_plugin_instance):
"""Test initialize method of Ollama plugin."""
Expand Down Expand Up @@ -240,3 +228,49 @@ def test_define_ollama_embedder(name, expected_name, clean_name, ollama_plugin_i
},
},
)


def test_list_actions(ollama_plugin_instance):
"""Unit tests for list_actions method."""

class MockModelResponse(BaseModel):
model: str

class MockListResponse(BaseModel):
models: list[MockModelResponse]

_client_mock = MagicMock()
list_method_mock = AsyncMock()
_client_mock.list = list_method_mock

list_method_mock.return_value = MockListResponse(
models=[
MockModelResponse(model='test_model'),
MockModelResponse(model='test_embedder'),
]
)

def mock_client():
return _client_mock

ollama_plugin_instance.client = mock_client

actions = ollama_plugin_instance.list_actions

assert len(actions) == 2

has_model = False
for action in actions:
if action.kind == ActionKind.MODEL:
has_model = True
break

assert has_model

has_embedder = False
for action in actions:
if action.kind == ActionKind.EMBEDDER:
has_embedder = True
break

assert has_embedder
Loading