Skip to content

Commit 11b94b1

Browse files
committed
[5336829][AutoCast] Support subgraphs in AutoCast
Initial bug mentioned conditional operators, but the issue can be generalized to any subgraph in the ONNX. Support by recursively traversing subgraphs in PrecisionConverter. Signed-off-by: Gal Hubara Agam <[email protected]>
1 parent ed25176 commit 11b94b1

File tree

3 files changed

+395
-56
lines changed

3 files changed

+395
-56
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 163 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ class InitializerConsumerTracker:
6767

6868
OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Upsample", "NonMaxSuppression", "Celu"]
6969

70-
# Temporarily block these ops in low precision, as they are not supported yet
71-
OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION.extend(["Scan", "If", "Loop"])
72-
7370
# Mapping of op types to indices of inputs that should not be converted to low precision.
7471
SKIP_LOW_PRECISION_MAPPING_FP16 = {"Resize": {2}}
7572
SKIP_LOW_PRECISION_MAPPING_BF16 = {"Resize": {1, 2}}
@@ -244,8 +241,8 @@ def convert(
244241
tensor_to_producers=tensor_to_producers,
245242
)
246243

247-
# Convert initializers to correct precision according to the consumer nodes
248-
self._convert_initializers(
244+
# Convert initializers to correct precision according to the consumer nodes (main graph + subgraphs)
245+
self._convert_initializers_recursive(
249246
low_precision_nodes=low_precision_nodes, high_precision_nodes=high_precision_nodes
250247
)
251248

@@ -254,17 +251,8 @@ def convert(
254251
# Populate type information with inferred types
255252
self.model = self._propagate_types_shapes_custom_ops(self.model)
256253
else:
257-
# Clear type/shape information for intermediates and outputs
258-
for vi in self.model.graph.value_info:
259-
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
260-
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
261-
if d.dim_value:
262-
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
263-
for out in self.model.graph.output:
264-
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
265-
for idx, d in enumerate(out.type.tensor_type.shape.dim):
266-
if d.dim_value:
267-
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
254+
# Clear type/shape information for intermediates and outputs (including subgraphs)
255+
self._clear_types_and_shapes_recursive(self.model.graph)
268256
# Populate type information with inferred types
269257
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=False)
270258
self._ensure_types_are_defined()
@@ -289,6 +277,47 @@ def _ensure_types_are_defined(self):
289277
if vi.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED:
290278
vi.type.tensor_type.elem_type = self.low_precision_type.onnx_type
291279

280+
def _clear_types_and_shapes_recursive(
281+
self, graph: onnx.GraphProto, is_subgraph: bool = False
282+
) -> None:
283+
"""Recursively clear type/shape information for a graph and all its subgraphs.
284+
285+
This is necessary for control flow operators (Scan, If, Loop) which have subgraphs.
286+
287+
Args:
288+
graph: The ONNX graph to clear types and shapes for.
289+
is_subgraph: Whether this is a subgraph (True) or the main graph (False).
290+
"""
291+
292+
def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> None:
293+
logger.debug(
294+
f"Clearing types/shapes in {'subgraph' if is_sub else 'main graph'}: {g.name}"
295+
)
296+
297+
# Clear type/shape information for inputs (only for subgraphs, not main graph inputs)
298+
if is_sub:
299+
for inp in g.input:
300+
if inp.type.HasField("tensor_type"):
301+
inp.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
302+
for idx, d in enumerate(inp.type.tensor_type.shape.dim):
303+
if d.dim_value:
304+
inp.type.tensor_type.shape.dim[idx].dim_param = "unk"
305+
306+
# Clear type/shape information for intermediates and outputs
307+
for vi in g.value_info:
308+
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
309+
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
310+
if d.dim_value:
311+
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
312+
313+
for out in g.output:
314+
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
315+
for idx, d in enumerate(out.type.tensor_type.shape.dim):
316+
if d.dim_value:
317+
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
318+
319+
utils.walk_subgraphs_recursive(graph, _clear_callback, is_subgraph=is_subgraph)
320+
292321
def _propagate_types_shapes_custom_ops(self, model):
293322
"""Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""
294323
logger.info("Propagating tensor shapes and types in model with custom ops.")
@@ -688,59 +717,84 @@ def _convert_initializers(
688717
node.node.input[node.node_index] = new_init_name
689718
self.model.graph.initializer.extend([new_init])
690719

691-
def _cast_initializer(
720+
def _convert_initializers_recursive(
721+
self, low_precision_nodes: list[str], high_precision_nodes: list[str]
722+
) -> None:
723+
"""Convert initializers in main graph and all subgraphs to appropriate precision.
724+
725+
For the main graph, uses sophisticated consumer tracking to determine precision.
726+
For subgraphs, inherits precision from the parent control flow node and converts
727+
all initializers to that precision (no runtime casts).
728+
729+
Args:
730+
low_precision_nodes: List of node names in main graph that are low precision.
731+
high_precision_nodes: List of node names in main graph that are high precision.
732+
"""
733+
# Convert main graph initializers with full consumer tracking
734+
self._convert_initializers(low_precision_nodes, high_precision_nodes)
735+
736+
# Convert subgraph initializers - walk all subgraphs and convert based on parent node precision
737+
low_precision_nodes_set = set(low_precision_nodes)
738+
739+
def _convert_subgraph_callback(
740+
graph: onnx.GraphProto, parent: onnx.NodeProto, is_subgraph: bool
741+
) -> None:
742+
if not is_subgraph or parent is None:
743+
return
744+
745+
# Inherit precision from parent control flow node
746+
target_type = (
747+
self.low_precision_type
748+
if parent.name in low_precision_nodes_set
749+
else self.high_precision_type
750+
)
751+
752+
# Convert all float initializers to target precision
753+
for init in graph.initializer:
754+
if init.data_type not in ONNX_TYPES or init.data_type == target_type.onnx_type:
755+
continue
756+
757+
from_type = (
758+
self.high_precision_type
759+
if init.data_type == self.high_precision_type.onnx_type
760+
else self.low_precision_type
761+
if init.data_type == self.low_precision_type.onnx_type
762+
else None
763+
)
764+
765+
if from_type is None:
766+
logger.debug(
767+
f"Skipping subgraph initializer {init.name} with unsupported type {init.data_type}"
768+
)
769+
continue
770+
771+
new_init = self._convert_initializer_data(init, from_type, target_type)
772+
init.CopyFrom(new_init)
773+
774+
utils.walk_subgraphs_recursive(self.model.graph, _convert_subgraph_callback)
775+
776+
def _convert_initializer_data(
692777
self,
693778
init: onnx.TensorProto,
694779
from_type: PrecisionTypes,
695780
to_type: PrecisionTypes,
696-
low_precision_nodes: list[InputIndexTracker] | list[onnx.NodeProto],
697-
high_precision_nodes: list[InputIndexTracker] | list[onnx.NodeProto],
698-
) -> onnx.TensorProto | None:
699-
"""Cast an initializer to a new precision based on its consumer nodes.
781+
) -> onnx.TensorProto:
782+
"""Convert initializer data to a new precision.
700783
701-
This method converts an initializer to a new precision while handling special cases like bfloat16 conversion
702-
and providing warnings when values are clamped or replaced due to precision limits.
784+
This is the core conversion logic extracted for reuse. Handles bfloat16 conversion
785+
and provides warnings when values are clamped or replaced due to precision limits.
703786
704787
Args:
705-
init: The initializer to cast.
788+
init: The initializer to convert.
706789
from_type: The original precision of the initializer.
707790
to_type: The new precision to cast the initializer to.
708791
709792
Returns:
710-
onnx.TensorProto: The casted initializer.
793+
onnx.TensorProto: The converted initializer.
711794
"""
712-
713-
def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str:
714-
"""Get the name of a node or input index tracker."""
715-
if isinstance(node, onnx.NodeProto):
716-
return node.name
717-
elif isinstance(node, InputIndexTracker):
718-
return node.node.name
719-
else:
720-
raise ValueError(f"Unexpected: {type(node)}")
721-
722-
# Ensure the initializer is of the expected type
723-
assert init.data_type == from_type.onnx_type, (
724-
f"Initializer {init.name} is not of type {from_type.str_short}"
725-
)
726-
727-
if init.raw_data and len(init.raw_data) > self.init_conversion_max_bytes:
728-
# The initializer is too large, so we need to convert it at runtime.
729-
logger.debug(
730-
f"Initializer {init.name} is too large, skipping initializer conversion, cast in "
731-
"runtime instead"
732-
)
733-
exclude_consumers = (
734-
low_precision_nodes if self._is_fp32(to_type) else high_precision_nodes
735-
)
736-
exclude_consumers_names: list[str] = []
737-
738-
exclude_consumers_names = [_get_name(node) for node in exclude_consumers]
739-
self._add_cast(init.name, to_type, exclude_consumers=exclude_consumers_names)
740-
return None
741-
742795
np_array = numpy_helper.to_array(init)
743-
# Numpy does not support bfloat16, use ml_dtypes to create the raw data instead
796+
797+
# Handle bfloat16 conversion
744798
if self._is_bf16(to_type) and self._is_fp32(from_type):
745799
new_init = onnx.TensorProto()
746800
new_init.dims.extend(np_array.shape)
@@ -779,6 +833,59 @@ def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str:
779833

780834
return new_init
781835

836+
def _cast_initializer(
837+
self,
838+
init: onnx.TensorProto,
839+
from_type: PrecisionTypes,
840+
to_type: PrecisionTypes,
841+
low_precision_nodes: list[InputIndexTracker] | list[onnx.NodeProto],
842+
high_precision_nodes: list[InputIndexTracker] | list[onnx.NodeProto],
843+
) -> onnx.TensorProto | None:
844+
"""Cast an initializer to a new precision based on its consumer nodes.
845+
846+
This method converts an initializer to a new precision while handling special cases like bfloat16 conversion
847+
and providing warnings when values are clamped or replaced due to precision limits.
848+
849+
Args:
850+
init: The initializer to cast.
851+
from_type: The original precision of the initializer.
852+
to_type: The new precision to cast the initializer to.
853+
low_precision_nodes: Low precision nodes that consume this initializer.
854+
high_precision_nodes: High precision nodes that consume this initializer.
855+
856+
Returns:
857+
onnx.TensorProto | None: The casted initializer, or None if a runtime cast was inserted instead.
858+
"""
859+
860+
def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str:
861+
"""Get the name of a node or input index tracker."""
862+
if isinstance(node, onnx.NodeProto):
863+
return node.name
864+
elif isinstance(node, InputIndexTracker):
865+
return node.node.name
866+
else:
867+
raise ValueError(f"Unexpected: {type(node)}")
868+
869+
# Ensure the initializer is of the expected type
870+
assert init.data_type == from_type.onnx_type, (
871+
f"Initializer {init.name} is not of type {from_type.str_short}"
872+
)
873+
874+
if init.raw_data and len(init.raw_data) > self.init_conversion_max_bytes:
875+
# The initializer is too large, so we need to convert it at runtime.
876+
logger.debug(
877+
f"Initializer {init.name} is too large, skipping initializer conversion, cast in "
878+
"runtime instead"
879+
)
880+
exclude_consumers = (
881+
low_precision_nodes if self._is_fp32(to_type) else high_precision_nodes
882+
)
883+
exclude_consumers_names = [_get_name(node) for node in exclude_consumers]
884+
self._add_cast(init.name, to_type, exclude_consumers=exclude_consumers_names)
885+
return None
886+
887+
return self._convert_initializer_data(init, from_type, to_type)
888+
782889
def _replace_tensor_name(
783890
self, consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str
784891
) -> None:

modelopt/onnx/autocast/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import logging
2525
from collections import defaultdict
26+
from collections.abc import Callable
2627

2728
import onnx
2829

@@ -122,6 +123,41 @@ def get_cast_to_type(cast_node: onnx.NodeProto) -> int:
122123
raise ValueError("Cast node does not have 'to' attribute")
123124

124125

126+
def walk_subgraphs_recursive(
127+
graph: onnx.GraphProto,
128+
callback: Callable,
129+
parent_node: onnx.NodeProto = None,
130+
is_subgraph: bool = False,
131+
) -> None:
132+
"""Recursively walk through a graph and all its subgraphs, applying a callback.
133+
134+
This utility function traverses an ONNX graph and all nested subgraphs by examining
135+
graph attributes in nodes. It works with standard control flow operators (Scan, If, Loop)
136+
as well as custom operators that define subgraphs using ONNX graph attributes.
137+
138+
Args:
139+
graph: The graph to walk.
140+
callback: Function to call for each graph. Signature: callback(graph, parent_node, is_subgraph).
141+
parent_node: The parent node containing this subgraph (None for main graph).
142+
is_subgraph: Whether this is a subgraph (True) or the main graph (False).
143+
144+
Note:
145+
Works with any node that has attributes of type AttributeProto.GRAPH or
146+
AttributeProto.GRAPHS, including custom operators.
147+
"""
148+
# Apply callback to current graph
149+
callback(graph, parent_node, is_subgraph)
150+
151+
# Recursively process subgraphs in control flow nodes
152+
for node in graph.node:
153+
for attr in node.attribute:
154+
if attr.type == onnx.AttributeProto.GRAPH:
155+
walk_subgraphs_recursive(attr.g, callback, parent_node=node, is_subgraph=True)
156+
elif attr.type == onnx.AttributeProto.GRAPHS:
157+
for subgraph in attr.graphs:
158+
walk_subgraphs_recursive(subgraph, callback, parent_node=node, is_subgraph=True)
159+
160+
125161
def get_op_types_not_supported_in_low_precision(
126162
model: onnx.ModelProto,
127163
min_opset: int,

0 commit comments

Comments
 (0)