diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/constants.py b/py/plugins/ollama/src/genkit/plugins/ollama/constants.py index 7848f0f8ad..4060d414a5 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/constants.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/constants.py @@ -25,5 +25,7 @@ class OllamaAPITypes(StrEnum): + """Generation types for Ollama API.""" + CHAT = 'chat' GENERATE = 'generate' diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/embedders.py b/py/plugins/ollama/src/genkit/plugins/ollama/embedders.py index 995336649d..834674ec20 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/embedders.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/embedders.py @@ -23,21 +23,57 @@ class EmbeddingDefinition(BaseModel): + """Defines an embedding model for Ollama. + + This class specifies the characteristics of an embedding model that + can be used with the Ollama plugin. While Ollama models have fixed + output dimensions, this definition can specify the expected + dimensionality for informational purposes or for future truncation + support. + """ + name: str - # Ollama do not support changing dimensionality, but it can be truncated dimensions: int | None = None class OllamaEmbedder: + """Handles embedding requests using an Ollama embedding model. + + This class provides the necessary logic to interact with a specific + Ollama embedding model, processing input text into vector embeddings. + """ + def __init__( self, client: Callable, embedding_definition: EmbeddingDefinition, - ): + ) -> None: + """Initializes the OllamaEmbedder. + + Sets up the client for communicating with the Ollama server and stores + the definition of the embedding model. + + Args: + client: A callable that returns an asynchronous Ollama client instance. + embedding_definition: The definition describing the specific Ollama + embedding model to be used. + """ self.client = client() self.embedding_definition = embedding_definition async def embed(self, request: EmbedRequest) -> EmbedResponse: + """Generates embeddings for the provided input text. + + Converts the input documents from the Genkit EmbedRequest into a raw + list of strings, sends them to the Ollama server for embedding, and then + formats the response into a Genkit EmbedResponse. + + Args: + request: The embedding request containing the input documents. + + Returns: + An EmbedResponse containing the generated vector embeddings. + """ input_raw = [] for doc in request.input: input_raw.extend([content.root.text for content in doc.content]) diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/models.py b/py/plugins/ollama/src/genkit/plugins/ollama/models.py index 4c69bab33e..8dea7c8ae5 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/models.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/models.py @@ -14,11 +14,13 @@ # # SPDX-License-Identifier: Apache-2.0 +"""Models package for Ollama plugin.""" + from collections.abc import Callable from typing import Any, Literal import structlog -from pydantic import BaseModel, Field, HttpUrl +from pydantic import BaseModel import ollama as ollama_api from genkit.ai import ActionRunContext @@ -47,17 +49,37 @@ class OllamaSupports(BaseModel): + """Supports for Ollama models.""" + tools: bool = False class ModelDefinition(BaseModel): + """Meta definition for Ollama models.""" + name: str api_type: OllamaAPITypes = 'chat' supports: OllamaSupports = OllamaSupports() class OllamaModel: - def __init__(self, client: Callable, model_definition: ModelDefinition): + """Represents an Ollama language model for use with Genkit. + + This class encapsulates the interaction logic for a specific Ollama model, + allowing it to be integrated into the Genkit framework for generative tasks. + """ + + def __init__(self, client: Callable, model_definition: ModelDefinition) -> None: + """Initializes the OllamaModel. + + Sets up the client for communicating with the Ollama server and stores + the definition of the model. + + Args: + client: A callable that returns an asynchronous Ollama client instance. + model_definition: The definition describing the specific Ollama model + to be used (e.g., its name, API type, supported features). + """ self.client = client() self.model_definition = model_definition @@ -110,7 +132,6 @@ async def generate(self, request: GenerateRequest, ctx: ActionRunContext | None ), ) - # FIXME: Missing return statement. async def _chat_with_ollama( self, request: GenerateRequest, ctx: ActionRunContext | None = None ) -> ollama_api.ChatResponse | None: @@ -167,6 +188,7 @@ async def _chat_with_ollama( content=self._build_multimodal_chat_response(chat_response=chunk), ) ) + return chat_response else: return chat_response @@ -182,7 +204,6 @@ async def _generate_ollama_response( Returns: The generated response. """ - # FIXME: Missing return statement. prompt = self.build_prompt(request) streaming_request = self.is_streaming_request(ctx=ctx) @@ -341,6 +362,7 @@ def _to_ollama_role( @staticmethod def is_streaming_request(ctx: ActionRunContext | None) -> bool: + """Determines if streaming mode is requested.""" return ctx and ctx.is_streaming @staticmethod @@ -348,6 +370,19 @@ def get_usage_info( basic_generation_usage: GenerationUsage, api_response: ollama_api.GenerateResponse | ollama_api.ChatResponse, ) -> GenerationUsage: + """Extracts and calculates token usage information from an Ollama API response. + + Updates a basic generation usage object with input, output, and total token counts + based on the details provided in the Ollama API response. + + Args: + basic_generation_usage: An existing GenerationUsage object to update. + api_response: The response object received from the Ollama API, + containing token count details. + + Returns: + The updated GenerationUsage object with token counts populated. + """ if api_response: basic_generation_usage.input_tokens = api_response.prompt_eval_count or 0 basic_generation_usage.output_tokens = api_response.eval_count or 0 @@ -372,10 +407,10 @@ def _convert_parameters(input_schema: dict[str, Any]) -> ollama_api.Tool.Functio if schema_type == 'object': schema.properties = {} - properties = input_schema['properties'] + properties = input_schema.get('properties', []) for key in properties: schema.properties[key] = ollama_api.Tool.Function.Parameters.Property( - type=properties[key]['type'], description=properties[key]['description'] or '' + type=properties[key]['type'], description=properties[key].get('description', '') ) return schema diff --git a/py/plugins/ollama/tests/models/test_embedders.py b/py/plugins/ollama/tests/models/test_embedders.py new file mode 100644 index 0000000000..860256c735 --- /dev/null +++ b/py/plugins/ollama/tests/models/test_embedders.py @@ -0,0 +1,141 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for Ollama embedders package.""" + +import unittest +from unittest.mock import AsyncMock, MagicMock + +import ollama as ollama_api + +from genkit.plugins.ollama.embedders import EmbeddingDefinition, OllamaEmbedder +from genkit.types import ( + Document, + Embedding, + EmbedRequest, + EmbedResponse, + TextPart, +) + + +class TestOllamaEmbedderEmbed(unittest.IsolatedAsyncioTestCase): + """Unit tests for OllamaEmbedder.embed method.""" + + async def asyncSetUp(self): + """Common setup.""" + self.mock_ollama_client_instance = AsyncMock() + self.mock_ollama_client_factory = MagicMock(return_value=self.mock_ollama_client_instance) + + self.mock_embedding_definition = EmbeddingDefinition(name='test-embed-model', dimensions=1536) + self.ollama_embedder = OllamaEmbedder( + client=self.mock_ollama_client_factory, embedding_definition=self.mock_embedding_definition + ) + + async def test_embed_single_document_single_content(self): + """Test embed with a single document containing single text content.""" + request = EmbedRequest( + input=[ + Document.from_text(text='hello world'), + ] + ) + expected_ollama_embeddings = [[0.1, 0.2, 0.3]] + self.mock_ollama_client_instance.embed.return_value = ollama_api.EmbedResponse( + embeddings=expected_ollama_embeddings + ) + + response = await self.ollama_embedder.embed(request) + + # Assertions + self.mock_ollama_client_instance.embed.assert_awaited_once_with( + model='test-embed-model', + input=['hello world'], + ) + expected_genkit_embeddings = [Embedding(embedding=[0.1, 0.2, 0.3])] + self.assertEqual(response, EmbedResponse(embeddings=expected_genkit_embeddings)) + + async def test_embed_multiple_documents_multiple_content(self): + """Test embed with multiple documents, each with multiple text contents.""" + request = EmbedRequest( + input=[ + Document(content=[TextPart(text='doc1_part1'), TextPart(text='doc1_part2')]), + Document(content=[TextPart(text='doc2_part1')]), + ] + ) + expected_ollama_embeddings = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]] + self.mock_ollama_client_instance.embed.return_value = ollama_api.EmbedResponse( + embeddings=expected_ollama_embeddings + ) + + response = await self.ollama_embedder.embed(request) + + # Assertions + self.mock_ollama_client_instance.embed.assert_awaited_once_with( + model='test-embed-model', + input=['doc1_part1', 'doc1_part2', 'doc2_part1'], + ) + expected_genkit_embeddings = [ + Embedding(embedding=[0.1, 0.2]), + Embedding(embedding=[0.3, 0.4]), + Embedding(embedding=[0.5, 0.6]), + ] + self.assertEqual(response, EmbedResponse(embeddings=expected_genkit_embeddings)) + + async def test_embed_empty_input(self): + """Test embed with an empty input request.""" + request = EmbedRequest(input=[]) + self.mock_ollama_client_instance.embed.return_value = ollama_api.EmbedResponse(embeddings=[]) + + response = await self.ollama_embedder.embed(request) + + # Assertions + self.mock_ollama_client_instance.embed.assert_awaited_once_with( + model='test-embed-model', + input=[], + ) + self.assertEqual(response, EmbedResponse(embeddings=[])) + + async def test_embed_api_raises_exception(self): + """Test embed method handles exception from client.embed.""" + request = EmbedRequest(input=[Document(content=[TextPart(text='error text')])]) + self.mock_ollama_client_instance.embed.side_effect = Exception('Ollama Embed API Error') + + with self.assertRaisesRegex(Exception, 'Ollama Embed API Error'): + await self.ollama_embedder.embed(request) + + self.mock_ollama_client_instance.embed.assert_awaited_once() + + async def test_embed_response_mismatch_input_count(self): + """Test embed when client returns fewer embeddings than input texts (edge case).""" + request = EmbedRequest( + input=[ + Document(content=[TextPart(text='text1')]), + Document(content=[TextPart(text='text2')]), + ] + ) + # Simulate Ollama returning only one embedding for two inputs + expected_ollama_embeddings = [[1.0, 2.0]] + self.mock_ollama_client_instance.embed.return_value = ollama_api.EmbedResponse( + embeddings=expected_ollama_embeddings + ) + + response = await self.ollama_embedder.embed(request) + + # The current implementation will just use whatever embeddings are returned. + # It's up to the caller or a higher layer to decide if this is an error. + # This test ensures it doesn't crash and correctly maps the available embeddings. + expected_genkit_embeddings = [Embedding(embedding=[1.0, 2.0])] + self.assertEqual(response, EmbedResponse(embeddings=expected_genkit_embeddings)) + self.assertEqual(len(response.embeddings), 1) # Confirm only one embedding was processed diff --git a/py/plugins/ollama/tests/models/test_models.py b/py/plugins/ollama/tests/models/test_models.py new file mode 100644 index 0000000000..af90b047f7 --- /dev/null +++ b/py/plugins/ollama/tests/models/test_models.py @@ -0,0 +1,621 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for Ollama models package.""" + +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +import ollama as ollama_api +import pytest + +from genkit.plugins.ollama.constants import OllamaAPITypes +from genkit.plugins.ollama.models import ModelDefinition, OllamaModel, _convert_parameters +from genkit.types import ( + ActionRunContext, + GenerateRequest, + GenerateResponseChunk, + GenerationUsage, + Message, + OutputConfig, + Role, + TextPart, +) + + +class TestOllamaModelGenerate(unittest.IsolatedAsyncioTestCase): + """Tests for Generate method of OllamaModel.""" + + async def asyncSetUp(self): + """Common setup for all async tests.""" + self.mock_client = MagicMock() + self.request = GenerateRequest(messages=[Message(role=Role.USER, content=[TextPart(text='Hello')])]) + self.ctx = ActionRunContext() + + @patch( + 'genkit.blocks.model.get_basic_usage_stats', + return_value=GenerationUsage( + input_tokens=10, + output_tokens=20, + total_tokens=30, + ), + ) + async def test_generate_chat_non_streaming(self, mock_get_basic_usage_stats): + """Test generate method with CHAT API type in non-streaming mode.""" + model_def = ModelDefinition( + name='chat-model', + api_type=OllamaAPITypes.CHAT, + ) + ollama_model = OllamaModel( + client=self.mock_client, + model_definition=model_def, + ) + + # Mock internal methods + mock_chat_response = ollama_api.ChatResponse( + message=ollama_api.Message( + role='', + content='Generated chat text', + ), + ) + ollama_model._chat_with_ollama = AsyncMock( + return_value=mock_chat_response, + ) + ollama_model._generate_ollama_response = AsyncMock() + ollama_model._build_multimodal_chat_response = MagicMock( + return_value=[TextPart(text='Parsed chat content')], + ) + ollama_model.get_usage_info = MagicMock( + return_value=GenerationUsage( + input_tokens=5, + output_tokens=10, + total_tokens=15, + ), + ) + ollama_model.is_streaming_request = MagicMock(return_value=False) + + response = await ollama_model.generate(self.request, self.ctx) + + # Assertions + ollama_model._chat_with_ollama.assert_awaited_once_with(request=self.request, ctx=self.ctx) + ollama_model._generate_ollama_response.assert_not_awaited() + ollama_model._build_multimodal_chat_response.assert_called_once_with(chat_response=mock_chat_response) + ollama_model.is_streaming_request.assert_called_once_with(ctx=self.ctx) + ollama_model.get_usage_info.assert_called_once() + + self.assertEqual(response.message.role, Role.MODEL) + self.assertEqual(len(response.message.content), 1) + self.assertEqual(response.message.content[0].root.text, 'Parsed chat content') + self.assertEqual(response.usage.input_tokens, 5) + self.assertEqual(response.usage.output_tokens, 10) + + @patch( + 'genkit.blocks.model.get_basic_usage_stats', + return_value=GenerationUsage( + input_tokens=10, + output_tokens=20, + total_tokens=30, + ), + ) + async def test_generate_generate_non_streaming(self, mock_get_basic_usage_stats): + """Test generate method with GENERATE API type in non-streaming mode.""" + model_def = ModelDefinition( + name='generate-model', + api_type=OllamaAPITypes.GENERATE, + ) + ollama_model = OllamaModel( + client=self.mock_client, + model_definition=model_def, + ) + + # Mock internal methods + mock_generate_response = ollama_api.GenerateResponse( + response='Generated text', + ) + ollama_model._generate_ollama_response = AsyncMock( + return_value=mock_generate_response, + ) + ollama_model._chat_with_ollama = AsyncMock() + ollama_model.is_streaming_request = MagicMock(return_value=False) + ollama_model.get_usage_info = MagicMock( + return_value=GenerationUsage( + input_tokens=7, + output_tokens=14, + total_tokens=21, + ), + ) + + response = await ollama_model.generate(self.request, self.ctx) + + # Assertions + ollama_model._generate_ollama_response.assert_awaited_once_with(request=self.request, ctx=self.ctx) + ollama_model._chat_with_ollama.assert_not_called() + ollama_model.is_streaming_request.assert_called_once_with(ctx=self.ctx) + ollama_model.get_usage_info.assert_called_once() + + self.assertEqual(response.message.role, Role.MODEL) + self.assertEqual(len(response.message.content), 1) + self.assertEqual(response.message.content[0].root.text, 'Generated text') + self.assertEqual(response.usage.input_tokens, 7) + self.assertEqual(response.usage.output_tokens, 14) + + @patch( + 'genkit.blocks.model.get_basic_usage_stats', + return_value=GenerationUsage(), + ) + async def test_generate_chat_streaming(self, mock_get_basic_usage_stats): + """Test generate method with CHAT API type in streaming mode.""" + model_def = ModelDefinition(name='chat-model', api_type=OllamaAPITypes.CHAT) + ollama_model = OllamaModel(client=self.mock_client, model_definition=model_def) + streaming_ctx = ActionRunContext(on_chunk=True) + + # Mock internal methods + mock_chat_response = ollama_api.ChatResponse( + message=ollama_api.Message( + role='', + content='Generated chat text', + ), + ) + ollama_model._chat_with_ollama = AsyncMock( + return_value=mock_chat_response, + ) + ollama_model._build_multimodal_chat_response = MagicMock( + return_value=[TextPart(text='Parsed chat content')], + ) + ollama_model.is_streaming_request = MagicMock(return_value=True) + ollama_model.get_usage_info = MagicMock( + return_value=GenerationUsage( + input_tokens=0, + output_tokens=0, + total_tokens=0, + ), + ) + + response = await ollama_model.generate(self.request, streaming_ctx) + + # Assertions for streaming behavior + ollama_model._chat_with_ollama.assert_awaited_once_with( + request=self.request, + ctx=streaming_ctx, + ) + ollama_model.is_streaming_request.assert_called_once_with( + ctx=streaming_ctx, + ) + self.assertEqual(response.message.content, []) + + @patch( + 'genkit.blocks.model.get_basic_usage_stats', + return_value=GenerationUsage(), + ) + async def test_generate_generate_streaming(self, mock_get_basic_usage_stats): + """Test generate method with GENERATE API type in streaming mode.""" + model_def = ModelDefinition( + name='generate-model', + api_type=OllamaAPITypes.GENERATE, + ) + ollama_model = OllamaModel(client=self.mock_client, model_definition=model_def) + streaming_ctx = ActionRunContext(on_chunk=True) + + # Mock internal methods + mock_generate_response = ollama_api.GenerateResponse( + response='Generated text', + ) + ollama_model._generate_ollama_response = AsyncMock( + return_value=mock_generate_response, + ) + ollama_model.is_streaming_request = MagicMock(return_value=True) + ollama_model.get_usage_info = MagicMock( + return_value=GenerationUsage( + input_tokens=0, + output_tokens=0, + total_tokens=0, + ), + ) + + response = await ollama_model.generate(self.request, streaming_ctx) + + # Assertions for streaming behavior + ollama_model._generate_ollama_response.assert_awaited_once_with( + request=self.request, + ctx=streaming_ctx, + ) + ollama_model.is_streaming_request.assert_called_once_with( + ctx=streaming_ctx, + ) + self.assertEqual(response.message.content, []) + + @patch( + 'genkit.blocks.model.get_basic_usage_stats', + return_value=GenerationUsage(), + ) + async def test_generate_chat_api_response_none(self, mock_get_basic_usage_stats): + """Test generate method when _chat_with_ollama returns None.""" + model_def = ModelDefinition(name='chat-model', api_type=OllamaAPITypes.CHAT) + ollama_model = OllamaModel(client=self.mock_client, model_definition=model_def) + + ollama_model._chat_with_ollama = AsyncMock(return_value=None) + ollama_model._build_multimodal_chat_response = MagicMock() + ollama_model.is_streaming_request = MagicMock(return_value=False) + ollama_model.get_usage_info = MagicMock(return_value=GenerationUsage()) + + response = await ollama_model.generate(self.request, self.ctx) + + ollama_model._chat_with_ollama.assert_awaited_once() + ollama_model._build_multimodal_chat_response.assert_not_called() + self.assertEqual(response.message.content[0].root.text, 'Failed to get response from Ollama API') + self.assertEqual(response.usage.input_tokens, None) + self.assertEqual(response.usage.output_tokens, None) + + @patch( + 'genkit.blocks.model.get_basic_usage_stats', + return_value=GenerationUsage(), + ) + async def test_generate_generate_api_response_none(self, mock_get_basic_usage_stats): + """Test generate method when _generate_ollama_response returns None.""" + model_def = ModelDefinition(name='generate-model', api_type=OllamaAPITypes.GENERATE) + ollama_model = OllamaModel(client=self.mock_client, model_definition=model_def) + + ollama_model._generate_ollama_response = AsyncMock(return_value=None) + ollama_model.is_streaming_request = MagicMock(return_value=False) + ollama_model.get_usage_info = MagicMock(return_value=GenerationUsage()) + + response = await ollama_model.generate(self.request, self.ctx) + + ollama_model._generate_ollama_response.assert_awaited_once() + self.assertEqual(response.message.content[0].root.text, 'Failed to get response from Ollama API') + self.assertEqual(response.usage.input_tokens, None) + self.assertEqual(response.usage.output_tokens, None) + + +class TestOllamaModelChatWithOllama(unittest.IsolatedAsyncioTestCase): + """Unit tests for OllamaModel._chat_with_ollama method.""" + + async def asyncSetUp(self): + """Common setup.""" + self.mock_ollama_client_instance = AsyncMock() + self.mock_ollama_client_factory = MagicMock(return_value=self.mock_ollama_client_instance) + self.model_definition = ModelDefinition(name='test-chat-model', api_type=OllamaAPITypes.CHAT) + self.ollama_model = OllamaModel(client=self.mock_ollama_client_factory, model_definition=self.model_definition) + self.request = GenerateRequest(messages=[Message(role=Role.USER, content=[TextPart(text='Hello')])]) + self.ctx = ActionRunContext(on_chunk=False) + self.ctx.send_chunk = MagicMock() + + # Mocking methods of ollama_model that are called + self.ollama_model.build_chat_messages = MagicMock(return_value=[{}]) + self.ollama_model.is_streaming_request = MagicMock( + return_value=False, + ) + self.ollama_model.build_request_options = MagicMock( + return_value={'temperature': 0.7}, + ) + self.ollama_model._build_multimodal_chat_response = MagicMock( + return_value=[TextPart(text='mocked content')], + ) + + self.mock_convert_parameters = MagicMock(return_value={'type': 'string'}) + + async def test_non_streaming_chat_success(self): + """Test _chat_with_ollama in non-streaming mode with successful response.""" + expected_response = ollama_api.ChatResponse( + message=ollama_api.Message( + role='', + content='Ollama non-stream response', + ), + ) + self.mock_ollama_client_instance.chat.return_value = expected_response + + response = await self.ollama_model._chat_with_ollama(self.request, self.ctx) + + self.assertIsNotNone(response) + self.assertEqual(response.message.content, 'Ollama non-stream response') + self.ollama_model.build_chat_messages.assert_called_once_with(self.request) + self.ollama_model.is_streaming_request.assert_called_once_with(ctx=self.ctx) + self.ctx.send_chunk.assert_not_called() + self.mock_ollama_client_instance.chat.assert_awaited_once_with( + model=self.model_definition.name, + messages=self.ollama_model.build_chat_messages.return_value, + tools=[], + options=self.ollama_model.build_request_options.return_value, + format='', + stream=False, + ) + + self.ollama_model._build_multimodal_chat_response.assert_not_called() + + async def test_streaming_chat_success(self): + """Test _chat_with_ollama in streaming mode with multiple chunks.""" + self.ollama_model.is_streaming_request.return_value = True + self.ctx.is_streaming = True + + # Simulate an async iterator of chunks + async def mock_streaming_chunks(): + yield ollama_api.ChatResponse( + message=ollama_api.Message( + role='', + content='chunk1', + ), + ) + yield ollama_api.ChatResponse( + message=ollama_api.Message( + role='', + content='chunk2', + ), + ) + + self.mock_ollama_client_instance.chat.return_value = mock_streaming_chunks() + + response = await self.ollama_model._chat_with_ollama(self.request, self.ctx) + + self.assertIsNotNone(response) + self.ollama_model.build_chat_messages.assert_called_once_with(self.request) + self.ollama_model.is_streaming_request.assert_called_once_with(ctx=self.ctx) + self.mock_ollama_client_instance.chat.assert_awaited_once_with( + model=self.model_definition.name, + messages=self.ollama_model.build_chat_messages.return_value, + tools=[], + options=self.ollama_model.build_request_options.return_value, + format='', + stream=True, + ) + self.assertEqual(self.ctx.send_chunk.call_count, 2) + self.assertEqual(self.ollama_model._build_multimodal_chat_response.call_count, 2) + self.ctx.send_chunk.assert_any_call(chunk=unittest.mock.ANY) + self.ollama_model._build_multimodal_chat_response.assert_any_call(chat_response=unittest.mock.ANY) + + async def test_chat_with_output_format_string(self): + """Test _chat_with_ollama with request.output.format string.""" + self.request.output = OutputConfig(format='json') + + expected_response = ollama_api.ChatResponse( + message=ollama_api.Message( + role='', + content='json output', + ), + ) + self.mock_ollama_client_instance.chat.return_value = expected_response + + await self.ollama_model._chat_with_ollama(self.request, self.ctx) + + call_args, call_kwargs = self.mock_ollama_client_instance.chat.call_args + self.assertIn('format', call_kwargs) + self.assertEqual(call_kwargs['format'], 'json') + + async def test_chat_with_output_format_schema(self): + """Test _chat_with_ollama with request.output.schema_ dictionary.""" + schema_dict = {'type': 'object', 'properties': {'name': {'type': 'string'}}} + self.request.output = OutputConfig(schema_=schema_dict) + + expected_response = ollama_api.ChatResponse( + message=ollama_api.Message( + role='', + content='schema output', + ), + ) + self.mock_ollama_client_instance.chat.return_value = expected_response + + await self.ollama_model._chat_with_ollama(self.request, self.ctx) + + call_args, call_kwargs = self.mock_ollama_client_instance.chat.call_args + self.assertIn('format', call_kwargs) + self.assertEqual(call_kwargs['format'], schema_dict) + + async def test_chat_with_no_output_format(self): + """Test _chat_with_ollama with no output format specified.""" + self.request.output = OutputConfig(format=None, schema_=None) + + expected_response = ollama_api.ChatResponse( + message=ollama_api.Message( + role='', + content='normal output', + ), + ) + self.mock_ollama_client_instance.chat.return_value = expected_response + + await self.ollama_model._chat_with_ollama(self.request, self.ctx) + + call_args, call_kwargs = self.mock_ollama_client_instance.chat.call_args + self.assertIn('format', call_kwargs) + self.assertEqual(call_kwargs['format'], '') + + async def test_chat_api_raises_exception(self): + """Test _chat_with_ollama handles exception from client.chat.""" + self.mock_ollama_client_instance.chat.side_effect = Exception('Ollama API Error') + + with self.assertRaisesRegex(Exception, 'Ollama API Error'): + await self.ollama_model._chat_with_ollama(self.request, self.ctx) + + self.mock_ollama_client_instance.chat.assert_awaited_once() + self.ctx.send_chunk.assert_not_called() + + +class TestOllamaModelGenerateOllamaResponse(unittest.IsolatedAsyncioTestCase): + """Unit tests for OllamaModel._generate_ollama_response.""" + + async def asyncSetUp(self): + """Common setup.""" + self.mock_ollama_client_instance = AsyncMock() + self.mock_ollama_client_factory = MagicMock(return_value=self.mock_ollama_client_instance) + + self.model_definition = ModelDefinition(name='test-generate-model', api_type=OllamaAPITypes.GENERATE) + self.ollama_model = OllamaModel(client=self.mock_ollama_client_factory, model_definition=self.model_definition) + self.request = GenerateRequest( + messages=[ + Message( + role=Role.USER, + content=[TextPart(text='Test generate message')], + ) + ], + config={'temperature': 0.8}, + ) + self.ctx = ActionRunContext(on_chunk=False) + self.ctx.send_chunk = MagicMock() + + # Patch internal methods of OllamaModel that _generate_ollama_response calls + self.ollama_model.build_prompt = MagicMock(return_value='Mocked prompt from build_prompt') + self.ollama_model.is_streaming_request = MagicMock(return_value=False) + self.ollama_model.build_request_options = MagicMock(return_value={'temperature': 0.8}) + + async def test_non_streaming_generate_success(self): + """Test _generate_ollama_response in non-streaming mode with successful response.""" + expected_response = ollama_api.GenerateResponse(response='Full generated text') + self.mock_ollama_client_instance.generate.return_value = expected_response + + response = await self.ollama_model._generate_ollama_response(self.request, self.ctx) + + self.assertIsNotNone(response) + self.assertEqual(response.response, 'Full generated text') + + self.ollama_model.build_prompt.assert_called_once_with(self.request) + self.ollama_model.is_streaming_request.assert_called_once_with(ctx=self.ctx) + self.ollama_model.build_request_options.assert_called_once_with(config=self.request.config) + self.mock_ollama_client_instance.generate.assert_awaited_once_with( + model=self.model_definition.name, + prompt=self.ollama_model.build_prompt.return_value, + options=self.ollama_model.build_request_options.return_value, + stream=False, + ) + self.ctx.send_chunk.assert_not_called() + + async def test_streaming_generate_success(self): + """Test _generate_ollama_response in streaming mode with multiple chunks.""" + self.ollama_model.is_streaming_request.return_value = True + + # Simulate an async iterator of chunks + async def mock_streaming_chunks(): + yield ollama_api.GenerateResponse(response='chunk1 ') + yield ollama_api.GenerateResponse(response='chunk2') + + self.mock_ollama_client_instance.generate.return_value = mock_streaming_chunks() + + response = await self.ollama_model._generate_ollama_response(self.request, self.ctx) + + self.assertIsNotNone(response) + + self.ollama_model.build_prompt.assert_called_once_with(self.request) + self.ollama_model.is_streaming_request.assert_called_once_with(ctx=self.ctx) + self.mock_ollama_client_instance.generate.assert_awaited_once_with( + model=self.model_definition.name, + prompt=self.ollama_model.build_prompt.return_value, + options=self.ollama_model.build_request_options.return_value, + stream=True, + ) + self.assertEqual(self.ctx.send_chunk.call_count, 2) + self.ctx.send_chunk.assert_any_call( + chunk=GenerateResponseChunk(role=Role.MODEL, index=1, content=[TextPart(text='chunk1 ')]) + ) + self.ctx.send_chunk.assert_any_call( + chunk=GenerateResponseChunk(role=Role.MODEL, index=2, content=[TextPart(text='chunk2')]) + ) + + async def test_generate_api_raises_exception(self): + """Test _generate_ollama_response handles exception from client.generate.""" + self.mock_ollama_client_instance.generate.side_effect = Exception('Ollama generate API Error') + + with self.assertRaisesRegex(Exception, 'Ollama generate API Error'): + await self.ollama_model._generate_ollama_response(self.request, self.ctx) + + self.mock_ollama_client_instance.generate.assert_awaited_once() + self.ctx.send_chunk.assert_not_called() + + +@pytest.mark.parametrize( + 'input_schema, expected_output', + [ + ({}, None), + ({'properties': {'name': {'type': 'string'}}}, None), + ( + {'type': 'object'}, + ollama_api.Tool.Function.Parameters(type='object', properties={}), + ), + ( + { + 'type': 'object', + 'properties': { + 'name': {'type': 'string', 'description': 'User name'}, + 'age': {'type': 'integer', 'description': 'User age'}, + }, + 'required': ['name'], + }, + ollama_api.Tool.Function.Parameters( + type='object', + required=['name'], + properties={ + 'name': ollama_api.Tool.Function.Parameters.Property(type='string', description='User name'), + 'age': ollama_api.Tool.Function.Parameters.Property(type='integer', description='User age'), + }, + ), + ), + ( + { + 'type': 'object', + 'properties': { + 'city': {'type': 'string', 'description': 'City name'}, + }, + }, + ollama_api.Tool.Function.Parameters( + type='object', + required=None, + properties={ + 'city': ollama_api.Tool.Function.Parameters.Property( + type='string', + description='City name', + ), + }, + ), + ), + ( + { + 'type': 'object', + 'properties': {}, + }, + ollama_api.Tool.Function.Parameters( + type='object', + required=None, + properties={}, + ), + ), + # Test 8: Object schema with nested properties + ( + { + 'type': 'object', + 'properties': { + 'address': {'type': 'object', 'properties': {'street': {'type': 'string'}}}, + 'zip': {'type': 'string'}, + }, + }, + ollama_api.Tool.Function.Parameters( + type='object', + required=None, + properties={ + 'address': ollama_api.Tool.Function.Parameters.Property(type='object', description=''), + 'zip': ollama_api.Tool.Function.Parameters.Property(type='string', description=''), + }, + ), + ), + ( + {'type': 'object', 'description': 'A general description'}, + ollama_api.Tool.Function.Parameters( + type='object', + required=None, + properties={}, + ), + ), + ], +) +def test_convert_parameters(input_schema, expected_output): + """Unit Tests for _convert_parameters function with various input schemas.""" + result = _convert_parameters(input_schema) + assert result == expected_output diff --git a/py/plugins/ollama/tests/test_plugin_api.py b/py/plugins/ollama/tests/test_plugin_api.py index 5d20c0e83e..de2fd9d42a 100644 --- a/py/plugins/ollama/tests/test_plugin_api.py +++ b/py/plugins/ollama/tests/test_plugin_api.py @@ -25,7 +25,6 @@ from genkit.ai import ActionKind, Genkit from genkit.plugins.ollama import Ollama, ollama_name -from genkit.plugins.ollama.constants import DEFAULT_OLLAMA_SERVER_URL from genkit.plugins.ollama.embedders import EmbeddingDefinition from genkit.plugins.ollama.models import ModelDefinition from genkit.types import GenerationCommonConfig diff --git a/py/samples/ollama-hello/README.md b/py/samples/ollama-hello/README.md index de2f117ebc..5e08179571 100644 --- a/py/samples/ollama-hello/README.md +++ b/py/samples/ollama-hello/README.md @@ -1,18 +1,31 @@ # Hello Ollama -## NOTE +## Prerequisites -Before running the sample make sure to install the model and start ollama -serving. In case of questions, please refer to `./py/plugins/ollama/README.md` +- **Ollama** - a local AI model server, which is used to handle embeddings and generate responses. -## Installation +### Step 1: Install Ollama + +1. Go to the [Ollama website](https://ollama.com/download) to download and install Ollama for your operating system. +2. Once installed, start the Ollama server by running: + +```bash +ollama serve +``` + +The server will run at http://localhost:11434 by default. + +### Step 2: Pull the Required Models + +In this example, we use two models with Ollama. +Run the following commands in your terminal to pull these models: ```bash ollama pull mistral-nemo:latest ollama pull gemma3:latest ``` -## Execute "Hello World" Sample +### Step 3: Execute Sample ```bash genkit start -- uv run src/ollama_hello.py diff --git a/py/samples/ollama-simple-embed/README.md b/py/samples/ollama-simple-embed/README.md index 06764678f0..ecaa0b3b2a 100644 --- a/py/samples/ollama-simple-embed/README.md +++ b/py/samples/ollama-simple-embed/README.md @@ -1,11 +1,34 @@ # Ollama Simple Embed Sample -## NOTE +## Prerequisites -Before running the sample make sure to install the model and start ollama serving. -In case of questions, please refer to `./py/plugins/ollama/README.md` +- **Ollama** - a local AI model server, which is used to handle embeddings and generate responses. -## Execute "Ollama Embed" Sample +### Step 1: Install Ollama + +1. Go to the [Ollama website](https://ollama.com/download) to download and install Ollama for your operating system. +2. Once installed, start the Ollama server by running: + +```bash +ollama serve +``` + +The server will run at http://localhost:11434 by default. + +### Step 2: Pull the Required Models + +In this example, we use two models with Ollama: + +An embedding model (nomic-embed-text) +A generation model (phi3.5:latest) +Run the following commands in your terminal to pull these models: + +```bash +ollama pull nomic-embed-text +ollama pull phi3.5:latest +``` + +### Step 3: Execute Sample ```bash genkit start -- uv run src/pokemon_glossary.py