From bc669bc2f6bba2bb29c7d1135fbc73ef0c37a1a2 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 14 Jul 2025 14:21:58 -0700 Subject: [PATCH] Serializing from_node info in et serializer We need to use the from_node informaton in deserialzied exported graph for operator tracing in et.inspector. this diff updates the serizalier to support serde from_node info. Differential Revision: [D78293986](https://our.internmc.facebook.com/intern/diff/D78293986/) [ghstack-poisoned] --- devtools/etrecord/tests/etrecord_test.py | 11 ++++ exir/serde/serialize.py | 68 ++++++++++++++++++++++++ exir/tests/test_serde.py | 34 ++++++++++++ 3 files changed, 113 insertions(+) diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 85d19c5e952..432397347a5 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -92,6 +92,17 @@ def check_graph_closeness(self, graph_a, graph_b): self.assertEqual( node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle") ) + from_node_a = node_a.meta.get("from_node") + from_node_b = node_b.meta.get("from_node") + + if from_node_a is None: + self.assertIsNone(from_node_b) + else: + self.assertIsNotNone(from_node_b) + for node_source_a, node_source_b in zip(from_node_a, from_node_b): + self.assertEqual( + node_source_a.to_dict(), node_source_b.to_dict() + ) def test_etrecord_generation(self): captured_output, edge_output, et_output = self.get_test_model() diff --git a/exir/serde/serialize.py b/exir/serde/serialize.py index b587813c72c..a2ac115be55 100644 --- a/exir/serde/serialize.py +++ b/exir/serde/serialize.py @@ -41,6 +41,7 @@ ) from torch._export.verifier import load_verifier from torch.fx.experimental import symbolic_shapes +from torch.fx.traceback import NodeSource, NodeSourceAction log: logging.Logger = logging.getLogger(__name__) @@ -141,8 +142,24 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: debug_handle = node.meta["debug_handle"] meta["debug_handle"] = str(debug_handle) + if "from_node" in node.meta: + from_node = node.meta["from_node"] + # Serialize from_node as JSON since it's a complex nested structure + meta["from_node"] = json.dumps(self._make_from_node_json_acceptable(from_node)) + return meta + def _make_from_node_json_acceptable(self, from_node: Optional[List[NodeSource]]): + """ + Recursively serialize from_node metadata which can be a list of NodeSource objects. + """ + if from_node is None: + return None + + json_acceptable_from_node = [node_source.to_dict() for node_source in from_node if isinstance(node_source, NodeSource)] + + return json_acceptable_from_node + def serialize_alloc_inputs( self, inputs # pyre-ignore ) -> List[schema.NamedArgument]: @@ -473,8 +490,59 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: if debug_handle := metadata.get("debug_handle"): res["debug_handle"] = int(debug_handle) + if from_node_str := metadata.get("from_node"): + res["from_node"] = self._deserialize_from_node(json.loads(from_node_str)) + return res + def _deserialize_from_node(self, from_node_data): + """ + Recursively deserialize from_node metadata from JSON data. + """ + if from_node_data is None: + return None + + if isinstance(from_node_data, list): + return [self._deserialize_from_node(item) for item in from_node_data] + + if isinstance(from_node_data, dict): + # Create a NodeSource object directly without going through the constructor + # to avoid issues with graph ID and node creation + node_source = NodeSource.__new__(NodeSource) + + # Set the basic attributes + node_source.pass_name = from_node_data.get('pass_name', '') + + # Parse action string back to NodeSourceAction enum list + action_str = from_node_data.get('action', '') + actions = [] + if action_str: + for action_name in action_str.split('+'): + if action_name.upper() == 'CREATE': + actions.append(NodeSourceAction.CREATE) + elif action_name.upper() == 'REPLACE': + actions.append(NodeSourceAction.REPLACE) + node_source.action = actions + + # Create the NodeInfo object directly + if 'name' in from_node_data and 'target' in from_node_data and 'graph_id' in from_node_data: + node_info = NodeSource.NodeInfo( + from_node_data.get('name', ''), + from_node_data.get('target', ''), + from_node_data.get('graph_id', -1) + ) + node_source.node_info = node_info + else: + node_source.node_info = None + + # Recursively deserialize nested from_node + node_source.from_node = self._deserialize_from_node(from_node_data.get('from_node', [])) + + return node_source + + # Fallback for primitive types + return from_node_data + # pyre-ignore def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]): def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec: diff --git a/exir/tests/test_serde.py b/exir/tests/test_serde.py index 67821d0bffb..d6a4ae235ba 100644 --- a/exir/tests/test_serde.py +++ b/exir/tests/test_serde.py @@ -275,3 +275,37 @@ def forward(self, x): ) self.assertEqual(metadata[0], metadata_serde[0]) self.assertEqual(list(metadata[1].keys()), list(metadata_serde[1].keys())) + + def test_meta_debug_handle_and_from_node(self) -> None: + class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.conv_layer = nn.Conv2d( + in_channels=1, out_channels=64, kernel_size=3, padding=1 + ) + + def forward(self, x): + return self.conv_layer(x) + + m = Model() + inputs = (torch.randn(1, 1, 32, 32),) + + edge = to_edge(export(m, inputs, strict=True)) + edge_new = deserialize(serialize(edge.exported_program())) + for node, node_new in zip( + edge.exported_program().graph_module.graph.nodes, + edge_new.graph_module.graph.nodes, + ): + if node.op not in {"placeholder", "output"}: + self.assertIsNotNone(node.meta.get("debug_handle")) + self.assertIsNotNone(node.meta.get("from_node")) + self.assertEqual( + node.meta.get("debug_handle"), node_new.meta.get("debug_handle") + ) + self.assertEqual( + len(node.meta.get("from_node")), len(node_new.meta.get("from_node")) + ) + for node_source, node_source_new in zip( + node.meta.get("from_node"), node_new.meta.get("from_node") + ): + self.assertEqual(node_source.to_dict(), node_source_new.to_dict())