@@ -67,9 +67,6 @@ class InitializerConsumerTracker:
6767
6868OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Upsample" , "NonMaxSuppression" , "Celu" ]
6969
70- # Temporarily block these ops in low precision, as they are not supported yet
71- OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION .extend (["Scan" , "If" , "Loop" ])
72-
7370# Mapping of op types to indices of inputs that should not be converted to low precision.
7471SKIP_LOW_PRECISION_MAPPING_FP16 = {"Resize" : {2 }}
7572SKIP_LOW_PRECISION_MAPPING_BF16 = {"Resize" : {1 , 2 }}
@@ -244,8 +241,8 @@ def convert(
244241 tensor_to_producers = tensor_to_producers ,
245242 )
246243
247- # Convert initializers to correct precision according to the consumer nodes
248- self ._convert_initializers (
244+ # Convert initializers to correct precision according to the consumer nodes (main graph + subgraphs)
245+ self ._convert_initializers_recursive (
249246 low_precision_nodes = low_precision_nodes , high_precision_nodes = high_precision_nodes
250247 )
251248
@@ -254,17 +251,8 @@ def convert(
254251 # Populate type information with inferred types
255252 self .model = self ._propagate_types_shapes_custom_ops (self .model )
256253 else :
257- # Clear type/shape information for intermediates and outputs
258- for vi in self .model .graph .value_info :
259- vi .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
260- for idx , d in enumerate (vi .type .tensor_type .shape .dim ):
261- if d .dim_value :
262- vi .type .tensor_type .shape .dim [idx ].dim_param = "unk"
263- for out in self .model .graph .output :
264- out .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
265- for idx , d in enumerate (out .type .tensor_type .shape .dim ):
266- if d .dim_value :
267- out .type .tensor_type .shape .dim [idx ].dim_param = "unk"
254+ # Clear type/shape information for intermediates and outputs (including subgraphs)
255+ self ._clear_types_and_shapes_recursive (self .model .graph )
268256 # Populate type information with inferred types
269257 self .model = onnx_utils .infer_shapes (self .model , strict_mode = True , check_type = False )
270258 self ._ensure_types_are_defined ()
@@ -289,6 +277,47 @@ def _ensure_types_are_defined(self):
289277 if vi .type .tensor_type .elem_type == onnx .TensorProto .UNDEFINED :
290278 vi .type .tensor_type .elem_type = self .low_precision_type .onnx_type
291279
280+ def _clear_types_and_shapes_recursive (
281+ self , graph : onnx .GraphProto , is_subgraph : bool = False
282+ ) -> None :
283+ """Recursively clear type/shape information for a graph and all its subgraphs.
284+
285+ This is necessary for control flow operators (Scan, If, Loop) which have subgraphs.
286+
287+ Args:
288+ graph: The ONNX graph to clear types and shapes for.
289+ is_subgraph: Whether this is a subgraph (True) or the main graph (False).
290+ """
291+
292+ def _clear_callback (g : onnx .GraphProto , parent : onnx .NodeProto , is_sub : bool ) -> None :
293+ logger .debug (
294+ f"Clearing types/shapes in { 'subgraph' if is_sub else 'main graph' } : { g .name } "
295+ )
296+
297+ # Clear type/shape information for inputs (only for subgraphs, not main graph inputs)
298+ if is_sub :
299+ for inp in g .input :
300+ if inp .type .HasField ("tensor_type" ):
301+ inp .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
302+ for idx , d in enumerate (inp .type .tensor_type .shape .dim ):
303+ if d .dim_value :
304+ inp .type .tensor_type .shape .dim [idx ].dim_param = "unk"
305+
306+ # Clear type/shape information for intermediates and outputs
307+ for vi in g .value_info :
308+ vi .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
309+ for idx , d in enumerate (vi .type .tensor_type .shape .dim ):
310+ if d .dim_value :
311+ vi .type .tensor_type .shape .dim [idx ].dim_param = "unk"
312+
313+ for out in g .output :
314+ out .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
315+ for idx , d in enumerate (out .type .tensor_type .shape .dim ):
316+ if d .dim_value :
317+ out .type .tensor_type .shape .dim [idx ].dim_param = "unk"
318+
319+ utils .walk_subgraphs_recursive (graph , _clear_callback , is_subgraph = is_subgraph )
320+
292321 def _propagate_types_shapes_custom_ops (self , model ):
293322 """Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""
294323 logger .info ("Propagating tensor shapes and types in model with custom ops." )
@@ -688,59 +717,84 @@ def _convert_initializers(
688717 node .node .input [node .node_index ] = new_init_name
689718 self .model .graph .initializer .extend ([new_init ])
690719
691- def _cast_initializer (
720+ def _convert_initializers_recursive (
721+ self , low_precision_nodes : list [str ], high_precision_nodes : list [str ]
722+ ) -> None :
723+ """Convert initializers in main graph and all subgraphs to appropriate precision.
724+
725+ For the main graph, uses sophisticated consumer tracking to determine precision.
726+ For subgraphs, inherits precision from the parent control flow node and converts
727+ all initializers to that precision (no runtime casts).
728+
729+ Args:
730+ low_precision_nodes: List of node names in main graph that are low precision.
731+ high_precision_nodes: List of node names in main graph that are high precision.
732+ """
733+ # Convert main graph initializers with full consumer tracking
734+ self ._convert_initializers (low_precision_nodes , high_precision_nodes )
735+
736+ # Convert subgraph initializers - walk all subgraphs and convert based on parent node precision
737+ low_precision_nodes_set = set (low_precision_nodes )
738+
739+ def _convert_subgraph_callback (
740+ graph : onnx .GraphProto , parent : onnx .NodeProto , is_subgraph : bool
741+ ) -> None :
742+ if not is_subgraph or parent is None :
743+ return
744+
745+ # Inherit precision from parent control flow node
746+ target_type = (
747+ self .low_precision_type
748+ if parent .name in low_precision_nodes_set
749+ else self .high_precision_type
750+ )
751+
752+ # Convert all float initializers to target precision
753+ for init in graph .initializer :
754+ if init .data_type not in ONNX_TYPES or init .data_type == target_type .onnx_type :
755+ continue
756+
757+ from_type = (
758+ self .high_precision_type
759+ if init .data_type == self .high_precision_type .onnx_type
760+ else self .low_precision_type
761+ if init .data_type == self .low_precision_type .onnx_type
762+ else None
763+ )
764+
765+ if from_type is None :
766+ logger .debug (
767+ f"Skipping subgraph initializer { init .name } with unsupported type { init .data_type } "
768+ )
769+ continue
770+
771+ new_init = self ._convert_initializer_data (init , from_type , target_type )
772+ init .CopyFrom (new_init )
773+
774+ utils .walk_subgraphs_recursive (self .model .graph , _convert_subgraph_callback )
775+
776+ def _convert_initializer_data (
692777 self ,
693778 init : onnx .TensorProto ,
694779 from_type : PrecisionTypes ,
695780 to_type : PrecisionTypes ,
696- low_precision_nodes : list [InputIndexTracker ] | list [onnx .NodeProto ],
697- high_precision_nodes : list [InputIndexTracker ] | list [onnx .NodeProto ],
698- ) -> onnx .TensorProto | None :
699- """Cast an initializer to a new precision based on its consumer nodes.
781+ ) -> onnx .TensorProto :
782+ """Convert initializer data to a new precision.
700783
701- This method converts an initializer to a new precision while handling special cases like bfloat16 conversion
702- and providing warnings when values are clamped or replaced due to precision limits.
784+ This is the core conversion logic extracted for reuse. Handles bfloat16 conversion
785+ and provides warnings when values are clamped or replaced due to precision limits.
703786
704787 Args:
705- init: The initializer to cast .
788+ init: The initializer to convert .
706789 from_type: The original precision of the initializer.
707790 to_type: The new precision to cast the initializer to.
708791
709792 Returns:
710- onnx.TensorProto: The casted initializer.
793+ onnx.TensorProto: The converted initializer.
711794 """
712-
713- def _get_name (node : onnx .NodeProto | InputIndexTracker ) -> str :
714- """Get the name of a node or input index tracker."""
715- if isinstance (node , onnx .NodeProto ):
716- return node .name
717- elif isinstance (node , InputIndexTracker ):
718- return node .node .name
719- else :
720- raise ValueError (f"Unexpected: { type (node )} " )
721-
722- # Ensure the initializer is of the expected type
723- assert init .data_type == from_type .onnx_type , (
724- f"Initializer { init .name } is not of type { from_type .str_short } "
725- )
726-
727- if init .raw_data and len (init .raw_data ) > self .init_conversion_max_bytes :
728- # The initializer is too large, so we need to convert it at runtime.
729- logger .debug (
730- f"Initializer { init .name } is too large, skipping initializer conversion, cast in "
731- "runtime instead"
732- )
733- exclude_consumers = (
734- low_precision_nodes if self ._is_fp32 (to_type ) else high_precision_nodes
735- )
736- exclude_consumers_names : list [str ] = []
737-
738- exclude_consumers_names = [_get_name (node ) for node in exclude_consumers ]
739- self ._add_cast (init .name , to_type , exclude_consumers = exclude_consumers_names )
740- return None
741-
742795 np_array = numpy_helper .to_array (init )
743- # Numpy does not support bfloat16, use ml_dtypes to create the raw data instead
796+
797+ # Handle bfloat16 conversion
744798 if self ._is_bf16 (to_type ) and self ._is_fp32 (from_type ):
745799 new_init = onnx .TensorProto ()
746800 new_init .dims .extend (np_array .shape )
@@ -779,6 +833,59 @@ def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str:
779833
780834 return new_init
781835
836+ def _cast_initializer (
837+ self ,
838+ init : onnx .TensorProto ,
839+ from_type : PrecisionTypes ,
840+ to_type : PrecisionTypes ,
841+ low_precision_nodes : list [InputIndexTracker ] | list [onnx .NodeProto ],
842+ high_precision_nodes : list [InputIndexTracker ] | list [onnx .NodeProto ],
843+ ) -> onnx .TensorProto | None :
844+ """Cast an initializer to a new precision based on its consumer nodes.
845+
846+ This method converts an initializer to a new precision while handling special cases like bfloat16 conversion
847+ and providing warnings when values are clamped or replaced due to precision limits.
848+
849+ Args:
850+ init: The initializer to cast.
851+ from_type: The original precision of the initializer.
852+ to_type: The new precision to cast the initializer to.
853+ low_precision_nodes: Low precision nodes that consume this initializer.
854+ high_precision_nodes: High precision nodes that consume this initializer.
855+
856+ Returns:
857+ onnx.TensorProto | None: The casted initializer, or None if a runtime cast was inserted instead.
858+ """
859+
860+ def _get_name (node : onnx .NodeProto | InputIndexTracker ) -> str :
861+ """Get the name of a node or input index tracker."""
862+ if isinstance (node , onnx .NodeProto ):
863+ return node .name
864+ elif isinstance (node , InputIndexTracker ):
865+ return node .node .name
866+ else :
867+ raise ValueError (f"Unexpected: { type (node )} " )
868+
869+ # Ensure the initializer is of the expected type
870+ assert init .data_type == from_type .onnx_type , (
871+ f"Initializer { init .name } is not of type { from_type .str_short } "
872+ )
873+
874+ if init .raw_data and len (init .raw_data ) > self .init_conversion_max_bytes :
875+ # The initializer is too large, so we need to convert it at runtime.
876+ logger .debug (
877+ f"Initializer { init .name } is too large, skipping initializer conversion, cast in "
878+ "runtime instead"
879+ )
880+ exclude_consumers = (
881+ low_precision_nodes if self ._is_fp32 (to_type ) else high_precision_nodes
882+ )
883+ exclude_consumers_names = [_get_name (node ) for node in exclude_consumers ]
884+ self ._add_cast (init .name , to_type , exclude_consumers = exclude_consumers_names )
885+ return None
886+
887+ return self ._convert_initializer_data (init , from_type , to_type )
888+
782889 def _replace_tensor_name (
783890 self , consumers : list [onnx .NodeProto ], original_tensor_name : str , new_tensor_name : str
784891 ) -> None :
0 commit comments