From d3fcfd2c29c8df05506735a06aa789df2889a3e0 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 10 Jul 2025 20:32:04 -0700 Subject: [PATCH] Use same constant in runtime for unset debug handle In the runtime we have a contant number for unset debug handle. This diff bring that to python env for usage. Differential Revision: [D78132322](https://our.internmc.facebook.com/intern/diff/D78132322/) [ghstack-poisoned] --- devtools/inspector/_inspector_utils.py | 10 +++------- devtools/inspector/tests/inspector_utils_test.py | 9 +++++---- exir/debug_handle_utils.py | 1 + 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index d49ce3959a6..eec1d3e2577 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -37,6 +37,7 @@ from executorch.exir.debug_handle_utils import ( DEBUG_HANDLE_KEY, + UNSET_DEBUG_HANDLE, get_greatest_ancestor_node_identifier, ) @@ -914,7 +915,7 @@ def propagate_back_debug_handle( where op1_0 is from op1, op3_0 and op3_1 are from op3, op2 is removed by to_edge pipeline (e.g. RemoveNoopPass). Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1. - The debug handle of op2 will be a non-existing debug handle in edge dialect program for further skipping. + The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping. Return: True if: a. every debug handle in the edge dialect program has a corresponding node in the exported program @@ -935,11 +936,6 @@ def propagate_back_debug_handle( # number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle n_matched_node = 0 - # debug handle for the node in the exported program but not in the edge dialect program - debug_handle_for_removed_node = ( - max(export_graph_node_id_to_debug_handle.values()) + 1 - ) - def _find_n_match_node(node: torch.fx.Node) -> None: nonlocal n_matched_node if node.name in ("output", "placeholder"): @@ -955,7 +951,7 @@ def _equip_debug_handle(node: torch.fx.Node) -> None: if node_id in export_graph_node_id_to_debug_handle: node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id] else: - node.meta[DEBUG_HANDLE_KEY] = debug_handle_for_removed_node + node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node) diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index a77d541cb06..d198262a1c8 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -47,7 +47,7 @@ ) from executorch.devtools.inspector.numerical_comparator import L1Comparator from executorch.exir import to_edge -from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY +from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY, UNSET_DEBUG_HANDLE from torch.export import export @@ -682,8 +682,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) ) - # only two add ops in the exported program will keep in edge dialect program, so the debug handles for removed op will be three - debug_handle_for_removed_node = 3 + n_removed_nodes = 0 for node in exported_program.graph.nodes: if node.name == "add": @@ -691,10 +690,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: elif node.name == "add_1": self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 2) elif node.op not in ("placeholder", "output"): + n_removed_nodes += 1 self.assertEqual( - node.meta[DEBUG_HANDLE_KEY], debug_handle_for_removed_node + node.meta[DEBUG_HANDLE_KEY], UNSET_DEBUG_HANDLE ) + self.assertEqual(n_removed_nodes, 2) def gen_mock_operator_graph_with_expected_map() -> ( Tuple[OperatorGraph, Dict[int, OperatorNode]] diff --git a/exir/debug_handle_utils.py b/exir/debug_handle_utils.py index d1a70fcd213..8ffee71e280 100644 --- a/exir/debug_handle_utils.py +++ b/exir/debug_handle_utils.py @@ -9,6 +9,7 @@ FROM_NODE_KEY = "from_node" DEBUG_HANDLE_KEY = "debug_handle" +UNSET_DEBUG_HANDLE = 0 def get_greatest_ancestor_node_identifier(node: Node) -> str: """Get the identifier of the greatest ancestor node of the given node.