Skip to content

Commit f502ca4

Browse files
AbeJLazaroAbraham Lazaro Martinez
andauthored
fix(py): type and tests coverage Ollama plugin (#3011)
Co-authored-by: Abraham Lazaro Martinez <[email protected]>
1 parent a05d65b commit f502ca4

File tree

8 files changed

+888
-18
lines changed

8 files changed

+888
-18
lines changed

py/plugins/ollama/src/genkit/plugins/ollama/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,7 @@
2525

2626

2727
class OllamaAPITypes(StrEnum):
28+
"""Generation types for Ollama API."""
29+
2830
CHAT = 'chat'
2931
GENERATE = 'generate'

py/plugins/ollama/src/genkit/plugins/ollama/embedders.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,57 @@
2323

2424

2525
class EmbeddingDefinition(BaseModel):
26+
"""Defines an embedding model for Ollama.
27+
28+
This class specifies the characteristics of an embedding model that
29+
can be used with the Ollama plugin. While Ollama models have fixed
30+
output dimensions, this definition can specify the expected
31+
dimensionality for informational purposes or for future truncation
32+
support.
33+
"""
34+
2635
name: str
27-
# Ollama do not support changing dimensionality, but it can be truncated
2836
dimensions: int | None = None
2937

3038

3139
class OllamaEmbedder:
40+
"""Handles embedding requests using an Ollama embedding model.
41+
42+
This class provides the necessary logic to interact with a specific
43+
Ollama embedding model, processing input text into vector embeddings.
44+
"""
45+
3246
def __init__(
3347
self,
3448
client: Callable,
3549
embedding_definition: EmbeddingDefinition,
36-
):
50+
) -> None:
51+
"""Initializes the OllamaEmbedder.
52+
53+
Sets up the client for communicating with the Ollama server and stores
54+
the definition of the embedding model.
55+
56+
Args:
57+
client: A callable that returns an asynchronous Ollama client instance.
58+
embedding_definition: The definition describing the specific Ollama
59+
embedding model to be used.
60+
"""
3761
self.client = client()
3862
self.embedding_definition = embedding_definition
3963

4064
async def embed(self, request: EmbedRequest) -> EmbedResponse:
65+
"""Generates embeddings for the provided input text.
66+
67+
Converts the input documents from the Genkit EmbedRequest into a raw
68+
list of strings, sends them to the Ollama server for embedding, and then
69+
formats the response into a Genkit EmbedResponse.
70+
71+
Args:
72+
request: The embedding request containing the input documents.
73+
74+
Returns:
75+
An EmbedResponse containing the generated vector embeddings.
76+
"""
4177
input_raw = []
4278
for doc in request.input:
4379
input_raw.extend([content.root.text for content in doc.content])

py/plugins/ollama/src/genkit/plugins/ollama/models.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
#
1515
# SPDX-License-Identifier: Apache-2.0
1616

17+
"""Models package for Ollama plugin."""
18+
1719
from collections.abc import Callable
1820
from typing import Any, Literal
1921

2022
import structlog
21-
from pydantic import BaseModel, Field, HttpUrl
23+
from pydantic import BaseModel
2224

2325
import ollama as ollama_api
2426
from genkit.ai import ActionRunContext
@@ -47,17 +49,37 @@
4749

4850

4951
class OllamaSupports(BaseModel):
52+
"""Supports for Ollama models."""
53+
5054
tools: bool = False
5155

5256

5357
class ModelDefinition(BaseModel):
58+
"""Meta definition for Ollama models."""
59+
5460
name: str
5561
api_type: OllamaAPITypes = 'chat'
5662
supports: OllamaSupports = OllamaSupports()
5763

5864

5965
class OllamaModel:
60-
def __init__(self, client: Callable, model_definition: ModelDefinition):
66+
"""Represents an Ollama language model for use with Genkit.
67+
68+
This class encapsulates the interaction logic for a specific Ollama model,
69+
allowing it to be integrated into the Genkit framework for generative tasks.
70+
"""
71+
72+
def __init__(self, client: Callable, model_definition: ModelDefinition) -> None:
73+
"""Initializes the OllamaModel.
74+
75+
Sets up the client for communicating with the Ollama server and stores
76+
the definition of the model.
77+
78+
Args:
79+
client: A callable that returns an asynchronous Ollama client instance.
80+
model_definition: The definition describing the specific Ollama model
81+
to be used (e.g., its name, API type, supported features).
82+
"""
6183
self.client = client()
6284
self.model_definition = model_definition
6385

@@ -110,7 +132,6 @@ async def generate(self, request: GenerateRequest, ctx: ActionRunContext | None
110132
),
111133
)
112134

113-
# FIXME: Missing return statement.
114135
async def _chat_with_ollama(
115136
self, request: GenerateRequest, ctx: ActionRunContext | None = None
116137
) -> ollama_api.ChatResponse | None:
@@ -167,6 +188,7 @@ async def _chat_with_ollama(
167188
content=self._build_multimodal_chat_response(chat_response=chunk),
168189
)
169190
)
191+
return chat_response
170192
else:
171193
return chat_response
172194

@@ -182,7 +204,6 @@ async def _generate_ollama_response(
182204
Returns:
183205
The generated response.
184206
"""
185-
# FIXME: Missing return statement.
186207
prompt = self.build_prompt(request)
187208
streaming_request = self.is_streaming_request(ctx=ctx)
188209

@@ -341,13 +362,27 @@ def _to_ollama_role(
341362

342363
@staticmethod
343364
def is_streaming_request(ctx: ActionRunContext | None) -> bool:
365+
"""Determines if streaming mode is requested."""
344366
return ctx and ctx.is_streaming
345367

346368
@staticmethod
347369
def get_usage_info(
348370
basic_generation_usage: GenerationUsage,
349371
api_response: ollama_api.GenerateResponse | ollama_api.ChatResponse,
350372
) -> GenerationUsage:
373+
"""Extracts and calculates token usage information from an Ollama API response.
374+
375+
Updates a basic generation usage object with input, output, and total token counts
376+
based on the details provided in the Ollama API response.
377+
378+
Args:
379+
basic_generation_usage: An existing GenerationUsage object to update.
380+
api_response: The response object received from the Ollama API,
381+
containing token count details.
382+
383+
Returns:
384+
The updated GenerationUsage object with token counts populated.
385+
"""
351386
if api_response:
352387
basic_generation_usage.input_tokens = api_response.prompt_eval_count or 0
353388
basic_generation_usage.output_tokens = api_response.eval_count or 0
@@ -372,10 +407,10 @@ def _convert_parameters(input_schema: dict[str, Any]) -> ollama_api.Tool.Functio
372407

373408
if schema_type == 'object':
374409
schema.properties = {}
375-
properties = input_schema['properties']
410+
properties = input_schema.get('properties', [])
376411
for key in properties:
377412
schema.properties[key] = ollama_api.Tool.Function.Parameters.Property(
378-
type=properties[key]['type'], description=properties[key]['description'] or ''
413+
type=properties[key]['type'], description=properties[key].get('description', '')
379414
)
380415

381416
return schema
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# SPDX-License-Identifier: Apache-2.0
16+
17+
"""Unit tests for Ollama embedders package."""
18+
19+
import unittest
20+
from unittest.mock import AsyncMock, MagicMock
21+
22+
import ollama as ollama_api
23+
24+
from genkit.plugins.ollama.embedders import EmbeddingDefinition, OllamaEmbedder
25+
from genkit.types import (
26+
Document,
27+
Embedding,
28+
EmbedRequest,
29+
EmbedResponse,
30+
TextPart,
31+
)
32+
33+
34+
class TestOllamaEmbedderEmbed(unittest.IsolatedAsyncioTestCase):
35+
"""Unit tests for OllamaEmbedder.embed method."""
36+
37+
async def asyncSetUp(self):
38+
"""Common setup."""
39+
self.mock_ollama_client_instance = AsyncMock()
40+
self.mock_ollama_client_factory = MagicMock(return_value=self.mock_ollama_client_instance)
41+
42+
self.mock_embedding_definition = EmbeddingDefinition(name='test-embed-model', dimensions=1536)
43+
self.ollama_embedder = OllamaEmbedder(
44+
client=self.mock_ollama_client_factory, embedding_definition=self.mock_embedding_definition
45+
)
46+
47+
async def test_embed_single_document_single_content(self):
48+
"""Test embed with a single document containing single text content."""
49+
request = EmbedRequest(
50+
input=[
51+
Document.from_text(text='hello world'),
52+
]
53+
)
54+
expected_ollama_embeddings = [[0.1, 0.2, 0.3]]
55+
self.mock_ollama_client_instance.embed.return_value = ollama_api.EmbedResponse(
56+
embeddings=expected_ollama_embeddings
57+
)
58+
59+
response = await self.ollama_embedder.embed(request)
60+
61+
# Assertions
62+
self.mock_ollama_client_instance.embed.assert_awaited_once_with(
63+
model='test-embed-model',
64+
input=['hello world'],
65+
)
66+
expected_genkit_embeddings = [Embedding(embedding=[0.1, 0.2, 0.3])]
67+
self.assertEqual(response, EmbedResponse(embeddings=expected_genkit_embeddings))
68+
69+
async def test_embed_multiple_documents_multiple_content(self):
70+
"""Test embed with multiple documents, each with multiple text contents."""
71+
request = EmbedRequest(
72+
input=[
73+
Document(content=[TextPart(text='doc1_part1'), TextPart(text='doc1_part2')]),
74+
Document(content=[TextPart(text='doc2_part1')]),
75+
]
76+
)
77+
expected_ollama_embeddings = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
78+
self.mock_ollama_client_instance.embed.return_value = ollama_api.EmbedResponse(
79+
embeddings=expected_ollama_embeddings
80+
)
81+
82+
response = await self.ollama_embedder.embed(request)
83+
84+
# Assertions
85+
self.mock_ollama_client_instance.embed.assert_awaited_once_with(
86+
model='test-embed-model',
87+
input=['doc1_part1', 'doc1_part2', 'doc2_part1'],
88+
)
89+
expected_genkit_embeddings = [
90+
Embedding(embedding=[0.1, 0.2]),
91+
Embedding(embedding=[0.3, 0.4]),
92+
Embedding(embedding=[0.5, 0.6]),
93+
]
94+
self.assertEqual(response, EmbedResponse(embeddings=expected_genkit_embeddings))
95+
96+
async def test_embed_empty_input(self):
97+
"""Test embed with an empty input request."""
98+
request = EmbedRequest(input=[])
99+
self.mock_ollama_client_instance.embed.return_value = ollama_api.EmbedResponse(embeddings=[])
100+
101+
response = await self.ollama_embedder.embed(request)
102+
103+
# Assertions
104+
self.mock_ollama_client_instance.embed.assert_awaited_once_with(
105+
model='test-embed-model',
106+
input=[],
107+
)
108+
self.assertEqual(response, EmbedResponse(embeddings=[]))
109+
110+
async def test_embed_api_raises_exception(self):
111+
"""Test embed method handles exception from client.embed."""
112+
request = EmbedRequest(input=[Document(content=[TextPart(text='error text')])])
113+
self.mock_ollama_client_instance.embed.side_effect = Exception('Ollama Embed API Error')
114+
115+
with self.assertRaisesRegex(Exception, 'Ollama Embed API Error'):
116+
await self.ollama_embedder.embed(request)
117+
118+
self.mock_ollama_client_instance.embed.assert_awaited_once()
119+
120+
async def test_embed_response_mismatch_input_count(self):
121+
"""Test embed when client returns fewer embeddings than input texts (edge case)."""
122+
request = EmbedRequest(
123+
input=[
124+
Document(content=[TextPart(text='text1')]),
125+
Document(content=[TextPart(text='text2')]),
126+
]
127+
)
128+
# Simulate Ollama returning only one embedding for two inputs
129+
expected_ollama_embeddings = [[1.0, 2.0]]
130+
self.mock_ollama_client_instance.embed.return_value = ollama_api.EmbedResponse(
131+
embeddings=expected_ollama_embeddings
132+
)
133+
134+
response = await self.ollama_embedder.embed(request)
135+
136+
# The current implementation will just use whatever embeddings are returned.
137+
# It's up to the caller or a higher layer to decide if this is an error.
138+
# This test ensures it doesn't crash and correctly maps the available embeddings.
139+
expected_genkit_embeddings = [Embedding(embedding=[1.0, 2.0])]
140+
self.assertEqual(response, EmbedResponse(embeddings=expected_genkit_embeddings))
141+
self.assertEqual(len(response.embeddings), 1) # Confirm only one embedding was processed

0 commit comments

Comments
 (0)