@@ -67,9 +67,6 @@ class InitializerConsumerTracker:
6767
6868OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Upsample" , "NonMaxSuppression" , "Celu" ]
6969
70- # Temporarily block these ops in low precision, as they are not supported yet
71- OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION .extend (["Scan" , "If" , "Loop" ])
72-
7370# Mapping of op types to indices of inputs that should not be converted to low precision.
7471SKIP_LOW_PRECISION_MAPPING_FP16 = {"Resize" : {2 }}
7572SKIP_LOW_PRECISION_MAPPING_BF16 = {"Resize" : {1 , 2 }}
@@ -240,8 +237,8 @@ def convert(
240237 tensor_to_producers = tensor_to_producers ,
241238 )
242239
243- # Convert initializers to correct precision according to the consumer nodes
244- self ._convert_initializers (
240+ # Convert initializers to correct precision according to the consumer nodes (main graph + subgraphs)
241+ self ._convert_initializers_recursive (
245242 low_precision_nodes = low_precision_nodes , high_precision_nodes = high_precision_nodes
246243 )
247244
@@ -250,17 +247,8 @@ def convert(
250247 # Populate type information with inferred types
251248 self .model = self ._propagate_types_shapes_custom_ops (self .model )
252249 else :
253- # Clear type/shape information for intermediates and outputs
254- for vi in self .model .graph .value_info :
255- vi .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
256- for idx , d in enumerate (vi .type .tensor_type .shape .dim ):
257- if d .dim_value :
258- vi .type .tensor_type .shape .dim [idx ].dim_param = "unk"
259- for out in self .model .graph .output :
260- out .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
261- for idx , d in enumerate (out .type .tensor_type .shape .dim ):
262- if d .dim_value :
263- out .type .tensor_type .shape .dim [idx ].dim_param = "unk"
250+ # Clear type/shape information for intermediates and outputs (including subgraphs)
251+ self ._clear_types_and_shapes_recursive (self .model .graph )
264252 # Populate type information with inferred types
265253 self .model = onnx_utils .infer_shapes (self .model , strict_mode = True , check_type = False )
266254 self ._ensure_types_are_defined ()
@@ -285,6 +273,47 @@ def _ensure_types_are_defined(self):
285273 if vi .type .tensor_type .elem_type == onnx .TensorProto .UNDEFINED :
286274 vi .type .tensor_type .elem_type = self .low_precision_type .onnx_type
287275
276+ def _clear_types_and_shapes_recursive (
277+ self , graph : onnx .GraphProto , is_subgraph : bool = False
278+ ) -> None :
279+ """Recursively clear type/shape information for a graph and all its subgraphs.
280+
281+ This is necessary for control flow operators (Scan, If, Loop) which have subgraphs.
282+
283+ Args:
284+ graph: The ONNX graph to clear types and shapes for.
285+ is_subgraph: Whether this is a subgraph (True) or the main graph (False).
286+ """
287+
288+ def _clear_callback (g : onnx .GraphProto , parent : onnx .NodeProto , is_sub : bool ) -> None :
289+ logger .debug (
290+ f"Clearing types/shapes in { 'subgraph' if is_sub else 'main graph' } : { g .name } "
291+ )
292+
293+ # Clear type/shape information for inputs (only for subgraphs, not main graph inputs)
294+ if is_sub :
295+ for inp in g .input :
296+ if inp .type .HasField ("tensor_type" ):
297+ inp .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
298+ for idx , d in enumerate (inp .type .tensor_type .shape .dim ):
299+ if d .dim_value :
300+ inp .type .tensor_type .shape .dim [idx ].dim_param = "unk"
301+
302+ # Clear type/shape information for intermediates and outputs
303+ for vi in g .value_info :
304+ vi .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
305+ for idx , d in enumerate (vi .type .tensor_type .shape .dim ):
306+ if d .dim_value :
307+ vi .type .tensor_type .shape .dim [idx ].dim_param = "unk"
308+
309+ for out in g .output :
310+ out .type .tensor_type .elem_type = onnx .TensorProto .UNDEFINED
311+ for idx , d in enumerate (out .type .tensor_type .shape .dim ):
312+ if d .dim_value :
313+ out .type .tensor_type .shape .dim [idx ].dim_param = "unk"
314+
315+ utils .walk_subgraphs_recursive (graph , _clear_callback , is_subgraph = is_subgraph )
316+
288317 def _propagate_types_shapes_custom_ops (self , model ):
289318 """Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""
290319 logger .info ("Propagating tensor shapes and types in model with custom ops." )
@@ -682,6 +711,118 @@ def _convert_initializers(
682711 node .node .input [node .node_index ] = new_init_name
683712 self .model .graph .initializer .extend ([new_init ])
684713
714+ def _convert_initializers_recursive (
715+ self , low_precision_nodes : list [str ], high_precision_nodes : list [str ]
716+ ) -> None :
717+ """Convert initializers in main graph and all subgraphs to appropriate precision.
718+
719+ For the main graph, uses sophisticated consumer tracking to determine precision.
720+ For subgraphs, inherits precision from the parent control flow node and converts
721+ all initializers to that precision (no runtime casts).
722+
723+ Args:
724+ low_precision_nodes: List of node names in main graph that are low precision.
725+ high_precision_nodes: List of node names in main graph that are high precision.
726+ """
727+ # Convert main graph initializers with full consumer tracking
728+ self ._convert_initializers (low_precision_nodes , high_precision_nodes )
729+
730+ # Convert subgraph initializers - walk all subgraphs and convert based on parent node precision
731+ low_precision_nodes_set = set (low_precision_nodes )
732+
733+ def _convert_subgraph_callback (
734+ graph : onnx .GraphProto , parent : onnx .NodeProto , is_subgraph : bool
735+ ) -> None :
736+ if not is_subgraph or parent is None :
737+ return
738+
739+ # Inherit precision from parent control flow node
740+ target_type = (
741+ self .low_precision_type
742+ if parent .name in low_precision_nodes_set
743+ else self .high_precision_type
744+ )
745+
746+ # Convert all float initializers to target precision
747+ for init in graph .initializer :
748+ if init .data_type not in ONNX_TYPES or init .data_type == target_type .onnx_type :
749+ continue
750+
751+ from_type = (
752+ self .high_precision_type
753+ if init .data_type == self .high_precision_type .onnx_type
754+ else self .low_precision_type
755+ if init .data_type == self .low_precision_type .onnx_type
756+ else None
757+ )
758+
759+ if from_type is None :
760+ logger .debug (
761+ f"Skipping subgraph initializer { init .name } with unsupported type { init .data_type } "
762+ )
763+ continue
764+
765+ new_init = self ._convert_initializer_data (init , from_type , target_type )
766+ init .CopyFrom (new_init )
767+
768+ utils .walk_subgraphs_recursive (self .model .graph , _convert_subgraph_callback )
769+
770+ def _convert_initializer_data (
771+ self ,
772+ init : onnx .TensorProto ,
773+ from_type : PrecisionTypes ,
774+ to_type : PrecisionTypes ,
775+ ) -> onnx .TensorProto :
776+ """Convert initializer data to a new precision.
777+
778+ This is the core conversion logic extracted for reuse. Handles bfloat16 conversion
779+ and provides warnings when values are clamped or replaced due to precision limits.
780+
781+ Args:
782+ init: The initializer to convert.
783+ from_type: The original precision of the initializer.
784+ to_type: The new precision to cast the initializer to.
785+
786+ Returns:
787+ onnx.TensorProto: The converted initializer.
788+ """
789+ np_array = numpy_helper .to_array (init )
790+
791+ # Handle bfloat16 conversion
792+ if self ._is_bf16 (to_type ) and self ._is_fp32 (from_type ):
793+ new_init = onnx .TensorProto ()
794+ new_init .dims .extend (np_array .shape )
795+ new_init .name = init .name
796+ new_init .data_type = onnx .TensorProto .BFLOAT16
797+ bf16_bytes = np_array .astype (ml_dtypes .bfloat16 ).view (np .uint16 )
798+ new_init .raw_data = bf16_bytes .tobytes ()
799+ else :
800+ assert to_type .numpy_type is not None
801+ data_max , data_lowest = (
802+ np .finfo (to_type .numpy_type ).max ,
803+ np .finfo (to_type .numpy_type ).smallest_subnormal ,
804+ )
805+ if np .any (np .abs (np_array ) > data_max ):
806+ logger .warning (
807+ f"Initializer '{ init .name } ' contains values larger than largest "
808+ f"{ to_type .str_short } value, values will be clamped to { data_max } ."
809+ )
810+ np_array = np .clip (np_array , - 1 * data_max , data_max )
811+ if np .any ((np_array != 0.0 ) & (np .abs (np_array ) < data_lowest )):
812+ logger .warning (
813+ f"Initializer '{ init .name } ' contains values smaller than smallest "
814+ f"{ to_type .str_short } value, values will be replaced with { data_lowest :.1e} ."
815+ )
816+ np_array = np .where (
817+ (np_array != 0.0 ) & (np .abs (np_array ) < data_lowest ),
818+ data_lowest ,
819+ np_array ,
820+ )
821+ new_array = np_array .astype (to_type .numpy_type )
822+ new_init = numpy_helper .from_array (new_array , init .name )
823+
824+ return new_init
825+
685826 def _cast_initializer (
686827 self ,
687828 init : onnx .TensorProto ,
@@ -699,9 +840,11 @@ def _cast_initializer(
699840 init: The initializer to cast.
700841 from_type: The original precision of the initializer.
701842 to_type: The new precision to cast the initializer to.
843+ low_precision_nodes: Low precision nodes that consume this initializer.
844+ high_precision_nodes: High precision nodes that consume this initializer.
702845
703846 Returns:
704- onnx.TensorProto: The casted initializer.
847+ onnx.TensorProto | None : The casted initializer, or None if a runtime cast was inserted instead .
705848 """
706849
707850 def _get_name (node : onnx .NodeProto | InputIndexTracker ) -> str :
@@ -727,47 +870,11 @@ def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str:
727870 exclude_consumers = (
728871 low_precision_nodes if self ._is_fp32 (to_type ) else high_precision_nodes
729872 )
730- exclude_consumers_names : list [str ] = []
731-
732873 exclude_consumers_names = [_get_name (node ) for node in exclude_consumers ]
733874 self ._add_cast (init .name , to_type , exclude_consumers = exclude_consumers_names )
734875 return None
735876
736- np_array = numpy_helper .to_array (init )
737- # Numpy does not support bfloat16, use ml_dtypes to create the raw data instead
738- if self ._is_bf16 (to_type ) and self ._is_fp32 (from_type ):
739- new_init = onnx .TensorProto ()
740- new_init .dims .extend (np_array .shape )
741- new_init .name = init .name
742- new_init .data_type = onnx .TensorProto .BFLOAT16
743- bf16_bytes = np_array .astype (ml_dtypes .bfloat16 ).view (np .uint16 )
744- new_init .raw_data = bf16_bytes .tobytes ()
745- else :
746- assert to_type .numpy_type is not None
747- data_max , data_lowest = (
748- np .finfo (to_type .numpy_type ).max ,
749- np .finfo (to_type .numpy_type ).smallest_subnormal ,
750- )
751- if np .any (np .abs (np_array ) > data_max ):
752- logger .warning (
753- f"Initializer { init .name } contains values larger than largest "
754- f"{ to_type .str_short } value, values will be clamped to { data_max } ."
755- )
756- np_array = np .clip (np_array , - 1 * data_max , data_max )
757- if np .any ((np_array != 0.0 ) & (np .abs (np_array ) < data_lowest )):
758- logger .warning (
759- f"Initializer { init .name } contains values smaller than smallest "
760- f"{ to_type .str_short } value, values will be replaced with { data_lowest :.1e} ."
761- )
762- np_array = np .where (
763- (np_array != 0.0 ) & (np .abs (np_array ) < data_lowest ),
764- data_lowest ,
765- np_array ,
766- )
767- new_array = np_array .astype (to_type .numpy_type )
768- new_init = numpy_helper .from_array (new_array , init .name )
769-
770- return new_init
877+ return self ._convert_initializer_data (init , from_type , to_type )
771878
772879 def _replace_tensor_name (
773880 self , consumers : list [onnx .NodeProto ], original_tensor_name : str , new_tensor_name : str
0 commit comments