Skip to content

Commit fee25ba

Browse files
committed
support back propagate debug handle to arbitrary ancestor export graph
Pull Request resolved: #12580 Currently propagate_back_debug_handle function can only support propagating debug handle back to the greatest ancestor export graph. This diff update algo to support every possible ancestor export graph on the flow. ghstack-source-id: 296898644 Differential Revision: [D78464992](https://our.internmc.facebook.com/intern/diff/D78464992/)
1 parent 1af653c commit fee25ba

File tree

2 files changed

+234
-40
lines changed

2 files changed

+234
-40
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 145 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import Sequence
1212
from dataclasses import dataclass
1313
from enum import Enum
14-
from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
14+
from typing import Any, Dict, IO, List, Mapping, Optional, Set, Tuple, TypeAlias, Union
1515

1616
import executorch.devtools.etdump.schema_flatcc as flatcc
1717

@@ -37,7 +37,7 @@
3737

3838
from executorch.exir.debug_handle_utils import (
3939
DEBUG_HANDLE_KEY,
40-
get_greatest_ancestor_node_identifier,
40+
FROM_NODE_KEY,
4141
UNSET_DEBUG_HANDLE,
4242
)
4343

@@ -46,6 +46,7 @@
4646
from tabulate import tabulate
4747

4848
from torch.export import ExportedProgram
49+
from torch.fx import Node
4950

5051
FORWARD = "forward"
5152
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"
@@ -936,6 +937,133 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
936937
)
937938

938939

940+
def get_ancestor_node_identifiers(node: Node) -> List[str]:
941+
"""Get the identifier of the ancestor node of the given node, with the graph id the ancestor node lives in.
942+
943+
The identifier is the concatenation of the node name and graph id of the
944+
greatest ancestor node, where the graph id is the unique id for every graph
945+
module in the export flow and node name is unique within the same graph module.
946+
947+
Returns: the identifiers of all its ancestor nodes
948+
"""
949+
950+
node_source = node.meta[FROM_NODE_KEY]
951+
node_source = node_source[-1]
952+
ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"]
953+
954+
while len(node_source.from_node) > 0:
955+
node_source = node_source.from_node[-1]
956+
ancestor_node_ids.append(f"{node_source.name}.{str(node_source.graph_id)}")
957+
958+
return ancestor_node_ids
959+
960+
961+
def get_parent_node_identifier(node: Node) -> Optional[str]:
962+
"""Get the identifier of the parent node of the given node, with the graph id the parent node lives in.
963+
964+
The identifier is the concatenation of the node name and graph id of the
965+
greatest parent node, where the graph id is the unique id for every graph
966+
module in the export flow and node name is unique within the same graph module.
967+
968+
Returns: the identifier of the parent node, or None if can not find the parent
969+
"""
970+
971+
if FROM_NODE_KEY not in node.meta:
972+
return None
973+
974+
node_source = node.meta[FROM_NODE_KEY][-1]
975+
return f"{node_source.name}.{str(node_source.graph_id)}"
976+
977+
978+
def _extract_ancestor_debug_handles(
979+
edge_dialect_program: ExportedProgram,
980+
) -> Dict[str, int]:
981+
"""Extract mapping from ancestor node identifiers to debug handles."""
982+
ancestors_node_id_to_debug_handle: Dict[str, int] = {}
983+
984+
def _extract_node_id_to_debug_handle(node: Node) -> None:
985+
if node.op in ("placeholder", "output"):
986+
return
987+
for ancestor_node_id in get_ancestor_node_identifiers(node):
988+
if ancestor_node_id not in ancestors_node_id_to_debug_handle:
989+
ancestors_node_id_to_debug_handle[ancestor_node_id] = node.meta[
990+
DEBUG_HANDLE_KEY
991+
]
992+
else:
993+
assert (
994+
ancestors_node_id_to_debug_handle[ancestor_node_id]
995+
== node.meta[DEBUG_HANDLE_KEY]
996+
)
997+
998+
bfs_trace_with_node_process(
999+
edge_dialect_program.graph_module, _extract_node_id_to_debug_handle
1000+
)
1001+
return ancestors_node_id_to_debug_handle
1002+
1003+
1004+
def _find_matched_debug_handles(
1005+
exported_program: ExportedProgram,
1006+
exported_program_graph_id: int,
1007+
ancestors_node_id_to_debug_handle: Dict[str, int],
1008+
) -> Set[int]:
1009+
"""Find debug handles that have corresponding nodes in the exported program."""
1010+
matched_debug_handles: Set[int] = set()
1011+
1012+
def _find_n_match_node(node: Node) -> None:
1013+
if node.op in ("output", "placeholder"):
1014+
return
1015+
node_id = f"{node.name}.{exported_program_graph_id}"
1016+
parent_node_id = get_parent_node_identifier(node)
1017+
if node_id in ancestors_node_id_to_debug_handle:
1018+
matched_debug_handles.add(ancestors_node_id_to_debug_handle[node_id])
1019+
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
1020+
matched_debug_handles.add(ancestors_node_id_to_debug_handle[parent_node_id])
1021+
1022+
bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
1023+
return matched_debug_handles
1024+
1025+
1026+
def _verify_graph_match(
1027+
edge_dialect_program: ExportedProgram, matched_debug_handles: Set[int]
1028+
) -> bool:
1029+
"""Verify if every debug handle in edge dialect program has a corresponding node."""
1030+
graph_matched = True
1031+
1032+
def _check_graph_match(node: Node) -> None:
1033+
nonlocal graph_matched
1034+
if node.op in ("output", "placeholder"):
1035+
return
1036+
if node.meta[DEBUG_HANDLE_KEY] not in matched_debug_handles:
1037+
graph_matched = False
1038+
1039+
bfs_trace_with_node_process(edge_dialect_program.graph_module, _check_graph_match)
1040+
return graph_matched
1041+
1042+
1043+
def _apply_debug_handles(
1044+
exported_program: ExportedProgram,
1045+
exported_program_graph_id: int,
1046+
ancestors_node_id_to_debug_handle: Dict[str, int],
1047+
) -> None:
1048+
"""Apply debug handles to the exported program nodes."""
1049+
1050+
def _equip_debug_handle(node: Node) -> None:
1051+
if node.op in ("output", "placeholder"):
1052+
return
1053+
node_id = f"{node.name}.{exported_program_graph_id}"
1054+
parent_node_id = get_parent_node_identifier(node)
1055+
if node_id in ancestors_node_id_to_debug_handle:
1056+
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[node_id]
1057+
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
1058+
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[
1059+
parent_node_id
1060+
]
1061+
else:
1062+
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
1063+
1064+
bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)
1065+
1066+
9391067
def propagate_back_debug_handle(
9401068
exported_program: ExportedProgram,
9411069
exported_program_graph_id: int,
@@ -953,47 +1081,24 @@ def propagate_back_debug_handle(
9531081
Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
9541082
The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.
9551083
956-
Return: True if:
957-
a. every debug handle in the edge dialect program has a corresponding node in the exported program
958-
b. the exported program is the greatest ancestor of the edge dialect program
959-
960-
Otherwise, return False.
1084+
Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False.
9611085
"""
1086+
# 1. Extract mapping from ancestor node identifiers to debug handles
1087+
ancestors_node_id_to_debug_handle = _extract_ancestor_debug_handles(
1088+
edge_dialect_program
1089+
)
9621090

963-
# 1. set up a mapping from debug handle to identifier of export program's node
964-
# using edge dialect program nodes' debug handles and from_node info
965-
export_graph_node_id_to_debug_handle = {
966-
get_greatest_ancestor_node_identifier(node): node.meta[DEBUG_HANDLE_KEY]
967-
for node in edge_dialect_program.graph.nodes
968-
if node.op not in ("placeholder", "output")
969-
}
970-
971-
# 2. equip debug handle to the exported program's nodes using the mapping
972-
# number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle
973-
n_matched_node = 0
974-
975-
def _find_n_match_node(node: torch.fx.Node) -> None:
976-
nonlocal n_matched_node
977-
if node.name in ("output", "placeholder"):
978-
return
979-
node_id = f"{node.name}.{exported_program_graph_id}"
980-
if node_id in export_graph_node_id_to_debug_handle:
981-
n_matched_node += 1
982-
983-
def _equip_debug_handle(node: torch.fx.Node) -> None:
984-
if node.name in ("output", "placeholder"):
985-
return
986-
node_id = f"{node.name}.{exported_program_graph_id}"
987-
if node_id in export_graph_node_id_to_debug_handle:
988-
node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id]
989-
else:
990-
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
991-
992-
bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
1091+
# 2. Find debug handles that have corresponding nodes in the exported program
1092+
matched_debug_handles = _find_matched_debug_handles(
1093+
exported_program, exported_program_graph_id, ancestors_node_id_to_debug_handle
1094+
)
9931095

994-
# if any node in the edge dialect program has no corresponding node in the exported program, match failed
995-
if n_matched_node != len(export_graph_node_id_to_debug_handle):
1096+
# 3. Verify if every debug handle in edge dialect program has a corresponding node
1097+
if not _verify_graph_match(edge_dialect_program, matched_debug_handles):
9961098
return False
9971099

998-
bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)
1100+
# 4. Apply debug handles to the exported program
1101+
_apply_debug_handles(
1102+
exported_program, exported_program_graph_id, ancestors_node_id_to_debug_handle
1103+
)
9991104
return True

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,95 @@ def test_equip_debug_handle_to_export_program_success(self):
654654
exported_program_debug_handles[0], edge_dialect_program_debug_handles[1]
655655
)
656656

657+
def test_equip_debug_handle_to_strict_export_program_success(self):
658+
"""Test that propagate_back_debug_handle returns True and properly equips debug handles."""
659+
# Create a test model
660+
model = models.FeedForwardBlock(5, 10)
661+
inputs = (torch.rand(5, 5),)
662+
663+
# Export the model
664+
exported_program = export(model, inputs, strict=True)
665+
export_graph_id = id(exported_program.graph)
666+
667+
# Convert to edge dialect
668+
edge_dialect_program = to_edge(exported_program).exported_program()
669+
670+
# Call propagate_back_debug_handle
671+
result = propagate_back_debug_handle(
672+
exported_program, export_graph_id, edge_dialect_program
673+
)
674+
675+
self.assertTrue(result)
676+
677+
# Check that debug handles are properly equipped in the exported program
678+
exported_program_debug_handles = []
679+
for node in exported_program.graph.nodes:
680+
if node.op not in ("placeholder", "output"):
681+
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
682+
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
683+
exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])
684+
685+
edge_dialect_program_debug_handles = []
686+
for node in edge_dialect_program.graph.nodes:
687+
if node.op not in ("placeholder", "output"):
688+
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
689+
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
690+
edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])
691+
692+
# The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem)
693+
# So they should have the same debug handle
694+
self.assertEqual(
695+
exported_program_debug_handles[0], edge_dialect_program_debug_handles[0]
696+
)
697+
self.assertEqual(
698+
exported_program_debug_handles[0], edge_dialect_program_debug_handles[1]
699+
)
700+
701+
def test_equip_debug_handle_to_reexport_program_success(self):
702+
"""Test that propagate_back_debug_handle returns True and properly equips debug handles."""
703+
# Create a test model
704+
model = models.FeedForwardBlock(5, 10)
705+
inputs = (torch.rand(5, 5),)
706+
707+
# Export the model
708+
init_export_program = export(model, inputs)
709+
exported_program = export(init_export_program.module(), inputs)
710+
export_graph_id = id(exported_program.graph)
711+
712+
# Convert to edge dialect
713+
edge_dialect_program = to_edge(exported_program).exported_program()
714+
715+
# Call propagate_back_debug_handle
716+
result = propagate_back_debug_handle(
717+
exported_program, export_graph_id, edge_dialect_program
718+
)
719+
720+
self.assertTrue(result)
721+
722+
# Check that debug handles are properly equipped in the exported program
723+
exported_program_debug_handles = []
724+
for node in exported_program.graph.nodes:
725+
if node.op not in ("placeholder", "output"):
726+
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
727+
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
728+
exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])
729+
730+
edge_dialect_program_debug_handles = []
731+
for node in edge_dialect_program.graph.nodes:
732+
if node.op not in ("placeholder", "output"):
733+
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
734+
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
735+
edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])
736+
737+
# The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem)
738+
# So they should have the same debug handle
739+
self.assertEqual(
740+
exported_program_debug_handles[0], edge_dialect_program_debug_handles[0]
741+
)
742+
self.assertEqual(
743+
exported_program_debug_handles[0], edge_dialect_program_debug_handles[1]
744+
)
745+
657746
def test_equip_debug_handle_to_export_program_failure(self):
658747
"""Test that propagate_back_debug_handle returns False when there's a mismatch."""
659748
# Create a test model

0 commit comments

Comments
 (0)