Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions graphrag_neo4j/custom_qdrant_neo4j_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

import logging
from typing import Any, Optional

import neo4j
from neo4j_graphrag.exceptions import EmbeddingRequiredError, SearchValidationError
from neo4j_graphrag.retrievers import QdrantNeo4jRetriever
from neo4j_graphrag.retrievers.external.utils import get_match_query
from neo4j_graphrag.types import RawSearchResult, VectorSearchModel
from pydantic import ValidationError

logger = logging.getLogger(__name__)


class CustomQdrantNeo4jRetriever(QdrantNeo4jRetriever):
"""
Custom retriever inheriting from QdrantNeo4jRetriever.
Handles cases where the external ID in Qdrant payload might be a list.

Inherits initialization and other methods from QdrantNeo4jRetriever.
Only overrides the get_search_results method for custom logic.
"""

def get_search_results(
self,
query_vector: Optional[list[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
**kwargs: Any,
) -> RawSearchResult:
try:
validated_data = VectorSearchModel(
query_vector=query_vector,
query_text=query_text,
top_k=top_k,
)
except ValidationError as e:
raise SearchValidationError(e.errors()) from e

if validated_data.query_text:
if self.embedder:
query_vector = self.embedder.embed_query(validated_data.query_text)
logger.debug("Locally generated query vector: %s", query_vector)
else:
logger.error("No embedder provided for query_text.")
raise EmbeddingRequiredError("No embedder provided for query_text.")

points = self.client.query_points(
collection_name=self.collection_name,
query=query_vector,
limit=top_k,
with_payload=[self.id_property_external],
**kwargs,
).points

# Custom logic
result_tuples = []
for point in points:
assert point.payload is not None
target_ids = point.payload.get(self.id_property_external, [point.id])
result_tuples = [[target_id, point.score] for target_id in target_ids]

search_query = get_match_query(
return_properties=self.return_properties,
retrieval_query=self.retrieval_query,
)

parameters = {
"match_params": result_tuples,
"id_property": self.id_property_neo4j,
}

logger.debug("Qdrant Store Cypher parameters: %s", parameters)
logger.debug("Qdrant Store Cypher query: %s", search_query)

records, _, _ = self.driver.execute_query(
search_query,
parameters,
database_=self.neo4j_database,
routing_=neo4j.RoutingControl.READ,
)

return RawSearchResult(records=records)
132 changes: 71 additions & 61 deletions graphrag_neo4j/graphrag.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from neo4j import GraphDatabase
from qdrant_client import QdrantClient, models
import os
import uuid

from custom_qdrant_neo4j_retriever import CustomQdrantNeo4jRetriever
from dotenv import load_dotenv
from pydantic import BaseModel
from neo4j import GraphDatabase
from openai import OpenAI
from collections import defaultdict
from neo4j_graphrag.retrievers import QdrantNeo4jRetriever
import uuid
import os
from pydantic import BaseModel
from qdrant_client import QdrantClient, models

# Load environment variables
load_dotenv()
Expand Down Expand Up @@ -71,62 +71,70 @@ def openai_llm_parser(prompt):

return GraphComponents.model_validate_json(completion.choices[0].message.content)

def extract_graph_components(raw_data):
prompt = f"Extract nodes and relationships from the following text:\n{raw_data}"

parsed_response = openai_llm_parser(prompt) # Assuming this returns a list of dictionaries
parsed_response = parsed_response.graph # Assuming the 'graph' structure is a key in the parsed response

nodes = {}
relationships = []

for entry in parsed_response:
node = entry.node
target_node = entry.target_node # Get target node if available
relationship = entry.relationship # Get relationship if available

# Add nodes to the dictionary with a unique ID
if node not in nodes:
nodes[node] = str(uuid.uuid4())

if target_node and target_node not in nodes:
nodes[target_node] = str(uuid.uuid4())

# Add relationship to the relationships list with node IDs
if target_node and relationship:
relationships.append({
"source": nodes[node],
"target": nodes[target_node],
"type": relationship
})
def extract_graph_components(chunks):
nodes_list = []
relationships_list = []
for chunk in chunks:
prompt = f"Extract nodes and relationships from the following text:\n{chunk}"

parsed_response = openai_llm_parser(prompt) # Assuming this returns a list of dictionaries
parsed_response = parsed_response.graph # Assuming the 'graph' structure is a key in the parsed response

nodes = {}
relationships = []

for entry in parsed_response:
node = entry.node
target_node = entry.target_node # Get target node if available
relationship = entry.relationship # Get relationship if available

# Add nodes to the dictionary with a unique ID
if node not in nodes:
nodes[node] = str(uuid.uuid4())

if target_node and target_node not in nodes:
nodes[target_node] = str(uuid.uuid4())

# Add relationship to the relationships list with node IDs
if target_node and relationship:
relationships.append({
"source": nodes[node],
"target": nodes[target_node],
"type": relationship
})

nodes_list.append(nodes)
relationships_list.append(relationships)

return nodes, relationships
return nodes_list, relationships_list

def ingest_to_neo4j(nodes, relationships):
def ingest_to_neo4j(nodes_list, relationships_list):
"""
Ingest nodes and relationships into Neo4j.
"""

with neo4j_driver.session() as session:
# Create nodes in Neo4j
for name, node_id in nodes.items():
session.run(
"CREATE (n:Entity {id: $id, name: $name})",
id=node_id,
name=name
)
for nodes in nodes_list:
for name, node_id in nodes.items():
session.run(
"CREATE (n:Entity {id: $id, name: $name})",
id=node_id,
name=name
)

# Create relationships in Neo4j
for relationship in relationships:
session.run(
"MATCH (a:Entity {id: $source_id}), (b:Entity {id: $target_id}) "
"CREATE (a)-[:RELATIONSHIP {type: $type}]->(b)",
source_id=relationship["source"],
target_id=relationship["target"],
type=relationship["type"]
for relationships in relationships_list:
for relationship in relationships:
session.run(
"MATCH (a:Entity {id: $source_id}), (b:Entity {id: $target_id}) "
"CREATE (a)-[:RELATIONSHIP {type: $type}]->(b)",
source_id=relationship["source"],
target_id=relationship["target"],
type=relationship["type"]
)

return nodes
return nodes_list

def create_collection(client, collection_name, vector_dimension):
# Try to fetch the collection status
Expand Down Expand Up @@ -155,23 +163,23 @@ def openai_embeddings(text):

return response.data[0].embedding

def ingest_to_qdrant(collection_name, raw_data, node_id_mapping):
embeddings = [openai_embeddings(paragraph) for paragraph in raw_data.split("\n")]
def ingest_to_qdrant(collection_name, chunks, node_id_mapping_list):
embeddings = [openai_embeddings(chunk) for chunk in chunks]

qdrant_client.upsert(
collection_name=collection_name,
points=[
{
"id": str(uuid.uuid4()),
"vector": embedding,
"payload": {"id": node_id}
"payload": {"id": list(node_id_mapping.values())}
}
for node_id, embedding in zip(node_id_mapping.values(), embeddings)
for node_id_mapping, embedding in zip(node_id_mapping_list, embeddings)
]
)

def retriever_search(neo4j_driver, qdrant_client, collection_name, query):
retriever = QdrantNeo4jRetriever(
retriever = CustomQdrantNeo4jRetriever(
driver=neo4j_driver,
client=qdrant_client,
collection_name=collection_name,
Expand Down Expand Up @@ -302,16 +310,18 @@ def graphRAG_run(graph_context, user_query):
Carol's team grew significantly after moving to New York.
Seattle remains the technology hub for TechCorp."""

nodes, relationships = extract_graph_components(raw_data)
print("Nodes:", nodes)
print("Relationships:", relationships)
chunks = raw_data.split("\n")

nodes_list, relationships_list = extract_graph_components(chunks)
print("Nodes:", nodes_list)
print("Relationships:", relationships_list)

print("Ingesting to Neo4j...")
node_id_mapping = ingest_to_neo4j(nodes, relationships)
node_id_mapping_list = ingest_to_neo4j(nodes_list, relationships_list)
print("Neo4j ingestion complete")

print("Ingesting to Qdrant...")
ingest_to_qdrant(collection_name, raw_data, node_id_mapping)
ingest_to_qdrant(collection_name, chunks, node_id_mapping_list)
print("Qdrant ingestion complete")

query = "How is Bob connected to New York?"
Expand Down