Skip to content

Commit

Permalink
simplified insertion
Browse files Browse the repository at this point in the history
  • Loading branch information
epinzur committed Oct 14, 2024
1 parent 4c6c46e commit 8097a09
Showing 1 changed file with 41 additions and 20 deletions.
61 changes: 41 additions & 20 deletions libs/astradb/langchain_astradb/graph_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,40 +359,62 @@ def _restore_links(self, doc: Document) -> Document:
del doc.metadata[self.metadata_incoming_links_key]
return doc

# TODO: Async (aadd_nodes)
@override
def add_nodes(
self,
nodes: Iterable[Node],
**kwargs: Any,
) -> Iterable[str]:
"""Add nodes to the graph store.
def _get_node_metadata_for_insertion(self, node: Node) -> dict[str, Any]:
metadata = node.metadata.copy()
metadata[METADATA_LINKS_KEY] = _serialize_links(node.links)
metadata[self.metadata_incoming_links_key] = [
_metadata_link_key(link=link) for link in _incoming_links(node=node)
]
return metadata

Args:
nodes: the nodes to add.
**kwargs: Additional keyword arguments.
"""
def _get_docs_for_insertion(
self, nodes: Iterable[Node]
) -> tuple[list[Document], list[str]]:
docs = []
ids = []
for node in nodes:
node_id = secrets.token_hex(8) if not node.id else node.id

combined_metadata = node.metadata.copy()
combined_metadata[METADATA_LINKS_KEY] = _serialize_links(node.links)
combined_metadata[self.metadata_incoming_links_key] = [
_metadata_link_key(link=link) for link in _incoming_links(node=node)
]

doc = Document(
page_content=node.text,
metadata=combined_metadata,
metadata=self._get_node_metadata_for_insertion(node=node),
id=node_id,
)
docs.append(doc)
ids.append(node_id)
return (docs, ids)

@override
def add_nodes(
self,
nodes: Iterable[Node],
**kwargs: Any,
) -> Iterable[str]:
"""Add nodes to the graph store.
Args:
nodes: the nodes to add.
**kwargs: Additional keyword arguments.
"""
(docs, ids) = self._get_docs_for_insertion(nodes=nodes)
return self.vector_store.add_documents(docs, ids=ids)

@override
async def aadd_nodes(
self,
nodes: Iterable[Node],
**kwargs: Any,
) -> AsyncIterable[str]:
"""Add nodes to the graph store.
Args:
nodes: the nodes to add.
**kwargs: Additional keyword arguments.
"""
(docs, ids) = self._get_docs_for_insertion(nodes=nodes)
for inserted_id in await self.vector_store.aadd_documents(docs, ids=ids):
yield inserted_id

@classmethod
@override
def from_texts(
Expand Down Expand Up @@ -597,7 +619,6 @@ async def aget_by_document_id(self, document_id: str) -> Document | None:
Returns:
The the document if it exists. Otherwise None.
"""
await self.astra_env.aensure_db_setup()
doc = await self.vector_store.aget_by_document_id(document_id=document_id)
return self._restore_links(doc) if doc is not None else None

Expand Down

0 comments on commit 8097a09

Please sign in to comment.