Skip to content

support back propagate debug handle to arbitrary ancestor export graph #12580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: gh/gasoonjia/26/base
Choose a base branch
from
185 changes: 145 additions & 40 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
from typing import Any, Dict, IO, List, Mapping, Optional, Set, Tuple, TypeAlias, Union

import executorch.devtools.etdump.schema_flatcc as flatcc

Expand All @@ -37,7 +37,7 @@

from executorch.exir.debug_handle_utils import (
DEBUG_HANDLE_KEY,
get_greatest_ancestor_node_identifier,
FROM_NODE_KEY,
UNSET_DEBUG_HANDLE,
)

Expand All @@ -46,6 +46,7 @@
from tabulate import tabulate

from torch.export import ExportedProgram
from torch.fx import Node

FORWARD = "forward"
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"
Expand Down Expand Up @@ -936,6 +937,133 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
)


def get_ancestor_node_identifiers(node: Node) -> List[str]:
"""Get the identifier of the ancestor node of the given node, with the graph id the ancestor node lives in.

The identifier is the concatenation of the node name and graph id of the
greatest ancestor node, where the graph id is the unique id for every graph
module in the export flow and node name is unique within the same graph module.

Returns: the identifiers of all its ancestor nodes
"""

node_source = node.meta[FROM_NODE_KEY]
node_source = node_source[-1]
ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"]

while len(node_source.from_node) > 0:
node_source = node_source.from_node[-1]
ancestor_node_ids.append(f"{node_source.name}.{str(node_source.graph_id)}")

return ancestor_node_ids


def get_parent_node_identifier(node: Node) -> Optional[str]:
"""Get the identifier of the parent node of the given node, with the graph id the parent node lives in.

The identifier is the concatenation of the node name and graph id of the
greatest parent node, where the graph id is the unique id for every graph
module in the export flow and node name is unique within the same graph module.

Returns: the identifier of the parent node, or None if can not find the parent
"""

if FROM_NODE_KEY not in node.meta:
return None

node_source = node.meta[FROM_NODE_KEY][-1]
return f"{node_source.name}.{str(node_source.graph_id)}"


def _extract_ancestor_debug_handles(
edge_dialect_program: ExportedProgram,
) -> Dict[str, int]:
"""Extract mapping from ancestor node identifiers to debug handles."""
ancestors_node_id_to_debug_handle: Dict[str, int] = {}

def _extract_node_id_to_debug_handle(node: Node) -> None:
if node.op in ("placeholder", "output"):
return
for ancestor_node_id in get_ancestor_node_identifiers(node):
if ancestor_node_id not in ancestors_node_id_to_debug_handle:
ancestors_node_id_to_debug_handle[ancestor_node_id] = node.meta[
DEBUG_HANDLE_KEY
]
else:
assert (
ancestors_node_id_to_debug_handle[ancestor_node_id]
== node.meta[DEBUG_HANDLE_KEY]
)

bfs_trace_with_node_process(
edge_dialect_program.graph_module, _extract_node_id_to_debug_handle
)
return ancestors_node_id_to_debug_handle


def _find_matched_debug_handles(
exported_program: ExportedProgram,
exported_program_graph_id: int,
ancestors_node_id_to_debug_handle: Dict[str, int],
) -> Set[int]:
"""Find debug handles that have corresponding nodes in the exported program."""
matched_debug_handles: Set[int] = set()

def _find_n_match_node(node: Node) -> None:
if node.op in ("output", "placeholder"):
return
node_id = f"{node.name}.{exported_program_graph_id}"
parent_node_id = get_parent_node_identifier(node)
if node_id in ancestors_node_id_to_debug_handle:
matched_debug_handles.add(ancestors_node_id_to_debug_handle[node_id])
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
matched_debug_handles.add(ancestors_node_id_to_debug_handle[parent_node_id])

bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
return matched_debug_handles


def _verify_graph_match(
edge_dialect_program: ExportedProgram, matched_debug_handles: Set[int]
) -> bool:
"""Verify if every debug handle in edge dialect program has a corresponding node."""
graph_matched = True

def _check_graph_match(node: Node) -> None:
nonlocal graph_matched
if node.op in ("output", "placeholder"):
return
if node.meta[DEBUG_HANDLE_KEY] not in matched_debug_handles:
graph_matched = False

bfs_trace_with_node_process(edge_dialect_program.graph_module, _check_graph_match)
return graph_matched


def _apply_debug_handles(
exported_program: ExportedProgram,
exported_program_graph_id: int,
ancestors_node_id_to_debug_handle: Dict[str, int],
) -> None:
"""Apply debug handles to the exported program nodes."""

def _equip_debug_handle(node: Node) -> None:
if node.op in ("output", "placeholder"):
return
node_id = f"{node.name}.{exported_program_graph_id}"
parent_node_id = get_parent_node_identifier(node)
if node_id in ancestors_node_id_to_debug_handle:
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[node_id]
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[
parent_node_id
]
else:
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE

bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)


def propagate_back_debug_handle(
exported_program: ExportedProgram,
exported_program_graph_id: int,
Expand All @@ -953,47 +1081,24 @@ def propagate_back_debug_handle(
Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.

Return: True if:
a. every debug handle in the edge dialect program has a corresponding node in the exported program
b. the exported program is the greatest ancestor of the edge dialect program

Otherwise, return False.
Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False.
"""
# 1. Extract mapping from ancestor node identifiers to debug handles
ancestors_node_id_to_debug_handle = _extract_ancestor_debug_handles(
edge_dialect_program
)

# 1. set up a mapping from debug handle to identifier of export program's node
# using edge dialect program nodes' debug handles and from_node info
export_graph_node_id_to_debug_handle = {
get_greatest_ancestor_node_identifier(node): node.meta[DEBUG_HANDLE_KEY]
for node in edge_dialect_program.graph.nodes
if node.op not in ("placeholder", "output")
}

# 2. equip debug handle to the exported program's nodes using the mapping
# number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle
n_matched_node = 0

def _find_n_match_node(node: torch.fx.Node) -> None:
nonlocal n_matched_node
if node.name in ("output", "placeholder"):
return
node_id = f"{node.name}.{exported_program_graph_id}"
if node_id in export_graph_node_id_to_debug_handle:
n_matched_node += 1

def _equip_debug_handle(node: torch.fx.Node) -> None:
if node.name in ("output", "placeholder"):
return
node_id = f"{node.name}.{exported_program_graph_id}"
if node_id in export_graph_node_id_to_debug_handle:
node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id]
else:
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE

bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
# 2. Find debug handles that have corresponding nodes in the exported program
matched_debug_handles = _find_matched_debug_handles(
exported_program, exported_program_graph_id, ancestors_node_id_to_debug_handle
)

# if any node in the edge dialect program has no corresponding node in the exported program, match failed
if n_matched_node != len(export_graph_node_id_to_debug_handle):
# 3. Verify if every debug handle in edge dialect program has a corresponding node
if not _verify_graph_match(edge_dialect_program, matched_debug_handles):
return False

bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)
# 4. Apply debug handles to the exported program
_apply_debug_handles(
exported_program, exported_program_graph_id, ancestors_node_id_to_debug_handle
)
return True
89 changes: 89 additions & 0 deletions devtools/inspector/tests/inspector_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,95 @@ def test_equip_debug_handle_to_export_program_success(self):
exported_program_debug_handles[0], edge_dialect_program_debug_handles[1]
)

def test_equip_debug_handle_to_strict_export_program_success(self):
"""Test that propagate_back_debug_handle returns True and properly equips debug handles."""
# Create a test model
model = models.FeedForwardBlock(5, 10)
inputs = (torch.rand(5, 5),)

# Export the model
exported_program = export(model, inputs, strict=True)
export_graph_id = id(exported_program.graph)

# Convert to edge dialect
edge_dialect_program = to_edge(exported_program).exported_program()

# Call propagate_back_debug_handle
result = propagate_back_debug_handle(
exported_program, export_graph_id, edge_dialect_program
)

self.assertTrue(result)

# Check that debug handles are properly equipped in the exported program
exported_program_debug_handles = []
for node in exported_program.graph.nodes:
if node.op not in ("placeholder", "output"):
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])

edge_dialect_program_debug_handles = []
for node in edge_dialect_program.graph.nodes:
if node.op not in ("placeholder", "output"):
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])

# The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem)
# So they should have the same debug handle
self.assertEqual(
exported_program_debug_handles[0], edge_dialect_program_debug_handles[0]
)
self.assertEqual(
exported_program_debug_handles[0], edge_dialect_program_debug_handles[1]
)

def test_equip_debug_handle_to_reexport_program_success(self):
"""Test that propagate_back_debug_handle returns True and properly equips debug handles."""
# Create a test model
model = models.FeedForwardBlock(5, 10)
inputs = (torch.rand(5, 5),)

# Export the model
init_export_program = export(model, inputs)
exported_program = export(init_export_program.module(), inputs)
export_graph_id = id(exported_program.graph)

# Convert to edge dialect
edge_dialect_program = to_edge(exported_program).exported_program()

# Call propagate_back_debug_handle
result = propagate_back_debug_handle(
exported_program, export_graph_id, edge_dialect_program
)

self.assertTrue(result)

# Check that debug handles are properly equipped in the exported program
exported_program_debug_handles = []
for node in exported_program.graph.nodes:
if node.op not in ("placeholder", "output"):
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])

edge_dialect_program_debug_handles = []
for node in edge_dialect_program.graph.nodes:
if node.op not in ("placeholder", "output"):
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])

# The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem)
# So they should have the same debug handle
self.assertEqual(
exported_program_debug_handles[0], edge_dialect_program_debug_handles[0]
)
self.assertEqual(
exported_program_debug_handles[0], edge_dialect_program_debug_handles[1]
)

def test_equip_debug_handle_to_export_program_failure(self):
"""Test that propagate_back_debug_handle returns False when there's a mismatch."""
# Create a test model
Expand Down
Loading