-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding MimeNode type #503
Open
mieslep
wants to merge
10
commits into
datastax:main
Choose a base branch
from
mieslep:knowledge_store/mime
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
adding MimeNode type #503
Changes from 8 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
a904165
adding MimeNode type
mieslep bb802e5
consolidating MIME headers to Node, adding Cassandra
mieslep 8fee3fe
adding MimeNode type
mieslep 1b81c2c
consolidating MIME headers to Node, adding Cassandra
mieslep da62a11
Merge branch 'knowledge_store/mime' of github.com:mieslep/ragstack-ai…
mieslep 1222740
incorporate EmbeddingModel abstraction
mieslep f950df1
fixing text embeddings
mieslep 43e4f60
Merge branch 'main' into knowledge_store/mime
mieslep 3046ce7
preserve incoming node_id
mieslep 62d6b41
Merge remote-tracking branch 'DataStax/main' into knowledge_store/mime
mieslep File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
59 changes: 52 additions & 7 deletions
59
libs/knowledge-store/ragstack_knowledge_store/embedding_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,67 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
from abc import ABC | ||
from typing import List, Any, Optional | ||
|
||
class EmbeddingModel(ABC): | ||
"""Embedding model.""" | ||
|
||
@abstractmethod | ||
def __init__(self, embeddings: Any, method_map: Optional[dict] = None, other_methods: Optional[List[str]] = None): | ||
self.embeddings = embeddings | ||
self.method_name = {} | ||
method_map = method_map if method_map else {} | ||
other_methods = other_methods if other_methods else [] | ||
|
||
base_methods = ['embed_texts', 'aembed_texts', 'embed_query', 'aembed_query'] | ||
extended_methods = ['embed_images', 'aembed_images', 'embed_image', 'aembed_image'] | ||
|
||
# Combining all method names, including those mapped | ||
all_methods = set(base_methods + extended_methods + other_methods + list(method_map.values())) | ||
|
||
for method in all_methods: | ||
mapped_method = method_map.get(method) | ||
if hasattr(embeddings, method): | ||
self.method_name[method] = method | ||
elif hasattr(embeddings, mapped_method) if mapped_method else False: | ||
self.method_name[method] = mapped_method | ||
else: | ||
self.method_name[method] = None | ||
|
||
def does_implement(self, method_name: str) -> bool: | ||
"""Check if the method is implemented.""" | ||
return self.method_name.get(method_name) is not None | ||
|
||
def implements(self) -> List[str]: | ||
"""List of methods that are implemented""" | ||
return [method for method, impl in self.method_name.items() if impl is not None] | ||
|
||
def invoke(self, method_name: str, *args, **kwargs): | ||
"""Invoke a synchronous method if it's implemented.""" | ||
target_method = self.method_name.get(method_name) | ||
if target_method and hasattr(self.embeddings, target_method): | ||
return getattr(self.embeddings, target_method)(*args, **kwargs) | ||
else: | ||
raise NotImplementedError(f"{self.embeddings.__class__.__name__} does not implement {target_method}") | ||
|
||
async def ainvoke(self, method_name: str, *args, **kwargs): | ||
"""Invoke an asynchronous method if it's implemented.""" | ||
target_method = self.method_name.get(method_name) | ||
if target_method and hasattr(self.embeddings, target_method): | ||
return await getattr(self.embeddings, target_method)(*args, **kwargs) | ||
else: | ||
raise NotImplementedError(f"{self.embeddings.__class__.__name__} does not implement {target_method}") | ||
|
||
def embed_texts(self, texts: List[str]) -> List[List[float]]: | ||
"""Embed texts.""" | ||
return self.invoke('embed_texts', texts) | ||
|
||
@abstractmethod | ||
def embed_query(self, text: str) -> List[float]: | ||
"""Embed query text.""" | ||
return self.invoke('embed_query', text) | ||
|
||
@abstractmethod | ||
async def aembed_texts(self, texts: List[str]) -> List[List[float]]: | ||
"""Embed texts.""" | ||
return await self.ainvoke('aembed_texts', texts) | ||
|
||
@abstractmethod | ||
async def aembed_query(self, text: str) -> List[float]: | ||
"""Embed query text.""" | ||
return await self.ainvoke('aembed_query', text) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,12 +34,21 @@ class Node: | |
"""Metadata for the node. May contain information used to link this node | ||
with other nodes.""" | ||
|
||
content: str = None | ||
"""Encoded content""" | ||
|
||
mime_type: str = None | ||
"""Type of content, e.g. text/plain or image/png.""" | ||
|
||
mime_encoding: str = None | ||
"""Encoding format""" | ||
|
||
@dataclass | ||
class TextNode(Node): | ||
text: str = None | ||
"""Text contained by the node.""" | ||
|
||
mime_type = "text/plain" | ||
|
||
class SetupMode(Enum): | ||
SYNC = 1 | ||
|
@@ -326,52 +335,71 @@ def add_nodes( | |
self, | ||
nodes: Iterable[Node] = None, | ||
) -> Iterable[str]: | ||
texts = [] | ||
metadatas = [] | ||
for node in nodes: | ||
if not isinstance(node, TextNode): | ||
raise ValueError("Only adding TextNode is supported at the moment") | ||
texts.append(node.text) | ||
metadatas.append(node.metadata) | ||
|
||
text_embeddings = self._embedding.embed_texts(texts) | ||
|
||
# Organize nodes by MIME type | ||
mime_buckets = {} | ||
ids = [] | ||
|
||
tag_to_new_sources: Dict[str, List[Tuple[str, str]]] = {} | ||
tag_to_new_targets: Dict[str, Dict[str, Tuple[str, List[float]]]] = {} | ||
# Prepare nodes based on their type | ||
for node in nodes: | ||
if isinstance(node, TextNode): | ||
if 'text' not in mime_buckets: | ||
mime_buckets['text'] = [] | ||
mime_buckets['text'].append(node) | ||
if isinstance(node, Node) and node.mime_type: | ||
main_mime_type = node.mime_type.split('/')[0] # Split and take the first part, e.g., "image" from "image/png" | ||
if main_mime_type not in mime_buckets: | ||
mime_buckets[main_mime_type] = [] | ||
mime_buckets[main_mime_type].append(node) | ||
else: | ||
raise ValueError("Unsupported node type") | ||
|
||
# Process each MIME bucket | ||
embeddings_dict = {} | ||
for mime_type, nodes_list in mime_buckets.items(): | ||
method_name = f"embed_{mime_type}s" | ||
if self._embedding.does_implement(method_name): | ||
texts = [node.text if isinstance(node, TextNode) else node.content for node in nodes_list] | ||
embeddings_dict[mime_type] = self._embedding.invoke(method_name, texts) | ||
else: | ||
# If no bulk method, try to call a singular method for each content | ||
singular_method_name = f"embed_{mime_type}" | ||
if self._embedding.does_implement(singular_method_name): | ||
embeddings = [] | ||
for node in nodes_list: | ||
embedding = self._embedding.invoke(singular_method_name, node.text if isinstance(node, TextNode) else node.content) | ||
embeddings.append(embedding) | ||
embeddings_dict[mime_type] = embeddings | ||
else: | ||
raise NotImplementedError(f"No embedding method available for MIME type: {mime_type}, implemented methods: {self._embedding.implements()}.") | ||
|
||
|
||
# Step 1: Add the nodes, collecting the tags and new sources / targets. | ||
tag_to_new_sources = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has changed significantly from the previous implementation. I think it will need to be reworked if we work on adding this right now. |
||
tag_to_new_targets = {} | ||
with self._concurrent_queries() as cq: | ||
tuples = zip(texts, text_embeddings, metadatas) | ||
for text, text_embedding, metadata in tuples: | ||
if CONTENT_ID not in metadata: | ||
metadata[CONTENT_ID] = secrets.token_hex(8) | ||
id = metadata[CONTENT_ID] | ||
ids.append(id) | ||
|
||
link_to_tags = set() # link to these tags | ||
link_from_tags = set() # link from these tags | ||
|
||
for tag in get_link_tags(metadata): | ||
tag_str = f"{tag.kind}:{tag.tag}" | ||
if tag.direction == "incoming" or tag.direction == "bidir": | ||
# An incom`ing link should be linked *from* nodes with the given tag. | ||
link_from_tags.add(tag_str) | ||
tag_to_new_targets.setdefault(tag_str, dict())[id] = ( | ||
tag.kind, | ||
text_embedding, | ||
) | ||
if tag.direction == "outgoing" or tag.direction == "bidir": | ||
link_to_tags.add(tag_str) | ||
tag_to_new_sources.setdefault(tag_str, list()).append( | ||
(tag.kind, id) | ||
) | ||
|
||
cq.execute( | ||
self._insert_passage, | ||
(id, text, text_embedding, link_to_tags, link_from_tags), | ||
) | ||
for mime_type, embeddings in embeddings_dict.items(): | ||
for node, embedding in zip(mime_buckets[mime_type], embeddings): | ||
if CONTENT_ID not in node.metadata: | ||
node.metadata[CONTENT_ID] = secrets.token_hex(8) | ||
node_id = node.metadata[CONTENT_ID] | ||
ids.append(node_id) | ||
|
||
link_to_tags = set() | ||
link_from_tags = set() | ||
|
||
for tag in get_link_tags(node.metadata): | ||
tag_str = f"{tag.kind}:{tag.tag}" | ||
if tag.direction in ["incoming", "bidir"]: | ||
link_from_tags.add(tag_str) | ||
tag_to_new_targets.setdefault(tag_str, {})[node_id] = (tag.kind, embedding) | ||
if tag.direction in ["outgoing", "bidir"]: | ||
link_to_tags.add(tag_str) | ||
tag_to_new_sources.setdefault(tag_str, []).append((tag.kind, node_id)) | ||
|
||
cq.execute( | ||
self._insert_passage, | ||
(node_id, node.text if isinstance(node, TextNode) else node.content, embedding, link_to_tags, link_from_tags), | ||
) | ||
|
||
# Step 2: Query information about those tags to determine the edges to add. | ||
# Add edges as needed. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .base import GraphStore, Node, TextNode | ||
from .base import GraphStore, Node | ||
from .cassandra import CassandraGraphStore | ||
|
||
__all__ = ["CassandraGraphStore", "GraphStore", "Node", "TextNode"] | ||
__all__ = ["CassandraGraphStore", "GraphStore", "Node"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
9 changes: 9 additions & 0 deletions
9
libs/langchain/ragstack_langchain/graph_store/embedding_adapter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from typing import List | ||
from ragstack_knowledge_store import EmbeddingModel | ||
|
||
class EmbeddingAdapter(EmbeddingModel): | ||
def __init__(self, embeddings): | ||
super().__init__(embeddings, | ||
method_map={'embed_texts': 'embed_documents', | ||
'aembed_texts': 'aembed_documents'}) | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should try to add all of these as methods, it's definitely pretty messy.
I think we should just have
embed_mime(self, mime_type: str, content: Union[str, Bytes])
or something like that. Then there is only a single abstract method to use for any mime type and the names can be different, etc.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
100% but right now, LangChain doesn't have "embed_mime" :)