106106 Context ,
107107 DenseElementsAttr ,
108108 DenseResourceElementsAttr ,
109+ FlatSymbolRefAttr ,
109110 FloatAttr ,
110111 BF16Type ,
111112 ComplexType ,
@@ -834,8 +835,50 @@ def import_program(
834835 node_importer .return_node_values (loc , user_outputs , constant_output_values )
835836
836837 self .symbol_table .insert (func_op )
838+
839+ # Import all child graph modules recursively for HOPs
840+ # Even though import_stateless_graph is deprecated as an entrypoint mechanism,
841+ # HOP operator graphs are stateless graphs with no mutation, and it is correct
842+ # to import them as stateless graphs.
843+ self ._import_all_child_modules (
844+ prog ,
845+ func_name ,
846+ import_symbolic_shape_expressions
847+ )
848+
837849 return func_op
838850
851+ def _import_all_child_modules (
852+ self ,
853+ prog : torch .export .ExportedProgram ,
854+ parent_name : str ,
855+ import_symbolic_shape_expressions : bool = False
856+ ):
857+ """Recursively import all child modules that have graphs.
858+
859+ This simple approach imports all submodules recursively, which is sufficient
860+ for HOP operations since they only reference existing submodules.
861+ """
862+ for child_name , child_module in prog .graph .owning_module .named_children ():
863+ if isinstance (child_module , GraphModule ) and hasattr (child_module , 'graph' ):
864+ # Generate function name: parent_childname
865+ child_func_name = f"{ parent_name } _{ child_name } _{ id (child_module )} "
866+
867+ # Import the child as a stateless graph (private function)
868+ self .import_stateless_graph (
869+ child_module .graph ,
870+ func_name = child_func_name ,
871+ func_visibility = "private" ,
872+ import_symbolic_shape_expressions = import_symbolic_shape_expressions ,
873+ )
874+
875+ # Recursively import its children
876+ self ._import_all_child_modules (
877+ child_module ,
878+ child_func_name ,
879+ import_symbolic_shape_expressions
880+ )
881+
839882 def import_frozen_program (
840883 self ,
841884 prog : torch .export .ExportedProgram ,
@@ -996,9 +1039,17 @@ def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]:
9961039 if node .op == "placeholder" :
9971040 input_types .append (self ._cc .node_val_to_type (node ))
9981041 elif node .op == "output" :
999- # An output node's args[0] is the return value. This seems to
1000- # always be "boxed" as a tuple, which we emit as multi-results.
1001- for result_node in node .args [0 ]:
1042+ # An output node's args[0] is the return value. This is usually
1043+ # "boxed" as a tuple, which we emit as multi-results. However,
1044+ # for single returns it might be a single Node.
1045+ output_arg = node .args [0 ]
1046+ # Handle both single Node and tuple/list of Nodes
1047+ if isinstance (output_arg , (list , tuple )):
1048+ result_nodes = output_arg
1049+ else :
1050+ result_nodes = [output_arg ]
1051+
1052+ for result_node in result_nodes :
10021053 if result_node is None :
10031054 result_types .append (
10041055 IrType .parse ("!torch.none" , context = self ._c )
@@ -1509,7 +1560,14 @@ def import_nodes(
15091560 elif op == "output" and not skip_placeholders_outputs :
15101561 # args[0] is a singleton tuple that we flatten into multiple
15111562 # results.
1512- operands = [self ._import_argument (loc , arg ) for arg in node .args [0 ]]
1563+ output_arg = node .args [0 ]
1564+ # Handle both single Node and tuple/list of Nodes
1565+ if isinstance (output_arg , (list , tuple )):
1566+ result_nodes = output_arg
1567+ else :
1568+ result_nodes = [output_arg ]
1569+
1570+ operands = [self ._import_argument (loc , arg ) for arg in result_nodes ]
15131571 func_dialect .ReturnOp (operands , loc = loc )
15141572
15151573 if import_symbolic_shape_expressions :
@@ -1612,6 +1670,139 @@ def _import_hop(self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperat
16121670 )
16131671 handler (loc , node , hop )
16141672
1673+ def _import_hop_while_loop (
1674+ self , loc : Location , node : torch_fx .Node , hop : HigherOrderOperator
1675+ ):
1676+ """Imports the torch._higher_order_ops.while_loop HOP.
1677+
1678+ Args format: (cond_fn, body_fn, carries)
1679+ The cond_fn and body_fn are get_attr nodes pointing to submodule graphs
1680+ that have already been imported by import_program().
1681+
1682+ Emits torch.prim.Loop with proper control flow structure.
1683+ """
1684+ # while_loop HOP args: (cond_fn, body_fn, carries...)
1685+ # Unpack the first two args and the rest as carries
1686+ cond_fn_node , body_fn_node , * carries = node .args
1687+
1688+ # Extract function names from get_attr nodes
1689+ # The subgraphs were imported with names like "main_{target}"
1690+ assert cond_fn_node .op == "get_attr" , f"Expected get_attr for cond_fn, got { cond_fn_node .op } "
1691+ assert body_fn_node .op == "get_attr" , f"Expected get_attr for body_fn, got { body_fn_node .op } "
1692+
1693+ root_module = node .graph .owning_module
1694+ cond_fn_module = getattr (root_module , cond_fn_node .target , None )
1695+ body_fn_module = getattr (root_module , body_fn_node .target , None )
1696+
1697+ # Generate function names with module IDs for uniqueness
1698+ cond_fn_name = f"main_{ cond_fn_node .target } _{ id (cond_fn_module )} "
1699+ body_fn_name = f"main_{ body_fn_node .target } _{ id (body_fn_module )} "
1700+
1701+ # Import the carries (loop state variables)
1702+ carry_values = []
1703+ for carry in carries :
1704+ if isinstance (carry , tuple ):
1705+ # Handle tuple carries by importing each element
1706+ carry_values .extend (self ._import_tuple_argument (loc , carry , None ))
1707+ else :
1708+ carry_values .append (self ._import_argument (loc , carry ))
1709+
1710+ # Determine result types from node metadata
1711+ node_val = node .meta .get ("val" )
1712+ if isinstance (node_val , (list , tuple )) and len (node_val ) > 1 :
1713+ result_types = [self ._cc .value_info_to_type (v ) for v in node_val ]
1714+ self ._multi_result_nodes .add (node )
1715+ else :
1716+ result_types = [self ._cc .node_val_to_type (node )]
1717+
1718+ # Call the condition function with initial carries to get initial condition
1719+ cond_result_type = self ._cc .get_vtensor_type (torch .Size ([]), torch .bool )
1720+
1721+ initial_cond_call = Operation .create (
1722+ "func.call" ,
1723+ attributes = {"callee" : FlatSymbolRefAttr .get (cond_fn_name )},
1724+ results = [cond_result_type ],
1725+ operands = carry_values ,
1726+ loc = loc ,
1727+ )
1728+
1729+ # Convert vtensor<bool> to torch.bool
1730+ bool_conv = Operation .create (
1731+ name = "torch.aten.Bool.Tensor" ,
1732+ results = [self ._cc .torch_bool_type ],
1733+ operands = [initial_cond_call .results [0 ]],
1734+ loc = loc ,
1735+ )
1736+
1737+ # Create max iterations constant (INT64_MAX)
1738+ with loc :
1739+ max_iter = _make_constant_op (
1740+ "torch.constant.int" ,
1741+ self ._cc .integer_attr (9223372036854775807 , 64 ),
1742+ self ._cc .torch_int_type ,
1743+ )
1744+
1745+ # Create torch.prim.Loop operation with region
1746+ loop_op = Operation .create (
1747+ name = "torch.prim.Loop" ,
1748+ results = result_types ,
1749+ operands = [max_iter .results [0 ], bool_conv .results [0 ]] + carry_values ,
1750+ regions = 1 ,
1751+ loc = loc ,
1752+ )
1753+
1754+ # Create loop body region with block arguments
1755+ # Block args: iteration counter (!torch.int) + all carry values
1756+ loop_region = loop_op .regions [0 ]
1757+ block_arg_types = [self ._cc .torch_int_type ] + result_types
1758+ with loc :
1759+ loop_block = Block .create_at_start (loop_region , block_arg_types )
1760+
1761+ # Inside the loop body, call body function and condition function
1762+ with InsertionPoint (loop_block ):
1763+ # Call body function with current carry values (skip iteration counter)
1764+ body_results_op = Operation .create (
1765+ name = "func.call" ,
1766+ attributes = {"callee" : FlatSymbolRefAttr .get (body_fn_name )},
1767+ results = result_types ,
1768+ operands = list (loop_block .arguments [1 :]), # Skip iteration counter
1769+ loc = loc ,
1770+ )
1771+ body_results = list (body_results_op .results )
1772+
1773+ # Call condition function with updated carries
1774+ cond_result_loop = Operation .create (
1775+ name = "func.call" ,
1776+ attributes = {"callee" : FlatSymbolRefAttr .get (cond_fn_name )},
1777+ results = [IrType .parse ("!torch.vtensor<[],i1>" , context = self ._c )],
1778+ operands = body_results ,
1779+ loc = loc ,
1780+ ).result
1781+
1782+ # Convert to bool
1783+ cond_bool = Operation .create (
1784+ name = "torch.aten.Bool.Tensor" ,
1785+ results = [self ._cc .torch_bool_type ],
1786+ operands = [cond_result_loop ],
1787+ loc = loc ,
1788+ ).result
1789+
1790+ # Emit loop condition with updated carries
1791+ Operation .create (
1792+ name = "torch.prim.Loop.condition" ,
1793+ results = [],
1794+ operands = [cond_bool ] + body_results ,
1795+ loc = loc ,
1796+ )
1797+
1798+ # Bind the loop results to the node
1799+ if len (result_types ) > 1 :
1800+ self ._multi_result_nodes .add (node )
1801+ for i , value in enumerate (loop_op .results ):
1802+ self .bind_node_value (node , value , i )
1803+ else :
1804+ self .bind_node_value (node , loop_op .results [0 ])
1805+
16151806 def _import_hop_auto_functionalized (
16161807 self , loc : Location , node : torch_fx .Node , hop : HigherOrderOperator
16171808 ):
@@ -1823,6 +2014,9 @@ def _import_argument(
18232014 argument_value = self .resolve_node_value (arg )
18242015 elif isinstance (arg , torch_fx .immutable_collections .immutable_list ):
18252016 argument_value = self ._import_list_argument (loc , arg , expected_jit_type )
2017+ elif isinstance (arg , tuple ):
2018+ # Handle tuples of tensors (common in while_loop carries)
2019+ argument_value = self ._import_tuple_argument (loc , arg , expected_jit_type )
18262020 elif isinstance (expected_jit_type , torch .TensorType ) and not isinstance (
18272021 arg , torch .Tensor
18282022 ):
@@ -1930,6 +2124,13 @@ def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value:
19302124 loc = loc ,
19312125 ).result
19322126
2127+ def _import_tuple_argument (
2128+ self , loc : Location , arg : tuple , expected_jit_type
2129+ ) -> List [Value ]:
2130+ """Import a tuple argument by importing each element separately."""
2131+ # For tuples in while_loop carries, treat each element as a separate argument
2132+ return [self ._import_argument (loc , elem , expected_jit_type ) for elem in arg ]
2133+
19332134 def _import_list_argument (
19342135 self , loc : Location , arg : Sequence [NodeArgument ], expected_jit_type
19352136 ) -> Value :
@@ -2040,6 +2241,8 @@ def _import_getitem(self, loc: Location, node: torch.fx.Node):
20402241 # NOTE: the length of the list must be knowable at compile time.
20412242 if ref_node not in self ._unpack_list_values :
20422243 node_result = self .resolve_node_value (ref_node , 0 )
2244+ node_val = ref_node .meta .get ("val" )
2245+
20432246 if str (node_result .type ) in TORCH_LIST_TYPES :
20442247 result_types = [
20452248 self ._cc .value_info_to_type (v ) for v in ref_node .meta ["val" ]
@@ -2510,4 +2713,4 @@ def aten__embedding_bag_forward_only_default(node: torch_fx.Node):
25102713def node_canonicalize (node : torch_fx .Node ):
25112714 if node .target in NODE_CANONICALIZE :
25122715 return NODE_CANONICALIZE [node .target ](node )
2513- return node
2716+ return node
0 commit comments