Skip to content

Commit c8c711c

Browse files
Modified fx_importer to support hop_while_loop
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1 parent 7e1cd8b commit c8c711c

File tree

2 files changed

+247
-6
lines changed

2 files changed

+247
-6
lines changed

python/torch_mlir/extras/fx_importer.py

Lines changed: 208 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
Context,
107107
DenseElementsAttr,
108108
DenseResourceElementsAttr,
109+
FlatSymbolRefAttr,
109110
FloatAttr,
110111
BF16Type,
111112
ComplexType,
@@ -834,8 +835,50 @@ def import_program(
834835
node_importer.return_node_values(loc, user_outputs, constant_output_values)
835836

836837
self.symbol_table.insert(func_op)
838+
839+
# Import all child graph modules recursively for HOPs
840+
# Even though import_stateless_graph is deprecated as an entrypoint mechanism,
841+
# HOP operator graphs are stateless graphs with no mutation, and it is correct
842+
# to import them as stateless graphs.
843+
self._import_all_child_modules(
844+
prog,
845+
func_name,
846+
import_symbolic_shape_expressions
847+
)
848+
837849
return func_op
838850

851+
def _import_all_child_modules(
852+
self,
853+
prog: torch.export.ExportedProgram,
854+
parent_name: str,
855+
import_symbolic_shape_expressions: bool = False
856+
):
857+
"""Recursively import all child modules that have graphs.
858+
859+
This simple approach imports all submodules recursively, which is sufficient
860+
for HOP operations since they only reference existing submodules.
861+
"""
862+
for child_name, child_module in prog.graph.owning_module.named_children():
863+
if isinstance(child_module, GraphModule) and hasattr(child_module, 'graph'):
864+
# Generate function name: parent_childname
865+
child_func_name = f"{parent_name}_{child_name}_{id(child_module)}"
866+
867+
# Import the child as a stateless graph (private function)
868+
self.import_stateless_graph(
869+
child_module.graph,
870+
func_name=child_func_name,
871+
func_visibility="private",
872+
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
873+
)
874+
875+
# Recursively import its children
876+
self._import_all_child_modules(
877+
child_module,
878+
child_func_name,
879+
import_symbolic_shape_expressions
880+
)
881+
839882
def import_frozen_program(
840883
self,
841884
prog: torch.export.ExportedProgram,
@@ -996,9 +1039,17 @@ def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]:
9961039
if node.op == "placeholder":
9971040
input_types.append(self._cc.node_val_to_type(node))
9981041
elif node.op == "output":
999-
# An output node's args[0] is the return value. This seems to
1000-
# always be "boxed" as a tuple, which we emit as multi-results.
1001-
for result_node in node.args[0]:
1042+
# An output node's args[0] is the return value. This is usually
1043+
# "boxed" as a tuple, which we emit as multi-results. However,
1044+
# for single returns it might be a single Node.
1045+
output_arg = node.args[0]
1046+
# Handle both single Node and tuple/list of Nodes
1047+
if isinstance(output_arg, (list, tuple)):
1048+
result_nodes = output_arg
1049+
else:
1050+
result_nodes = [output_arg]
1051+
1052+
for result_node in result_nodes:
10021053
if result_node is None:
10031054
result_types.append(
10041055
IrType.parse("!torch.none", context=self._c)
@@ -1509,7 +1560,14 @@ def import_nodes(
15091560
elif op == "output" and not skip_placeholders_outputs:
15101561
# args[0] is a singleton tuple that we flatten into multiple
15111562
# results.
1512-
operands = [self._import_argument(loc, arg) for arg in node.args[0]]
1563+
output_arg = node.args[0]
1564+
# Handle both single Node and tuple/list of Nodes
1565+
if isinstance(output_arg, (list, tuple)):
1566+
result_nodes = output_arg
1567+
else:
1568+
result_nodes = [output_arg]
1569+
1570+
operands = [self._import_argument(loc, arg) for arg in result_nodes]
15131571
func_dialect.ReturnOp(operands, loc=loc)
15141572

15151573
if import_symbolic_shape_expressions:
@@ -1612,6 +1670,139 @@ def _import_hop(self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperat
16121670
)
16131671
handler(loc, node, hop)
16141672

1673+
def _import_hop_while_loop(
1674+
self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperator
1675+
):
1676+
"""Imports the torch._higher_order_ops.while_loop HOP.
1677+
1678+
Args format: (cond_fn, body_fn, carries)
1679+
The cond_fn and body_fn are get_attr nodes pointing to submodule graphs
1680+
that have already been imported by import_program().
1681+
1682+
Emits torch.prim.Loop with proper control flow structure.
1683+
"""
1684+
# while_loop HOP args: (cond_fn, body_fn, carries...)
1685+
# Unpack the first two args and the rest as carries
1686+
cond_fn_node, body_fn_node, *carries = node.args
1687+
1688+
# Extract function names from get_attr nodes
1689+
# The subgraphs were imported with names like "main_{target}"
1690+
assert cond_fn_node.op == "get_attr", f"Expected get_attr for cond_fn, got {cond_fn_node.op}"
1691+
assert body_fn_node.op == "get_attr", f"Expected get_attr for body_fn, got {body_fn_node.op}"
1692+
1693+
root_module = node.graph.owning_module
1694+
cond_fn_module = getattr(root_module, cond_fn_node.target, None)
1695+
body_fn_module = getattr(root_module, body_fn_node.target, None)
1696+
1697+
# Generate function names with module IDs for uniqueness
1698+
cond_fn_name = f"main_{cond_fn_node.target}_{id(cond_fn_module)}"
1699+
body_fn_name = f"main_{body_fn_node.target}_{id(body_fn_module)}"
1700+
1701+
# Import the carries (loop state variables)
1702+
carry_values = []
1703+
for carry in carries:
1704+
if isinstance(carry, tuple):
1705+
# Handle tuple carries by importing each element
1706+
carry_values.extend(self._import_tuple_argument(loc, carry, None))
1707+
else:
1708+
carry_values.append(self._import_argument(loc, carry))
1709+
1710+
# Determine result types from node metadata
1711+
node_val = node.meta.get("val")
1712+
if isinstance(node_val, (list, tuple)) and len(node_val) > 1:
1713+
result_types = [self._cc.value_info_to_type(v) for v in node_val]
1714+
self._multi_result_nodes.add(node)
1715+
else:
1716+
result_types = [self._cc.node_val_to_type(node)]
1717+
1718+
# Call the condition function with initial carries to get initial condition
1719+
cond_result_type = self._cc.get_vtensor_type(torch.Size([]), torch.bool)
1720+
1721+
initial_cond_call = Operation.create(
1722+
"func.call",
1723+
attributes={"callee": FlatSymbolRefAttr.get(cond_fn_name)},
1724+
results=[cond_result_type],
1725+
operands=carry_values,
1726+
loc=loc,
1727+
)
1728+
1729+
# Convert vtensor<bool> to torch.bool
1730+
bool_conv = Operation.create(
1731+
name="torch.aten.Bool.Tensor",
1732+
results=[self._cc.torch_bool_type],
1733+
operands=[initial_cond_call.results[0]],
1734+
loc=loc,
1735+
)
1736+
1737+
# Create max iterations constant (INT64_MAX)
1738+
with loc:
1739+
max_iter = _make_constant_op(
1740+
"torch.constant.int",
1741+
self._cc.integer_attr(9223372036854775807, 64),
1742+
self._cc.torch_int_type,
1743+
)
1744+
1745+
# Create torch.prim.Loop operation with region
1746+
loop_op = Operation.create(
1747+
name="torch.prim.Loop",
1748+
results=result_types,
1749+
operands=[max_iter.results[0], bool_conv.results[0]] + carry_values,
1750+
regions=1,
1751+
loc=loc,
1752+
)
1753+
1754+
# Create loop body region with block arguments
1755+
# Block args: iteration counter (!torch.int) + all carry values
1756+
loop_region = loop_op.regions[0]
1757+
block_arg_types = [self._cc.torch_int_type] + result_types
1758+
with loc:
1759+
loop_block = Block.create_at_start(loop_region, block_arg_types)
1760+
1761+
# Inside the loop body, call body function and condition function
1762+
with InsertionPoint(loop_block):
1763+
# Call body function with current carry values (skip iteration counter)
1764+
body_results_op = Operation.create(
1765+
name="func.call",
1766+
attributes={"callee": FlatSymbolRefAttr.get(body_fn_name)},
1767+
results=result_types,
1768+
operands=list(loop_block.arguments[1:]), # Skip iteration counter
1769+
loc=loc,
1770+
)
1771+
body_results = list(body_results_op.results)
1772+
1773+
# Call condition function with updated carries
1774+
cond_result_loop = Operation.create(
1775+
name="func.call",
1776+
attributes={"callee": FlatSymbolRefAttr.get(cond_fn_name)},
1777+
results=[IrType.parse("!torch.vtensor<[],i1>", context=self._c)],
1778+
operands=body_results,
1779+
loc=loc,
1780+
).result
1781+
1782+
# Convert to bool
1783+
cond_bool = Operation.create(
1784+
name="torch.aten.Bool.Tensor",
1785+
results=[self._cc.torch_bool_type],
1786+
operands=[cond_result_loop],
1787+
loc=loc,
1788+
).result
1789+
1790+
# Emit loop condition with updated carries
1791+
Operation.create(
1792+
name="torch.prim.Loop.condition",
1793+
results=[],
1794+
operands=[cond_bool] + body_results,
1795+
loc=loc,
1796+
)
1797+
1798+
# Bind the loop results to the node
1799+
if len(result_types) > 1:
1800+
self._multi_result_nodes.add(node)
1801+
for i, value in enumerate(loop_op.results):
1802+
self.bind_node_value(node, value, i)
1803+
else:
1804+
self.bind_node_value(node, loop_op.results[0])
1805+
16151806
def _import_hop_auto_functionalized(
16161807
self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperator
16171808
):
@@ -1823,6 +2014,9 @@ def _import_argument(
18232014
argument_value = self.resolve_node_value(arg)
18242015
elif isinstance(arg, torch_fx.immutable_collections.immutable_list):
18252016
argument_value = self._import_list_argument(loc, arg, expected_jit_type)
2017+
elif isinstance(arg, tuple):
2018+
# Handle tuples of tensors (common in while_loop carries)
2019+
argument_value = self._import_tuple_argument(loc, arg, expected_jit_type)
18262020
elif isinstance(expected_jit_type, torch.TensorType) and not isinstance(
18272021
arg, torch.Tensor
18282022
):
@@ -1930,6 +2124,13 @@ def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value:
19302124
loc=loc,
19312125
).result
19322126

2127+
def _import_tuple_argument(
2128+
self, loc: Location, arg: tuple, expected_jit_type
2129+
) -> List[Value]:
2130+
"""Import a tuple argument by importing each element separately."""
2131+
# For tuples in while_loop carries, treat each element as a separate argument
2132+
return [self._import_argument(loc, elem, expected_jit_type) for elem in arg]
2133+
19332134
def _import_list_argument(
19342135
self, loc: Location, arg: Sequence[NodeArgument], expected_jit_type
19352136
) -> Value:
@@ -2040,6 +2241,8 @@ def _import_getitem(self, loc: Location, node: torch.fx.Node):
20402241
# NOTE: the length of the list must be knowable at compile time.
20412242
if ref_node not in self._unpack_list_values:
20422243
node_result = self.resolve_node_value(ref_node, 0)
2244+
node_val = ref_node.meta.get("val")
2245+
20432246
if str(node_result.type) in TORCH_LIST_TYPES:
20442247
result_types = [
20452248
self._cc.value_info_to_type(v) for v in ref_node.meta["val"]
@@ -2510,4 +2713,4 @@ def aten__embedding_bag_forward_only_default(node: torch_fx.Node):
25102713
def node_canonicalize(node: torch_fx.Node):
25112714
if node.target in NODE_CANONICALIZE:
25122715
return NODE_CANONICALIZE[node.target](node)
2513-
return node
2716+
return node

test/python/fx_importer/basic_test.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,44 @@ def forward(self):
205205
)
206206
print(m)
207207

208+
@run
209+
# CHECK-LABEL: test_while_loop_two_returns
210+
# CHECK: func.func @test_while_loop_two_returns
211+
# CHECK-SAME: -> (!torch.vtensor<[],si64>, !torch.vtensor<[4,4],f32>)
212+
213+
# Validate literal/init plumbing:
214+
# CHECK: %[[ZERO:.*]] = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
215+
# CHECK: %[[NONE:.*]] = torch.constant.none
216+
# CHECK: %[[CLONE:.*]] = torch.aten.clone %[[ZERO]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64>
217+
218+
# CHECK: %[[COND:.*]] = call @while_loop_cond_graph_{{[0-9]+}}(%[[CLONE]]
219+
# CHECK: torch.aten.Bool.Tensor %[[COND]]
220+
# CHECK: %[[MAX_ITER:.*]] = torch.constant.int 9223372036854775807
221+
# CHECK: torch.prim.Loop %[[MAX_ITER]]
222+
223+
# CHECK: func.func private @while_loop_cond_graph_{{[0-9]+}}
224+
# CHECK: torch.aten.lt.Scalar
225+
226+
# CHECK: func.func private @while_loop_body_graph_{{[0-9]+}}
227+
# CHECK: torch.aten.add.Scalar
228+
# CHECK: torch.aten.mul.Tensor
229+
def test_while_loop_two_returns():
230+
class M(nn.Module):
231+
def forward(self, x):
232+
# Simple while_loop that carries a scalar and a tensor.
233+
def body(i, x):
234+
return i + 1, x * x
235+
i0 = torch.tensor(0)
236+
from torch._higher_order_ops.while_loop import while_loop
237+
238+
out_i, out_x = while_loop(
239+
lambda i, x: i < 3, body, (i0, x)
240+
)
241+
return out_i, out_x
242+
243+
# Export -> import to Torch-MLIR
244+
m = fx.export_and_import(M(), torch.randn(4, 4), func_name="test_while_loop_two_returns")
245+
print(m)
208246

209247
@run
210248
# CHECK-LABEL: test_stack_trace
@@ -229,4 +267,4 @@ def foo(x, y):
229267
y = torch.randn(128, 128)
230268
m = fx.export_and_import(Basic(), x, y, func_name="test_stack_trace")
231269
mlir_asm = m.operation.get_asm(enable_debug_info=True)
232-
print(mlir_asm)
270+
print(mlir_asm)

0 commit comments

Comments
 (0)