Skip to content

Serializing from_node info in et serializer #12462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: gh/gasoonjia/23/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions devtools/etrecord/tests/etrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,17 @@ def check_graph_closeness(self, graph_a, graph_b):
self.assertEqual(
node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle")
)
from_node_a = node_a.meta.get("from_node")
from_node_b = node_b.meta.get("from_node")

if from_node_a is None:
self.assertIsNone(from_node_b)
else:
self.assertIsNotNone(from_node_b)
for node_source_a, node_source_b in zip(from_node_a, from_node_b):
self.assertEqual(
node_source_a.to_dict(), node_source_b.to_dict()
)

def test_etrecord_generation(self):
captured_output, edge_output, et_output = self.get_test_model()
Expand Down
31 changes: 31 additions & 0 deletions exir/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from torch._export.verifier import load_verifier
from torch.fx.experimental import symbolic_shapes
from torch.fx.traceback import NodeSource, NodeSourceAction

log: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -141,8 +142,24 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
debug_handle = node.meta["debug_handle"]
meta["debug_handle"] = str(debug_handle)

if "from_node" in node.meta:
from_node = node.meta["from_node"]
# Serialize from_node as JSON since it's a complex nested structure
meta["from_node"] = json.dumps(self._make_from_node_json_acceptable(from_node))

return meta

def _make_from_node_json_acceptable(self, from_node: Optional[List[NodeSource]]):
"""
Recursively serialize from_node metadata which can be a list of NodeSource objects.
"""
if from_node is None:
return None

json_acceptable_from_node = [node_source.to_dict() for node_source in from_node if isinstance(node_source, NodeSource)]

return json_acceptable_from_node

def serialize_alloc_inputs(
self, inputs # pyre-ignore
) -> List[schema.NamedArgument]:
Expand Down Expand Up @@ -473,8 +490,22 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
if debug_handle := metadata.get("debug_handle"):
res["debug_handle"] = int(debug_handle)

if from_node_str := metadata.get("from_node"):
res["from_node"] = self._deserialize_from_node(json.loads(from_node_str))

return res

def _deserialize_from_node(self, from_node_data: Optional[List[Dict[str, Any]]]) -> Optional[List[NodeSource]]:
"""
Recursively deserialize from_node metadata from JSON data.
"""
if from_node_data is None:
return None

assert isinstance(from_node_data, list)

return [NodeSource._from_dict(fn_dict) for fn_dict in from_node_data]

# pyre-ignore
def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]):
def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec:
Expand Down
34 changes: 34 additions & 0 deletions exir/tests/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,37 @@ def forward(self, x):
)
self.assertEqual(metadata[0], metadata_serde[0])
self.assertEqual(list(metadata[1].keys()), list(metadata_serde[1].keys()))

def test_meta_debug_handle_and_from_node(self) -> None:
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv_layer = nn.Conv2d(
in_channels=1, out_channels=64, kernel_size=3, padding=1
)

def forward(self, x):
return self.conv_layer(x)

m = Model()
inputs = (torch.randn(1, 1, 32, 32),)

edge = to_edge(export(m, inputs, strict=True))
edge_new = deserialize(serialize(edge.exported_program()))
for node, node_new in zip(
edge.exported_program().graph_module.graph.nodes,
edge_new.graph_module.graph.nodes,
):
if node.op not in {"placeholder", "output"}:
self.assertIsNotNone(node.meta.get("debug_handle"))
self.assertIsNotNone(node.meta.get("from_node"))
self.assertEqual(
node.meta.get("debug_handle"), node_new.meta.get("debug_handle")
)
self.assertEqual(
len(node.meta.get("from_node")), len(node_new.meta.get("from_node"))
)
for node_source, node_source_new in zip(
node.meta.get("from_node"), node_new.meta.get("from_node")
):
self.assertEqual(node_source.to_dict(), node_source_new.to_dict())
Loading