Skip to content

Commit b5424e5

Browse files
committed
Support export program in intermediate numeric discrepancy detector
Pull Request resolved: #12581 This diff enables intermediate numeric discrepancy detector to leverage export program as label. More specific, if user creates etrecord with exported program, and the exported program is one of the exported programs in the export flow, then our numeric discrepancy detector will use it as label. Otherwise, we will continue use edge dialect graph as label. ghstack-source-id: 297184685 @exported-using-ghexport Differential Revision: [D78298935](https://our.internmc.facebook.com/intern/diff/D78298935/)
1 parent 5945bf8 commit b5424e5

File tree

4 files changed

+101
-20
lines changed

4 files changed

+101
-20
lines changed

devtools/inspector/_inspector.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
map_runtime_aot_intermediate_outputs,
6363
merge_runtime_overlapping_debug_handles,
6464
ProgramOutput,
65+
propagate_back_debug_handle,
6566
RESERVED_FRAMEWORK_EVENT_NAMES,
6667
TimeScale,
6768
verify_debug_data_equivalence,
@@ -1166,7 +1167,18 @@ def _get_aot_intermediate_outputs_and_op_names(
11661167
"""
11671168
if self._etrecord._representative_inputs is None:
11681169
return {}, {}
1169-
export_program = self._etrecord.edge_dialect_program
1170+
1171+
export_program = None
1172+
1173+
# Will use the exported program to extract intermediate output if and only if exported_program has been provided, and it is the greatest ancestor of the edge_dialect_program
1174+
if self._etrecord.exported_program and propagate_back_debug_handle(
1175+
self._etrecord.exported_program,
1176+
self._etrecord.export_graph_id,
1177+
self._etrecord.edge_dialect_program,
1178+
):
1179+
export_program = self._etrecord.exported_program
1180+
else:
1181+
export_program = self._etrecord.edge_dialect_program
11701182
graph_module = export_program.module()
11711183
aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(
11721184
graph_module

devtools/inspector/_intermediate_output_capturer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def run_and_capture(self, *args, **kwargs) -> Dict[DebugHandle, Any]:
3535

3636
def capture_run_node(n: torch.fx.Node) -> Any:
3737
result = super(IntermediateOutputCapturer, self).run_node(n)
38+
print(f"n: {n}, result: {result}")
3839
if all(filter.matches(n) for filter in self.node_filters):
40+
print("matched")
3941
debug_handle = n.meta["debug_handle"]
4042
# Convert the debug handle to a tuple to use as a dictionary key
4143
key = (

devtools/inspector/tests/inspector_test.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from executorch.devtools import generate_etrecord, parse_etrecord
2626
from executorch.devtools.debug_format.et_schema import OperatorNode
2727
from executorch.devtools.etdump.schema_flatcc import ProfileEvent
28-
from executorch.devtools.etrecord._etrecord import ETRecord
2928
from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord
3029

3130
from executorch.devtools.inspector import (
@@ -480,14 +479,14 @@ def test_populate_debugging_related_fields_passes_for_consistent_events(self):
480479
events=events,
481480
)
482481

483-
def test_etrecord_populates_correct_aot_intermediate_outputs(self):
482+
def test_etrecord_populates_correct_edge_dialect_aot_intermediate_outputs(self):
484483
with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file:
485484
etrecord_path = tmp_file.name
486485
mod = model_registry["ConvLinearModel"]()
487486
input_tensor = torch.tensor(
488487
[[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True
489488
)
490-
aten_model: ExportedProgram = export(mod, (input_tensor,), strict=True)
489+
aten_model: ExportedProgram = export(mod, (input_tensor,))
491490
edge_program_manager: EdgeProgramManager = to_edge(
492491
aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
493492
)
@@ -513,15 +512,11 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self):
513512
etdump_path=ETDUMP_PATH,
514513
etrecord=etrecord_path,
515514
)
516-
etrecord = ETRecord(
517-
edge_dialect_program=inspector_instance._etrecord.edge_dialect_program,
518-
graph_map=inspector_instance._etrecord.graph_map,
519-
_debug_handle_map=inspector_instance._etrecord._debug_handle_map,
520-
_delegate_map=inspector_instance._etrecord._delegate_map,
521-
_reference_outputs=inspector_instance._etrecord._reference_outputs,
522-
_representative_inputs=aten_model.example_inputs[0],
515+
516+
inspector_instance._etrecord._representative_inputs = (
517+
aten_model.example_inputs[0]
523518
)
524-
inspector_instance._etrecord = etrecord
519+
525520
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
526521
inspector_instance._get_aot_intermediate_outputs_and_op_names()
527522
)
@@ -534,7 +529,61 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self):
534529

535530
self.assertTrue(
536531
check_if_debug_handle_to_op_names_match(
537-
"ConvLinearModel", aot_debug_handle_to_op_names
532+
aot_debug_handle_to_op_names,
533+
mod.get_edge_dialect_expected_debug_handle_to_op_names(),
534+
)
535+
)
536+
537+
def test_etrecord_populates_correct_export_program_aot_intermediate_outputs(self):
538+
with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file:
539+
etrecord_path = tmp_file.name
540+
mod = model_registry["ConvLinearModel"]()
541+
input_tensor = mod.get_input()
542+
aten_model: ExportedProgram = export(mod, (input_tensor,), strict=True)
543+
edge_program_manager: EdgeProgramManager = to_edge(aten_model)
544+
edge_program_manager_copy = copy.deepcopy(edge_program_manager)
545+
et_program_manager: ExecutorchProgramManager = (
546+
edge_program_manager.to_executorch()
547+
)
548+
# Generate ETRecord with the exported program
549+
generate_etrecord(
550+
etrecord_path,
551+
edge_program_manager_copy,
552+
et_program_manager,
553+
exported_program=aten_model,
554+
)
555+
with patch.object(
556+
Inspector, "_consume_etrecord", return_value=None
557+
), patch.object(
558+
_inspector, "gen_etdump_object", return_value=None
559+
), patch.object(
560+
EventBlock, "_gen_from_etdump"
561+
), patch.object(
562+
_inspector, "gen_graphs_from_etrecord"
563+
):
564+
# Call the constructor of Inspector
565+
inspector_instance = Inspector(
566+
etdump_path=ETDUMP_PATH,
567+
etrecord=etrecord_path,
568+
)
569+
570+
inspector_instance._etrecord._representative_inputs = (
571+
aten_model.example_inputs[0]
572+
)
573+
574+
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
575+
inspector_instance._get_aot_intermediate_outputs_and_op_names()
576+
)
577+
self.assertTrue(
578+
check_if_intermediate_outputs_match(
579+
aot_intermediate_outputs,
580+
mod.get_exported_program_expected_intermediate_outputs(),
581+
)
582+
)
583+
self.assertTrue(
584+
check_if_debug_handle_to_op_names_match(
585+
aot_debug_handle_to_op_names,
586+
mod.get_exported_program_expected_debug_handle_to_op_names(),
538587
)
539588
)
540589

devtools/inspector/tests/inspector_test_utils.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def get_edge_dialect_expected_intermediate_outputs():
7979
}
8080

8181
@staticmethod
82-
def get_expected_debug_handle_to_op_names():
82+
def get_edge_dialect_expected_debug_handle_to_op_names():
8383
"""
8484
Returns the expected debug handle and op names mapping for this model for the given input.
8585
"""
@@ -100,7 +100,7 @@ def get_expected_debug_handle_to_op_names():
100100
@staticmethod
101101
def get_exported_program_expected_intermediate_outputs():
102102
"""
103-
Returns the expected outputs of the debug handles and intermediate output mapping for edge dialect graph of this model for the given input.
103+
Returns the expected outputs of the debug handles and intermediate output mapping for export graph of this model for the given input.
104104
"""
105105
return {
106106
(UNSET_DEBUG_HANDLE,): torch.tensor([[5.4000, 13.5200]]),
@@ -117,6 +117,26 @@ def get_exported_program_expected_intermediate_outputs():
117117
(11,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
118118
}
119119

120+
@staticmethod
121+
def get_exported_program_expected_debug_handle_to_op_names():
122+
"""
123+
Returns the expected debug handle and op name mapping for this model for the given input.
124+
"""
125+
return {
126+
(UNSET_DEBUG_HANDLE,): ["_assert_tensor_metadata_default", "to"],
127+
(1,): ["conv2d"],
128+
(2,): ["view"],
129+
(3,): ["linear"],
130+
(4,): ["add"],
131+
(5,): ["sub"],
132+
(6,): ["mul"],
133+
(7,): ["add_1"],
134+
(8,): ["div"],
135+
(9,): ["relu"],
136+
(10,): ["sigmoid"],
137+
(11,): ["split"],
138+
}
139+
120140

121141
# Global model registry
122142
model_registry = {
@@ -153,15 +173,13 @@ def check_if_intermediate_outputs_match(
153173
return True
154174

155175

156-
def check_if_debug_handle_to_op_names_match(model_name, actual_debug_handle_to_op_name):
176+
def check_if_debug_handle_to_op_names_match(
177+
actual_debug_handle_to_op_name, expected_debug_handle_to_op_name
178+
):
157179
"""
158180
Checks if the actual op names match the expected op names for the specified model.
159181
Returns True if all match, otherwise returns False.
160182
"""
161-
model_instance = model_registry[model_name]
162-
expected_debug_handle_to_op_name = (
163-
model_instance.get_expected_debug_handle_to_op_names()
164-
)
165183
if len(actual_debug_handle_to_op_name) != len(expected_debug_handle_to_op_name):
166184
return False
167185
for debug_handle, expected_op_name in expected_debug_handle_to_op_name.items():

0 commit comments

Comments
 (0)