From 14e15df0f2797f55f5b54bad0d17824ae080dba5 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 15 Jul 2025 16:38:23 -0700 Subject: [PATCH] make operator name consistent before and after serde Differential Revision: [D78380855](https://our.internmc.facebook.com/intern/diff/D78380855/) [ghstack-poisoned] --- exir/serde/export_serialize.py | 37 +++++++++++++++++++---------- exir/serde/schema.py | 1 + exir/serde/serialize.py | 3 +++ exir/tests/test_serde.py | 43 ++++++++++++++++++++++++++++++++-- 4 files changed, 70 insertions(+), 14 deletions(-) diff --git a/exir/serde/export_serialize.py b/exir/serde/export_serialize.py index 7a1d35c432e..02d225e95a0 100644 --- a/exir/serde/export_serialize.py +++ b/exir/serde/export_serialize.py @@ -504,6 +504,7 @@ def handle_call_function(self, node: torch.fx.Node): assert len(node.kwargs) == 0 meta_val = node.meta["val"] ex_node = Node( + name=node.name, target=self.serialize_operator(node.target), inputs=self.serialize_sym_op_inputs(node.target, node.args), outputs=[ @@ -517,6 +518,7 @@ def handle_call_function(self, node: torch.fx.Node): assert len(node.kwargs) == 0 meta_val = node.meta["val"] ex_node = Node( + name=node.name, target=self.serialize_operator(node.target), inputs=self.serialize_sym_op_inputs(node.target, node.args), outputs=[ @@ -528,6 +530,7 @@ def handle_call_function(self, node: torch.fx.Node): ) elif isinstance(node.target, torch._ops.OpOverload): ex_node = Node( + name=node.name, target=self.serialize_operator(node.target), inputs=self.serialize_inputs(node.target, node.args, node.kwargs), outputs=self.serialize_outputs(node), @@ -536,6 +539,7 @@ def handle_call_function(self, node: torch.fx.Node): ) elif isinstance(node.target, torch._ops.HigherOrderOperator): ex_node = Node( + name=node.name, target=self.serialize_operator(node.target), inputs=self.serialize_hoo_inputs(node.args, node.kwargs), outputs=self.serialize_hoo_outputs(node), @@ -1658,7 +1662,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: def deserialize_node(self, serialized_node: Node, target: Callable) -> None: if target in _SYM_BOOL_OPS or target in _SYM_INT_OPS: - name = serialized_node.outputs[0].value.as_name + name = serialized_node.name args = self.deserialize_sym_op_inputs(serialized_node.inputs) fx_node = self.graph.create_node("call_function", target, args, {}, name) @@ -1671,12 +1675,7 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None: # have names that are consistent with serialized. # # HOPs don't have schema yet, just check the output lengths and as_tensor attribute - name = ( - serialized_node.outputs[0].as_tensor.name - if len(serialized_node.outputs) == 1 - and hasattr(serialized_node.outputs[0], "as_tensor") - else None - ) + name = serialized_node.name fx_node = self.graph.create_node( "call_function", target, args, kwargs, name ) @@ -1687,16 +1686,30 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None: # For convenience: if this node returns a single tensor, name the # newly-created node after it. This ensures that these tensor values # have names that are consistent with serialized. - name = ( - serialized_node.outputs[0].as_tensor.name - if _is_single_tensor_return(target) - else None # FX will generate a name for us. - ) + + print(target) + print(target.__name__) + print(target.name) + + name = serialized_node.name + + print(name) + + if name == "split_tensor": + print(serialized_node) + print(serialized_node.inputs) + print(serialized_node.outputs) + args, kwargs = self.deserialize_inputs(target, serialized_node) fx_node = self.graph.create_node( "call_function", target, args, kwargs, name ) self.deserialize_outputs(serialized_node, fx_node) + + if name == "split_tensor": + print(fx_node) + print(fx_node.args) + print(fx_node.kwargs) else: raise SerializeError( f"Unsupported target type for node {serialized_node}: {target}" diff --git a/exir/serde/schema.py b/exir/serde/schema.py index 6d250ee7923..f91526c385f 100644 --- a/exir/serde/schema.py +++ b/exir/serde/schema.py @@ -195,6 +195,7 @@ class NamedArgument: @dataclass class Node: + name: str target: str inputs: List[NamedArgument] outputs: List[Argument] diff --git a/exir/serde/serialize.py b/exir/serde/serialize.py index ef969d0f2af..25cbbd616e8 100644 --- a/exir/serde/serialize.py +++ b/exir/serde/serialize.py @@ -89,6 +89,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None: if node.target is memory.alloc: ex_node = schema.Node( + name=node.name, target="memory.alloc", inputs=self.serialize_alloc_inputs(node.args), outputs=self.serialize_arbitrary_outputs(node), @@ -99,6 +100,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None: elif isinstance(node.target, EdgeOpOverload): assert node.target._op is not None ex_node = schema.Node( + name=node.name, target=self.serialize_operator(node.target), # pyre-ignore Undefined attribute [16]: Item `typing.Callable` of # `typing.Union[typing.Callable[..., typing.Any], str]` has no attribute `_op`. @@ -111,6 +113,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None: return elif node.target is delegate.executorch_call_delegate: ex_node = schema.Node( + name=node.name, target=self.serialize_operator(node.target), inputs=self.serialize_call_delegate_inputs(node.args), outputs=self.serialize_arbitrary_outputs(node), diff --git a/exir/tests/test_serde.py b/exir/tests/test_serde.py index d6a4ae235ba..df1e3cc119f 100644 --- a/exir/tests/test_serde.py +++ b/exir/tests/test_serde.py @@ -42,6 +42,7 @@ def check_ep( ep1: TorchExportedProgram, ep2: TorchExportedProgram, inputs: Tuple[exir.Value, ...], + compare_closeness: bool = False, ) -> None: """ Checks if two graphs are equivalent @@ -55,15 +56,53 @@ def check_ep( for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs, strict=True): self.assertTrue(torch.allclose(orig, loaded)) + # print node names in ep1 and ep2 seperately + print("---------------------") + print("ep1") + print(len(ep1.graph.nodes)) + for node in ep1.graph.nodes: + print(node.name) + print("**************") + print("ep2") + print(len(ep2.graph.nodes)) + for node in ep2.graph.nodes: + print(node.name) + print("____________________") + + if compare_closeness: + self.assertEqual(len(ep1.graph.nodes), len(ep2.graph.nodes)) + for node_a, node_b in zip(ep1.graph.nodes, ep2.graph.nodes): + self.assertEqual(node_a.target, node_b.target) + self.assertEqual(node_a.name, node_b.name) + self.assertEqual(node_a.type, node_b.type) + self.assertEqual(node_a.op, node_b.op) + if node_a.op != "call_function": + continue + + self.assertEqual( + node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle") + ) + from_node_a = node_a.meta.get("from_node") + from_node_b = node_b.meta.get("from_node") + + if from_node_a is None: + self.assertIsNone(from_node_b) + else: + self.assertIsNotNone(from_node_b) + for node_source_a, node_source_b in zip(from_node_a, from_node_b): + self.assertEqual( + node_source_a.to_dict(), node_source_b.to_dict() + ) + # pyre-ignore def check_serde(self, m, inputs, check_executorch=True) -> None: aten = export(m, inputs, strict=True) aten_new = deserialize(serialize(aten)) - self.check_ep(aten, aten_new, inputs) + self.check_ep(aten, aten_new, inputs, compare_closeness=True) edge = to_edge(aten) edge_new = deserialize(serialize(edge.exported_program())) - self.check_ep(edge.exported_program(), edge_new, inputs) + self.check_ep(edge.exported_program(), edge_new, inputs, compare_closeness=True) buffer = io.BytesIO() exir.save(edge.exported_program(), buffer)