diff --git a/py/packages/genkit/src/genkit/ai/_aio.py b/py/packages/genkit/src/genkit/ai/_aio.py index 5f3c9e920a..ed27e8236a 100644 --- a/py/packages/genkit/src/genkit/ai/_aio.py +++ b/py/packages/genkit/src/genkit/ai/_aio.py @@ -37,6 +37,7 @@ class while customizing it with any plugins. ModelMiddleware, ) from genkit.blocks.prompt import PromptConfig, to_generate_action_options +from genkit.blocks.retriever import IndexerRef, IndexerRequest, RetrieverRef from genkit.core.action import ActionRunContext from genkit.core.action.types import ActionKind from genkit.core.typing import EmbedRequest, EmbedResponse @@ -294,63 +295,75 @@ def generate_stream( return stream, stream.closed - async def embed( + async def retrieve( self, - embedder: str | EmbedderRef | None = None, - documents: list[Document] | None = None, + retriever: str | RetrieverRef | None = None, + query: str | DocumentData | None = None, options: dict[str, Any] | None = None, - ) -> EmbedResponse: - embedder_name: str - embedder_config: dict[str, Any] = {} - """Calculates embeddings for documents. + ) -> RetrieverResponse: + """Retrieves documents based on query. Args: - embedder: Optional embedder model name to use. - documents: Texts to embed. - options: embedding options + retriever: Optional retriever name or reference to use. + query: Text query or a DocumentData containing query text. + options: retriever options Returns: - The generated response with embeddings. + The generated response with documents. """ - if isinstance(embedder, EmbedderRef): - embedder_name = embedder.name - embedder_config = embedder.config or {} - if embedder.version: - embedder_config['version'] = embedder.version # Handle version from ref - elif isinstance(embedder, str): - embedder_name = embedder + retriever_name: str + retriever_config: dict[str, Any] = {} + + if isinstance(retriever, RetrieverRef): + retriever_name = retriever.name + retriever_config = retriever.config or {} + if retriever.version: + retriever_config['version'] = retriever.version + elif isinstance(retriever, str): + retriever_name = retriever else: - # Handle case where embedder is None - raise ValueError('Embedder must be specified as a string name or an EmbedderRef.') + raise ValueError('Retriever must be specified as a string name or a RetrieverRef.') - # Merge options passed to embed() with config from EmbedderRef - final_options = {**(embedder_config or {}), **(options or {})} - embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, embedder_name) + if isinstance(query, str): + query = Document.from_text(query) - return (await embed_action.arun(EmbedRequest(input=documents, options=final_options))).response + final_options = {**(retriever_config or {}), **(options or {})} - async def retrieve( + retrieve_action = self.registry.lookup_action(ActionKind.RETRIEVER, retriever_name) + + return (await retrieve_action.arun(RetrieverRequest(query=query, options=final_options))).response + + async def index( self, - retriever: str | None = None, - query: str | DocumentData | None = None, + indexer: str | IndexerRef | None = None, + documents: list[Document] | None = None, options: dict[str, Any] | None = None, - ) -> RetrieverResponse: - """Retrieves documents based on query. + ) -> None: + """Indexes documents. Args: - retriever: Optional retriever name to use. - query: Text query or a DocumentData containing query text. - options: retriever options - - Returns: - The generated response with embeddings. + indexer: Optional indexer name or reference to use. + documents: Documents to index. + options: indexer options """ - if isinstance(query, str): - query = Document.from_text(query) + indexer_name: str + indexer_config: dict[str, Any] = {} + + if isinstance(indexer, IndexerRef): + indexer_name = indexer.name + indexer_config = indexer.config or {} + if indexer.version: + indexer_config['version'] = indexer.version + elif isinstance(indexer, str): + indexer_name = indexer + else: + raise ValueError('Indexer must be specified as a string name or an IndexerRef.') + + final_options = {**(indexer_config or {}), **(options or {})} - retrieve_action = self.registry.lookup_action(ActionKind.RETRIEVER, retriever) + index_action = self.registry.lookup_action(ActionKind.INDEXER, indexer_name) - return (await retrieve_action.arun(RetrieverRequest(query=query, options=options))).response + await index_action.arun(IndexerRequest(documents=documents, options=final_options)) async def embed( self, diff --git a/py/packages/genkit/src/genkit/ai/_registry.py b/py/packages/genkit/src/genkit/ai/_registry.py index a055bad30c..9295cd33a6 100644 --- a/py/packages/genkit/src/genkit/ai/_registry.py +++ b/py/packages/genkit/src/genkit/ai/_registry.py @@ -52,7 +52,7 @@ from genkit.blocks.formats.types import FormatDef from genkit.blocks.model import ModelFn, ModelMiddleware from genkit.blocks.prompt import define_prompt -from genkit.blocks.retriever import RetrieverFn +from genkit.blocks.retriever import IndexerFn, RetrieverFn from genkit.blocks.tools import ToolRunContext from genkit.codec import dump_dict from genkit.core.action import Action @@ -278,6 +278,40 @@ def define_retriever( description=retriever_description, ) + def define_indexer( + self, + name: str, + fn: IndexerFn, + config_schema: BaseModel | dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + description: str | None = None, + ) -> Callable[[Callable], Callable]: + """Define an indexer action. + + Args: + name: Name of the indexer. + fn: Function implementing the indexer behavior. + config_schema: Optional schema for indexer configuration. + metadata: Optional metadata for the indexer. + description: Optional description for the indexer. + """ + indexer_meta = metadata if metadata else {} + if 'indexer' not in indexer_meta: + indexer_meta['indexer'] = {} + if 'label' not in indexer_meta['indexer'] or not indexer_meta['indexer']['label']: + indexer_meta['indexer']['label'] = name + if config_schema: + indexer_meta['indexer']['customOptions'] = to_json_schema(config_schema) + + indexer_description = get_func_description(fn, description) + return self.registry.register_action( + name=name, + kind=ActionKind.INDEXER, + fn=fn, + metadata=indexer_meta, + description=indexer_description, + ) + def define_evaluator( self, name: str, diff --git a/py/packages/genkit/src/genkit/blocks/retriever.py b/py/packages/genkit/src/genkit/blocks/retriever.py index d7738c3229..6564bc92ff 100644 --- a/py/packages/genkit/src/genkit/blocks/retriever.py +++ b/py/packages/genkit/src/genkit/blocks/retriever.py @@ -23,10 +23,15 @@ """ from collections.abc import Callable -from typing import Generic, TypeVar +from typing import Any, Generic, TypeVar + +from pydantic import BaseModel, ConfigDict, Field from genkit.blocks.document import Document -from genkit.core.typing import RetrieverResponse +from genkit.core.action import ActionMetadata +from genkit.core.action.types import ActionKind +from genkit.core.schema import to_json_schema +from genkit.core.typing import DocumentData, RetrieverResponse T = TypeVar('T') # type RetrieverFn[T] = Callable[[Document, T], RetrieverResponse] @@ -39,3 +44,201 @@ def __init__( retriever_fn: RetrieverFn[T], ): self.retriever_fn = retriever_fn + + +class RetrieverRequest(BaseModel): + model_config = ConfigDict(extra='forbid', populate_by_name=True) + + query: DocumentData + options: Any | None = None + + +class RetrieverSupports(BaseModel): + """Retriever capability support.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + + media: bool | None = None + + +class RetrieverInfo(BaseModel): + model_config = ConfigDict(extra='forbid', populate_by_name=True) + + label: str | None = None + supports: RetrieverSupports | None = None + + +class RetrieverOptions(BaseModel): + """Configuration options for a retriever.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + + config_schema: dict[str, Any] | None = Field(None, alias='configSchema') + label: str | None = None + supports: RetrieverSupports | None = None + + +class RetrieverRef(BaseModel): + """Reference to a retriever with configuration.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + + name: str + config: Any | None = None + version: str | None = None + info: RetrieverInfo | None = None + + +def retriever_action_metadata( + name: str, + options: RetrieverOptions | None = None, +) -> ActionMetadata: + """Creates action metadata for a retriever.""" + options = options if options is not None else RetrieverOptions() + retriever_metadata_dict = {'retriever': {}} + + if options.label: + retriever_metadata_dict['retriever']['label'] = options.label + + if options.supports: + retriever_metadata_dict['retriever']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True) + + retriever_metadata_dict['retriever']['customOptions'] = options.config_schema if options.config_schema else None + + return ActionMetadata( + kind=ActionKind.RETRIEVER, + name=name, + input_json_schema=to_json_schema(RetrieverRequest), + output_json_schema=to_json_schema(RetrieverResponse), + metadata=retriever_metadata_dict, + ) + + +def create_retriever_ref( + name: str, + config: dict[str, Any] | None = None, + version: str | None = None, + info: RetrieverInfo | None = None, +) -> RetrieverRef: + """Creates a RetrieverRef instance.""" + return RetrieverRef(name=name, config=config, version=version, info=info) + + +class IndexerRequest(BaseModel): + model_config = ConfigDict(extra='forbid', populate_by_name=True) + + documents: list[DocumentData] + options: Any | None = None + + +class IndexerInfo(BaseModel): + model_config = ConfigDict(extra='forbid', populate_by_name=True) + + label: str | None = None + supports: RetrieverSupports | None = None + + +class IndexerOptions(BaseModel): + model_config = ConfigDict(extra='forbid', populate_by_name=True) + + config_schema: dict[str, Any] | None = Field(None, alias='configSchema') + label: str | None = None + supports: RetrieverSupports | None = None + + +class IndexerRef(BaseModel): + """Reference to an indexer with configuration.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + + name: str + config: Any | None = None + version: str | None = None + info: IndexerInfo | None = None + + +def indexer_action_metadata( + name: str, + options: IndexerOptions | None = None, +) -> ActionMetadata: + """Creates action metadata for an indexer.""" + options = options if options is not None else IndexerOptions() + indexer_metadata_dict = {'indexer': {}} + + if options.label: + indexer_metadata_dict['indexer']['label'] = options.label + + if options.supports: + indexer_metadata_dict['indexer']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True) + + indexer_metadata_dict['indexer']['customOptions'] = options.config_schema if options.config_schema else None + + return ActionMetadata( + kind=ActionKind.INDEXER, + name=name, + input_json_schema=to_json_schema(IndexerRequest), + output_json_schema=to_json_schema(None), + metadata=indexer_metadata_dict, + ) + + +def create_indexer_ref( + name: str, + config: dict[str, Any] | None = None, + version: str | None = None, + info: IndexerInfo | None = None, +) -> IndexerRef: + """Creates a IndexerRef instance.""" + return IndexerRef(name=name, config=config, version=version, info=info) + + +def define_retriever( + registry: Any, + name: str, + fn: RetrieverFn, + options: RetrieverOptions | None = None, +) -> None: + """Defines and registers a retriever action.""" + metadata = retriever_action_metadata(name, options) + + async def wrapper( + request: RetrieverRequest, + ctx: Any, + ) -> RetrieverResponse: + return await fn(request.query, request.options) + + registry.register_action( + kind=ActionKind.RETRIEVER, + name=name, + fn=wrapper, + metadata=metadata.metadata, + span_metadata=metadata.metadata, + ) + + +IndexerFn = Callable[[list[Document], T], None] + + +def define_indexer( + registry: Any, + name: str, + fn: IndexerFn, + options: IndexerOptions | None = None, +) -> None: + """Defines and registers an indexer action.""" + metadata = indexer_action_metadata(name, options) + + async def wrapper( + request: IndexerRequest, + ctx: Any, + ) -> None: + docs = [Document.from_data(d) for d in request.documents] + await fn(docs, request.options) + + registry.register_action( + kind=ActionKind.INDEXER, + name=name, + fn=wrapper, + metadata=metadata.metadata, + span_metadata=metadata.metadata, + ) diff --git a/py/packages/genkit/tests/genkit/blocks/retriever_test.py b/py/packages/genkit/tests/genkit/blocks/retriever_test.py new file mode 100644 index 0000000000..e8edc3d511 --- /dev/null +++ b/py/packages/genkit/tests/genkit/blocks/retriever_test.py @@ -0,0 +1,180 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from pydantic import BaseModel + +from genkit.blocks.retriever import ( + IndexerOptions, + RetrieverOptions, + RetrieverResponse, + RetrieverSupports, + create_indexer_ref, + create_retriever_ref, + define_indexer, + define_retriever, + indexer_action_metadata, + retriever_action_metadata, +) +from genkit.core.action import ActionMetadata +from genkit.core.schema import to_json_schema + + +def test_retriever_action_metadata(): + """Test for retriever_action_metadata with basic options.""" + options = RetrieverOptions(label='Test Retriever') + action_metadata = retriever_action_metadata( + name='test_retriever', + options=options, + ) + + assert isinstance(action_metadata, ActionMetadata) + assert action_metadata.input_json_schema is not None + assert action_metadata.output_json_schema is not None + assert action_metadata.metadata == { + 'retriever': { + 'label': options.label, + 'customOptions': None, + } + } + + +def test_retriever_action_metadata_with_supports_and_config_schema(): + """Test for retriever_action_metadata with supports and config_schema.""" + + class CustomConfig(BaseModel): + k: int + + options = RetrieverOptions( + label='Advanced Retriever', + supports=RetrieverSupports(media=True), + config_schema=to_json_schema(CustomConfig), + ) + action_metadata = retriever_action_metadata( + name='advanced_retriever', + options=options, + ) + assert isinstance(action_metadata, ActionMetadata) + assert action_metadata.metadata['retriever']['label'] == 'Advanced Retriever' + assert action_metadata.metadata['retriever']['supports'] == { + 'media': True, + } + assert action_metadata.metadata['retriever']['customOptions'] == { + 'title': 'CustomConfig', + 'type': 'object', + 'properties': { + 'k': {'title': 'K', 'type': 'integer'}, + }, + 'required': ['k'], + } + + +def test_retriever_action_metadata_no_options(): + """Test retriever_action_metadata when no options are provided.""" + action_metadata = retriever_action_metadata(name='default_retriever') + assert isinstance(action_metadata, ActionMetadata) + assert action_metadata.metadata == {'retriever': {'customOptions': None}} + + +def test_create_retriever_ref_basic(): + """Test basic creation of RetrieverRef.""" + ref = create_retriever_ref('my-retriever') + assert ref.name == 'my-retriever' + assert ref.config is None + assert ref.version is None + + +def test_create_retriever_ref_with_config(): + """Test creation of RetrieverRef with configuration.""" + config = {'k': 5} + ref = create_retriever_ref('configured-retriever', config=config) + assert ref.name == 'configured-retriever' + assert ref.config == config + assert ref.version is None + + +def test_create_retriever_ref_with_version(): + """Test creation of RetrieverRef with a version.""" + ref = create_retriever_ref('versioned-retriever', version='v1.0') + assert ref.name == 'versioned-retriever' + assert ref.config is None + assert ref.version == 'v1.0' + + +def test_create_retriever_ref_with_config_and_version(): + """Test creation of RetrieverRef with both config and version.""" + config = {'k': 10} + ref = create_retriever_ref('full-retriever', config=config, version='beta') + assert ref.name == 'full-retriever' + assert ref.config == config + assert ref.version == 'beta' + + +@pytest.mark.asyncio +async def test_define_retriever(): + """Test define_retriever registration.""" + registry = MagicMock() + fn = AsyncMock(return_value=RetrieverResponse(documents=[])) + + define_retriever(registry, 'test_retriever', fn) + + registry.register_action.assert_called_once() + call_args = registry.register_action.call_args + assert call_args.kwargs['kind'] == 'retriever' + assert call_args.kwargs['name'] == 'test_retriever' + + +@pytest.mark.asyncio +async def test_define_indexer(): + """Test define_indexer registration.""" + registry = MagicMock() + fn = AsyncMock() + + define_indexer(registry, 'test_indexer', fn) + + registry.register_action.assert_called_once() + call_args = registry.register_action.call_args + assert call_args.kwargs['kind'] == 'indexer' + assert call_args.kwargs['name'] == 'test_indexer' + + +def test_indexer_action_metadata(): + """Test for indexer_action_metadata with basic options.""" + options = IndexerOptions(label='Test Indexer') + action_metadata = indexer_action_metadata( + name='test_indexer', + options=options, + ) + + assert isinstance(action_metadata, ActionMetadata) + assert action_metadata.input_json_schema is not None + assert action_metadata.output_json_schema is not None + assert action_metadata.metadata == { + 'indexer': { + 'label': options.label, + 'customOptions': None, + } + } + + +def test_create_indexer_ref_basic(): + """Test basic creation of IndexerRef.""" + ref = create_indexer_ref('my-indexer') + assert ref.name == 'my-indexer' + assert ref.config is None + assert ref.version is None diff --git a/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/plugin_api.py b/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/plugin_api.py index 770c4f23b7..21cea58314 100644 --- a/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/plugin_api.py +++ b/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/plugin_api.py @@ -38,7 +38,6 @@ class DevLocalVectorStore(Plugin): """ name = 'devLocalVectorstore' - _indexers: dict[str, DevLocalVectorStoreIndexer] = {} def __init__(self, name: str, embedder: str, embedder_options: dict[str, Any] | None = None): self.index_name = name @@ -100,27 +99,7 @@ def _configure_dev_local_indexer(self, ai: GenkitRegistry) -> Action: embedder_options=self.embedder_options, ) - DevLocalVectorStore._indexers[self.index_name] = indexer - - @classmethod - async def index(cls, index_name: str, documents: Docs) -> None: - """Lookups the Local Vector Store indexer for provided index name. - - If matching indexer found - invokes indexing for provided documents - - Args: - index_name: name of the indexer to look up - documents: list of documents to index - - Returns: - None - - Raises: - KeyError: if index name is not found among registered indexers. - """ - matching_indexer = cls._indexers.get(index_name) - if not matching_indexer: - raise KeyError( - f'Failed to find indexer matching name: {index_name}!r\nRegistered indexers: {cls._indexers.keys()}' - ) - return await matching_indexer.index(documents) + return ai.define_indexer( + name=self.index_name, + fn=indexer.index, + )