Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions py/plugins/ollama/src/genkit/plugins/ollama/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,7 @@


class OllamaAPITypes(StrEnum):
"""Generation types for Ollama API."""

CHAT = 'chat'
GENERATE = 'generate'
40 changes: 38 additions & 2 deletions py/plugins/ollama/src/genkit/plugins/ollama/embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,57 @@


class EmbeddingDefinition(BaseModel):
"""Defines an embedding model for Ollama.

This class specifies the characteristics of an embedding model that
can be used with the Ollama plugin. While Ollama models have fixed
output dimensions, this definition can specify the expected
dimensionality for informational purposes or for future truncation
support.
"""

name: str
# Ollama do not support changing dimensionality, but it can be truncated
dimensions: int | None = None


class OllamaEmbedder:
"""Handles embedding requests using an Ollama embedding model.

This class provides the necessary logic to interact with a specific
Ollama embedding model, processing input text into vector embeddings.
"""

def __init__(
self,
client: Callable,
embedding_definition: EmbeddingDefinition,
):
) -> None:
"""Initializes the OllamaEmbedder.

Sets up the client for communicating with the Ollama server and stores
the definition of the embedding model.

Args:
client: A callable that returns an asynchronous Ollama client instance.
embedding_definition: The definition describing the specific Ollama
embedding model to be used.
"""
self.client = client()
self.embedding_definition = embedding_definition

async def embed(self, request: EmbedRequest) -> EmbedResponse:
"""Generates embeddings for the provided input text.

Converts the input documents from the Genkit EmbedRequest into a raw
list of strings, sends them to the Ollama server for embedding, and then
formats the response into a Genkit EmbedResponse.

Args:
request: The embedding request containing the input documents.

Returns:
An EmbedResponse containing the generated vector embeddings.
"""
input_raw = []
for doc in request.input:
input_raw.extend([content.root.text for content in doc.content])
Expand Down
47 changes: 41 additions & 6 deletions py/plugins/ollama/src/genkit/plugins/ollama/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
#
# SPDX-License-Identifier: Apache-2.0

"""Models package for Ollama plugin."""

from collections.abc import Callable
from typing import Any, Literal

import structlog
from pydantic import BaseModel, Field, HttpUrl
from pydantic import BaseModel

import ollama as ollama_api
from genkit.ai import ActionRunContext
Expand Down Expand Up @@ -47,17 +49,37 @@


class OllamaSupports(BaseModel):
"""Supports for Ollama models."""

tools: bool = False


class ModelDefinition(BaseModel):
"""Meta definition for Ollama models."""

name: str
api_type: OllamaAPITypes = 'chat'
supports: OllamaSupports = OllamaSupports()


class OllamaModel:
def __init__(self, client: Callable, model_definition: ModelDefinition):
"""Represents an Ollama language model for use with Genkit.

This class encapsulates the interaction logic for a specific Ollama model,
allowing it to be integrated into the Genkit framework for generative tasks.
"""

def __init__(self, client: Callable, model_definition: ModelDefinition) -> None:
"""Initializes the OllamaModel.

Sets up the client for communicating with the Ollama server and stores
the definition of the model.

Args:
client: A callable that returns an asynchronous Ollama client instance.
model_definition: The definition describing the specific Ollama model
to be used (e.g., its name, API type, supported features).
"""
self.client = client()
self.model_definition = model_definition

Expand Down Expand Up @@ -110,7 +132,6 @@ async def generate(self, request: GenerateRequest, ctx: ActionRunContext | None
),
)

# FIXME: Missing return statement.
async def _chat_with_ollama(
self, request: GenerateRequest, ctx: ActionRunContext | None = None
) -> ollama_api.ChatResponse | None:
Expand Down Expand Up @@ -167,6 +188,7 @@ async def _chat_with_ollama(
content=self._build_multimodal_chat_response(chat_response=chunk),
)
)
return chat_response
else:
return chat_response

Expand All @@ -182,7 +204,6 @@ async def _generate_ollama_response(
Returns:
The generated response.
"""
# FIXME: Missing return statement.
prompt = self.build_prompt(request)
streaming_request = self.is_streaming_request(ctx=ctx)

Expand Down Expand Up @@ -341,13 +362,27 @@ def _to_ollama_role(

@staticmethod
def is_streaming_request(ctx: ActionRunContext | None) -> bool:
"""Determines if streaming mode is requested."""
return ctx and ctx.is_streaming

@staticmethod
def get_usage_info(
basic_generation_usage: GenerationUsage,
api_response: ollama_api.GenerateResponse | ollama_api.ChatResponse,
) -> GenerationUsage:
"""Extracts and calculates token usage information from an Ollama API response.

Updates a basic generation usage object with input, output, and total token counts
based on the details provided in the Ollama API response.

Args:
basic_generation_usage: An existing GenerationUsage object to update.
api_response: The response object received from the Ollama API,
containing token count details.

Returns:
The updated GenerationUsage object with token counts populated.
"""
if api_response:
basic_generation_usage.input_tokens = api_response.prompt_eval_count or 0
basic_generation_usage.output_tokens = api_response.eval_count or 0
Expand All @@ -372,10 +407,10 @@ def _convert_parameters(input_schema: dict[str, Any]) -> ollama_api.Tool.Functio

if schema_type == 'object':
schema.properties = {}
properties = input_schema['properties']
properties = input_schema.get('properties', [])
for key in properties:
schema.properties[key] = ollama_api.Tool.Function.Parameters.Property(
type=properties[key]['type'], description=properties[key]['description'] or ''
type=properties[key]['type'], description=properties[key].get('description', '')
)

return schema
141 changes: 141 additions & 0 deletions py/plugins/ollama/tests/models/test_embedders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

"""Unit tests for Ollama embedders package."""

import unittest
from unittest.mock import AsyncMock, MagicMock

import ollama as ollama_api

from genkit.plugins.ollama.embedders import EmbeddingDefinition, OllamaEmbedder
from genkit.types import (
Document,
Embedding,
EmbedRequest,
EmbedResponse,
TextPart,
)


class TestOllamaEmbedderEmbed(unittest.IsolatedAsyncioTestCase):
"""Unit tests for OllamaEmbedder.embed method."""

async def asyncSetUp(self):
"""Common setup."""
self.mock_ollama_client_instance = AsyncMock()
self.mock_ollama_client_factory = MagicMock(return_value=self.mock_ollama_client_instance)

self.mock_embedding_definition = EmbeddingDefinition(name='test-embed-model', dimensions=1536)
self.ollama_embedder = OllamaEmbedder(
client=self.mock_ollama_client_factory, embedding_definition=self.mock_embedding_definition
)

async def test_embed_single_document_single_content(self):
"""Test embed with a single document containing single text content."""
request = EmbedRequest(
input=[
Document.from_text(text='hello world'),
]
)
expected_ollama_embeddings = [[0.1, 0.2, 0.3]]
self.mock_ollama_client_instance.embed.return_value = ollama_api.EmbedResponse(
embeddings=expected_ollama_embeddings
)

response = await self.ollama_embedder.embed(request)

# Assertions
self.mock_ollama_client_instance.embed.assert_awaited_once_with(
model='test-embed-model',
input=['hello world'],
)
expected_genkit_embeddings = [Embedding(embedding=[0.1, 0.2, 0.3])]
self.assertEqual(response, EmbedResponse(embeddings=expected_genkit_embeddings))

async def test_embed_multiple_documents_multiple_content(self):
"""Test embed with multiple documents, each with multiple text contents."""
request = EmbedRequest(
input=[
Document(content=[TextPart(text='doc1_part1'), TextPart(text='doc1_part2')]),
Document(content=[TextPart(text='doc2_part1')]),
]
)
expected_ollama_embeddings = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
self.mock_ollama_client_instance.embed.return_value = ollama_api.EmbedResponse(
embeddings=expected_ollama_embeddings
)

response = await self.ollama_embedder.embed(request)

# Assertions
self.mock_ollama_client_instance.embed.assert_awaited_once_with(
model='test-embed-model',
input=['doc1_part1', 'doc1_part2', 'doc2_part1'],
)
expected_genkit_embeddings = [
Embedding(embedding=[0.1, 0.2]),
Embedding(embedding=[0.3, 0.4]),
Embedding(embedding=[0.5, 0.6]),
]
self.assertEqual(response, EmbedResponse(embeddings=expected_genkit_embeddings))

async def test_embed_empty_input(self):
"""Test embed with an empty input request."""
request = EmbedRequest(input=[])
self.mock_ollama_client_instance.embed.return_value = ollama_api.EmbedResponse(embeddings=[])

response = await self.ollama_embedder.embed(request)

# Assertions
self.mock_ollama_client_instance.embed.assert_awaited_once_with(
model='test-embed-model',
input=[],
)
self.assertEqual(response, EmbedResponse(embeddings=[]))

async def test_embed_api_raises_exception(self):
"""Test embed method handles exception from client.embed."""
request = EmbedRequest(input=[Document(content=[TextPart(text='error text')])])
self.mock_ollama_client_instance.embed.side_effect = Exception('Ollama Embed API Error')

with self.assertRaisesRegex(Exception, 'Ollama Embed API Error'):
await self.ollama_embedder.embed(request)

self.mock_ollama_client_instance.embed.assert_awaited_once()

async def test_embed_response_mismatch_input_count(self):
"""Test embed when client returns fewer embeddings than input texts (edge case)."""
request = EmbedRequest(
input=[
Document(content=[TextPart(text='text1')]),
Document(content=[TextPart(text='text2')]),
]
)
# Simulate Ollama returning only one embedding for two inputs
expected_ollama_embeddings = [[1.0, 2.0]]
self.mock_ollama_client_instance.embed.return_value = ollama_api.EmbedResponse(
embeddings=expected_ollama_embeddings
)

response = await self.ollama_embedder.embed(request)

# The current implementation will just use whatever embeddings are returned.
# It's up to the caller or a higher layer to decide if this is an error.
# This test ensures it doesn't crash and correctly maps the available embeddings.
expected_genkit_embeddings = [Embedding(embedding=[1.0, 2.0])]
self.assertEqual(response, EmbedResponse(embeddings=expected_genkit_embeddings))
self.assertEqual(len(response.embeddings), 1) # Confirm only one embedding was processed
Loading
Loading