diff --git a/pylint/pyreverse/diagrams.py b/pylint/pyreverse/diagrams.py index a4fb8ce130..93e5896b36 100644 --- a/pylint/pyreverse/diagrams.py +++ b/pylint/pyreverse/diagrams.py @@ -226,6 +226,7 @@ def extract_relationships(self) -> None: obj.attrs = self.get_attrs(node) obj.methods = self.get_methods(node) obj.shape = "class" + # inheritance link for par_node in node.ancestors(recurs=False): try: @@ -234,7 +235,18 @@ def extract_relationships(self) -> None: except KeyError: continue - # associations & aggregations links + # Track processed attributes to avoid duplicates + processed_attrs = set() + + # Composition links + for name, values in list(node.compositions_type.items()): + for value in values: + self.assign_association_relationship( + value, obj, name, "composition" + ) + processed_attrs.add(name) + + # Aggregation links for name, values in list(node.aggregations_type.items()): for value in values: if not self.show_attr(name): @@ -244,8 +256,8 @@ def extract_relationships(self) -> None: value, obj, name, "aggregation" ) + # Association links associations = node.associations_type.copy() - for name, values in node.locals_type.items(): if name not in associations: associations[name] = values diff --git a/pylint/pyreverse/dot_printer.py b/pylint/pyreverse/dot_printer.py index 4baed6c3c2..b4aba94ea5 100644 --- a/pylint/pyreverse/dot_printer.py +++ b/pylint/pyreverse/dot_printer.py @@ -30,12 +30,18 @@ class HTMLLabels(Enum): # pylint: disable-next=consider-using-namedtuple-or-dataclass ARROWS: dict[EdgeType, dict[str, str]] = { EdgeType.INHERITS: {"arrowtail": "none", "arrowhead": "empty"}, - EdgeType.ASSOCIATION: { + EdgeType.COMPOSITION: { "fontcolor": "green", "arrowtail": "none", "arrowhead": "diamond", "style": "solid", }, + EdgeType.ASSOCIATION: { + "fontcolor": "green", + "arrowtail": "none", + "arrowhead": "normal", + "style": "solid", + }, EdgeType.AGGREGATION: { "fontcolor": "green", "arrowtail": "none", diff --git a/pylint/pyreverse/inspector.py b/pylint/pyreverse/inspector.py index 8e69e94470..ea8723998d 100644 --- a/pylint/pyreverse/inspector.py +++ b/pylint/pyreverse/inspector.py @@ -113,6 +113,9 @@ class Linker(IdGeneratorMixIn, utils.LocalsVisitor): * aggregations_type as instance_attrs_type but for aggregations relationships + + * compositions_type + as instance_attrs_type but for compositions relationships """ def __init__(self, project: Project, tag: bool = False) -> None: @@ -122,8 +125,14 @@ def __init__(self, project: Project, tag: bool = False) -> None: self.tag = tag # visited project self.project = project - self.associations_handler = AggregationsHandler() - self.associations_handler.set_next(OtherAssociationsHandler()) + + # Chain: Composition → Aggregation → Association + self.associations_handler = CompositionsHandler() + aggregation_handler = AggregationsHandler() + association_handler = AssociationsHandler() + + self.associations_handler.set_next(aggregation_handler) + aggregation_handler.set_next(association_handler) def visit_project(self, node: Project) -> None: """Visit a pyreverse.utils.Project node. @@ -167,6 +176,7 @@ def visit_classdef(self, node: nodes.ClassDef) -> None: specializations.append(node) baseobj.specializations = specializations # resolve instance attributes + node.compositions_type = collections.defaultdict(list) node.instance_attrs_type = collections.defaultdict(list) node.aggregations_type = collections.defaultdict(list) node.associations_type = collections.defaultdict(list) @@ -327,16 +337,39 @@ def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None: self._next_handler.handle(node, parent) +class CompositionsHandler(AbstractAssociationHandler): + """Handle composition relationships where parent creates child objects.""" + + def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None: + if not isinstance(node.parent, (nodes.AnnAssign, nodes.Assign)): + super().handle(node, parent) + return + + value = node.parent.value + + # Composition: parent creates child (self.x = P()) + if isinstance(value, nodes.Call): + current = set(parent.compositions_type[node.attrname]) + parent.compositions_type[node.attrname] = list( + current | utils.infer_node(node) + ) + return + + # Not a composition, pass to next handler + super().handle(node, parent) + + class AggregationsHandler(AbstractAssociationHandler): + """Handle aggregation relationships where parent receives child objects.""" + def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None: - # Check if we're not in an assignment context if not isinstance(node.parent, (nodes.AnnAssign, nodes.Assign)): super().handle(node, parent) return value = node.parent.value - # Handle direct name assignments + # Aggregation: parent receives child (self.x = x) if isinstance(value, astroid.node_classes.Name): current = set(parent.aggregations_type[node.attrname]) parent.aggregations_type[node.attrname] = list( @@ -344,11 +377,10 @@ def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None: ) return - # Handle comprehensions + # Aggregation: comprehensions (self.x = [P() for ...]) if isinstance( value, (nodes.ListComp, nodes.DictComp, nodes.SetComp, nodes.GeneratorExp) ): - # Determine the type of the element in the comprehension if isinstance(value, nodes.DictComp): element_type = safe_infer(value.value) else: @@ -358,12 +390,23 @@ def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None: parent.aggregations_type[node.attrname] = list(current | {element_type}) return - # Fallback to parent handler + # Type annotation only (x: P) defaults to aggregation + if isinstance(node.parent, nodes.AnnAssign) and node.parent.value is None: + current = set(parent.aggregations_type[node.attrname]) + parent.aggregations_type[node.attrname] = list( + current | utils.infer_node(node) + ) + return + + # Not an aggregation, pass to next handler super().handle(node, parent) -class OtherAssociationsHandler(AbstractAssociationHandler): +class AssociationsHandler(AbstractAssociationHandler): + """Handle regular association relationships.""" + def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None: + # Everything else is a regular association current = set(parent.associations_type[node.attrname]) parent.associations_type[node.attrname] = list(current | utils.infer_node(node)) diff --git a/pylint/pyreverse/mermaidjs_printer.py b/pylint/pyreverse/mermaidjs_printer.py index 0f1ebd04f0..45ad91f763 100644 --- a/pylint/pyreverse/mermaidjs_printer.py +++ b/pylint/pyreverse/mermaidjs_printer.py @@ -21,7 +21,8 @@ class MermaidJSPrinter(Printer): } ARROWS: dict[EdgeType, str] = { EdgeType.INHERITS: "--|>", - EdgeType.ASSOCIATION: "--*", + EdgeType.COMPOSITION: "--*", + EdgeType.ASSOCIATION: "-->", EdgeType.AGGREGATION: "--o", EdgeType.USES: "-->", EdgeType.TYPE_DEPENDENCY: "..>", diff --git a/pylint/pyreverse/plantuml_printer.py b/pylint/pyreverse/plantuml_printer.py index 379d57a4c6..98013224c4 100644 --- a/pylint/pyreverse/plantuml_printer.py +++ b/pylint/pyreverse/plantuml_printer.py @@ -21,7 +21,8 @@ class PlantUmlPrinter(Printer): } ARROWS: dict[EdgeType, str] = { EdgeType.INHERITS: "--|>", - EdgeType.ASSOCIATION: "--*", + EdgeType.ASSOCIATION: "-->", + EdgeType.COMPOSITION: "--*", EdgeType.AGGREGATION: "--o", EdgeType.USES: "-->", EdgeType.TYPE_DEPENDENCY: "..>", diff --git a/pylint/pyreverse/printer.py b/pylint/pyreverse/printer.py index caa7917ca0..3ec1804897 100644 --- a/pylint/pyreverse/printer.py +++ b/pylint/pyreverse/printer.py @@ -22,6 +22,7 @@ class NodeType(Enum): class EdgeType(Enum): INHERITS = "inherits" + COMPOSITION = "composition" ASSOCIATION = "association" AGGREGATION = "aggregation" USES = "uses" diff --git a/pylint/pyreverse/writer.py b/pylint/pyreverse/writer.py index e822f67096..28fb7ea095 100644 --- a/pylint/pyreverse/writer.py +++ b/pylint/pyreverse/writer.py @@ -146,6 +146,14 @@ def write_classes(self, diagram: ClassDiagram) -> None: label=rel.name, type_=EdgeType.ASSOCIATION, ) + # generate compositions + for rel in diagram.get_relationships("composition"): + self.printer.emit_edge( + rel.from_object.fig_id, + rel.to_object.fig_id, + label=rel.name, + type_=EdgeType.COMPOSITION, + ) # generate aggregations for rel in diagram.get_relationships("aggregation"): if rel.to_object.fig_id in associations[rel.from_object.fig_id]: diff --git a/tests/pyreverse/functional/class_diagrams/aggregation/fields.mmd b/tests/pyreverse/functional/class_diagrams/aggregation/fields.mmd index 9901b175c8..96e0defda1 100644 --- a/tests/pyreverse/functional/class_diagrams/aggregation/fields.mmd +++ b/tests/pyreverse/functional/class_diagrams/aggregation/fields.mmd @@ -16,8 +16,8 @@ classDiagram } class P { } - P --* A : x - P --* C : x + P --> A : x P --* D : x P --* E : x P --o B : x + P --o C : x diff --git a/tests/pyreverse/functional/class_diagrams/aggregation/fields.py b/tests/pyreverse/functional/class_diagrams/aggregation/fields.py index a2afb89913..dd812c5b7b 100644 --- a/tests/pyreverse/functional/class_diagrams/aggregation/fields.py +++ b/tests/pyreverse/functional/class_diagrams/aggregation/fields.py @@ -4,24 +4,24 @@ class P: pass class A: - x: P + x: P # just type hint, no ownership, soassociation class B: def __init__(self, x: P): - self.x = x + self.x = x # not instantiated, so aggregation class C: x: P def __init__(self, x: P): - self.x = x + self.x = x # not instantiated, so aggregation class D: x: P def __init__(self): - self.x = P() + self.x = P() # instantiated, so composition class E: def __init__(self): - self.x = P() + self.x = P() # instantiated, so composition