diff --git a/examples/customize/build_graph/components/schema_builders/schema_from_existing_graph.py b/examples/customize/build_graph/components/schema_builders/schema_from_existing_graph.py new file mode 100644 index 000000000..edf010e90 --- /dev/null +++ b/examples/customize/build_graph/components/schema_builders/schema_from_existing_graph.py @@ -0,0 +1,35 @@ +"""This example demonstrates how to use the SchemaFromExistingGraphExtractor component +to automatically extract a schema from an existing Neo4j database. +""" + +import asyncio + +import neo4j + +from neo4j_graphrag.experimental.components.schema import ( + SchemaFromExistingGraphExtractor, + GraphSchema, +) + + +URI = "neo4j+s://demo.neo4jlabs.com" +AUTH = ("recommendations", "recommendations") +DATABASE = "recommendations" +INDEX = "moviePlotsEmbedding" + + +async def main() -> None: + """Run the example.""" + + with neo4j.GraphDatabase.driver( + URI, + auth=AUTH, + ) as driver: + extractor = SchemaFromExistingGraphExtractor(driver) + schema: GraphSchema = await extractor.run() + # schema.store_as_json("my_schema.json") + print(schema) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 8f686c298..cd11046d9 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -15,6 +15,8 @@ from __future__ import annotations import json + +import neo4j import logging import warnings from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence @@ -43,6 +45,7 @@ from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate from neo4j_graphrag.llm import LLMInterface from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat +from neo4j_graphrag.schema import get_structured_schema class PropertyType(BaseModel): @@ -270,7 +273,12 @@ def from_file( raise SchemaValidationError(str(e)) from e -class SchemaBuilder(Component): +class BaseSchemaBuilder(Component): + async def run(self, *args: Any, **kwargs: Any) -> GraphSchema: + raise NotImplementedError() + + +class SchemaBuilder(BaseSchemaBuilder): """ A builder class for constructing GraphSchema objects from given entities, relations, and their interrelationships defined in a potential schema. @@ -379,7 +387,7 @@ async def run( return self.create_schema_model(node_types, relationship_types, patterns) -class SchemaFromTextExtractor(Component): +class SchemaFromTextExtractor(BaseSchemaBuilder): """ A component for constructing GraphSchema objects from the output of an LLM after automatic schema extraction from text. @@ -462,3 +470,146 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema "patterns": extracted_patterns, } ) + + +class SchemaFromExistingGraphExtractor(BaseSchemaBuilder): + """A class to build a GraphSchema object from an existing graph. + + Uses the get_structured_schema function to extract existing node labels, + relationship types, properties and existence constraints. + + By default, the built schema does not allow any additional item (property, + node label, relationship type or pattern). + + Args: + driver (neo4j.Driver): connection to the neo4j database. + additional_properties (bool, default False): see GraphSchema + additional_node_types (bool, default False): see GraphSchema + additional_relationship_types (bool, default False): see GraphSchema: + additional_patterns (bool, default False): see GraphSchema: + neo4j_database (Optional | str): name of the neo4j database to use + """ + + def __init__( + self, + driver: neo4j.Driver, + additional_properties: bool = False, + additional_node_types: bool = False, + additional_relationship_types: bool = False, + additional_patterns: bool = False, + neo4j_database: Optional[str] = None, + ) -> None: + self.driver = driver + self.database = neo4j_database + + self.additional_properties = additional_properties + self.additional_node_types = additional_node_types + self.additional_relationship_types = additional_relationship_types + self.additional_patterns = additional_patterns + + @staticmethod + def _extract_required_properties( + structured_schema: dict[str, Any], + ) -> list[tuple[str, str]]: + """Extract a list of (node label (or rel type), property name) for which + an "EXISTENCE" or "KEY" constraint is defined in the DB. + + Args: + + structured_schema (dict[str, Any]): the result of the `get_structured_schema()` function. + + Returns: + + list of tuples of (node label (or rel type), property name) + + """ + schema_metadata = structured_schema.get("metadata", {}) + existence_constraint = [] # list of (node label, property name) + for constraint in schema_metadata.get("constraints", []): + if constraint["type"] in ( + "NODE_PROPERTY_EXISTENCE", + "NODE_KEY", + "RELATIONSHIP_PROPERTY_EXISTENCE", + "RELATIONSHIP_KEY", + ): + properties = constraint["properties"] + labels = constraint["labelsOrTypes"] + # note: existence constraint only apply to a single property + # and a single label + prop = properties[0] + lab = labels[0] + existence_constraint.append((lab, prop)) + return existence_constraint + + async def run(self) -> GraphSchema: + structured_schema = get_structured_schema(self.driver, database=self.database) + existence_constraint = self._extract_required_properties(structured_schema) + + node_labels = set(structured_schema["node_props"].keys()) + node_types = [ + { + "label": key, + "properties": [ + { + "name": p["property"], + "type": p["type"], + "required": (key, p["property"]) in existence_constraint, + } + for p in properties + ], + "additional_properties": self.additional_properties, + } + for key, properties in structured_schema["node_props"].items() + ] + rel_labels = set(structured_schema["rel_props"].keys()) + relationship_types = [ + { + "label": key, + "properties": [ + { + "name": p["property"], + "type": p["type"], + "required": (key, p["property"]) in existence_constraint, + } + for p in properties + ], + } + for key, properties in structured_schema["rel_props"].items() + ] + patterns = [ + (s["start"], s["type"], s["end"]) + for s in structured_schema["relationships"] + ] + # deal with nodes and relationships without properties + for source, rel, target in patterns: + if source not in node_labels: + node_labels.add(source) + node_types.append( + { + "label": source, + } + ) + if target not in node_labels: + node_labels.add(target) + node_types.append( + { + "label": target, + } + ) + if rel not in rel_labels: + rel_labels.add(rel) + relationship_types.append( + { + "label": rel, + } + ) + return GraphSchema.model_validate( + { + "node_types": node_types, + "relationship_types": relationship_types, + "patterns": patterns, + "additional_node_types": self.additional_node_types, + "additional_relationship_types": self.additional_relationship_types, + "additional_patterns": self.additional_patterns, + } + ) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index e1d3af5a0..651fbe303 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -48,6 +48,7 @@ NodeType, RelationshipType, SchemaFromTextExtractor, + BaseSchemaBuilder, ) from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( @@ -175,7 +176,7 @@ def _get_run_params_for_splitter(self) -> dict[str, Any]: def _get_chunk_embedder(self) -> TextChunkEmbedder: return TextChunkEmbedder(embedder=self.get_default_embedder()) - def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]: + def _get_schema(self) -> BaseSchemaBuilder: """ Get the appropriate schema component based on configuration. Return SchemaFromTextExtractor for automatic extraction or SchemaBuilder for manual schema.