Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 120 additions & 16 deletions src/fed_rag/base/knowledge_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypedDict, Union

from pydantic import BaseModel, ConfigDict, Field, PrivateAttr

Expand All @@ -14,6 +14,19 @@
DEFAULT_KNOWLEDGE_STORE_NAME = "default"


class MultiModalEmbedding(TypedDict, total=False):
"""Type definition for multimodal embeddings supporting different modalities."""

text: list[float] | None
image: list[float] | None
audio: list[float] | None
video: list[float] | None


# Union type for backward compatibility
QueryEmbedding = Union[list[float], MultiModalEmbedding]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh, I kind of like how you did this better. Subtle difference, but I think cleaner than what I got.



class BaseKnowledgeStore(BaseModel, ABC):
"""Base Knowledge Store Class.

Expand Down Expand Up @@ -47,12 +60,29 @@ def load_nodes(self, nodes: list["KnowledgeNode"]) -> None:

@abstractmethod
def retrieve(
self, query_emb: list[float], top_k: int
self, query_emb: QueryEmbedding, top_k: int
) -> list[tuple[float, "KnowledgeNode"]]:
"""Retrieve top-k nodes from KnowledgeStore against a provided user query.

Args:
query_emb (list[float]): the query represented as an encoded vector.
query_emb (QueryEmbedding): the query represented as an encoded vector
or multimodal embedding dictionary.
top_k (int): the number of knowledge nodes to retrieve.

Returns:
A list of tuples where the first element represents the similarity score
of the node to the query, and the second element is the node itself.
"""

@abstractmethod
def retrieve_by_modality(
self, modality: str, query_emb: list[float], top_k: int
) -> list[tuple[float, "KnowledgeNode"]]:
"""Retrieve top-k nodes from a specific modality collection.

Args:
modality (str): The modality to search in ("text", "image", "audio", "video").
query_emb (list[float]): the query embedding for this modality.
top_k (int): the number of knowledge nodes to retrieve.

Returns:
Expand All @@ -62,12 +92,12 @@ def retrieve(

@abstractmethod
def batch_retrieve(
self, query_embs: list[list[float]], top_k: int
self, query_embs: list[QueryEmbedding], top_k: int
) -> list[list[tuple[float, "KnowledgeNode"]]]:
"""Batch retrieve top-k nodes from KnowledgeStore against provided user queries.

Args:
query_embs (list[list[float]]): the list of encoded queries.
query_embs (list[QueryEmbedding]): the list of encoded queries.
top_k (int): the number of knowledge nodes to retrieve.

Returns:
Expand Down Expand Up @@ -103,6 +133,32 @@ def persist(self) -> None:
def load(self) -> None:
"""Load the KnowledgeStore nodes from a permanent storage using `name`."""

# Helper methods for multimodal support
def _is_multimodal_embedding(self, query_emb: QueryEmbedding) -> bool:
"""Check if the query embedding is multimodal."""
return isinstance(query_emb, dict)

def _extract_text_embedding(
self, query_emb: QueryEmbedding
) -> list[float]:
"""Extract text embedding for backward compatibility."""
if isinstance(query_emb, list):
return query_emb
text_emb = query_emb.get("text")
return text_emb if text_emb is not None else []

def _get_modality_embeddings(
self, query_emb: QueryEmbedding
) -> dict[str, list[float]]:
"""Get all available modality embeddings."""
if isinstance(query_emb, list):
return {"text": query_emb}
result: dict[str, list[float]] = {}
for k, v in query_emb.items():
if v is not None:
result[k] = v # type: ignore[assignment]
return result


class BaseAsyncKnowledgeStore(BaseModel, ABC):
"""Base Asynchronous Knowledge Store Class."""
Expand Down Expand Up @@ -131,12 +187,29 @@ async def load_nodes(self, nodes: list["KnowledgeNode"]) -> None:

@abstractmethod
async def retrieve(
self, query_emb: list[float], top_k: int
self, query_emb: QueryEmbedding, top_k: int
) -> list[tuple[float, "KnowledgeNode"]]:
"""Asynchronously retrieve top-k nodes from KnowledgeStore against a provided user query.

Args:
query_emb (list[float]): the query represented as an encoded vector.
query_emb (QueryEmbedding): the query represented as an encoded vector
or multimodal embedding dictionary.
top_k (int): the number of knowledge nodes to retrieve.

Returns:
A list of tuples where the first element represents the similarity score
of the node to the query, and the second element is the node itself.
"""

@abstractmethod
async def retrieve_by_modality(
self, modality: str, query_emb: list[float], top_k: int
) -> list[tuple[float, "KnowledgeNode"]]:
"""Asynchronously retrieve top-k nodes from a specific modality collection.

Args:
modality (str): The modality to search in ("text", "image", "audio", "video").
query_emb (list[float]): the query embedding for this modality.
top_k (int): the number of knowledge nodes to retrieve.

Returns:
Expand All @@ -146,12 +219,12 @@ async def retrieve(

@abstractmethod
async def batch_retrieve(
self, query_embs: list[list[float]], top_k: int
self, query_embs: list[QueryEmbedding], top_k: int
) -> list[list[tuple[float, "KnowledgeNode"]]]:
"""Asynchronously batch retrieve top-k nodes from KnowledgeStore against provided user queries.

Args:
query_embs (list[list[float]]): the list of encoded queries.
query_embs (list[QueryEmbedding]): the list of encoded queries.
top_k (int): the number of knowledge nodes to retrieve.

Returns:
Expand Down Expand Up @@ -186,6 +259,32 @@ def persist(self) -> None:
def load(self) -> None:
"""Load the KnowledgeStore nodes from a permanent storage using `name`."""

# Helper methods for multimodal support
def _is_multimodal_embedding(self, query_emb: QueryEmbedding) -> bool:
"""Check if the query embedding is multimodal."""
return isinstance(query_emb, dict)

def _extract_text_embedding(
self, query_emb: QueryEmbedding
) -> list[float]:
"""Extract text embedding for backward compatibility."""
if isinstance(query_emb, list):
return query_emb
text_emb = query_emb.get("text")
return text_emb if text_emb is not None else []

def _get_modality_embeddings(
self, query_emb: QueryEmbedding
) -> dict[str, list[float]]:
"""Get all available modality embeddings."""
if isinstance(query_emb, list):
return {"text": query_emb}
result: dict[str, list[float]] = {}
for k, v in query_emb.items():
if v is not None:
result[k] = v # type: ignore[assignment]
return result

class _SyncConvertedKnowledgeStore(BaseKnowledgeStore):
"""A nested class for converting this store to a sync version."""

Expand Down Expand Up @@ -220,13 +319,19 @@ def load_nodes(self, nodes: list["KnowledgeNode"]) -> None:
asyncio_run(self._async_ks.load_nodes(nodes))

def retrieve(
self, query_emb: list[float], top_k: int
self, query_emb: QueryEmbedding, top_k: int
) -> list[tuple[float, "KnowledgeNode"]]:
"""Implements retrieve."""
return asyncio_run(self._async_ks.retrieve(query_emb=query_emb, top_k=top_k)) # type: ignore [no-any-return]

def retrieve_by_modality(
self, modality: str, query_emb: list[float], top_k: int
) -> list[tuple[float, "KnowledgeNode"]]:
"""Implements retrieve_by_modality."""
return asyncio_run(self._async_ks.retrieve_by_modality(modality=modality, query_emb=query_emb, top_k=top_k)) # type: ignore [no-any-return]

def batch_retrieve(
self, query_embs: list[list[float]], top_k: int
self, query_embs: list[QueryEmbedding], top_k: int
) -> list[list[tuple[float, "KnowledgeNode"]]]:
"""Implements batch_retrieve."""
return asyncio_run(self._async_ks.batch_retrieve(query_embs=query_embs, top_k=top_k)) # type: ignore [no-any-return]
Expand All @@ -241,7 +346,7 @@ def clear(self) -> None:

@property
def count(self) -> int:
"""Returns the number of nodes in the knowledge store."""
"""Implements count."""
return self._async_ks.count

def persist(self) -> None:
Expand All @@ -252,7 +357,6 @@ def load(self) -> None:
"""Implements load."""
self._async_ks.load()

def to_sync(self) -> BaseKnowledgeStore:
"""Convert this async knowledge store to a sync version."""

return BaseAsyncKnowledgeStore._SyncConvertedKnowledgeStore(self)
def to_sync(self) -> "BaseKnowledgeStore":
"""Convert this async knowledge store to a synchronous version."""
return self._SyncConvertedKnowledgeStore(self)
Loading
Loading