diff --git a/graphrag_neo4j/custom_qdrant_neo4j_retriever.py b/graphrag_neo4j/custom_qdrant_neo4j_retriever.py new file mode 100644 index 0000000..d49f917 --- /dev/null +++ b/graphrag_neo4j/custom_qdrant_neo4j_retriever.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import logging +from typing import Any, Optional + +import neo4j +from neo4j_graphrag.exceptions import EmbeddingRequiredError, SearchValidationError +from neo4j_graphrag.retrievers import QdrantNeo4jRetriever +from neo4j_graphrag.retrievers.external.utils import get_match_query +from neo4j_graphrag.types import RawSearchResult, VectorSearchModel +from pydantic import ValidationError + +logger = logging.getLogger(__name__) + + +class CustomQdrantNeo4jRetriever(QdrantNeo4jRetriever): + """ + Custom retriever inheriting from QdrantNeo4jRetriever. + Handles cases where the external ID in Qdrant payload might be a list. + + Inherits initialization and other methods from QdrantNeo4jRetriever. + Only overrides the get_search_results method for custom logic. + """ + + def get_search_results( + self, + query_vector: Optional[list[float]] = None, + query_text: Optional[str] = None, + top_k: int = 5, + **kwargs: Any, + ) -> RawSearchResult: + try: + validated_data = VectorSearchModel( + query_vector=query_vector, + query_text=query_text, + top_k=top_k, + ) + except ValidationError as e: + raise SearchValidationError(e.errors()) from e + + if validated_data.query_text: + if self.embedder: + query_vector = self.embedder.embed_query(validated_data.query_text) + logger.debug("Locally generated query vector: %s", query_vector) + else: + logger.error("No embedder provided for query_text.") + raise EmbeddingRequiredError("No embedder provided for query_text.") + + points = self.client.query_points( + collection_name=self.collection_name, + query=query_vector, + limit=top_k, + with_payload=[self.id_property_external], + **kwargs, + ).points + + # Custom logic + result_tuples = [] + for point in points: + assert point.payload is not None + target_ids = point.payload.get(self.id_property_external, [point.id]) + result_tuples = [[target_id, point.score] for target_id in target_ids] + + search_query = get_match_query( + return_properties=self.return_properties, + retrieval_query=self.retrieval_query, + ) + + parameters = { + "match_params": result_tuples, + "id_property": self.id_property_neo4j, + } + + logger.debug("Qdrant Store Cypher parameters: %s", parameters) + logger.debug("Qdrant Store Cypher query: %s", search_query) + + records, _, _ = self.driver.execute_query( + search_query, + parameters, + database_=self.neo4j_database, + routing_=neo4j.RoutingControl.READ, + ) + + return RawSearchResult(records=records) diff --git a/graphrag_neo4j/graphrag.py b/graphrag_neo4j/graphrag.py index a2702b7..23882d0 100644 --- a/graphrag_neo4j/graphrag.py +++ b/graphrag_neo4j/graphrag.py @@ -1,12 +1,12 @@ -from neo4j import GraphDatabase -from qdrant_client import QdrantClient, models +import os +import uuid + +from custom_qdrant_neo4j_retriever import CustomQdrantNeo4jRetriever from dotenv import load_dotenv -from pydantic import BaseModel +from neo4j import GraphDatabase from openai import OpenAI -from collections import defaultdict -from neo4j_graphrag.retrievers import QdrantNeo4jRetriever -import uuid -import os +from pydantic import BaseModel +from qdrant_client import QdrantClient, models # Load environment variables load_dotenv() @@ -71,62 +71,70 @@ def openai_llm_parser(prompt): return GraphComponents.model_validate_json(completion.choices[0].message.content) -def extract_graph_components(raw_data): - prompt = f"Extract nodes and relationships from the following text:\n{raw_data}" - - parsed_response = openai_llm_parser(prompt) # Assuming this returns a list of dictionaries - parsed_response = parsed_response.graph # Assuming the 'graph' structure is a key in the parsed response - - nodes = {} - relationships = [] - - for entry in parsed_response: - node = entry.node - target_node = entry.target_node # Get target node if available - relationship = entry.relationship # Get relationship if available - - # Add nodes to the dictionary with a unique ID - if node not in nodes: - nodes[node] = str(uuid.uuid4()) - - if target_node and target_node not in nodes: - nodes[target_node] = str(uuid.uuid4()) - - # Add relationship to the relationships list with node IDs - if target_node and relationship: - relationships.append({ - "source": nodes[node], - "target": nodes[target_node], - "type": relationship - }) +def extract_graph_components(chunks): + nodes_list = [] + relationships_list = [] + for chunk in chunks: + prompt = f"Extract nodes and relationships from the following text:\n{chunk}" + + parsed_response = openai_llm_parser(prompt) # Assuming this returns a list of dictionaries + parsed_response = parsed_response.graph # Assuming the 'graph' structure is a key in the parsed response + + nodes = {} + relationships = [] + + for entry in parsed_response: + node = entry.node + target_node = entry.target_node # Get target node if available + relationship = entry.relationship # Get relationship if available + + # Add nodes to the dictionary with a unique ID + if node not in nodes: + nodes[node] = str(uuid.uuid4()) + + if target_node and target_node not in nodes: + nodes[target_node] = str(uuid.uuid4()) + + # Add relationship to the relationships list with node IDs + if target_node and relationship: + relationships.append({ + "source": nodes[node], + "target": nodes[target_node], + "type": relationship + }) + + nodes_list.append(nodes) + relationships_list.append(relationships) - return nodes, relationships + return nodes_list, relationships_list -def ingest_to_neo4j(nodes, relationships): +def ingest_to_neo4j(nodes_list, relationships_list): """ Ingest nodes and relationships into Neo4j. """ with neo4j_driver.session() as session: # Create nodes in Neo4j - for name, node_id in nodes.items(): - session.run( - "CREATE (n:Entity {id: $id, name: $name})", - id=node_id, - name=name - ) + for nodes in nodes_list: + for name, node_id in nodes.items(): + session.run( + "CREATE (n:Entity {id: $id, name: $name})", + id=node_id, + name=name + ) # Create relationships in Neo4j - for relationship in relationships: - session.run( - "MATCH (a:Entity {id: $source_id}), (b:Entity {id: $target_id}) " - "CREATE (a)-[:RELATIONSHIP {type: $type}]->(b)", - source_id=relationship["source"], - target_id=relationship["target"], - type=relationship["type"] + for relationships in relationships_list: + for relationship in relationships: + session.run( + "MATCH (a:Entity {id: $source_id}), (b:Entity {id: $target_id}) " + "CREATE (a)-[:RELATIONSHIP {type: $type}]->(b)", + source_id=relationship["source"], + target_id=relationship["target"], + type=relationship["type"] ) - return nodes + return nodes_list def create_collection(client, collection_name, vector_dimension): # Try to fetch the collection status @@ -155,8 +163,8 @@ def openai_embeddings(text): return response.data[0].embedding -def ingest_to_qdrant(collection_name, raw_data, node_id_mapping): - embeddings = [openai_embeddings(paragraph) for paragraph in raw_data.split("\n")] +def ingest_to_qdrant(collection_name, chunks, node_id_mapping_list): + embeddings = [openai_embeddings(chunk) for chunk in chunks] qdrant_client.upsert( collection_name=collection_name, @@ -164,14 +172,14 @@ def ingest_to_qdrant(collection_name, raw_data, node_id_mapping): { "id": str(uuid.uuid4()), "vector": embedding, - "payload": {"id": node_id} + "payload": {"id": list(node_id_mapping.values())} } - for node_id, embedding in zip(node_id_mapping.values(), embeddings) + for node_id_mapping, embedding in zip(node_id_mapping_list, embeddings) ] ) def retriever_search(neo4j_driver, qdrant_client, collection_name, query): - retriever = QdrantNeo4jRetriever( + retriever = CustomQdrantNeo4jRetriever( driver=neo4j_driver, client=qdrant_client, collection_name=collection_name, @@ -302,16 +310,18 @@ def graphRAG_run(graph_context, user_query): Carol's team grew significantly after moving to New York. Seattle remains the technology hub for TechCorp.""" - nodes, relationships = extract_graph_components(raw_data) - print("Nodes:", nodes) - print("Relationships:", relationships) + chunks = raw_data.split("\n") + + nodes_list, relationships_list = extract_graph_components(chunks) + print("Nodes:", nodes_list) + print("Relationships:", relationships_list) print("Ingesting to Neo4j...") - node_id_mapping = ingest_to_neo4j(nodes, relationships) + node_id_mapping_list = ingest_to_neo4j(nodes_list, relationships_list) print("Neo4j ingestion complete") print("Ingesting to Qdrant...") - ingest_to_qdrant(collection_name, raw_data, node_id_mapping) + ingest_to_qdrant(collection_name, chunks, node_id_mapping_list) print("Qdrant ingestion complete") query = "How is Bob connected to New York?"