diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 11b7b4f70e3..2bda03b4873 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -11,7 +11,7 @@ from collections.abc import Sequence from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union +from typing import Any, Dict, IO, List, Mapping, Optional, Set, Tuple, TypeAlias, Union import executorch.devtools.etdump.schema_flatcc as flatcc @@ -37,7 +37,7 @@ from executorch.exir.debug_handle_utils import ( DEBUG_HANDLE_KEY, - get_greatest_ancestor_node_identifier, + FROM_NODE_KEY, UNSET_DEBUG_HANDLE, ) @@ -46,6 +46,7 @@ from tabulate import tabulate from torch.export import ExportedProgram +from torch.fx import Node FORWARD = "forward" EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module" @@ -936,6 +937,133 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]: ) +def get_ancestor_node_identifiers(node: Node) -> List[str]: + """Get the identifier of the ancestor node of the given node, with the graph id the ancestor node lives in. + + The identifier is the concatenation of the node name and graph id of the + greatest ancestor node, where the graph id is the unique id for every graph + module in the export flow and node name is unique within the same graph module. + + Returns: the identifiers of all its ancestor nodes + """ + + node_source = node.meta[FROM_NODE_KEY] + node_source = node_source[-1] + ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"] + + while len(node_source.from_node) > 0: + node_source = node_source.from_node[-1] + ancestor_node_ids.append(f"{node_source.name}.{str(node_source.graph_id)}") + + return ancestor_node_ids + + +def get_parent_node_identifier(node: Node) -> Optional[str]: + """Get the identifier of the parent node of the given node, with the graph id the parent node lives in. + + The identifier is the concatenation of the node name and graph id of the + greatest parent node, where the graph id is the unique id for every graph + module in the export flow and node name is unique within the same graph module. + + Returns: the identifier of the parent node, or None if can not find the parent + """ + + if FROM_NODE_KEY not in node.meta: + return None + + node_source = node.meta[FROM_NODE_KEY][-1] + return f"{node_source.name}.{str(node_source.graph_id)}" + + +def _extract_ancestor_debug_handles( + edge_dialect_program: ExportedProgram, +) -> Dict[str, int]: + """Extract mapping from ancestor node identifiers to debug handles.""" + ancestors_node_id_to_debug_handle: Dict[str, int] = {} + + def _extract_node_id_to_debug_handle(node: Node) -> None: + if node.op in ("placeholder", "output"): + return + for ancestor_node_id in get_ancestor_node_identifiers(node): + if ancestor_node_id not in ancestors_node_id_to_debug_handle: + ancestors_node_id_to_debug_handle[ancestor_node_id] = node.meta[ + DEBUG_HANDLE_KEY + ] + else: + assert ( + ancestors_node_id_to_debug_handle[ancestor_node_id] + == node.meta[DEBUG_HANDLE_KEY] + ) + + bfs_trace_with_node_process( + edge_dialect_program.graph_module, _extract_node_id_to_debug_handle + ) + return ancestors_node_id_to_debug_handle + + +def _find_matched_debug_handles( + exported_program: ExportedProgram, + exported_program_graph_id: int, + ancestors_node_id_to_debug_handle: Dict[str, int], +) -> Set[int]: + """Find debug handles that have corresponding nodes in the exported program.""" + matched_debug_handles: Set[int] = set() + + def _find_n_match_node(node: Node) -> None: + if node.op in ("output", "placeholder"): + return + node_id = f"{node.name}.{exported_program_graph_id}" + parent_node_id = get_parent_node_identifier(node) + if node_id in ancestors_node_id_to_debug_handle: + matched_debug_handles.add(ancestors_node_id_to_debug_handle[node_id]) + elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle: + matched_debug_handles.add(ancestors_node_id_to_debug_handle[parent_node_id]) + + bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node) + return matched_debug_handles + + +def _verify_graph_match( + edge_dialect_program: ExportedProgram, matched_debug_handles: Set[int] +) -> bool: + """Verify if every debug handle in edge dialect program has a corresponding node.""" + graph_matched = True + + def _check_graph_match(node: Node) -> None: + nonlocal graph_matched + if node.op in ("output", "placeholder"): + return + if node.meta[DEBUG_HANDLE_KEY] not in matched_debug_handles: + graph_matched = False + + bfs_trace_with_node_process(edge_dialect_program.graph_module, _check_graph_match) + return graph_matched + + +def _apply_debug_handles( + exported_program: ExportedProgram, + exported_program_graph_id: int, + ancestors_node_id_to_debug_handle: Dict[str, int], +) -> None: + """Apply debug handles to the exported program nodes.""" + + def _equip_debug_handle(node: Node) -> None: + if node.op in ("output", "placeholder"): + return + node_id = f"{node.name}.{exported_program_graph_id}" + parent_node_id = get_parent_node_identifier(node) + if node_id in ancestors_node_id_to_debug_handle: + node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[node_id] + elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle: + node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[ + parent_node_id + ] + else: + node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE + + bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle) + + def propagate_back_debug_handle( exported_program: ExportedProgram, exported_program_graph_id: int, @@ -953,47 +1081,24 @@ def propagate_back_debug_handle( Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1. The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping. - Return: True if: - a. every debug handle in the edge dialect program has a corresponding node in the exported program - b. the exported program is the greatest ancestor of the edge dialect program - - Otherwise, return False. + Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False. """ + # 1. Extract mapping from ancestor node identifiers to debug handles + ancestors_node_id_to_debug_handle = _extract_ancestor_debug_handles( + edge_dialect_program + ) - # 1. set up a mapping from debug handle to identifier of export program's node - # using edge dialect program nodes' debug handles and from_node info - export_graph_node_id_to_debug_handle = { - get_greatest_ancestor_node_identifier(node): node.meta[DEBUG_HANDLE_KEY] - for node in edge_dialect_program.graph.nodes - if node.op not in ("placeholder", "output") - } - - # 2. equip debug handle to the exported program's nodes using the mapping - # number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle - n_matched_node = 0 - - def _find_n_match_node(node: torch.fx.Node) -> None: - nonlocal n_matched_node - if node.name in ("output", "placeholder"): - return - node_id = f"{node.name}.{exported_program_graph_id}" - if node_id in export_graph_node_id_to_debug_handle: - n_matched_node += 1 - - def _equip_debug_handle(node: torch.fx.Node) -> None: - if node.name in ("output", "placeholder"): - return - node_id = f"{node.name}.{exported_program_graph_id}" - if node_id in export_graph_node_id_to_debug_handle: - node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id] - else: - node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE - - bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node) + # 2. Find debug handles that have corresponding nodes in the exported program + matched_debug_handles = _find_matched_debug_handles( + exported_program, exported_program_graph_id, ancestors_node_id_to_debug_handle + ) - # if any node in the edge dialect program has no corresponding node in the exported program, match failed - if n_matched_node != len(export_graph_node_id_to_debug_handle): + # 3. Verify if every debug handle in edge dialect program has a corresponding node + if not _verify_graph_match(edge_dialect_program, matched_debug_handles): return False - bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle) + # 4. Apply debug handles to the exported program + _apply_debug_handles( + exported_program, exported_program_graph_id, ancestors_node_id_to_debug_handle + ) return True diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 69aa6f65dec..ea8c0e653af 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -654,6 +654,95 @@ def test_equip_debug_handle_to_export_program_success(self): exported_program_debug_handles[0], edge_dialect_program_debug_handles[1] ) + def test_equip_debug_handle_to_strict_export_program_success(self): + """Test that propagate_back_debug_handle returns True and properly equips debug handles.""" + # Create a test model + model = models.FeedForwardBlock(5, 10) + inputs = (torch.rand(5, 5),) + + # Export the model + exported_program = export(model, inputs, strict=True) + export_graph_id = id(exported_program.graph) + + # Convert to edge dialect + edge_dialect_program = to_edge(exported_program).exported_program() + + # Call propagate_back_debug_handle + result = propagate_back_debug_handle( + exported_program, export_graph_id, edge_dialect_program + ) + + self.assertTrue(result) + + # Check that debug handles are properly equipped in the exported program + exported_program_debug_handles = [] + for node in exported_program.graph.nodes: + if node.op not in ("placeholder", "output"): + self.assertIn(DEBUG_HANDLE_KEY, node.meta) + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) + exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) + + edge_dialect_program_debug_handles = [] + for node in edge_dialect_program.graph.nodes: + if node.op not in ("placeholder", "output"): + self.assertIn(DEBUG_HANDLE_KEY, node.meta) + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) + edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) + + # The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem) + # So they should have the same debug handle + self.assertEqual( + exported_program_debug_handles[0], edge_dialect_program_debug_handles[0] + ) + self.assertEqual( + exported_program_debug_handles[0], edge_dialect_program_debug_handles[1] + ) + + def test_equip_debug_handle_to_reexport_program_success(self): + """Test that propagate_back_debug_handle returns True and properly equips debug handles.""" + # Create a test model + model = models.FeedForwardBlock(5, 10) + inputs = (torch.rand(5, 5),) + + # Export the model + init_export_program = export(model, inputs) + exported_program = export(init_export_program.module(), inputs) + export_graph_id = id(exported_program.graph) + + # Convert to edge dialect + edge_dialect_program = to_edge(exported_program).exported_program() + + # Call propagate_back_debug_handle + result = propagate_back_debug_handle( + exported_program, export_graph_id, edge_dialect_program + ) + + self.assertTrue(result) + + # Check that debug handles are properly equipped in the exported program + exported_program_debug_handles = [] + for node in exported_program.graph.nodes: + if node.op not in ("placeholder", "output"): + self.assertIn(DEBUG_HANDLE_KEY, node.meta) + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) + exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) + + edge_dialect_program_debug_handles = [] + for node in edge_dialect_program.graph.nodes: + if node.op not in ("placeholder", "output"): + self.assertIn(DEBUG_HANDLE_KEY, node.meta) + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) + edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) + + # The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem) + # So they should have the same debug handle + self.assertEqual( + exported_program_debug_handles[0], edge_dialect_program_debug_handles[0] + ) + self.assertEqual( + exported_program_debug_handles[0], edge_dialect_program_debug_handles[1] + ) + def test_equip_debug_handle_to_export_program_failure(self): """Test that propagate_back_debug_handle returns False when there's a mismatch.""" # Create a test model