Skip to content

Commit d2b7b2f

Browse files
restucture debug handle
Pull Request resolved: #7197 This diff formats the debug handle generation process in et stack by. extracting bfs graph tracing process. ghstack-source-id: 257880795 @exported-using-ghexport Differential Revision: [D66622890](https://our.internmc.facebook.com/intern/diff/D66622890/) Co-authored-by: gasoonjia <[email protected]>
1 parent 9069017 commit d2b7b2f

File tree

2 files changed

+42
-36
lines changed

2 files changed

+42
-36
lines changed

exir/graph_module.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
from types import FunctionType as function
10-
from typing import Dict, List, Tuple, Union
10+
from typing import Callable, Dict, List, Tuple, Union
1111

1212
import torch
1313

@@ -68,3 +68,23 @@ def get_control_flow_submodules(
6868
control_flow_submodules.append(_get_submodule(graph_module, node, 0))
6969

7070
return control_flow_submodules
71+
72+
73+
def bfs_trace_with_node_process(
74+
gm: torch.fx.GraphModule, node_op: Callable[[torch.fx.Node], None]
75+
) -> None:
76+
"""Traverse the graph module and apply node_op to each node."""
77+
78+
assert isinstance(gm, torch.fx.GraphModule), f"Expected GraphModule, got {type(gm)}"
79+
80+
queue = [gm]
81+
while queue:
82+
current_graph_module = queue.pop(0)
83+
for node in current_graph_module.graph.nodes:
84+
node_op(node)
85+
86+
control_flow_submodules = [
87+
submodule
88+
for _, submodule, _ in get_control_flow_submodules(current_graph_module)
89+
]
90+
queue.extend(control_flow_submodules)

exir/passes/debug_handle_generator_pass.py

+21-35
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from executorch.exir.graph_module import get_control_flow_submodules
7+
from executorch.exir.graph_module import bfs_trace_with_node_process
88
from executorch.exir.pass_base import ExportPass
99
from torch.export import ExportedProgram
1010
from torch.fx import GraphModule
@@ -17,19 +17,15 @@ def call(self, graph_module: GraphModule) -> PassResult:
1717
to executorch backend, that has a canonical set of quantized operators
1818
"""
1919

20-
queue = [graph_module]
2120
index = 1
22-
# bfs to traverse all modules including control flow submodules to attached debug handle id
23-
while queue:
24-
current_graph_module = queue.pop(0)
25-
for node in current_graph_module.graph.nodes:
26-
node.meta["debug_handle"] = index
27-
index += 1
28-
control_flow_submodules = [
29-
submodule
30-
for _, submodule, _ in get_control_flow_submodules(current_graph_module)
31-
]
32-
queue.extend(control_flow_submodules)
21+
22+
def _extract_debug_handles_from_node(node):
23+
nonlocal index
24+
node.meta["debug_handle"] = index
25+
index += 1
26+
27+
bfs_trace_with_node_process(graph_module, _extract_debug_handles_from_node)
28+
3329
return PassResult(graph_module, True)
3430

3531

@@ -38,28 +34,18 @@ def generate_missing_debug_handles(ep: ExportedProgram):
3834
This pass is used to generate missing debug handles for the graph module and its submodules.
3935
"""
4036

41-
def get_control_flow_submodules_list(graph_module):
42-
return [
43-
submodule for _, submodule, _ in get_control_flow_submodules(graph_module)
44-
]
45-
4637
max_handle = 0
47-
queue = [ep.graph_module]
4838

49-
while queue:
50-
current_graph_module = queue.pop(0)
51-
for node in current_graph_module.graph.nodes:
52-
if "debug_handle" in node.meta:
53-
max_handle = max(max_handle, node.meta["debug_handle"])
54-
control_flow_submodules = get_control_flow_submodules_list(current_graph_module)
55-
queue.extend(control_flow_submodules)
39+
def _extract_max_debug_handle(node):
40+
nonlocal max_handle
41+
if "debug_handle" in node.meta:
42+
max_handle = max(max_handle, node.meta["debug_handle"])
43+
44+
def _insert_new_debug_handles(node):
45+
nonlocal max_handle
46+
if node.meta.get("debug_handle", 0) in (0, None):
47+
node.meta["debug_handle"] = max_handle + 1
48+
max_handle += 1
5649

57-
queue = [ep.graph_module]
58-
while queue:
59-
current_graph_module = queue.pop(0)
60-
for node in current_graph_module.graph.nodes:
61-
if node.meta.get("debug_handle", 0) in (0, None):
62-
node.meta["debug_handle"] = max_handle + 1
63-
max_handle += 1
64-
control_flow_submodules = get_control_flow_submodules_list(current_graph_module)
65-
queue.extend(control_flow_submodules)
50+
bfs_trace_with_node_process(ep.graph_module, _extract_max_debug_handle)
51+
bfs_trace_with_node_process(ep.graph_module, _insert_new_debug_handles)

0 commit comments

Comments
 (0)