11
11
from collections .abc import Sequence
12
12
from dataclasses import dataclass
13
13
from enum import Enum
14
- from typing import Any , Dict , IO , List , Mapping , Optional , Tuple , TypeAlias , Union
14
+ from typing import Any , Dict , IO , List , Mapping , Optional , Set , Tuple , TypeAlias , Union
15
15
16
16
import executorch .devtools .etdump .schema_flatcc as flatcc
17
17
37
37
38
38
from executorch .exir .debug_handle_utils import (
39
39
DEBUG_HANDLE_KEY ,
40
- get_greatest_ancestor_node_identifier ,
40
+ FROM_NODE_KEY ,
41
41
UNSET_DEBUG_HANDLE ,
42
42
)
43
43
46
46
from tabulate import tabulate
47
47
48
48
from torch .export import ExportedProgram
49
+ from torch .fx import Node
49
50
50
51
FORWARD = "forward"
51
52
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"
@@ -936,6 +937,133 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
936
937
)
937
938
938
939
940
+ def get_ancestor_node_identifiers (node : Node ) -> List [str ]:
941
+ """Get the identifier of the ancestor node of the given node, with the graph id the ancestor node lives in.
942
+
943
+ The identifier is the concatenation of the node name and graph id of the
944
+ greatest ancestor node, where the graph id is the unique id for every graph
945
+ module in the export flow and node name is unique within the same graph module.
946
+
947
+ Returns: the identifiers of all its ancestor nodes
948
+ """
949
+
950
+ node_source = node .meta [FROM_NODE_KEY ]
951
+ node_source = node_source [- 1 ]
952
+ ancestor_node_ids : List [str ] = [f"{ node_source .name } .{ str (node_source .graph_id )} " ]
953
+
954
+ while len (node_source .from_node ) > 0 :
955
+ node_source = node_source .from_node [- 1 ]
956
+ ancestor_node_ids .append (f"{ node_source .name } .{ str (node_source .graph_id )} " )
957
+
958
+ return ancestor_node_ids
959
+
960
+
961
+ def get_parent_node_identifier (node : Node ) -> Optional [str ]:
962
+ """Get the identifier of the parent node of the given node, with the graph id the parent node lives in.
963
+
964
+ The identifier is the concatenation of the node name and graph id of the
965
+ greatest parent node, where the graph id is the unique id for every graph
966
+ module in the export flow and node name is unique within the same graph module.
967
+
968
+ Returns: the identifier of the parent node, or None if can not find the parent
969
+ """
970
+
971
+ if FROM_NODE_KEY not in node .meta :
972
+ return None
973
+
974
+ node_source = node .meta [FROM_NODE_KEY ][- 1 ]
975
+ return f"{ node_source .name } .{ str (node_source .graph_id )} "
976
+
977
+
978
+ def _extract_ancestor_debug_handles (
979
+ edge_dialect_program : ExportedProgram ,
980
+ ) -> Dict [str , int ]:
981
+ """Extract mapping from ancestor node identifiers to debug handles."""
982
+ ancestors_node_id_to_debug_handle : Dict [str , int ] = {}
983
+
984
+ def _extract_node_id_to_debug_handle (node : Node ) -> None :
985
+ if node .op in ("placeholder" , "output" ):
986
+ return
987
+ for ancestor_node_id in get_ancestor_node_identifiers (node ):
988
+ if ancestor_node_id not in ancestors_node_id_to_debug_handle :
989
+ ancestors_node_id_to_debug_handle [ancestor_node_id ] = node .meta [
990
+ DEBUG_HANDLE_KEY
991
+ ]
992
+ else :
993
+ assert (
994
+ ancestors_node_id_to_debug_handle [ancestor_node_id ]
995
+ == node .meta [DEBUG_HANDLE_KEY ]
996
+ )
997
+
998
+ bfs_trace_with_node_process (
999
+ edge_dialect_program .graph_module , _extract_node_id_to_debug_handle
1000
+ )
1001
+ return ancestors_node_id_to_debug_handle
1002
+
1003
+
1004
+ def _find_matched_debug_handles (
1005
+ exported_program : ExportedProgram ,
1006
+ exported_program_graph_id : int ,
1007
+ ancestors_node_id_to_debug_handle : Dict [str , int ],
1008
+ ) -> Set [int ]:
1009
+ """Find debug handles that have corresponding nodes in the exported program."""
1010
+ matched_debug_handles : Set [int ] = set ()
1011
+
1012
+ def _find_n_match_node (node : Node ) -> None :
1013
+ if node .op in ("output" , "placeholder" ):
1014
+ return
1015
+ node_id = f"{ node .name } .{ exported_program_graph_id } "
1016
+ parent_node_id = get_parent_node_identifier (node )
1017
+ if node_id in ancestors_node_id_to_debug_handle :
1018
+ matched_debug_handles .add (ancestors_node_id_to_debug_handle [node_id ])
1019
+ elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle :
1020
+ matched_debug_handles .add (ancestors_node_id_to_debug_handle [parent_node_id ])
1021
+
1022
+ bfs_trace_with_node_process (exported_program .graph_module , _find_n_match_node )
1023
+ return matched_debug_handles
1024
+
1025
+
1026
+ def _verify_graph_match (
1027
+ edge_dialect_program : ExportedProgram , matched_debug_handles : Set [int ]
1028
+ ) -> bool :
1029
+ """Verify if every debug handle in edge dialect program has a corresponding node."""
1030
+ graph_matched = True
1031
+
1032
+ def _check_graph_match (node : Node ) -> None :
1033
+ nonlocal graph_matched
1034
+ if node .op in ("output" , "placeholder" ):
1035
+ return
1036
+ if node .meta [DEBUG_HANDLE_KEY ] not in matched_debug_handles :
1037
+ graph_matched = False
1038
+
1039
+ bfs_trace_with_node_process (edge_dialect_program .graph_module , _check_graph_match )
1040
+ return graph_matched
1041
+
1042
+
1043
+ def _apply_debug_handles (
1044
+ exported_program : ExportedProgram ,
1045
+ exported_program_graph_id : int ,
1046
+ ancestors_node_id_to_debug_handle : Dict [str , int ],
1047
+ ) -> None :
1048
+ """Apply debug handles to the exported program nodes."""
1049
+
1050
+ def _equip_debug_handle (node : Node ) -> None :
1051
+ if node .op in ("output" , "placeholder" ):
1052
+ return
1053
+ node_id = f"{ node .name } .{ exported_program_graph_id } "
1054
+ parent_node_id = get_parent_node_identifier (node )
1055
+ if node_id in ancestors_node_id_to_debug_handle :
1056
+ node .meta [DEBUG_HANDLE_KEY ] = ancestors_node_id_to_debug_handle [node_id ]
1057
+ elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle :
1058
+ node .meta [DEBUG_HANDLE_KEY ] = ancestors_node_id_to_debug_handle [
1059
+ parent_node_id
1060
+ ]
1061
+ else :
1062
+ node .meta [DEBUG_HANDLE_KEY ] = UNSET_DEBUG_HANDLE
1063
+
1064
+ bfs_trace_with_node_process (exported_program .graph_module , _equip_debug_handle )
1065
+
1066
+
939
1067
def propagate_back_debug_handle (
940
1068
exported_program : ExportedProgram ,
941
1069
exported_program_graph_id : int ,
@@ -953,47 +1081,24 @@ def propagate_back_debug_handle(
953
1081
Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
954
1082
The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.
955
1083
956
- Return: True if:
957
- a. every debug handle in the edge dialect program has a corresponding node in the exported program
958
- b. the exported program is the greatest ancestor of the edge dialect program
959
-
960
- Otherwise, return False.
1084
+ Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False.
961
1085
"""
1086
+ # 1. Extract mapping from ancestor node identifiers to debug handles
1087
+ ancestors_node_id_to_debug_handle = _extract_ancestor_debug_handles (
1088
+ edge_dialect_program
1089
+ )
962
1090
963
- # 1. set up a mapping from debug handle to identifier of export program's node
964
- # using edge dialect program nodes' debug handles and from_node info
965
- export_graph_node_id_to_debug_handle = {
966
- get_greatest_ancestor_node_identifier (node ): node .meta [DEBUG_HANDLE_KEY ]
967
- for node in edge_dialect_program .graph .nodes
968
- if node .op not in ("placeholder" , "output" )
969
- }
970
-
971
- # 2. equip debug handle to the exported program's nodes using the mapping
972
- # number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle
973
- n_matched_node = 0
974
-
975
- def _find_n_match_node (node : torch .fx .Node ) -> None :
976
- nonlocal n_matched_node
977
- if node .name in ("output" , "placeholder" ):
978
- return
979
- node_id = f"{ node .name } .{ exported_program_graph_id } "
980
- if node_id in export_graph_node_id_to_debug_handle :
981
- n_matched_node += 1
982
-
983
- def _equip_debug_handle (node : torch .fx .Node ) -> None :
984
- if node .name in ("output" , "placeholder" ):
985
- return
986
- node_id = f"{ node .name } .{ exported_program_graph_id } "
987
- if node_id in export_graph_node_id_to_debug_handle :
988
- node .meta [DEBUG_HANDLE_KEY ] = export_graph_node_id_to_debug_handle [node_id ]
989
- else :
990
- node .meta [DEBUG_HANDLE_KEY ] = UNSET_DEBUG_HANDLE
991
-
992
- bfs_trace_with_node_process (exported_program .graph_module , _find_n_match_node )
1091
+ # 2. Find debug handles that have corresponding nodes in the exported program
1092
+ matched_debug_handles = _find_matched_debug_handles (
1093
+ exported_program , exported_program_graph_id , ancestors_node_id_to_debug_handle
1094
+ )
993
1095
994
- # if any node in the edge dialect program has no corresponding node in the exported program, match failed
995
- if n_matched_node != len ( export_graph_node_id_to_debug_handle ):
1096
+ # 3. Verify if every debug handle in edge dialect program has a corresponding node
1097
+ if not _verify_graph_match ( edge_dialect_program , matched_debug_handles ):
996
1098
return False
997
1099
998
- bfs_trace_with_node_process (exported_program .graph_module , _equip_debug_handle )
1100
+ # 4. Apply debug handles to the exported program
1101
+ _apply_debug_handles (
1102
+ exported_program , exported_program_graph_id , ancestors_node_id_to_debug_handle
1103
+ )
999
1104
return True
0 commit comments