diff --git a/src/fed_rag/base/knowledge_store.py b/src/fed_rag/base/knowledge_store.py index 622542d2..2271769c 100644 --- a/src/fed_rag/base/knowledge_store.py +++ b/src/fed_rag/base/knowledge_store.py @@ -2,7 +2,7 @@ import asyncio from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypedDict, Union from pydantic import BaseModel, ConfigDict, Field, PrivateAttr @@ -14,6 +14,19 @@ DEFAULT_KNOWLEDGE_STORE_NAME = "default" +class MultiModalEmbedding(TypedDict, total=False): + """Type definition for multimodal embeddings supporting different modalities.""" + + text: list[float] | None + image: list[float] | None + audio: list[float] | None + video: list[float] | None + + +# Union type for backward compatibility +QueryEmbedding = Union[list[float], MultiModalEmbedding] + + class BaseKnowledgeStore(BaseModel, ABC): """Base Knowledge Store Class. @@ -47,12 +60,29 @@ def load_nodes(self, nodes: list["KnowledgeNode"]) -> None: @abstractmethod def retrieve( - self, query_emb: list[float], top_k: int + self, query_emb: QueryEmbedding, top_k: int ) -> list[tuple[float, "KnowledgeNode"]]: """Retrieve top-k nodes from KnowledgeStore against a provided user query. Args: - query_emb (list[float]): the query represented as an encoded vector. + query_emb (QueryEmbedding): the query represented as an encoded vector + or multimodal embedding dictionary. + top_k (int): the number of knowledge nodes to retrieve. + + Returns: + A list of tuples where the first element represents the similarity score + of the node to the query, and the second element is the node itself. + """ + + @abstractmethod + def retrieve_by_modality( + self, modality: str, query_emb: list[float], top_k: int + ) -> list[tuple[float, "KnowledgeNode"]]: + """Retrieve top-k nodes from a specific modality collection. + + Args: + modality (str): The modality to search in ("text", "image", "audio", "video"). + query_emb (list[float]): the query embedding for this modality. top_k (int): the number of knowledge nodes to retrieve. Returns: @@ -62,12 +92,12 @@ def retrieve( @abstractmethod def batch_retrieve( - self, query_embs: list[list[float]], top_k: int + self, query_embs: list[QueryEmbedding], top_k: int ) -> list[list[tuple[float, "KnowledgeNode"]]]: """Batch retrieve top-k nodes from KnowledgeStore against provided user queries. Args: - query_embs (list[list[float]]): the list of encoded queries. + query_embs (list[QueryEmbedding]): the list of encoded queries. top_k (int): the number of knowledge nodes to retrieve. Returns: @@ -103,6 +133,32 @@ def persist(self) -> None: def load(self) -> None: """Load the KnowledgeStore nodes from a permanent storage using `name`.""" + # Helper methods for multimodal support + def _is_multimodal_embedding(self, query_emb: QueryEmbedding) -> bool: + """Check if the query embedding is multimodal.""" + return isinstance(query_emb, dict) + + def _extract_text_embedding( + self, query_emb: QueryEmbedding + ) -> list[float]: + """Extract text embedding for backward compatibility.""" + if isinstance(query_emb, list): + return query_emb + text_emb = query_emb.get("text") + return text_emb if text_emb is not None else [] + + def _get_modality_embeddings( + self, query_emb: QueryEmbedding + ) -> dict[str, list[float]]: + """Get all available modality embeddings.""" + if isinstance(query_emb, list): + return {"text": query_emb} + result: dict[str, list[float]] = {} + for k, v in query_emb.items(): + if v is not None: + result[k] = v # type: ignore[assignment] + return result + class BaseAsyncKnowledgeStore(BaseModel, ABC): """Base Asynchronous Knowledge Store Class.""" @@ -131,12 +187,29 @@ async def load_nodes(self, nodes: list["KnowledgeNode"]) -> None: @abstractmethod async def retrieve( - self, query_emb: list[float], top_k: int + self, query_emb: QueryEmbedding, top_k: int ) -> list[tuple[float, "KnowledgeNode"]]: """Asynchronously retrieve top-k nodes from KnowledgeStore against a provided user query. Args: - query_emb (list[float]): the query represented as an encoded vector. + query_emb (QueryEmbedding): the query represented as an encoded vector + or multimodal embedding dictionary. + top_k (int): the number of knowledge nodes to retrieve. + + Returns: + A list of tuples where the first element represents the similarity score + of the node to the query, and the second element is the node itself. + """ + + @abstractmethod + async def retrieve_by_modality( + self, modality: str, query_emb: list[float], top_k: int + ) -> list[tuple[float, "KnowledgeNode"]]: + """Asynchronously retrieve top-k nodes from a specific modality collection. + + Args: + modality (str): The modality to search in ("text", "image", "audio", "video"). + query_emb (list[float]): the query embedding for this modality. top_k (int): the number of knowledge nodes to retrieve. Returns: @@ -146,12 +219,12 @@ async def retrieve( @abstractmethod async def batch_retrieve( - self, query_embs: list[list[float]], top_k: int + self, query_embs: list[QueryEmbedding], top_k: int ) -> list[list[tuple[float, "KnowledgeNode"]]]: """Asynchronously batch retrieve top-k nodes from KnowledgeStore against provided user queries. Args: - query_embs (list[list[float]]): the list of encoded queries. + query_embs (list[QueryEmbedding]): the list of encoded queries. top_k (int): the number of knowledge nodes to retrieve. Returns: @@ -186,6 +259,32 @@ def persist(self) -> None: def load(self) -> None: """Load the KnowledgeStore nodes from a permanent storage using `name`.""" + # Helper methods for multimodal support + def _is_multimodal_embedding(self, query_emb: QueryEmbedding) -> bool: + """Check if the query embedding is multimodal.""" + return isinstance(query_emb, dict) + + def _extract_text_embedding( + self, query_emb: QueryEmbedding + ) -> list[float]: + """Extract text embedding for backward compatibility.""" + if isinstance(query_emb, list): + return query_emb + text_emb = query_emb.get("text") + return text_emb if text_emb is not None else [] + + def _get_modality_embeddings( + self, query_emb: QueryEmbedding + ) -> dict[str, list[float]]: + """Get all available modality embeddings.""" + if isinstance(query_emb, list): + return {"text": query_emb} + result: dict[str, list[float]] = {} + for k, v in query_emb.items(): + if v is not None: + result[k] = v # type: ignore[assignment] + return result + class _SyncConvertedKnowledgeStore(BaseKnowledgeStore): """A nested class for converting this store to a sync version.""" @@ -220,13 +319,19 @@ def load_nodes(self, nodes: list["KnowledgeNode"]) -> None: asyncio_run(self._async_ks.load_nodes(nodes)) def retrieve( - self, query_emb: list[float], top_k: int + self, query_emb: QueryEmbedding, top_k: int ) -> list[tuple[float, "KnowledgeNode"]]: """Implements retrieve.""" return asyncio_run(self._async_ks.retrieve(query_emb=query_emb, top_k=top_k)) # type: ignore [no-any-return] + def retrieve_by_modality( + self, modality: str, query_emb: list[float], top_k: int + ) -> list[tuple[float, "KnowledgeNode"]]: + """Implements retrieve_by_modality.""" + return asyncio_run(self._async_ks.retrieve_by_modality(modality=modality, query_emb=query_emb, top_k=top_k)) # type: ignore [no-any-return] + def batch_retrieve( - self, query_embs: list[list[float]], top_k: int + self, query_embs: list[QueryEmbedding], top_k: int ) -> list[list[tuple[float, "KnowledgeNode"]]]: """Implements batch_retrieve.""" return asyncio_run(self._async_ks.batch_retrieve(query_embs=query_embs, top_k=top_k)) # type: ignore [no-any-return] @@ -241,7 +346,7 @@ def clear(self) -> None: @property def count(self) -> int: - """Returns the number of nodes in the knowledge store.""" + """Implements count.""" return self._async_ks.count def persist(self) -> None: @@ -252,7 +357,6 @@ def load(self) -> None: """Implements load.""" self._async_ks.load() - def to_sync(self) -> BaseKnowledgeStore: - """Convert this async knowledge store to a sync version.""" - - return BaseAsyncKnowledgeStore._SyncConvertedKnowledgeStore(self) + def to_sync(self) -> "BaseKnowledgeStore": + """Convert this async knowledge store to a synchronous version.""" + return self._SyncConvertedKnowledgeStore(self) diff --git a/src/fed_rag/core/rag_system/_asynchronous.py b/src/fed_rag/core/rag_system/_asynchronous.py index 39e16c9b..ab20a9c7 100644 --- a/src/fed_rag/core/rag_system/_asynchronous.py +++ b/src/fed_rag/core/rag_system/_asynchronous.py @@ -3,10 +3,12 @@ import asyncio from typing import TYPE_CHECKING +import torch from pydantic import BaseModel, ConfigDict from fed_rag.base.bridge import BridgeRegistryMixin from fed_rag.data_structures import RAGConfig, RAGResponse, SourceNode +from fed_rag.data_structures.rag import Query from fed_rag.exceptions import RAGSystemError if TYPE_CHECKING: # pragma: no cover @@ -32,14 +34,16 @@ class _AsyncRAGSystem(BridgeRegistryMixin, BaseModel): knowledge_store: "BaseAsyncKnowledgeStore" rag_config: RAGConfig - async def query(self, query: str) -> RAGResponse: + async def query(self, query: str | Query) -> RAGResponse: """Query the RAG system.""" source_nodes = await self.retrieve(query) context = self._format_context(source_nodes) response = await self.generate(query=query, context=context) return RAGResponse(source_nodes=source_nodes, response=response) - async def batch_query(self, queries: list[str]) -> list[RAGResponse]: + async def batch_query( + self, queries: list[str | Query] + ) -> list[RAGResponse]: """Batch query the RAG system.""" source_nodes_list = await self.batch_retrieve(queries) contexts = [ @@ -52,47 +56,115 @@ async def batch_query(self, queries: list[str]) -> list[RAGResponse]: for source_nodes, response in zip(source_nodes_list, responses) ] - async def retrieve(self, query: str) -> list[SourceNode]: - """Retrieve from KnowledgeStore.""" - query_emb: list[float] = self.retriever.encode_query(query).tolist() - raw_retrieval_result = await self.knowledge_store.retrieve( - query_emb=query_emb, top_k=self.rag_config.top_k + async def retrieve(self, query: str | Query) -> list[SourceNode]: + """Retrieve from multiple collections based on query modalities.""" + # Get multimodal embeddings from retriever + query_emb_tensor = self.retriever.encode_query(query) + + # Convert to separate embeddings by modality + modality_embeddings = self._prepare_modality_embeddings( + query_emb_tensor, query ) - return [ - SourceNode(score=el[0], node=el[1]) for el in raw_retrieval_result - ] - async def batch_retrieve( - self, queries: list[str] - ) -> list[list[SourceNode]]: - """Batch retrieve from KnowledgeStore.""" - query_embs: list[list[float]] = self.retriever.encode_query( - queries - ).tolist() - try: - raw_retrieval_results = await self.knowledge_store.batch_retrieve( - query_embs=query_embs, top_k=self.rag_config.top_k - ) - except NotImplementedError: - raw_retrieval_tasks = [ - self.knowledge_store.retrieve( - query_emb=query_emb, top_k=self.rag_config.top_k + # Retrieve from each modality collection concurrently + retrieve_tasks = [] + for modality, embedding in modality_embeddings.items(): + if embedding is not None: + task = self.knowledge_store.retrieve_by_modality( + modality=modality, + query_emb=embedding, + top_k=self.rag_config.top_k, ) - for query_emb in query_embs - ] - raw_retrieval_results = await asyncio.gather(*raw_retrieval_tasks) + retrieve_tasks.append((modality, task)) - return [ - [SourceNode(score=el[0], node=el[1]) for el in raw_result] - for raw_result in raw_retrieval_results - ] + # Wait for all retrievals to complete + all_results = [] + for modality, task in retrieve_tasks: + modality_results = await task + for score, node in modality_results: + source_node = SourceNode(score=score, node=node) + source_node.modality = modality + all_results.append(source_node) + + all_results.sort(key=lambda x: x.score, reverse=True) + return all_results[: self.rag_config.top_k] - async def generate(self, query: str, context: str) -> str: + async def batch_retrieve( + self, queries: list[str | Query] + ) -> list[list[SourceNode]]: + """Batch retrieve from multiple collections.""" + retrieve_tasks = [self.retrieve(query) for query in queries] + return await asyncio.gather(*retrieve_tasks) + + def _prepare_modality_embeddings( + self, embedding_tensor: torch.Tensor, query: str | Query + ) -> dict[str, list[float]]: + """Extract embeddings for each modality present in the query.""" + modality_embeddings = {} + + if isinstance(query, str): + # Text-only query + modality_embeddings["text"] = embedding_tensor.squeeze().tolist() + elif isinstance(query, Query): + # Check what modalities are present in the query + available_modalities = [] + if query.text is not None: + available_modalities.append("text") + if query.images is not None and len(query.images) > 0: + available_modalities.append("image") + if query.audios is not None and len(query.audios) > 0: + available_modalities.append("audio") + if query.videos is not None and len(query.videos) > 0: + available_modalities.append("video") + + # Map tensor outputs to modalities + if embedding_tensor.dim() == 1: + primary_modality = ( + available_modalities[0] if available_modalities else "text" + ) + modality_embeddings[ + primary_modality + ] = embedding_tensor.tolist() + elif embedding_tensor.dim() == 2: + for i, modality in enumerate(available_modalities): + if i < embedding_tensor.shape[0]: + modality_embeddings[modality] = embedding_tensor[ + i + ].tolist() + else: + # Handle unexpected tensor dimensions + if embedding_tensor.dim() > 2: + # Flatten to 2D and try again + flattened = embedding_tensor.view( + embedding_tensor.shape[0], -1 + ) + if flattened.shape[0] == len(available_modalities): + for i, modality in enumerate(available_modalities): + modality_embeddings[modality] = flattened[ + i + ].tolist() + else: + modality_embeddings[ + "text" + ] = embedding_tensor.flatten().tolist() + else: + # dim() == 0, treat as single text embedding + modality_embeddings["text"] = ( + [embedding_tensor.item()] + if embedding_tensor.numel() == 1 + else embedding_tensor.flatten().tolist() + ) + else: + modality_embeddings["text"] = embedding_tensor.squeeze().tolist() + + return modality_embeddings + + async def generate(self, query: str | Query, context: str) -> str: """Generate response to query with context.""" return self.generator.generate(query=query, context=context) # type: ignore async def batch_generate( - self, queries: list[str], contexts: list[str] + self, queries: list[str | Query], contexts: list[str] ) -> list[str]: """Batch generate responses to queries with contexts.""" if len(queries) != len(contexts): @@ -102,13 +174,63 @@ async def batch_generate( return self.generator.generate(query=queries, context=contexts) # type: ignore def _format_context(self, source_nodes: list[SourceNode]) -> str: - """Format the context from the source nodes.""" - # TODO: how to format image context - return str( - self.rag_config.context_separator.join( - [node.get_content()["text_content"] for node in source_nodes] - ) - ) + """Format context from nodes retrieved from different modality collections.""" + # Group nodes by modality for better organization + modality_groups: dict[str, list[SourceNode]] = {} + for node in source_nodes: + modality = getattr(node, "modality", "text") + if modality not in modality_groups: + modality_groups[modality] = [] + modality_groups[modality].append(node) + + # Modality-specific content extraction rules + modality_config = { + "text": { + "title": "Text Context", + "content_keys": ["text_content"], + "prefix": "", + }, + "image": { + "title": "Image Context", + "content_keys": ["text_content", "image_description"], + "prefix": "Image: ", + }, + "audio": { + "title": "Audio Context", + "content_keys": ["text_content", "audio_transcript"], + "prefix": "Audio: ", + }, + "video": { + "title": "Video Context", + "content_keys": ["text_content", "video_description"], + "prefix": "Video: ", + }, + } + + context_parts = [] + # Process modalities in preferred order + for modality in ["text", "image", "audio", "video"]: + if modality in modality_groups: + config = modality_config[modality] + descriptions = [] + + for node in modality_groups[modality]: + content = node.get_content() + # Try each content key until we find one + for key in config["content_keys"]: + if key in content and content[key]: + descriptions.append( + f"{config['prefix']}{content[key]}" + ) + break + + if descriptions: + section = self.rag_config.context_separator.join( + descriptions + ) + context_parts.append(f"{config['title']}:\n{section}") + + return "\n\n".join(context_parts) def _resolve_forward_refs() -> None: diff --git a/src/fed_rag/core/rag_system/_synchronous.py b/src/fed_rag/core/rag_system/_synchronous.py index 25502603..59e4bc15 100644 --- a/src/fed_rag/core/rag_system/_synchronous.py +++ b/src/fed_rag/core/rag_system/_synchronous.py @@ -2,10 +2,12 @@ from typing import TYPE_CHECKING +import torch from pydantic import BaseModel, ConfigDict from fed_rag.base.bridge import BridgeRegistryMixin from fed_rag.data_structures import RAGConfig, RAGResponse, SourceNode +from fed_rag.data_structures.rag import Query from fed_rag.exceptions import RAGSystemError if TYPE_CHECKING: # pragma: no cover @@ -31,14 +33,14 @@ class _RAGSystem(BridgeRegistryMixin, BaseModel): knowledge_store: "BaseKnowledgeStore" rag_config: RAGConfig - def query(self, query: str) -> RAGResponse: + def query(self, query: str | Query) -> RAGResponse: """Query the RAG system.""" source_nodes = self.retrieve(query) context = self._format_context(source_nodes) response = self.generate(query=query, context=context) return RAGResponse(source_nodes=source_nodes, response=response) - def batch_query(self, queries: list[str]) -> list[RAGResponse]: + def batch_query(self, queries: list[str | Query]) -> list[RAGResponse]: """Batch query the RAG system.""" source_nodes_list = self.batch_retrieve(queries) contexts = [ @@ -51,44 +53,110 @@ def batch_query(self, queries: list[str]) -> list[RAGResponse]: for source_nodes, response in zip(source_nodes_list, responses) ] - def retrieve(self, query: str) -> list[SourceNode]: - """Retrieve from KnowledgeStore.""" - query_emb: list[float] = self.retriever.encode_query(query).tolist() - raw_retrieval_result = self.knowledge_store.retrieve( - query_emb=query_emb, top_k=self.rag_config.top_k + def retrieve(self, query: str | Query) -> list[SourceNode]: + """Retrieve from multiple collections based on query modalities.""" + # Get multimodal embeddings from retriever + query_emb_tensor = self.retriever.encode_query(query) + + # Convert to separate embeddings by modality + modality_embeddings = self._prepare_modality_embeddings( + query_emb_tensor, query ) - return [ - SourceNode(score=el[0], node=el[1]) for el in raw_retrieval_result - ] - def batch_retrieve(self, queries: list[str]) -> list[list[SourceNode]]: - """Batch retrieve from KnowledgeStore.""" - query_embs: list[list[float]] = self.retriever.encode_query( - queries - ).tolist() - try: - raw_retrieval_results = self.knowledge_store.batch_retrieve( - query_embs=query_embs, top_k=self.rag_config.top_k - ) - except NotImplementedError: - raw_retrieval_results = [ - self.knowledge_store.retrieve( - query_emb=query_emb, top_k=self.rag_config.top_k + # Retrieve from each modality collection separately + all_results = [] + for modality, embedding in modality_embeddings.items(): + if embedding is not None: + # Use modality-specific collection in knowledge store + modality_results = self.knowledge_store.retrieve_by_modality( + modality=modality, + query_emb=embedding, + top_k=self.rag_config.top_k, ) - for query_emb in query_embs - ] - - return [ - [SourceNode(score=el[0], node=el[1]) for el in raw_result] - for raw_result in raw_retrieval_results - ] - - def generate(self, query: str, context: str) -> str: + # Add modality info to source nodes + for score, node in modality_results: + source_node = SourceNode(score=score, node=node) + source_node.modality = modality + all_results.append(source_node) + + all_results.sort(key=lambda x: x.score, reverse=True) + return all_results[: self.rag_config.top_k] + + def batch_retrieve( + self, queries: list[str | Query] + ) -> list[list[SourceNode]]: + """Batch retrieve from multiple collections.""" + return [self.retrieve(query) for query in queries] + + def _prepare_modality_embeddings( + self, embedding_tensor: torch.Tensor, query: str | Query + ) -> dict[str, list[float]]: + """Extract embeddings for each modality present in the query.""" + modality_embeddings = {} + + if isinstance(query, str): + # Text-only query + modality_embeddings["text"] = embedding_tensor.squeeze().tolist() + elif isinstance(query, Query): + # Check what modalities are present in the query + available_modalities = [] + if query.text is not None: + available_modalities.append("text") + if query.images is not None and len(query.images) > 0: + available_modalities.append("image") + if query.audios is not None and len(query.audios) > 0: + available_modalities.append("audio") + if query.videos is not None and len(query.videos) > 0: + available_modalities.append("video") + + # Map tensor outputs to modalities + if embedding_tensor.dim() == 1: + primary_modality = ( + available_modalities[0] if available_modalities else "text" + ) + modality_embeddings[ + primary_modality + ] = embedding_tensor.tolist() + elif embedding_tensor.dim() == 2: + for i, modality in enumerate(available_modalities): + if i < embedding_tensor.shape[0]: + modality_embeddings[modality] = embedding_tensor[ + i + ].tolist() + else: + # Handle unexpected tensor dimensions + if embedding_tensor.dim() > 2: + # Flatten to 2D and try again + flattened = embedding_tensor.view( + embedding_tensor.shape[0], -1 + ) + if flattened.shape[0] == len(available_modalities): + for i, modality in enumerate(available_modalities): + modality_embeddings[modality] = flattened[ + i + ].tolist() + else: + modality_embeddings[ + "text" + ] = embedding_tensor.flatten().tolist() + else: + # dim() == 0, treat as single text embedding + modality_embeddings["text"] = ( + [embedding_tensor.item()] + if embedding_tensor.numel() == 1 + else embedding_tensor.flatten().tolist() + ) + else: + modality_embeddings["text"] = embedding_tensor.squeeze().tolist() + + return modality_embeddings + + def generate(self, query: str | Query, context: str) -> str: """Generate response to query with context.""" return self.generator.generate(query=query, context=context) # type: ignore def batch_generate( - self, queries: list[str], contexts: list[str] + self, queries: list[str | Query], contexts: list[str] ) -> list[str]: """Batch generate responses to queries with contexts.""" if len(queries) != len(contexts): @@ -98,13 +166,60 @@ def batch_generate( return self.generator.generate(query=queries, context=contexts) # type: ignore def _format_context(self, source_nodes: list[SourceNode]) -> str: - """Format the context from the source nodes.""" - # TODO: how to format image context - return str( - self.rag_config.context_separator.join( - [node.get_content()["text_content"] for node in source_nodes] - ) - ) + """Format context from nodes retrieved from different modality collections.""" + modality_groups: dict[str, list[SourceNode]] = {} + for node in source_nodes: + modality = getattr(node, "modality", "text") + if modality not in modality_groups: + modality_groups[modality] = [] + modality_groups[modality].append(node) + + # Modality-specific content extraction rules + modality_config = { + "text": { + "title": "Text Context", + "content_keys": ["text_content"], + "prefix": "", + }, + "image": { + "title": "Image Context", + "content_keys": ["text_content", "image_description"], + "prefix": "Image: ", + }, + "audio": { + "title": "Audio Context", + "content_keys": ["text_content", "audio_transcript"], + "prefix": "Audio: ", + }, + "video": { + "title": "Video Context", + "content_keys": ["text_content", "video_description"], + "prefix": "Video: ", + }, + } + + context_parts = [] + for modality in ["text", "image", "audio", "video"]: + if modality in modality_groups: + config = modality_config[modality] + descriptions = [] + + for node in modality_groups[modality]: + content = node.get_content() + for key in config["content_keys"]: + if key in content and content[key]: + descriptions.append( + f"{config['prefix']}{content[key]}" + ) + break + + if descriptions: + section = self.rag_config.context_separator.join( + descriptions + ) + context_parts.append(f"{config['title']}:\n{section}") + + return "\n\n".join(context_parts) def _resolve_forward_refs() -> None: