diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 85d19c5e952..432397347a5 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -92,6 +92,17 @@ def check_graph_closeness(self, graph_a, graph_b): self.assertEqual( node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle") ) + from_node_a = node_a.meta.get("from_node") + from_node_b = node_b.meta.get("from_node") + + if from_node_a is None: + self.assertIsNone(from_node_b) + else: + self.assertIsNotNone(from_node_b) + for node_source_a, node_source_b in zip(from_node_a, from_node_b): + self.assertEqual( + node_source_a.to_dict(), node_source_b.to_dict() + ) def test_etrecord_generation(self): captured_output, edge_output, et_output = self.get_test_model() diff --git a/exir/serde/serialize.py b/exir/serde/serialize.py index b587813c72c..ef969d0f2af 100644 --- a/exir/serde/serialize.py +++ b/exir/serde/serialize.py @@ -41,6 +41,7 @@ ) from torch._export.verifier import load_verifier from torch.fx.experimental import symbolic_shapes +from torch.fx.traceback import NodeSource, NodeSourceAction log: logging.Logger = logging.getLogger(__name__) @@ -141,8 +142,24 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: debug_handle = node.meta["debug_handle"] meta["debug_handle"] = str(debug_handle) + if "from_node" in node.meta: + from_node = node.meta["from_node"] + # Serialize from_node as JSON since it's a complex nested structure + meta["from_node"] = json.dumps(self._make_from_node_json_acceptable(from_node)) + return meta + def _make_from_node_json_acceptable(self, from_node: Optional[List[NodeSource]]): + """ + Recursively serialize from_node metadata which can be a list of NodeSource objects. + """ + if from_node is None: + return None + + json_acceptable_from_node = [node_source.to_dict() for node_source in from_node if isinstance(node_source, NodeSource)] + + return json_acceptable_from_node + def serialize_alloc_inputs( self, inputs # pyre-ignore ) -> List[schema.NamedArgument]: @@ -473,8 +490,22 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: if debug_handle := metadata.get("debug_handle"): res["debug_handle"] = int(debug_handle) + if from_node_str := metadata.get("from_node"): + res["from_node"] = self._deserialize_from_node(json.loads(from_node_str)) + return res + def _deserialize_from_node(self, from_node_data: Optional[List[Dict[str, Any]]]) -> Optional[List[NodeSource]]: + """ + Recursively deserialize from_node metadata from JSON data. + """ + if from_node_data is None: + return None + + assert isinstance(from_node_data, list) + + return [NodeSource._from_dict(fn_dict) for fn_dict in from_node_data] + # pyre-ignore def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]): def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec: diff --git a/exir/tests/test_serde.py b/exir/tests/test_serde.py index 67821d0bffb..d6a4ae235ba 100644 --- a/exir/tests/test_serde.py +++ b/exir/tests/test_serde.py @@ -275,3 +275,37 @@ def forward(self, x): ) self.assertEqual(metadata[0], metadata_serde[0]) self.assertEqual(list(metadata[1].keys()), list(metadata_serde[1].keys())) + + def test_meta_debug_handle_and_from_node(self) -> None: + class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.conv_layer = nn.Conv2d( + in_channels=1, out_channels=64, kernel_size=3, padding=1 + ) + + def forward(self, x): + return self.conv_layer(x) + + m = Model() + inputs = (torch.randn(1, 1, 32, 32),) + + edge = to_edge(export(m, inputs, strict=True)) + edge_new = deserialize(serialize(edge.exported_program())) + for node, node_new in zip( + edge.exported_program().graph_module.graph.nodes, + edge_new.graph_module.graph.nodes, + ): + if node.op not in {"placeholder", "output"}: + self.assertIsNotNone(node.meta.get("debug_handle")) + self.assertIsNotNone(node.meta.get("from_node")) + self.assertEqual( + node.meta.get("debug_handle"), node_new.meta.get("debug_handle") + ) + self.assertEqual( + len(node.meta.get("from_node")), len(node_new.meta.get("from_node")) + ) + for node_source, node_source_new in zip( + node.meta.get("from_node"), node_new.meta.get("from_node") + ): + self.assertEqual(node_source.to_dict(), node_source_new.to_dict())