4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from executorch .exir .graph_module import get_control_flow_submodules
7
+ from executorch .exir .graph_module import bfs_trace_with_node_process
8
8
from executorch .exir .pass_base import ExportPass
9
9
from torch .export import ExportedProgram
10
10
from torch .fx import GraphModule
@@ -17,19 +17,15 @@ def call(self, graph_module: GraphModule) -> PassResult:
17
17
to executorch backend, that has a canonical set of quantized operators
18
18
"""
19
19
20
- queue = [graph_module ]
21
20
index = 1
22
- # bfs to traverse all modules including control flow submodules to attached debug handle id
23
- while queue :
24
- current_graph_module = queue .pop (0 )
25
- for node in current_graph_module .graph .nodes :
26
- node .meta ["debug_handle" ] = index
27
- index += 1
28
- control_flow_submodules = [
29
- submodule
30
- for _ , submodule , _ in get_control_flow_submodules (current_graph_module )
31
- ]
32
- queue .extend (control_flow_submodules )
21
+
22
+ def _extract_debug_handles_from_node (node ):
23
+ nonlocal index
24
+ node .meta ["debug_handle" ] = index
25
+ index += 1
26
+
27
+ bfs_trace_with_node_process (graph_module , _extract_debug_handles_from_node )
28
+
33
29
return PassResult (graph_module , True )
34
30
35
31
@@ -38,28 +34,18 @@ def generate_missing_debug_handles(ep: ExportedProgram):
38
34
This pass is used to generate missing debug handles for the graph module and its submodules.
39
35
"""
40
36
41
- def get_control_flow_submodules_list (graph_module ):
42
- return [
43
- submodule for _ , submodule , _ in get_control_flow_submodules (graph_module )
44
- ]
45
-
46
37
max_handle = 0
47
- queue = [ep .graph_module ]
48
38
49
- while queue :
50
- current_graph_module = queue .pop (0 )
51
- for node in current_graph_module .graph .nodes :
52
- if "debug_handle" in node .meta :
53
- max_handle = max (max_handle , node .meta ["debug_handle" ])
54
- control_flow_submodules = get_control_flow_submodules_list (current_graph_module )
55
- queue .extend (control_flow_submodules )
39
+ def _extract_max_debug_handle (node ):
40
+ nonlocal max_handle
41
+ if "debug_handle" in node .meta :
42
+ max_handle = max (max_handle , node .meta ["debug_handle" ])
43
+
44
+ def _insert_new_debug_handles (node ):
45
+ nonlocal max_handle
46
+ if node .meta .get ("debug_handle" , 0 ) in (0 , None ):
47
+ node .meta ["debug_handle" ] = max_handle + 1
48
+ max_handle += 1
56
49
57
- queue = [ep .graph_module ]
58
- while queue :
59
- current_graph_module = queue .pop (0 )
60
- for node in current_graph_module .graph .nodes :
61
- if node .meta .get ("debug_handle" , 0 ) in (0 , None ):
62
- node .meta ["debug_handle" ] = max_handle + 1
63
- max_handle += 1
64
- control_flow_submodules = get_control_flow_submodules_list (current_graph_module )
65
- queue .extend (control_flow_submodules )
50
+ bfs_trace_with_node_process (ep .graph_module , _extract_max_debug_handle )
51
+ bfs_trace_with_node_process (ep .graph_module , _insert_new_debug_handles )
0 commit comments