diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 25ae85e75..d7c45694d 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -76,7 +76,6 @@ class PropertyType(BaseModel): ] description: str = "" required: bool = False - model_config = ConfigDict( frozen=True, ) @@ -161,6 +160,22 @@ def property_type_from_name(self, name: str) -> Optional[PropertyType]: return None +class ConstraintType(BaseModel): + """ + Represents a constraint on a node in the graph. + """ + + type: Literal[ + "UNIQUENESS" + ] # TODO: add other constraint types ["propertyExistence", "propertyType", "key"] + node_type: str + property_name: str + + model_config = ConfigDict( + frozen=True, + ) + + class GraphSchema(DataModel): """This model represents the expected node and relationship types in the graph. @@ -177,6 +192,7 @@ class GraphSchema(DataModel): node_types: Tuple[NodeType, ...] relationship_types: Tuple[RelationshipType, ...] = tuple() patterns: Tuple[Tuple[str, str, str], ...] = tuple() + constraints: Tuple[ConstraintType, ...] = tuple() additional_node_types: bool = Field( default_factory=default_additional_item("node_types") @@ -239,6 +255,34 @@ def validate_additional_parameters(self) -> Self: ) return self + @model_validator(mode="after") + def validate_constraints_against_node_types(self) -> Self: + if not self.constraints: + return self + for constraint in self.constraints: + # Only validate UNIQUENESS constraints (other types will be added) + if constraint.type != "UNIQUENESS": + continue + + if not constraint.property_name: + raise SchemaValidationError( + f"Constraint has no property name: {constraint}. Property name is required." + ) + if constraint.node_type not in self._node_type_index: + raise SchemaValidationError( + f"Constraint references undefined node type: {constraint.node_type}" + ) + # Check if property_name exists on the node type + node_type = self._node_type_index[constraint.node_type] + valid_property_names = {p.name for p in node_type.properties} + if constraint.property_name not in valid_property_names: + raise SchemaValidationError( + f"Constraint references undefined property '{constraint.property_name}' " + f"on node type '{constraint.node_type}'. " + f"Valid properties: {valid_property_names}" + ) + return self + def node_type_from_label(self, label: str) -> Optional[NodeType]: return self._node_type_index.get(label) @@ -382,6 +426,7 @@ def create_schema_model( node_types: Sequence[NodeType], relationship_types: Optional[Sequence[RelationshipType]] = None, patterns: Optional[Sequence[Tuple[str, str, str]]] = None, + constraints: Optional[Sequence[ConstraintType]] = None, **kwargs: Any, ) -> GraphSchema: """ @@ -403,6 +448,7 @@ def create_schema_model( node_types=node_types, relationship_types=relationship_types or (), patterns=patterns or (), + constraints=constraints or (), **kwargs, ) ) @@ -415,6 +461,7 @@ async def run( node_types: Sequence[NodeType], relationship_types: Optional[Sequence[RelationshipType]] = None, patterns: Optional[Sequence[Tuple[str, str, str]]] = None, + constraints: Optional[Sequence[ConstraintType]] = None, **kwargs: Any, ) -> GraphSchema: """ @@ -432,6 +479,7 @@ async def run( node_types, relationship_types, patterns, + constraints, **kwargs, ) @@ -555,6 +603,69 @@ def _filter_relationships_without_labels( relationship_types, "relationship type" ) + def _filter_invalid_constraints( + self, constraints: List[Dict[str, Any]], node_types: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Filter out constraints that reference undefined node types, have no property name, are not UNIQUENESS type + or reference a property that doesn't exist on the node type.""" + if not constraints: + return [] + + if not node_types: + logging.info( + "Filtering out all constraints because no node types are defined. " + "Constraints reference node types that must be defined." + ) + return [] + + # Build a mapping of node_type label -> set of property names + node_type_properties: Dict[str, set[str]] = {} + for node_type_dict in node_types: + label = node_type_dict.get("label") + if label: + properties = node_type_dict.get("properties", []) + property_names = {p.get("name") for p in properties if p.get("name")} + node_type_properties[label] = property_names + + valid_node_labels = set(node_type_properties.keys()) + + filtered_constraints = [] + for constraint in constraints: + # Only process UNIQUENESS constraints (other types will be added) + if constraint.get("type") != "UNIQUENESS": + logging.info( + f"Filtering out constraint: {constraint}. " + f"Only UNIQUENESS constraints are supported." + ) + continue + + # check if the property_name is provided + if not constraint.get("property_name"): + logging.info( + f"Filtering out constraint: {constraint}. " + f"Property name is not provided." + ) + continue + # check if the node_type is valid + node_type = constraint.get("node_type") + if node_type not in valid_node_labels: + logging.info( + f"Filtering out constraint: {constraint}. " + f"Node type '{node_type}' is not valid. Valid node types: {valid_node_labels}" + ) + continue + # check if the property_name exists on the node type + property_name = constraint.get("property_name") + if property_name not in node_type_properties.get(node_type, set()): + logging.info( + f"Filtering out constraint: {constraint}. " + f"Property '{property_name}' does not exist on node type '{node_type}'. " + f"Valid properties: {node_type_properties.get(node_type, set())}" + ) + continue + filtered_constraints.append(constraint) + return filtered_constraints + def _clean_json_content(self, content: str) -> str: content = content.strip() @@ -624,6 +735,9 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema extracted_patterns: Optional[List[Tuple[str, str, str]]] = extracted_schema.get( "patterns" ) + extracted_constraints: Optional[List[Dict[str, Any]]] = extracted_schema.get( + "constraints" + ) # Filter out nodes and relationships without labels extracted_node_types = self._filter_nodes_without_labels(extracted_node_types) @@ -638,11 +752,18 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema extracted_patterns, extracted_node_types, extracted_relationship_types ) + # Filter out invalid constraints + if extracted_constraints: + extracted_constraints = self._filter_invalid_constraints( + extracted_constraints, extracted_node_types + ) + return GraphSchema.model_validate( { "node_types": extracted_node_types, "relationship_types": extracted_relationship_types, "patterns": extracted_patterns, + "constraints": extracted_constraints or [], } ) diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index d9045a944..6fedb511b 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -218,6 +218,12 @@ class SchemaExtractionTemplate(PromptTemplate): 5. When defining patterns, ensure that every node label and relationship label mentioned exists in your lists of node types and relationship types. 6. Do not create node types that aren't clearly mentioned in the text. 7. Keep your schema minimal and focused on clearly identifiable patterns in the text. +8. UNIQUENESS CONSTRAINTS: +8.1 UNIQUENESS is optional; each node_type may or may not have exactly one uniqueness constraint. +8.2 Only use properties that seem to not have too many missing values in the sample. +8.3 Constraints reference node_types by label and specify which property is unique. +8.4 If a property appears in a uniqueness constraint it MUST also appear in the corresponding node_type as a property. + Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST, LOCAL_DATETIME, LOCAL_TIME, POINT, STRING, ZONED_DATETIME, ZONED_TIME. @@ -233,18 +239,26 @@ class SchemaExtractionTemplate(PromptTemplate): "type": "STRING" }} ] - }}, + }} ... ], "relationship_types": [ {{ "label": "WORKS_FOR" - }}, + }} ... ], "patterns": [ ["Person", "WORKS_FOR", "Company"], ... + ], + "constraints": [ + {{ + "type": "UNIQUENESS", + "node_type": "Person", + "property_name": "name" + }} + ... ] }} diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 9ff440557..98bb3fe58 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -27,6 +27,7 @@ NodeType, PropertyType, RelationshipType, + ConstraintType, SchemaFromTextExtractor, GraphSchema, SchemaFromExistingGraphExtractor, @@ -119,6 +120,30 @@ def test_relationship_type_additional_properties_default() -> None: assert relationship_type.additional_properties is True +def test_constraint_type_initialization() -> None: + constraint = ConstraintType( + type="UNIQUENESS", node_type="Person", property_name="name" + ) + assert constraint.type == "UNIQUENESS" + assert constraint.node_type == "Person" + assert constraint.property_name == "name" + + +def test_constraint_type_is_frozen() -> None: + constraint = ConstraintType( + type="UNIQUENESS", node_type="Person", property_name="name" + ) + + with pytest.raises(ValidationError): + constraint.type = "UNIQUENESS" + + with pytest.raises(ValidationError): + constraint.node_type = "Organization" + + with pytest.raises(ValidationError): + constraint.property_name = "id" + + def test_schema_additional_node_types_default() -> None: schema_dict: dict[str, Any] = { "node_types": [], @@ -200,6 +225,105 @@ def test_schema_additional_parameter_validation() -> None: GraphSchema.model_validate(schema_dict) +def test_schema_constraint_validation_property_not_in_node_type() -> None: + schema_dict: dict[str, Any] = { + "node_types": [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + } + ], + "constraints": [ + {"type": "UNIQUENESS", "node_type": "Person", "property_name": "email"} + ], + } + + with pytest.raises(SchemaValidationError) as exc_info: + GraphSchema.model_validate(schema_dict) + + assert "Constraint references undefined property" in str(exc_info.value) + assert "on node type 'Person'" in str(exc_info.value) + + +def test_schema_constraint_with_additional_properties_with_allows_unknown_property() -> ( + None +): + # if additional_properties is True, we can define constraints that are not in the node_type + schema_dict: dict[str, Any] = { + "node_types": [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + "additional_properties": True, + } + ], + "constraints": [ + {"type": "UNIQUENESS", "node_type": "Person", "property_name": "email"} + ], + } + + # Should raise - email is not allowed because the property is not defined in the node + with pytest.raises(SchemaValidationError) as exc_info: + GraphSchema.model_validate(schema_dict) + + assert "Constraint references undefined property 'email'" in str(exc_info.value) + + +def test_schema_with_valid_constraints() -> None: + schema_dict: dict[str, Any] = { + "node_types": [ + {"label": "Person", "properties": [{"name": "name", "type": "STRING"}]} + ], + "constraints": [ + {"type": "UNIQUENESS", "node_type": "Person", "property_name": "name"} + ], + } + schema = GraphSchema.model_validate(schema_dict) + + assert len(schema.constraints) == 1 + assert schema.constraints[0].type == "UNIQUENESS" + assert schema.constraints[0].node_type == "Person" + assert schema.constraints[0].property_name == "name" + + +def test_schema_constraint_validation_invalid_node_type() -> None: + schema_dict: dict[str, Any] = { + "node_types": [ + {"label": "Person", "properties": [{"name": "name", "type": "STRING"}]} + ], + "constraints": [ + { + "type": "UNIQUENESS", + "node_type": "NonExistentNode", + "property_name": "id", + } + ], + } + + with pytest.raises(SchemaValidationError) as exc_info: + GraphSchema.model_validate(schema_dict) + + assert "Constraint references undefined node type: NonExistentNode" in str( + exc_info.value + ) + + +def test_schema_constraint_validation_missing_property_name() -> None: + schema_dict: dict[str, Any] = { + "node_types": [ + {"label": "Person", "properties": [{"name": "name", "type": "STRING"}]} + ], + "constraints": [ + {"type": "UNIQUENESS", "node_type": "Person", "property_name": ""} + ], + } + + with pytest.raises(SchemaValidationError) as exc_info: + GraphSchema.model_validate(schema_dict) + + assert "Constraint has no property name" in str(exc_info.value) + + @pytest.fixture def valid_node_types() -> tuple[NodeType, ...]: return ( @@ -258,6 +382,13 @@ def patterns_with_invalid_entity() -> tuple[tuple[str, str, str], ...]: ) +@pytest.fixture +def valid_constraints() -> tuple[ConstraintType, ...]: + return ( + ConstraintType(type="UNIQUENESS", node_type="PERSON", property_name="name"), + ) + + @pytest.fixture def patterns_with_invalid_relation() -> tuple[tuple[str, str, str], ...]: return (("PERSON", "NON_EXISTENT_RELATION", "ORGANIZATION"),) @@ -298,6 +429,24 @@ def test_create_schema_model_valid_data( assert schema.additional_patterns is False +def test_create_schema_model_with_constraints( + schema_builder: SchemaBuilder, + valid_node_types: Tuple[NodeType, ...], + valid_constraints: Tuple[ConstraintType, ...], +) -> None: + schema = schema_builder.create_schema_model( + list(valid_node_types), + constraints=list(valid_constraints), + ) + + assert schema.node_types == valid_node_types + assert schema.constraints == valid_constraints + assert len(schema.constraints) == 1 + assert schema.constraints[0].type == "UNIQUENESS" + assert schema.constraints[0].node_type == "PERSON" + assert schema.constraints[0].property_name == "name" + + @pytest.mark.asyncio async def test_run_method( schema_builder: SchemaBuilder, @@ -326,6 +475,25 @@ async def test_run_method( assert schema.additional_patterns is False +@pytest.mark.asyncio +async def test_run_method_with_constraints( + schema_builder: SchemaBuilder, + valid_node_types: Tuple[NodeType, ...], + valid_constraints: Tuple[ConstraintType, ...], +) -> None: + schema = await schema_builder.run( + list(valid_node_types), + constraints=list(valid_constraints), + ) + + assert schema.node_types == valid_node_types + assert schema.constraints == valid_constraints + assert len(schema.constraints) == 1 + assert schema.constraints[0].type == "UNIQUENESS" + assert schema.constraints[0].node_type == "PERSON" + assert schema.constraints[0].property_name == "name" + + def test_create_schema_model_invalid_entity( schema_builder: SchemaBuilder, valid_node_types: Tuple[NodeType, ...], @@ -452,6 +620,116 @@ def valid_schema_json() -> str: """ +@pytest.fixture +def schema_json_with_valid_constraints() -> str: + return """ + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"}, + {"name": "email", "type": "STRING"} + ] + }, + { + "label": "Organization", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relationship_types": [ + { + "label": "WORKS_FOR", + "properties": [ + {"name": "since", "type": "DATE"} + ] + } + ], + "patterns": [ + ["Person", "WORKS_FOR", "Organization"] + ], + "constraints": [ + {"type": "UNIQUENESS", "node_type": "Person", "property_name": "name"} + ] + } + """ + + +@pytest.fixture +def schema_json_with_invalid_constraints() -> str: + return """ + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Organization", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relationship_types": [ + { + "label": "WORKS_FOR", + "properties": [ + {"name": "since", "type": "DATE"} + ] + } + ], + "patterns": [ + ["Person", "WORKS_FOR", "Organization"] + ], + "constraints": [ + {"type": "UNIQUENESS", "node_type": "Person", "property_name": "name"}, + {"type": "UNIQUENESS", "node_type": "Person", "property_name": "email"}, + {"type": "UNIQUENESS", "node_type": "NonExistentNode", "property_name": "id"}, + {"type": "UNIQUENESS", "node_type": "Person", "property_name": ""} + ] + } + """ + + +@pytest.fixture +def schema_json_with_null_constraints() -> str: + return """ + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Organization", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relationship_types": [ + { + "label": "WORKS_FOR", + "properties": [ + {"name": "since", "type": "DATE"} + ] + } + ], + "patterns": [ + ["Person", "WORKS_FOR", "Organization"] + ], + "constraints": null + } + """ + + @pytest.fixture def invalid_schema_json() -> str: return """ @@ -866,6 +1144,28 @@ def schema_json_with_relationships_without_labels() -> str: """ +@pytest.fixture +def schema_json_with_nonexistent_property_constraint() -> str: + return """ + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relationship_types": [], + "patterns": [], + "constraints": [ + {"type": "UNIQUENESS", "node_type": "Person", "property_name": "name"}, + {"type": "UNIQUENESS", "node_type": "Person", "property_name": "nonexistent_property"} + ] + } + """ + + @pytest.mark.asyncio async def test_schema_from_text_filters_invalid_node_patterns( schema_from_text: SchemaFromTextExtractor, @@ -960,6 +1260,103 @@ async def test_schema_from_text_filters_relationships_without_labels( assert ("Person", "MANAGES", "Organization") in schema.patterns +@pytest.mark.asyncio +async def test_schema_from_text_with_valid_constraints( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + schema_json_with_valid_constraints: str, +) -> None: + # configure the mock LLM to return schema with valid constraints + mock_llm.ainvoke.return_value = LLMResponse( + content=schema_json_with_valid_constraints + ) + + # run the schema extraction + schema = await schema_from_text.run(text="Sample text for extraction") + + assert len(schema.constraints) == 1 + assert schema.constraints[0].type == "UNIQUENESS" + assert schema.constraints[0].node_type == "Person" + assert schema.constraints[0].property_name == "name" + + +@pytest.mark.asyncio +async def test_schema_from_text_filters_invalid_constraints( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + schema_json_with_invalid_constraints: str, +) -> None: + # configure the mock LLM to return schema with invalid constraints + mock_llm.ainvoke.return_value = LLMResponse( + content=schema_json_with_invalid_constraints + ) + + # run the schema extraction + schema = await schema_from_text.run(text="Sample text for extraction") + + # verify that invalid constraints were filtered out: + # constraints with NonExistentNode should be removed + # constraint with empty property_name should be removed + # only the valid constraint should remain + assert len(schema.constraints) == 1 + assert schema.constraints[0].node_type == "Person" + assert schema.constraints[0].property_name == "name" + + +@pytest.mark.asyncio +async def test_schema_from_text_filters_constraint_with_nonexistent_property( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + schema_json_with_nonexistent_property_constraint: str, +) -> None: + # configure the mock LLM to return schema with constraint on nonexistent property + mock_llm.ainvoke.return_value = LLMResponse( + content=schema_json_with_nonexistent_property_constraint + ) + + # run the schema extraction + schema = await schema_from_text.run(text="Sample text for extraction") + + # verify that only the valid constraint (with "name" property) remains + # the constraint with "nonexistent_property" should be filtered out + assert len(schema.constraints) == 1 + assert schema.constraints[0].property_name == "name" + + +@pytest.mark.asyncio +async def test_schema_from_text_handles_null_constraints( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + schema_json_with_null_constraints: str, +) -> None: + # configure the mock LLM to return schema with null constraints + mock_llm.ainvoke.return_value = LLMResponse( + content=schema_json_with_null_constraints + ) + + # run the schema extraction - should not crash + schema = await schema_from_text.run(text="Sample text for extraction") + + # verify schema was created with empty constraints + assert len(schema.constraints) == 0 + + +@pytest.mark.asyncio +async def test_schema_from_text_handles_missing_constraints( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + valid_schema_json: str, +) -> None: + # configure the mock LLM to return schema without constraints field + mock_llm.ainvoke.return_value = LLMResponse(content=valid_schema_json) + + # run the schema extraction - should not crash + schema = await schema_from_text.run(text="Sample text for extraction") + + # verify schema was created with empty constraints + assert len(schema.constraints) == 0 + + def test_clean_json_content_markdown_with_json_language( schema_from_text: SchemaFromTextExtractor, ) -> None: