Skip to content

Commit e8da4cb

Browse files
Amir LayeghAmir Layegh
authored andcommitted
feat: add ConstraintType to GraphSchema for constraint extraction
1 parent 3353643 commit e8da4cb

File tree

2 files changed

+93
-2
lines changed

2 files changed

+93
-2
lines changed

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,19 @@ def property_type_from_name(self, name: str) -> Optional[PropertyType]:
160160
return prop
161161
return None
162162

163+
class ConstraintType(BaseModel):
164+
"""
165+
Represents a constraint on a node in the graph.
166+
"""
167+
168+
type: Literal["UNIQUENESS"] #TODO: add other constraint types ["propertyExistence", "propertyType", "key"]
169+
node_type: str
170+
property_name: str
171+
172+
model_config = ConfigDict(
173+
frozen=True,
174+
)
175+
163176

164177
class GraphSchema(DataModel):
165178
"""This model represents the expected
@@ -177,6 +190,7 @@ class GraphSchema(DataModel):
177190
node_types: Tuple[NodeType, ...]
178191
relationship_types: Tuple[RelationshipType, ...] = tuple()
179192
patterns: Tuple[Tuple[str, str, str], ...] = tuple()
193+
constraints: Tuple[ConstraintType, ...] = tuple()
180194

181195
additional_node_types: bool = Field(
182196
default_factory=default_additional_item("node_types")
@@ -239,6 +253,21 @@ def validate_additional_parameters(self) -> Self:
239253
)
240254
return self
241255

256+
@model_validator(mode="after")
257+
def validate_constraints_against_node_types(self) -> Self:
258+
if not self.constraints:
259+
return self
260+
for constraint in self.constraints:
261+
if not constraint.get('property_name'):
262+
raise SchemaValidationError(
263+
f"Constraint has no property name: {constraint}. Property name is required."
264+
)
265+
if constraint.get('node_type') not in self._node_type_index:
266+
raise SchemaValidationError(
267+
f"Constraint references undefined node type: {constraint.get('node_type')}"
268+
)
269+
return self
270+
242271
def node_type_from_label(self, label: str) -> Optional[NodeType]:
243272
return self._node_type_index.get(label)
244273

@@ -382,6 +411,7 @@ def create_schema_model(
382411
node_types: Sequence[NodeType],
383412
relationship_types: Optional[Sequence[RelationshipType]] = None,
384413
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
414+
constraints: Optional[Sequence[ConstraintType]] = None,
385415
**kwargs: Any,
386416
) -> GraphSchema:
387417
"""
@@ -403,6 +433,7 @@ def create_schema_model(
403433
node_types=node_types,
404434
relationship_types=relationship_types or (),
405435
patterns=patterns or (),
436+
constraints=constraints or (),
406437
**kwargs,
407438
)
408439
)
@@ -415,6 +446,7 @@ async def run(
415446
node_types: Sequence[NodeType],
416447
relationship_types: Optional[Sequence[RelationshipType]] = None,
417448
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
449+
constraints: Optional[Sequence[ConstraintType]] = None,
418450
**kwargs: Any,
419451
) -> GraphSchema:
420452
"""
@@ -432,6 +464,7 @@ async def run(
432464
node_types,
433465
relationship_types,
434466
patterns,
467+
constraints,
435468
**kwargs,
436469
)
437470

@@ -554,6 +587,42 @@ def _filter_relationships_without_labels(
554587
return self._filter_items_without_labels(
555588
relationship_types, "relationship type"
556589
)
590+
591+
def _filter_invalid_constraints(
592+
self,
593+
constraints: List[Dict[str, Any]], node_types: List[Dict[str, Any]]
594+
) -> List[Dict[str, Any]]:
595+
"""Filter out constraints that reference undefined node types or have no property name."""
596+
if not constraints:
597+
return []
598+
599+
if not node_types:
600+
logging.info(
601+
"Filtering out all constraints because no node types are defined. "
602+
"Constraints reference node types that must be defined."
603+
)
604+
return []
605+
606+
valid_node_labels = {node_type.get('label') for node_type in node_types}
607+
608+
filtered_constraints = []
609+
for constraint in constraints:
610+
# check if the property_name is provided
611+
if not constraint.get('property_name'):
612+
logging.info(
613+
f"Filtering out constraint: {constraint}. "
614+
f"Property name is not provided."
615+
)
616+
continue
617+
# check if the node_type is valid
618+
if constraint.get('node_type') not in valid_node_labels:
619+
logging.info(
620+
f"Filtering out constraint: {constraint}. "
621+
f"Node type '{constraint.get('node_type')}' is not valid. Valid node types: {valid_node_labels}"
622+
)
623+
continue
624+
filtered_constraints.append(constraint)
625+
return filtered_constraints
557626

558627
def _clean_json_content(self, content: str) -> str:
559628
content = content.strip()
@@ -624,6 +693,9 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
624693
extracted_patterns: Optional[List[Tuple[str, str, str]]] = extracted_schema.get(
625694
"patterns"
626695
)
696+
extracted_constraints: Optional[List[Dict[str, Any]]] = extracted_schema.get(
697+
"constraints"
698+
)
627699

628700
# Filter out nodes and relationships without labels
629701
extracted_node_types = self._filter_nodes_without_labels(extracted_node_types)
@@ -637,12 +709,19 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
637709
extracted_patterns = self._filter_invalid_patterns(
638710
extracted_patterns, extracted_node_types, extracted_relationship_types
639711
)
712+
713+
# Filter out invalid constraints
714+
if extracted_constraints:
715+
extracted_constraints = self._filter_invalid_constraints(
716+
extracted_constraints, extracted_node_types
717+
)
640718

641719
return GraphSchema.model_validate(
642720
{
643721
"node_types": extracted_node_types,
644722
"relationship_types": extracted_relationship_types,
645723
"patterns": extracted_patterns,
724+
"constraints": extracted_constraints,
646725
}
647726
)
648727

src/neo4j_graphrag/generation/prompts.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,11 @@ class SchemaExtractionTemplate(PromptTemplate):
217217
4. Include property definitions only when the type can be confidently inferred, otherwise omit them.
218218
5. When defining patterns, ensure that every node label and relationship label mentioned exists in your lists of node types and relationship types.
219219
6. Do not create node types that aren't clearly mentioned in the text.
220-
7. Keep your schema minimal and focused on clearly identifiable patterns in the text.
220+
7. For each node type, identify a unique identifier property and add it as a UNIQUENESS constraint to the list of constraints.
221+
8. Constraints must reference a node_type label that exists in the list of node types.
222+
9. Each constraint must have a property_name having a name that indicates it is a unique identifier for the node type (e.g., person_id for Person, company_id for Company)
223+
10. Keep your schema minimal and focused on clearly identifiable patterns in the text.
224+
221225
222226
Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST,
223227
LOCAL_DATETIME, LOCAL_TIME, POINT, STRING, ZONED_DATETIME, ZONED_TIME.
@@ -231,7 +235,7 @@ class SchemaExtractionTemplate(PromptTemplate):
231235
{{
232236
"name": "name",
233237
"type": "STRING"
234-
}}
238+
}},
235239
]
236240
}},
237241
...
@@ -245,6 +249,14 @@ class SchemaExtractionTemplate(PromptTemplate):
245249
"patterns": [
246250
["Person", "WORKS_FOR", "Company"],
247251
...
252+
],
253+
"constraints": [
254+
{{
255+
"type": "UNIQUENESS",
256+
"node_type": "Person",
257+
"property_name": "person_id"
258+
}},
259+
...
248260
]
249261
}}
250262

0 commit comments

Comments
 (0)