25
25
from executorch .devtools import generate_etrecord , parse_etrecord
26
26
from executorch .devtools .debug_format .et_schema import OperatorNode
27
27
from executorch .devtools .etdump .schema_flatcc import ProfileEvent
28
- from executorch .devtools .etrecord ._etrecord import ETRecord
29
28
from executorch .devtools .etrecord .tests .etrecord_test import TestETRecord
30
29
31
30
from executorch .devtools .inspector import (
@@ -480,14 +479,14 @@ def test_populate_debugging_related_fields_passes_for_consistent_events(self):
480
479
events = events ,
481
480
)
482
481
483
- def test_etrecord_populates_correct_aot_intermediate_outputs (self ):
482
+ def test_etrecord_populates_correct_edge_dialect_aot_intermediate_outputs (self ):
484
483
with tempfile .NamedTemporaryFile (suffix = ".bin" ) as tmp_file :
485
484
etrecord_path = tmp_file .name
486
485
mod = model_registry ["ConvLinearModel" ]()
487
486
input_tensor = torch .tensor (
488
487
[[[[1.0 , 2.0 ], [3.0 , 4.0 ]]]], requires_grad = True
489
488
)
490
- aten_model : ExportedProgram = export (mod , (input_tensor ,), strict = True )
489
+ aten_model : ExportedProgram = export (mod , (input_tensor ,))
491
490
edge_program_manager : EdgeProgramManager = to_edge (
492
491
aten_model , compile_config = EdgeCompileConfig (_check_ir_validity = True )
493
492
)
@@ -513,15 +512,11 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self):
513
512
etdump_path = ETDUMP_PATH ,
514
513
etrecord = etrecord_path ,
515
514
)
516
- etrecord = ETRecord (
517
- edge_dialect_program = inspector_instance ._etrecord .edge_dialect_program ,
518
- graph_map = inspector_instance ._etrecord .graph_map ,
519
- _debug_handle_map = inspector_instance ._etrecord ._debug_handle_map ,
520
- _delegate_map = inspector_instance ._etrecord ._delegate_map ,
521
- _reference_outputs = inspector_instance ._etrecord ._reference_outputs ,
522
- _representative_inputs = aten_model .example_inputs [0 ],
515
+
516
+ inspector_instance ._etrecord ._representative_inputs = (
517
+ aten_model .example_inputs [0 ]
523
518
)
524
- inspector_instance . _etrecord = etrecord
519
+
525
520
aot_intermediate_outputs , aot_debug_handle_to_op_names = (
526
521
inspector_instance ._get_aot_intermediate_outputs_and_op_names ()
527
522
)
@@ -534,7 +529,61 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self):
534
529
535
530
self .assertTrue (
536
531
check_if_debug_handle_to_op_names_match (
537
- "ConvLinearModel" , aot_debug_handle_to_op_names
532
+ aot_debug_handle_to_op_names ,
533
+ mod .get_edge_dialect_expected_debug_handle_to_op_names (),
534
+ )
535
+ )
536
+
537
+ def test_etrecord_populates_correct_export_program_aot_intermediate_outputs (self ):
538
+ with tempfile .NamedTemporaryFile (suffix = ".bin" ) as tmp_file :
539
+ etrecord_path = tmp_file .name
540
+ mod = model_registry ["ConvLinearModel" ]()
541
+ input_tensor = mod .get_input ()
542
+ aten_model : ExportedProgram = export (mod , (input_tensor ,), strict = True )
543
+ edge_program_manager : EdgeProgramManager = to_edge (aten_model )
544
+ edge_program_manager_copy = copy .deepcopy (edge_program_manager )
545
+ et_program_manager : ExecutorchProgramManager = (
546
+ edge_program_manager .to_executorch ()
547
+ )
548
+ # Generate ETRecord with the exported program
549
+ generate_etrecord (
550
+ etrecord_path ,
551
+ edge_program_manager_copy ,
552
+ et_program_manager ,
553
+ exported_program = aten_model ,
554
+ )
555
+ with patch .object (
556
+ Inspector , "_consume_etrecord" , return_value = None
557
+ ), patch .object (
558
+ _inspector , "gen_etdump_object" , return_value = None
559
+ ), patch .object (
560
+ EventBlock , "_gen_from_etdump"
561
+ ), patch .object (
562
+ _inspector , "gen_graphs_from_etrecord"
563
+ ):
564
+ # Call the constructor of Inspector
565
+ inspector_instance = Inspector (
566
+ etdump_path = ETDUMP_PATH ,
567
+ etrecord = etrecord_path ,
568
+ )
569
+
570
+ inspector_instance ._etrecord ._representative_inputs = (
571
+ aten_model .example_inputs [0 ]
572
+ )
573
+
574
+ aot_intermediate_outputs , aot_debug_handle_to_op_names = (
575
+ inspector_instance ._get_aot_intermediate_outputs_and_op_names ()
576
+ )
577
+ self .assertTrue (
578
+ check_if_intermediate_outputs_match (
579
+ aot_intermediate_outputs ,
580
+ mod .get_exported_program_expected_intermediate_outputs (),
581
+ )
582
+ )
583
+ self .assertTrue (
584
+ check_if_debug_handle_to_op_names_match (
585
+ aot_debug_handle_to_op_names ,
586
+ mod .get_exported_program_expected_debug_handle_to_op_names (),
538
587
)
539
588
)
540
589
0 commit comments