diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/TorchToSCF/TorchToSCF.cpp index 27e0a61f4b31..6f970de06592 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/TorchToSCF/TorchToSCF.cpp @@ -150,6 +150,10 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern { targetType = Torch::IntType::get(op->getContext()); torchArg = typeConverter->materializeSourceConversion( rewriter, scfWhileOp.getLoc(), targetType, {to}); + } else if (auto tty = dyn_cast(targetType)) { + targetType = op.getIterArgsInit()[barg.index()].getType(); + torchArg = typeConverter->materializeSourceConversion( + rewriter, scfWhileOp.getLoc(), targetType, {to}); } if (!torchArg) return rewriter.notifyMatchFailure(op, @@ -173,14 +177,6 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern { "unsupported type of the operand"); loopConditionIterArgs.push_back(shouldContinue); for (auto torchArg : primLoopConditionOp.getIterArgs()) { - Type torchType = torchArg.getType(); - - // If the argument is a torch tensor, directly add it in the list of - // iter args. - if (isa(torchType)) { - loopConditionIterArgs.push_back(torchArg); - continue; - } Value arg = typeConverter->materializeTargetConversion( rewriter, scfWhileOp->getLoc(), typeConverter->convertType(torchArg.getType()), {torchArg}); diff --git a/projects/pt1/e2e_testing/main.py b/projects/pt1/e2e_testing/main.py index ae4986b7f96c..567b00ef7a2f 100644 --- a/projects/pt1/e2e_testing/main.py +++ b/projects/pt1/e2e_testing/main.py @@ -40,6 +40,7 @@ from .xfail_sets import ( LINALG_XFAIL_SET, LINALG_CRASHING_SET, + TORCHSCRIPT_XFAIL_SET, STABLEHLO_PASS_SET, STABLEHLO_CRASHING_SET, TOSA_PASS_SET, @@ -167,7 +168,7 @@ def main(): crashing_set = set() elif args.config == "torchscript": config = TorchScriptTestConfig() - xfail_set = set() + xfail_set = TORCHSCRIPT_XFAIL_SET crashing_set = set() elif args.config == "lazy_tensor_core": config = LazyTensorCoreTestConfig() diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index f494bc8574e6..cdb9316a7c8d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -89,6 +89,11 @@ "AtenMmInt8Types_basic", } +TORCHSCRIPT_XFAIL_SET = { + # Compilation Error: torch.jit.frontend.UnsupportedNodeError: import statements aren't supported: + "TorchPrimLoopWhileLikeHOPModule_basic", +} + TORCHDYNAMO_XFAIL_SET = { #### General TorchDynamo/PyTorch errors # torch._dynamo.exc.Unsupported: Tensor.item @@ -246,6 +251,8 @@ "IsFloatingPointInt_False", "TorchPrimLoopForLikeModule_basic", "TorchPrimLoopWhileLikeModule_basic", + # torch._dynamo.exc.BackendCompilerFailed: Unsupported op: get_attr + "TorchPrimLoopWhileLikeHOPModule_basic", "ScalarConstantTupleModule_basic", # END tests failing due to: empty graph in dynamo # ERROR due to: backend never runs because of empty frame @@ -481,6 +488,7 @@ "TensorToBoolZeroRank_basic", "TensorToBool_basic", "ThresholdBackward2dMixedModule_basic", + "TorchPrimLoopWhileLikeHOPModule_basic", # Compilation error: failed to legalize operation 'func.call' "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "UpSampleNearest2dDynamicFactor_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", @@ -993,6 +1001,8 @@ "ElementwiseClampMinModule_bfloat16", "ElementwiseClampModule_bfloat16", "ElementwiseReluModule_bfloat16", + # Runtime error: failed to legalize operation 'torch.constant.int' + "TorchPrimLoopWhileLikeHOPModule_basic", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { @@ -2575,6 +2585,7 @@ LTC_XFAIL_SET = { "TorchPrimLoopForLikeTensorArgModule_basic" "CollapseAllDimensionsModule_basic", + "TorchPrimLoopWhileLikeHOPModule_basic", "CollapseRank1DynamicModule_basic", "CollapseStaticModule_basic", "CollapsePartialDynamicModule_basic", @@ -3261,6 +3272,8 @@ "ToCopyWithDTypeModule_basic", "TorchPrimLoopForLikeModule_basic", "TorchPrimLoopWhileLikeModule_basic", + # RuntimeError: Detected that you are using FX to torch.jit.trace a dynamo-optimized function + "TorchPrimLoopWhileLikeHOPModule_basic", "TraceModule_basic", "TraceModule_empty", "TraceModule_nonsquare", @@ -3957,6 +3970,8 @@ "ThresholdBackward2dMixedModule_basic", "TorchPrimLoopForLikeModule_basic", "TorchPrimLoopWhileLikeModule_basic", + # Runtime error: failed to legalize operation 'torch.aten.Bool.Tensor' + "TorchPrimLoopWhileLikeHOPModule_basic", "TraceModule_empty", "TraceUnsignedIntModule_empty", "TransposedConv1dNegativePadding_basic", @@ -5036,6 +5051,8 @@ "ToDtypeFloatFromIntModule_basic", "TorchPrimLoopForLikeModule_basic", "TorchPrimLoopWhileLikeModule_basic", + # RuntimeError: Detected that you are using FX to torch.jit.trace a dynamo-optimized function + "TorchPrimLoopWhileLikeHOPModule_basic", "TraceModule_basic", "TraceModule_empty", "TraceModule_nonsquare", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py index a04114043583..5bb907615e70 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py @@ -10,6 +10,7 @@ from torch_mlir_e2e_test.framework import TestUtils from torch_mlir_e2e_test.registry import register_test_case from torch_mlir_e2e_test.annotations import annotate_args, export +from torch._higher_order_ops.while_loop import while_loop # ============================================================================== @@ -78,3 +79,36 @@ def TorchPrimLoopForLikeTensorArgModule_basic(module, tu: TestUtils): x_test = torch.zeros([7, 9]).float() module.forward(x_test) + + +# ============================================================================== + + +class TorchPrimLoopWhileLikeHOPModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def body_fn(self, i, x): + return i + 1, x + 1 + + def cond_fn(self, i, x): + return i < 3 + + @export + @annotate_args( + [ + None, + ([7, 9], torch.float32, True), + ] + ) + def forward(self, x: torch.Tensor) -> torch.Tensor: + i0 = torch.tensor(0) + out_i, out_x = while_loop(self.cond_fn, self.body_fn, (i0, x)) + return out_i, out_x + + +@register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeHOPModule()) +def TorchPrimLoopWhileLikeHOPModule_basic(module, tu: TestUtils): + x_test = torch.zeros([7, 9]).float() + + module.forward(x_test) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 35501741149d..c5b9ffdad851 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -106,6 +106,7 @@ Context, DenseElementsAttr, DenseResourceElementsAttr, + FlatSymbolRefAttr, FloatAttr, BF16Type, ComplexType, @@ -536,6 +537,8 @@ class FxImporter: "_py_attr_tracker", "_hooks", "symbol_table", + "_graph_module_to_func_name", + "_func_name_counter", ] def __init__( @@ -563,6 +566,10 @@ def __init__( self._hooks = hooks or FxImporterHooks() self.symbol_table = SymbolTable(self._m.operation) self._hooks.prepare_module(self._m.operation) + # Used specifically in HOPs to map module IDs to function names + self._graph_module_to_func_name: Dict[int, str] = {} + # Handles collision of function names in the same module + self._func_name_counter: int = 0 def _config_check(self): for dname in REQUIRED_DIALCTS: @@ -823,6 +830,15 @@ def import_program( for node, (buffer_value, info) in buffer_bindings.items(): node_importer.lazy_import_buffer(loc, node, buffer_value, info) + # Import all child graph modules recursively for HOPs BEFORE importing nodes + # This is necessary because HOP nodes need to reference these functions. + # Even though import_stateless_graph is deprecated as an entrypoint mechanism, + # HOP operator graphs are stateless graphs with no mutation, and it is correct + # to import them as stateless graphs. + self._import_all_child_modules( + prog.graph.owning_module, func_name, import_symbolic_shape_expressions + ) + # Import all nodes and return. node_importer.import_nodes( all_producer_nodes.values(), @@ -834,8 +850,32 @@ def import_program( node_importer.return_node_values(loc, user_outputs, constant_output_values) self.symbol_table.insert(func_op) + return func_op + def _import_all_child_modules( + self, + module: GraphModule, + parent_name: str, + import_symbolic_shape_expressions: bool = False, + ): + """Import all child modules by delegating to import_graph_module. + + This is a thin wrapper that extracts the owning module and delegates to + import_graph_module for each child. + + Note: This only imports children, not the parent module itself. + """ + + for child_name, child_module in module.named_children(): + if isinstance(child_module, GraphModule) and hasattr(child_module, "graph"): + self.import_graph_module( + child_module, + func_name=child_name, + func_visibility="private", + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ) + def import_frozen_program( self, prog: torch.export.ExportedProgram, @@ -926,6 +966,14 @@ def import_frozen_program( node.replace_all_uses_with(replacement) g.erase_node(node) + # Import child modules for HOPs before importing the main graph + # This ensures that any higher-order operations (like while_loop) can + # reference the already-imported child module functions + if hasattr(g, "owning_module") and g.owning_module is not None: + self._import_all_child_modules( + g.owning_module, func_name, import_symbolic_shape_expressions + ) + return self.import_stateless_graph( g, func_name=func_name, @@ -933,13 +981,56 @@ def import_frozen_program( import_symbolic_shape_expressions=import_symbolic_shape_expressions, ) - def import_graph_module(self, gm: GraphModule) -> Operation: + def import_graph_module( + self, + gm: GraphModule, + *, + func_name: str = "main", + func_visibility: Optional[str] = None, + import_symbolic_shape_expressions: bool = False, + ) -> Operation: """Low-level import of a GraphModule assuming that it has been functionalized. + This method recursively imports all child GraphModules first, then imports + the provided GraphModule itself. This ensures that any higher-order operations + that reference child modules will find them already imported. + TODO: This mechanism is deprecated by the `import_program` entry-point and it should be removed when no longer required for backwards compatibility. + + Note: This method should only be used for HOPs. """ - return self.import_stateless_graph(gm.graph) + # Store the mapping for this module itself (HOPs will need to look this up) + module_id = id(gm) + if module_id not in self._graph_module_to_func_name: + # Ensure the func_name is unique + final_func_name = func_name + if func_name in self._graph_module_to_func_name.values(): + final_func_name = f"{func_name}_{self._func_name_counter}" + self._func_name_counter += 1 + self._graph_module_to_func_name[module_id] = final_func_name + else: + # Module already imported, use existing name + final_func_name = self._graph_module_to_func_name[module_id] + + # First, recursively import all child modules + for child_name, child_module in gm.named_children(): + if isinstance(child_module, GraphModule) and hasattr(child_module, "graph"): + # Recursively import this child (which will handle its own mapping) + self.import_graph_module( + child_module, + func_name=child_name, + func_visibility="private", + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ) + + # Then import this module's own graph + return self.import_stateless_graph( + gm.graph, + func_name=final_func_name, + func_visibility=func_visibility, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ) def import_stateless_graph( self, @@ -972,6 +1063,11 @@ def import_stateless_graph( self._cc, entry_block, ) + + # Note: Child module importing is handled by import_graph_module, which is + # the recommended entry point. This method is deprecated and should only be + # used for stateless graphs that truly have no child modules. + node_importer.import_nodes( g.nodes, import_symbolic_shape_expressions=import_symbolic_shape_expressions ) @@ -996,9 +1092,18 @@ def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]: if node.op == "placeholder": input_types.append(self._cc.node_val_to_type(node)) elif node.op == "output": - # An output node's args[0] is the return value. This seems to - # always be "boxed" as a tuple, which we emit as multi-results. - for result_node in node.args[0]: + # An output node's args[0] is the return value. This is usually + # "boxed" as a tuple, which we emit as multi-results. However, + # for single returns it might be a single Node. + output_arg = node.args[0] + # Handle both single Node and tuple/list of Nodes + result_nodes = ( + output_arg + if isinstance(output_arg, (list, tuple)) + else [output_arg] + ) + + for result_node in result_nodes: if result_node is None: result_types.append( IrType.parse("!torch.none", context=self._c) @@ -1509,7 +1614,14 @@ def import_nodes( elif op == "output" and not skip_placeholders_outputs: # args[0] is a singleton tuple that we flatten into multiple # results. - operands = [self._import_argument(loc, arg) for arg in node.args[0]] + output_arg = node.args[0] + # Handle both single Node and tuple/list of Nodes + result_nodes = ( + output_arg + if isinstance(output_arg, (list, tuple)) + else [output_arg] + ) + operands = [self._import_argument(loc, arg) for arg in result_nodes] func_dialect.ReturnOp(operands, loc=loc) if import_symbolic_shape_expressions: @@ -1612,6 +1724,142 @@ def _import_hop(self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperat ) handler(loc, node, hop) + def _import_hop_while_loop( + self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperator + ): + """Imports the torch._higher_order_ops.while_loop HOP. + + Args format: (cond_fn, body_fn, carries) + The cond_fn and body_fn are get_attr nodes pointing to submodule graphs + that have already been imported by import_program(). + + Emits torch.prim.Loop with proper control flow structure. + """ + # while_loop HOP args: (cond_fn, body_fn, car`ries...) + # Unpack the first two args and the rest as carries + cond_fn_node, body_fn_node, *carries = node.args + + # Extract function names from get_attr nodes + assert ( + cond_fn_node.op == "get_attr" + ), f"Expected get_attr for cond_fn, got {cond_fn_node.op}" + assert ( + body_fn_node.op == "get_attr" + ), f"Expected get_attr for body_fn, got {body_fn_node.op}" + + root_module = node.graph.owning_module + cond_fn_module = getattr(root_module, cond_fn_node.target, None) + body_fn_module = getattr(root_module, body_fn_node.target, None) + + # Generate function names with module IDs for uniqueness + cond_fn_name = self.fx_importer._graph_module_to_func_name[id(cond_fn_module)] + body_fn_name = self.fx_importer._graph_module_to_func_name[id(body_fn_module)] + + # Import the carries (loop state variables) + carry_values = [] + for carry in carries: + if isinstance(carry, tuple): + # Handle tuple carries by importing each element + carry_values.extend(self._import_tuple_argument(loc, carry, None)) + else: + carry_values.append(self._import_argument(loc, carry)) + + # Determine result types from node metadata + node_val = node.meta.get("val") + if isinstance(node_val, (list, tuple)) and len(node_val) > 1: + result_types = [self._cc.value_info_to_type(v) for v in node_val] + self._multi_result_nodes.add(node) + else: + result_types = [self._cc.node_val_to_type(node)] + + # Call the condition function with initial carries to get initial condition + cond_result_type = self._cc.get_vtensor_type(torch.Size([]), torch.bool) + + initial_cond_call = Operation.create( + "func.call", + attributes={"callee": FlatSymbolRefAttr.get(cond_fn_name)}, + results=[cond_result_type], + operands=carry_values, + loc=loc, + ) + + # Convert vtensor to torch.bool + bool_conv = Operation.create( + name="torch.aten.Bool.Tensor", + results=[self._cc.torch_bool_type], + operands=[initial_cond_call.results[0]], + loc=loc, + ) + + # Create max iterations constant (INT64_MAX) + with loc: + max_iter = _make_constant_op( + "torch.constant.int", + self._cc.integer_attr(torch.iinfo(torch.int64).max, 64), + self._cc.torch_int_type, + ) + + # Create torch.prim.Loop operation with region + loop_op = Operation.create( + name="torch.prim.Loop", + results=result_types, + operands=[max_iter.results[0], bool_conv.results[0]] + carry_values, + regions=1, + loc=loc, + ) + + # Create loop body region with block arguments + # Block args: iteration counter (!torch.int) + all carry values + loop_region = loop_op.regions[0] + block_arg_types = [self._cc.torch_int_type] + result_types + with loc: + loop_block = Block.create_at_start(loop_region, block_arg_types) + + # Inside the loop body, call body function and condition function + with InsertionPoint(loop_block): + # Call body function with current carry values (skip iteration counter) + body_results_op = Operation.create( + name="func.call", + attributes={"callee": FlatSymbolRefAttr.get(body_fn_name)}, + results=result_types, + operands=list(loop_block.arguments[1:]), # Skip iteration counter + loc=loc, + ) + body_results = list(body_results_op.results) + + # Call condition function with updated carries + cond_result_loop = Operation.create( + name="func.call", + attributes={"callee": FlatSymbolRefAttr.get(cond_fn_name)}, + results=[IrType.parse("!torch.vtensor<[],i1>", context=self._c)], + operands=body_results, + loc=loc, + ).result + + # Convert to bool + cond_bool = Operation.create( + name="torch.aten.Bool.Tensor", + results=[self._cc.torch_bool_type], + operands=[cond_result_loop], + loc=loc, + ).result + + # Emit loop condition with updated carries + Operation.create( + name="torch.prim.Loop.condition", + results=[], + operands=[cond_bool] + body_results, + loc=loc, + ) + + # Bind the loop results to the node + if len(result_types) > 1: + self._multi_result_nodes.add(node) + for i, value in enumerate(loop_op.results): + self.bind_node_value(node, value, i) + else: + self.bind_node_value(node, loop_op.results[0]) + def _import_hop_auto_functionalized( self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperator ): @@ -1823,6 +2071,9 @@ def _import_argument( argument_value = self.resolve_node_value(arg) elif isinstance(arg, torch_fx.immutable_collections.immutable_list): argument_value = self._import_list_argument(loc, arg, expected_jit_type) + elif isinstance(arg, tuple): + # Handle tuples of tensors (common in while_loop carries) + argument_value = self._import_tuple_argument(loc, arg, expected_jit_type) elif isinstance(expected_jit_type, torch.TensorType) and not isinstance( arg, torch.Tensor ): @@ -1930,6 +2181,13 @@ def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value: loc=loc, ).result + def _import_tuple_argument( + self, loc: Location, arg: tuple, expected_jit_type + ) -> list[Value]: + """Import a tuple argument by importing each element separately.""" + # For tuples in while_loop carries, treat each element as a separate argument + return [self._import_argument(loc, elem, expected_jit_type) for elem in arg] + def _import_list_argument( self, loc: Location, arg: Sequence[NodeArgument], expected_jit_type ) -> Value: @@ -2040,6 +2298,7 @@ def _import_getitem(self, loc: Location, node: torch.fx.Node): # NOTE: the length of the list must be knowable at compile time. if ref_node not in self._unpack_list_values: node_result = self.resolve_node_value(ref_node, 0) + if str(node_result.type) in TORCH_LIST_TYPES: result_types = [ self._cc.value_info_to_type(v) for v in ref_node.meta["val"] diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 7a5660b028b3..e490a8d3636c 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -206,6 +206,51 @@ def forward(self): print(m) +@run +# CHECK-LABEL: test_while_loop_two_returns +# Check that helper functions are emitted first +# CHECK: func.func private @while_loop_cond_graph_{{[0-9]+}}(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[],i1> +# CHECK: torch.aten.lt.Scalar +# CHECK: func.func private @while_loop_body_graph_{{[0-9]+}}(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[4,4],f32>) -> (!torch.vtensor<[],si64>, !torch.vtensor<[4,4],f32>) +# CHECK: torch.aten.add.Scalar +# CHECK: torch.aten.mul.Tensor +# Then check the main function +# CHECK: func.func @test_while_loop_two_returns(%arg0: !torch.vtensor<[4,4],f32>) +# CHECK-SAME: -> (!torch.vtensor<[],si64>, !torch.vtensor<[4,4],f32>) +# Validate literal/init plumbing: +# CHECK: %[[ZERO:.*]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> +# CHECK: %[[NONE:.*]] = torch.constant.none +# CHECK: %[[CLONE:.*]] = torch.aten.clone %[[ZERO]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> +# CHECK: %[[COND:.*]] = call @while_loop_cond_graph_{{[0-9]+}}(%[[CLONE]], %arg0) +# CHECK: %[[BOOL:.*]] = torch.aten.Bool.Tensor %[[COND]] +# CHECK: %[[MAX_ITER:.*]] = torch.constant.int 9223372036854775807 +# CHECK: %[[RESULT:.*]]:2 = torch.prim.Loop %[[MAX_ITER]], %[[BOOL]], init(%[[CLONE]], %arg0) +# CHECK: ^bb0(%arg1: !torch.int, %arg2: !torch.vtensor<[],si64>, %arg3: !torch.vtensor<[4,4],f32>): +# CHECK: %[[BODY_RESULT:.*]]:2 = func.call @while_loop_body_graph_{{[0-9]+}}(%arg2, %arg3) +# CHECK: %[[COND_RESULT:.*]] = func.call @while_loop_cond_graph_{{[0-9]+}}(%[[BODY_RESULT]]#0, %[[BODY_RESULT]]#1) +# CHECK: %[[BOOL_RESULT:.*]] = torch.aten.Bool.Tensor %[[COND_RESULT]] +# CHECK: torch.prim.Loop.condition %[[BOOL_RESULT]], iter(%[[BODY_RESULT]]#0, %[[BODY_RESULT]]#1 : !torch.vtensor<[],si64>, !torch.vtensor<[4,4],f32>) +# CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 +def test_while_loop_two_returns(): + class M(nn.Module): + def forward(self, x): + # Simple while_loop that carries a scalar and a tensor. + def body(i, x): + return i + 1, x * x + + i0 = torch.tensor(0) + from torch._higher_order_ops.while_loop import while_loop + + out_i, out_x = while_loop(lambda i, x: i < 3, body, (i0, x)) + return out_i, out_x + + # Export -> import to Torch-MLIR + m = fx.export_and_import( + M(), torch.randn(4, 4), func_name="test_while_loop_two_returns" + ) + print(m) + + @run # CHECK-LABEL: test_stack_trace # CHECK: #loc[[LOC1:.+]] = loc(