Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding MimeNode type #503

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
59 changes: 52 additions & 7 deletions libs/knowledge-store/ragstack_knowledge_store/embedding_model.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,67 @@
from abc import ABC, abstractmethod
from typing import List

from abc import ABC
from typing import List, Any, Optional

class EmbeddingModel(ABC):
"""Embedding model."""

@abstractmethod
def __init__(self, embeddings: Any, method_map: Optional[dict] = None, other_methods: Optional[List[str]] = None):
self.embeddings = embeddings
self.method_name = {}
method_map = method_map if method_map else {}
other_methods = other_methods if other_methods else []

base_methods = ['embed_texts', 'aembed_texts', 'embed_query', 'aembed_query']
extended_methods = ['embed_images', 'aembed_images', 'embed_image', 'aembed_image']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should try to add all of these as methods, it's definitely pretty messy.

I think we should just have embed_mime(self, mime_type: str, content: Union[str, Bytes]) or something like that. Then there is only a single abstract method to use for any mime type and the names can be different, etc.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% but right now, LangChain doesn't have "embed_mime" :)


# Combining all method names, including those mapped
all_methods = set(base_methods + extended_methods + other_methods + list(method_map.values()))

for method in all_methods:
mapped_method = method_map.get(method)
if hasattr(embeddings, method):
self.method_name[method] = method
elif hasattr(embeddings, mapped_method) if mapped_method else False:
self.method_name[method] = mapped_method
else:
self.method_name[method] = None

def does_implement(self, method_name: str) -> bool:
"""Check if the method is implemented."""
return self.method_name.get(method_name) is not None

def implements(self) -> List[str]:
"""List of methods that are implemented"""
return [method for method, impl in self.method_name.items() if impl is not None]

def invoke(self, method_name: str, *args, **kwargs):
"""Invoke a synchronous method if it's implemented."""
target_method = self.method_name.get(method_name)
if target_method and hasattr(self.embeddings, target_method):
return getattr(self.embeddings, target_method)(*args, **kwargs)
else:
raise NotImplementedError(f"{self.embeddings.__class__.__name__} does not implement {target_method}")

async def ainvoke(self, method_name: str, *args, **kwargs):
"""Invoke an asynchronous method if it's implemented."""
target_method = self.method_name.get(method_name)
if target_method and hasattr(self.embeddings, target_method):
return await getattr(self.embeddings, target_method)(*args, **kwargs)
else:
raise NotImplementedError(f"{self.embeddings.__class__.__name__} does not implement {target_method}")

def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed texts."""
return self.invoke('embed_texts', texts)

@abstractmethod
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
return self.invoke('embed_query', text)

@abstractmethod
async def aembed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed texts."""
return await self.ainvoke('aembed_texts', texts)

@abstractmethod
async def aembed_query(self, text: str) -> List[float]:
"""Embed query text."""
return await self.ainvoke('aembed_query', text)

110 changes: 69 additions & 41 deletions libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,21 @@ class Node:
"""Metadata for the node. May contain information used to link this node
with other nodes."""

content: str = None
"""Encoded content"""

mime_type: str = None
"""Type of content, e.g. text/plain or image/png."""

mime_encoding: str = None
"""Encoding format"""

@dataclass
class TextNode(Node):
text: str = None
"""Text contained by the node."""

mime_type = "text/plain"

class SetupMode(Enum):
SYNC = 1
Expand Down Expand Up @@ -326,52 +335,71 @@ def add_nodes(
self,
nodes: Iterable[Node] = None,
) -> Iterable[str]:
texts = []
metadatas = []
for node in nodes:
if not isinstance(node, TextNode):
raise ValueError("Only adding TextNode is supported at the moment")
texts.append(node.text)
metadatas.append(node.metadata)

text_embeddings = self._embedding.embed_texts(texts)

# Organize nodes by MIME type
mime_buckets = {}
ids = []

tag_to_new_sources: Dict[str, List[Tuple[str, str]]] = {}
tag_to_new_targets: Dict[str, Dict[str, Tuple[str, List[float]]]] = {}
# Prepare nodes based on their type
for node in nodes:
if isinstance(node, TextNode):
if 'text' not in mime_buckets:
mime_buckets['text'] = []
mime_buckets['text'].append(node)
if isinstance(node, Node) and node.mime_type:
main_mime_type = node.mime_type.split('/')[0] # Split and take the first part, e.g., "image" from "image/png"
if main_mime_type not in mime_buckets:
mime_buckets[main_mime_type] = []
mime_buckets[main_mime_type].append(node)
else:
raise ValueError("Unsupported node type")

# Process each MIME bucket
embeddings_dict = {}
for mime_type, nodes_list in mime_buckets.items():
method_name = f"embed_{mime_type}s"
if self._embedding.does_implement(method_name):
texts = [node.text if isinstance(node, TextNode) else node.content for node in nodes_list]
embeddings_dict[mime_type] = self._embedding.invoke(method_name, texts)
else:
# If no bulk method, try to call a singular method for each content
singular_method_name = f"embed_{mime_type}"
if self._embedding.does_implement(singular_method_name):
embeddings = []
for node in nodes_list:
embedding = self._embedding.invoke(singular_method_name, node.text if isinstance(node, TextNode) else node.content)
embeddings.append(embedding)
embeddings_dict[mime_type] = embeddings
else:
raise NotImplementedError(f"No embedding method available for MIME type: {mime_type}, implemented methods: {self._embedding.implements()}.")


# Step 1: Add the nodes, collecting the tags and new sources / targets.
tag_to_new_sources = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has changed significantly from the previous implementation. I think it will need to be reworked if we work on adding this right now.

tag_to_new_targets = {}
with self._concurrent_queries() as cq:
tuples = zip(texts, text_embeddings, metadatas)
for text, text_embedding, metadata in tuples:
if CONTENT_ID not in metadata:
metadata[CONTENT_ID] = secrets.token_hex(8)
id = metadata[CONTENT_ID]
ids.append(id)

link_to_tags = set() # link to these tags
link_from_tags = set() # link from these tags

for tag in get_link_tags(metadata):
tag_str = f"{tag.kind}:{tag.tag}"
if tag.direction == "incoming" or tag.direction == "bidir":
# An incom`ing link should be linked *from* nodes with the given tag.
link_from_tags.add(tag_str)
tag_to_new_targets.setdefault(tag_str, dict())[id] = (
tag.kind,
text_embedding,
)
if tag.direction == "outgoing" or tag.direction == "bidir":
link_to_tags.add(tag_str)
tag_to_new_sources.setdefault(tag_str, list()).append(
(tag.kind, id)
)

cq.execute(
self._insert_passage,
(id, text, text_embedding, link_to_tags, link_from_tags),
)
for mime_type, embeddings in embeddings_dict.items():
for node, embedding in zip(mime_buckets[mime_type], embeddings):
if CONTENT_ID not in node.metadata:
node.metadata[CONTENT_ID] = secrets.token_hex(8)
node_id = node.metadata[CONTENT_ID]
ids.append(node_id)

link_to_tags = set()
link_from_tags = set()

for tag in get_link_tags(node.metadata):
tag_str = f"{tag.kind}:{tag.tag}"
if tag.direction in ["incoming", "bidir"]:
link_from_tags.add(tag_str)
tag_to_new_targets.setdefault(tag_str, {})[node_id] = (tag.kind, embedding)
if tag.direction in ["outgoing", "bidir"]:
link_to_tags.add(tag_str)
tag_to_new_sources.setdefault(tag_str, []).append((tag.kind, node_id))

cq.execute(
self._insert_passage,
(node_id, node.text if isinstance(node, TextNode) else node.content, embedding, link_to_tags, link_from_tags),
)

# Step 2: Query information about those tags to determine the edges to add.
# Add edges as needed.
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain/ragstack_langchain/graph_store/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import GraphStore, Node, TextNode
from .base import GraphStore, Node
from .cassandra import CassandraGraphStore

__all__ = ["CassandraGraphStore", "GraphStore", "Node", "TextNode"]
__all__ = ["CassandraGraphStore", "GraphStore", "Node"]
21 changes: 14 additions & 7 deletions libs/langchain/ragstack_langchain/graph_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@ class Node(Serializable):
"""Metadata for the node. May contain information used to link this node
with other nodes."""

content: str = None
"""Encoded content"""

class TextNode(Node):
text: str
"""Text contained by the node."""
mime_type: str = None
"""Type of content, e.g. text/plain or image/png."""

mime_encoding: str = None
"""Encoding format"""

def _texts_to_nodes(
texts: Iterable[str],
Expand All @@ -61,10 +64,11 @@ def _texts_to_nodes(
_id = next(ids_it) if ids_it else None
except StopIteration:
raise ValueError("texts iterable longer than ids")
yield TextNode(
yield Node(
id=_id,
metadata=_metadata,
text=text,
mime_type="text/plain",
content=text,
)
if ids and _has_next(ids_it):
raise ValueError("ids iterable longer than texts")
Expand All @@ -81,10 +85,13 @@ def _documents_to_nodes(
_id = next(ids_it) if ids_it else None
except StopIteration:
raise ValueError("documents iterable longer than ids")
yield TextNode(

yield Node(
id=_id,
metadata=doc.metadata,
text=doc.page_content,
mime_type=doc.metadata.get('mime_type', 'text/plain'),
mime_encoding=doc.metadata.get('mime_encoding', None),
content=doc.page_content,
)
if ids and _has_next(ids_it):
raise ValueError("ids iterable longer than documents")
Expand Down
33 changes: 8 additions & 25 deletions libs/langchain/ragstack_langchain/graph_store/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,9 @@
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings

from .base import GraphStore, Node, TextNode
from ragstack_knowledge_store import EmbeddingModel, graph_store


class _EmbeddingModelAdapter(EmbeddingModel):
def __init__(self, embeddings: Embeddings):
self.embeddings = embeddings

def embed_texts(self, texts: List[str]) -> List[List[float]]:
return self.embeddings.embed_documents(texts)

def embed_query(self, text: str) -> List[float]:
return self.embeddings.embed_query(text)

async def aembed_texts(self, texts: List[str]) -> List[List[float]]:
return await self.embeddings.aembed_documents(texts)

async def aembed_query(self, text: str) -> List[float]:
return await self.embeddings.aembed_query(text)

from .base import GraphStore, Node
from .embedding_adapter import EmbeddingAdapter
from ragstack_knowledge_store import graph_store

def _row_to_document(row) -> Document:
return Document(
Expand Down Expand Up @@ -78,7 +61,7 @@ def __init__(
_setup_mode = getattr(graph_store.SetupMode, setup_mode.name)

self.store = graph_store.GraphStore(
embedding=_EmbeddingModelAdapter(embedding),
embedding=EmbeddingAdapter(embedding),
node_table=node_table,
edge_table=edge_table,
session=session,
Expand All @@ -98,11 +81,11 @@ def add_nodes(
):
_nodes = []
for node in nodes:
if not isinstance(node, TextNode):
raise ValueError("Only adding TextNode is supported at the moment")
if not isinstance(node, Node):
raise ValueError("Only adding Node is supported at the moment")
_nodes.append(
graph_store.TextNode(id=node.id, text=node.text, metadata=node.metadata)
)
graph_store.Node(id=node.id, content=node.content, mime_type=node.mime_type, mime_encoding=node.mime_encoding, metadata=node.metadata)
)
return self.store.add_nodes(_nodes)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import List
from ragstack_knowledge_store import EmbeddingModel

class EmbeddingAdapter(EmbeddingModel):
def __init__(self, embeddings):
super().__init__(embeddings,
method_map={'embed_texts': 'embed_documents',
'aembed_texts': 'aembed_documents'})

Loading