Skip to content

Commit 3fa6d44

Browse files
committed
feat: add ConstraintType to GraphSchema for constraint extraction
1 parent cb5e021 commit 3fa6d44

File tree

1 file changed

+14
-12
lines changed
  • src/neo4j_graphrag/experimental/components

1 file changed

+14
-12
lines changed

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,15 @@ def property_type_from_name(self, name: str) -> Optional[PropertyType]:
160160
return prop
161161
return None
162162

163+
163164
class ConstraintType(BaseModel):
164165
"""
165166
Represents a constraint on a node in the graph.
166167
"""
167168

168-
type: Literal["UNIQUENESS"] #TODO: add other constraint types ["propertyExistence", "propertyType", "key"]
169+
type: Literal[
170+
"UNIQUENESS"
171+
] # TODO: add other constraint types ["propertyExistence", "propertyType", "key"]
169172
node_type: str
170173
property_name: str
171174

@@ -258,11 +261,11 @@ def validate_constraints_against_node_types(self) -> Self:
258261
if not self.constraints:
259262
return self
260263
for constraint in self.constraints:
261-
if not constraint.get('property_name'):
264+
if not constraint.get("property_name"):
262265
raise SchemaValidationError(
263266
f"Constraint has no property name: {constraint}. Property name is required."
264267
)
265-
if constraint.get('node_type') not in self._node_type_index:
268+
if constraint.get("node_type") not in self._node_type_index:
266269
raise SchemaValidationError(
267270
f"Constraint references undefined node type: {constraint.get('node_type')}"
268271
)
@@ -587,35 +590,34 @@ def _filter_relationships_without_labels(
587590
return self._filter_items_without_labels(
588591
relationship_types, "relationship type"
589592
)
590-
593+
591594
def _filter_invalid_constraints(
592-
self,
593-
constraints: List[Dict[str, Any]], node_types: List[Dict[str, Any]]
595+
self, constraints: List[Dict[str, Any]], node_types: List[Dict[str, Any]]
594596
) -> List[Dict[str, Any]]:
595597
"""Filter out constraints that reference undefined node types or have no property name."""
596598
if not constraints:
597599
return []
598-
600+
599601
if not node_types:
600602
logging.info(
601603
"Filtering out all constraints because no node types are defined. "
602604
"Constraints reference node types that must be defined."
603605
)
604606
return []
605-
606-
valid_node_labels = {node_type.get('label') for node_type in node_types}
607+
608+
valid_node_labels = {node_type.get("label") for node_type in node_types}
607609

608610
filtered_constraints = []
609611
for constraint in constraints:
610612
# check if the property_name is provided
611-
if not constraint.get('property_name'):
613+
if not constraint.get("property_name"):
612614
logging.info(
613615
f"Filtering out constraint: {constraint}. "
614616
f"Property name is not provided."
615617
)
616618
continue
617619
# check if the node_type is valid
618-
if constraint.get('node_type') not in valid_node_labels:
620+
if constraint.get("node_type") not in valid_node_labels:
619621
logging.info(
620622
f"Filtering out constraint: {constraint}. "
621623
f"Node type '{constraint.get('node_type')}' is not valid. Valid node types: {valid_node_labels}"
@@ -709,7 +711,7 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
709711
extracted_patterns = self._filter_invalid_patterns(
710712
extracted_patterns, extracted_node_types, extracted_relationship_types
711713
)
712-
714+
713715
# Filter out invalid constraints
714716
if extracted_constraints:
715717
extracted_constraints = self._filter_invalid_constraints(

0 commit comments

Comments
 (0)