Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 52 additions & 39 deletions py/packages/genkit/src/genkit/ai/_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class while customizing it with any plugins.
ModelMiddleware,
)
from genkit.blocks.prompt import PromptConfig, to_generate_action_options
from genkit.blocks.retriever import IndexerRef, IndexerRequest, RetrieverRef
from genkit.core.action import ActionRunContext
from genkit.core.action.types import ActionKind
from genkit.core.typing import EmbedRequest, EmbedResponse
Expand Down Expand Up @@ -294,63 +295,75 @@ def generate_stream(

return stream, stream.closed

async def embed(
async def retrieve(
self,
embedder: str | EmbedderRef | None = None,
documents: list[Document] | None = None,
retriever: str | RetrieverRef | None = None,
query: str | DocumentData | None = None,
options: dict[str, Any] | None = None,
) -> EmbedResponse:
embedder_name: str
embedder_config: dict[str, Any] = {}
"""Calculates embeddings for documents.
) -> RetrieverResponse:
"""Retrieves documents based on query.

Args:
embedder: Optional embedder model name to use.
documents: Texts to embed.
options: embedding options
retriever: Optional retriever name or reference to use.
query: Text query or a DocumentData containing query text.
options: retriever options

Returns:
The generated response with embeddings.
The generated response with documents.
"""
if isinstance(embedder, EmbedderRef):
embedder_name = embedder.name
embedder_config = embedder.config or {}
if embedder.version:
embedder_config['version'] = embedder.version # Handle version from ref
elif isinstance(embedder, str):
embedder_name = embedder
retriever_name: str
retriever_config: dict[str, Any] = {}

if isinstance(retriever, RetrieverRef):
retriever_name = retriever.name
retriever_config = retriever.config or {}
if retriever.version:
retriever_config['version'] = retriever.version
elif isinstance(retriever, str):
retriever_name = retriever
else:
# Handle case where embedder is None
raise ValueError('Embedder must be specified as a string name or an EmbedderRef.')
raise ValueError('Retriever must be specified as a string name or a RetrieverRef.')

# Merge options passed to embed() with config from EmbedderRef
final_options = {**(embedder_config or {}), **(options or {})}
embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, embedder_name)
if isinstance(query, str):
query = Document.from_text(query)

return (await embed_action.arun(EmbedRequest(input=documents, options=final_options))).response
final_options = {**(retriever_config or {}), **(options or {})}

async def retrieve(
retrieve_action = self.registry.lookup_action(ActionKind.RETRIEVER, retriever_name)

return (await retrieve_action.arun(RetrieverRequest(query=query, options=final_options))).response

async def index(
self,
retriever: str | None = None,
query: str | DocumentData | None = None,
indexer: str | IndexerRef | None = None,
documents: list[Document] | None = None,
options: dict[str, Any] | None = None,
) -> RetrieverResponse:
"""Retrieves documents based on query.
) -> None:
"""Indexes documents.

Args:
retriever: Optional retriever name to use.
query: Text query or a DocumentData containing query text.
options: retriever options

Returns:
The generated response with embeddings.
indexer: Optional indexer name or reference to use.
documents: Documents to index.
options: indexer options
"""
if isinstance(query, str):
query = Document.from_text(query)
indexer_name: str
indexer_config: dict[str, Any] = {}

if isinstance(indexer, IndexerRef):
indexer_name = indexer.name
indexer_config = indexer.config or {}
if indexer.version:
indexer_config['version'] = indexer.version
elif isinstance(indexer, str):
indexer_name = indexer
else:
raise ValueError('Indexer must be specified as a string name or an IndexerRef.')

final_options = {**(indexer_config or {}), **(options or {})}

retrieve_action = self.registry.lookup_action(ActionKind.RETRIEVER, retriever)
index_action = self.registry.lookup_action(ActionKind.INDEXER, indexer_name)

return (await retrieve_action.arun(RetrieverRequest(query=query, options=options))).response
await index_action.arun(IndexerRequest(documents=documents, options=final_options))

async def embed(
self,
Expand Down
36 changes: 35 additions & 1 deletion py/packages/genkit/src/genkit/ai/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from genkit.blocks.formats.types import FormatDef
from genkit.blocks.model import ModelFn, ModelMiddleware
from genkit.blocks.prompt import define_prompt
from genkit.blocks.retriever import RetrieverFn
from genkit.blocks.retriever import IndexerFn, RetrieverFn
from genkit.blocks.tools import ToolRunContext
from genkit.codec import dump_dict
from genkit.core.action import Action
Expand Down Expand Up @@ -278,6 +278,40 @@ def define_retriever(
description=retriever_description,
)

def define_indexer(
self,
name: str,
fn: IndexerFn,
config_schema: BaseModel | dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
description: str | None = None,
) -> Callable[[Callable], Callable]:
"""Define an indexer action.

Args:
name: Name of the indexer.
fn: Function implementing the indexer behavior.
config_schema: Optional schema for indexer configuration.
metadata: Optional metadata for the indexer.
description: Optional description for the indexer.
"""
indexer_meta = metadata if metadata else {}
if 'indexer' not in indexer_meta:
indexer_meta['indexer'] = {}
if 'label' not in indexer_meta['indexer'] or not indexer_meta['indexer']['label']:
indexer_meta['indexer']['label'] = name
if config_schema:
indexer_meta['indexer']['customOptions'] = to_json_schema(config_schema)

indexer_description = get_func_description(fn, description)
return self.registry.register_action(
name=name,
kind=ActionKind.INDEXER,
fn=fn,
metadata=indexer_meta,
description=indexer_description,
)

def define_evaluator(
self,
name: str,
Expand Down
207 changes: 205 additions & 2 deletions py/packages/genkit/src/genkit/blocks/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,15 @@
"""

from collections.abc import Callable
from typing import Generic, TypeVar
from typing import Any, Generic, TypeVar

from pydantic import BaseModel, ConfigDict, Field

from genkit.blocks.document import Document
from genkit.core.typing import RetrieverResponse
from genkit.core.action import ActionMetadata
from genkit.core.action.types import ActionKind
from genkit.core.schema import to_json_schema
from genkit.core.typing import DocumentData, RetrieverResponse

T = TypeVar('T')
# type RetrieverFn[T] = Callable[[Document, T], RetrieverResponse]
Expand All @@ -39,3 +44,201 @@ def __init__(
retriever_fn: RetrieverFn[T],
):
self.retriever_fn = retriever_fn


class RetrieverRequest(BaseModel):
model_config = ConfigDict(extra='forbid', populate_by_name=True)

query: DocumentData
options: Any | None = None


class RetrieverSupports(BaseModel):
"""Retriever capability support."""

model_config = ConfigDict(extra='forbid', populate_by_name=True)

media: bool | None = None


class RetrieverInfo(BaseModel):
model_config = ConfigDict(extra='forbid', populate_by_name=True)

label: str | None = None
supports: RetrieverSupports | None = None


class RetrieverOptions(BaseModel):
"""Configuration options for a retriever."""

model_config = ConfigDict(extra='forbid', populate_by_name=True)

config_schema: dict[str, Any] | None = Field(None, alias='configSchema')
label: str | None = None
supports: RetrieverSupports | None = None


class RetrieverRef(BaseModel):
"""Reference to a retriever with configuration."""

model_config = ConfigDict(extra='forbid', populate_by_name=True)

name: str
config: Any | None = None
version: str | None = None
info: RetrieverInfo | None = None


def retriever_action_metadata(
name: str,
options: RetrieverOptions | None = None,
) -> ActionMetadata:
"""Creates action metadata for a retriever."""
options = options if options is not None else RetrieverOptions()
retriever_metadata_dict = {'retriever': {}}

if options.label:
retriever_metadata_dict['retriever']['label'] = options.label

if options.supports:
retriever_metadata_dict['retriever']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True)

retriever_metadata_dict['retriever']['customOptions'] = options.config_schema if options.config_schema else None

return ActionMetadata(
kind=ActionKind.RETRIEVER,
name=name,
input_json_schema=to_json_schema(RetrieverRequest),
output_json_schema=to_json_schema(RetrieverResponse),
metadata=retriever_metadata_dict,
)


def create_retriever_ref(
name: str,
config: dict[str, Any] | None = None,
version: str | None = None,
info: RetrieverInfo | None = None,
) -> RetrieverRef:
"""Creates a RetrieverRef instance."""
return RetrieverRef(name=name, config=config, version=version, info=info)


class IndexerRequest(BaseModel):
model_config = ConfigDict(extra='forbid', populate_by_name=True)

documents: list[DocumentData]
options: Any | None = None


class IndexerInfo(BaseModel):
model_config = ConfigDict(extra='forbid', populate_by_name=True)

label: str | None = None
supports: RetrieverSupports | None = None


class IndexerOptions(BaseModel):
model_config = ConfigDict(extra='forbid', populate_by_name=True)

config_schema: dict[str, Any] | None = Field(None, alias='configSchema')
label: str | None = None
supports: RetrieverSupports | None = None


class IndexerRef(BaseModel):
"""Reference to an indexer with configuration."""

model_config = ConfigDict(extra='forbid', populate_by_name=True)

name: str
config: Any | None = None
version: str | None = None
info: IndexerInfo | None = None


def indexer_action_metadata(
name: str,
options: IndexerOptions | None = None,
) -> ActionMetadata:
"""Creates action metadata for an indexer."""
options = options if options is not None else IndexerOptions()
indexer_metadata_dict = {'indexer': {}}

if options.label:
indexer_metadata_dict['indexer']['label'] = options.label

if options.supports:
indexer_metadata_dict['indexer']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True)

indexer_metadata_dict['indexer']['customOptions'] = options.config_schema if options.config_schema else None

return ActionMetadata(
kind=ActionKind.INDEXER,
name=name,
input_json_schema=to_json_schema(IndexerRequest),
output_json_schema=to_json_schema(None),
metadata=indexer_metadata_dict,
)


def create_indexer_ref(
name: str,
config: dict[str, Any] | None = None,
version: str | None = None,
info: IndexerInfo | None = None,
) -> IndexerRef:
"""Creates a IndexerRef instance."""
return IndexerRef(name=name, config=config, version=version, info=info)


def define_retriever(
registry: Any,
name: str,
fn: RetrieverFn,
options: RetrieverOptions | None = None,
) -> None:
"""Defines and registers a retriever action."""
metadata = retriever_action_metadata(name, options)

async def wrapper(
request: RetrieverRequest,
ctx: Any,
) -> RetrieverResponse:
return await fn(request.query, request.options)

registry.register_action(
kind=ActionKind.RETRIEVER,
name=name,
fn=wrapper,
metadata=metadata.metadata,
span_metadata=metadata.metadata,
)


IndexerFn = Callable[[list[Document], T], None]


def define_indexer(
registry: Any,
name: str,
fn: IndexerFn,
options: IndexerOptions | None = None,
) -> None:
"""Defines and registers an indexer action."""
metadata = indexer_action_metadata(name, options)

async def wrapper(
request: IndexerRequest,
ctx: Any,
) -> None:
docs = [Document.from_data(d) for d in request.documents]
await fn(docs, request.options)

registry.register_action(
kind=ActionKind.INDEXER,
name=name,
fn=wrapper,
metadata=metadata.metadata,
span_metadata=metadata.metadata,
)
Loading
Loading