Skip to content

Commit 965b1ec

Browse files
committed
[5336829][AutoCast] Support subgraphs in AutoCast
Initial bug mentioned conditional operators, but the issue can be generalized to any subgraph in the ONNX. Support by recursively traversing subgraphs in PrecisionConverter. Signed-off-by: Gal Hubara Agam <[email protected]>
1 parent 53a2dde commit 965b1ec

File tree

3 files changed

+393
-54
lines changed

3 files changed

+393
-54
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 161 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ class InitializerConsumerTracker:
6767

6868
OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Upsample", "NonMaxSuppression", "Celu"]
6969

70-
# Temporarily block these ops in low precision, as they are not supported yet
71-
OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION.extend(["Scan", "If", "Loop"])
72-
7370
# Mapping of op types to indices of inputs that should not be converted to low precision.
7471
SKIP_LOW_PRECISION_MAPPING_FP16 = {"Resize": {2}}
7572
SKIP_LOW_PRECISION_MAPPING_BF16 = {"Resize": {1, 2}}
@@ -240,8 +237,8 @@ def convert(
240237
tensor_to_producers=tensor_to_producers,
241238
)
242239

243-
# Convert initializers to correct precision according to the consumer nodes
244-
self._convert_initializers(
240+
# Convert initializers to correct precision according to the consumer nodes (main graph + subgraphs)
241+
self._convert_initializers_recursive(
245242
low_precision_nodes=low_precision_nodes, high_precision_nodes=high_precision_nodes
246243
)
247244

@@ -250,17 +247,8 @@ def convert(
250247
# Populate type information with inferred types
251248
self.model = self._propagate_types_shapes_custom_ops(self.model)
252249
else:
253-
# Clear type/shape information for intermediates and outputs
254-
for vi in self.model.graph.value_info:
255-
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
256-
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
257-
if d.dim_value:
258-
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
259-
for out in self.model.graph.output:
260-
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
261-
for idx, d in enumerate(out.type.tensor_type.shape.dim):
262-
if d.dim_value:
263-
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
250+
# Clear type/shape information for intermediates and outputs (including subgraphs)
251+
self._clear_types_and_shapes_recursive(self.model.graph)
264252
# Populate type information with inferred types
265253
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=False)
266254
self._ensure_types_are_defined()
@@ -285,6 +273,47 @@ def _ensure_types_are_defined(self):
285273
if vi.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED:
286274
vi.type.tensor_type.elem_type = self.low_precision_type.onnx_type
287275

276+
def _clear_types_and_shapes_recursive(
277+
self, graph: onnx.GraphProto, is_subgraph: bool = False
278+
) -> None:
279+
"""Recursively clear type/shape information for a graph and all its subgraphs.
280+
281+
This is necessary for control flow operators (Scan, If, Loop) which have subgraphs.
282+
283+
Args:
284+
graph: The ONNX graph to clear types and shapes for.
285+
is_subgraph: Whether this is a subgraph (True) or the main graph (False).
286+
"""
287+
288+
def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> None:
289+
logger.debug(
290+
f"Clearing types/shapes in {'subgraph' if is_sub else 'main graph'}: {g.name}"
291+
)
292+
293+
# Clear type/shape information for inputs (only for subgraphs, not main graph inputs)
294+
if is_sub:
295+
for inp in g.input:
296+
if inp.type.HasField("tensor_type"):
297+
inp.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
298+
for idx, d in enumerate(inp.type.tensor_type.shape.dim):
299+
if d.dim_value:
300+
inp.type.tensor_type.shape.dim[idx].dim_param = "unk"
301+
302+
# Clear type/shape information for intermediates and outputs
303+
for vi in g.value_info:
304+
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
305+
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
306+
if d.dim_value:
307+
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
308+
309+
for out in g.output:
310+
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
311+
for idx, d in enumerate(out.type.tensor_type.shape.dim):
312+
if d.dim_value:
313+
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
314+
315+
utils.walk_subgraphs_recursive(graph, _clear_callback, is_subgraph=is_subgraph)
316+
288317
def _propagate_types_shapes_custom_ops(self, model):
289318
"""Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""
290319
logger.info("Propagating tensor shapes and types in model with custom ops.")
@@ -682,6 +711,118 @@ def _convert_initializers(
682711
node.node.input[node.node_index] = new_init_name
683712
self.model.graph.initializer.extend([new_init])
684713

714+
def _convert_initializers_recursive(
715+
self, low_precision_nodes: list[str], high_precision_nodes: list[str]
716+
) -> None:
717+
"""Convert initializers in main graph and all subgraphs to appropriate precision.
718+
719+
For the main graph, uses sophisticated consumer tracking to determine precision.
720+
For subgraphs, inherits precision from the parent control flow node and converts
721+
all initializers to that precision (no runtime casts).
722+
723+
Args:
724+
low_precision_nodes: List of node names in main graph that are low precision.
725+
high_precision_nodes: List of node names in main graph that are high precision.
726+
"""
727+
# Convert main graph initializers with full consumer tracking
728+
self._convert_initializers(low_precision_nodes, high_precision_nodes)
729+
730+
# Convert subgraph initializers - walk all subgraphs and convert based on parent node precision
731+
low_precision_nodes_set = set(low_precision_nodes)
732+
733+
def _convert_subgraph_callback(
734+
graph: onnx.GraphProto, parent: onnx.NodeProto, is_subgraph: bool
735+
) -> None:
736+
if not is_subgraph or parent is None:
737+
return
738+
739+
# Inherit precision from parent control flow node
740+
target_type = (
741+
self.low_precision_type
742+
if parent.name in low_precision_nodes_set
743+
else self.high_precision_type
744+
)
745+
746+
# Convert all float initializers to target precision
747+
for init in graph.initializer:
748+
if init.data_type not in ONNX_TYPES or init.data_type == target_type.onnx_type:
749+
continue
750+
751+
from_type = (
752+
self.high_precision_type
753+
if init.data_type == self.high_precision_type.onnx_type
754+
else self.low_precision_type
755+
if init.data_type == self.low_precision_type.onnx_type
756+
else None
757+
)
758+
759+
if from_type is None:
760+
logger.debug(
761+
f"Skipping subgraph initializer {init.name} with unsupported type {init.data_type}"
762+
)
763+
continue
764+
765+
new_init = self._convert_initializer_data(init, from_type, target_type)
766+
init.CopyFrom(new_init)
767+
768+
utils.walk_subgraphs_recursive(self.model.graph, _convert_subgraph_callback)
769+
770+
def _convert_initializer_data(
771+
self,
772+
init: onnx.TensorProto,
773+
from_type: PrecisionTypes,
774+
to_type: PrecisionTypes,
775+
) -> onnx.TensorProto:
776+
"""Convert initializer data to a new precision.
777+
778+
This is the core conversion logic extracted for reuse. Handles bfloat16 conversion
779+
and provides warnings when values are clamped or replaced due to precision limits.
780+
781+
Args:
782+
init: The initializer to convert.
783+
from_type: The original precision of the initializer.
784+
to_type: The new precision to cast the initializer to.
785+
786+
Returns:
787+
onnx.TensorProto: The converted initializer.
788+
"""
789+
np_array = numpy_helper.to_array(init)
790+
791+
# Handle bfloat16 conversion
792+
if self._is_bf16(to_type) and self._is_fp32(from_type):
793+
new_init = onnx.TensorProto()
794+
new_init.dims.extend(np_array.shape)
795+
new_init.name = init.name
796+
new_init.data_type = onnx.TensorProto.BFLOAT16
797+
bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16)
798+
new_init.raw_data = bf16_bytes.tobytes()
799+
else:
800+
assert to_type.numpy_type is not None
801+
data_max, data_lowest = (
802+
np.finfo(to_type.numpy_type).max,
803+
np.finfo(to_type.numpy_type).smallest_subnormal,
804+
)
805+
if np.any(np.abs(np_array) > data_max):
806+
logger.warning(
807+
f"Initializer '{init.name}' contains values larger than largest "
808+
f"{to_type.str_short} value, values will be clamped to {data_max}."
809+
)
810+
np_array = np.clip(np_array, -1 * data_max, data_max)
811+
if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)):
812+
logger.warning(
813+
f"Initializer '{init.name}' contains values smaller than smallest "
814+
f"{to_type.str_short} value, values will be replaced with {data_lowest:.1e}."
815+
)
816+
np_array = np.where(
817+
(np_array != 0.0) & (np.abs(np_array) < data_lowest),
818+
data_lowest,
819+
np_array,
820+
)
821+
new_array = np_array.astype(to_type.numpy_type)
822+
new_init = numpy_helper.from_array(new_array, init.name)
823+
824+
return new_init
825+
685826
def _cast_initializer(
686827
self,
687828
init: onnx.TensorProto,
@@ -699,9 +840,11 @@ def _cast_initializer(
699840
init: The initializer to cast.
700841
from_type: The original precision of the initializer.
701842
to_type: The new precision to cast the initializer to.
843+
low_precision_nodes: Low precision nodes that consume this initializer.
844+
high_precision_nodes: High precision nodes that consume this initializer.
702845
703846
Returns:
704-
onnx.TensorProto: The casted initializer.
847+
onnx.TensorProto | None: The casted initializer, or None if a runtime cast was inserted instead.
705848
"""
706849

707850
def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str:
@@ -727,47 +870,11 @@ def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str:
727870
exclude_consumers = (
728871
low_precision_nodes if self._is_fp32(to_type) else high_precision_nodes
729872
)
730-
exclude_consumers_names: list[str] = []
731-
732873
exclude_consumers_names = [_get_name(node) for node in exclude_consumers]
733874
self._add_cast(init.name, to_type, exclude_consumers=exclude_consumers_names)
734875
return None
735876

736-
np_array = numpy_helper.to_array(init)
737-
# Numpy does not support bfloat16, use ml_dtypes to create the raw data instead
738-
if self._is_bf16(to_type) and self._is_fp32(from_type):
739-
new_init = onnx.TensorProto()
740-
new_init.dims.extend(np_array.shape)
741-
new_init.name = init.name
742-
new_init.data_type = onnx.TensorProto.BFLOAT16
743-
bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16)
744-
new_init.raw_data = bf16_bytes.tobytes()
745-
else:
746-
assert to_type.numpy_type is not None
747-
data_max, data_lowest = (
748-
np.finfo(to_type.numpy_type).max,
749-
np.finfo(to_type.numpy_type).smallest_subnormal,
750-
)
751-
if np.any(np.abs(np_array) > data_max):
752-
logger.warning(
753-
f"Initializer {init.name} contains values larger than largest "
754-
f"{to_type.str_short} value, values will be clamped to {data_max}."
755-
)
756-
np_array = np.clip(np_array, -1 * data_max, data_max)
757-
if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)):
758-
logger.warning(
759-
f"Initializer {init.name} contains values smaller than smallest "
760-
f"{to_type.str_short} value, values will be replaced with {data_lowest:.1e}."
761-
)
762-
np_array = np.where(
763-
(np_array != 0.0) & (np.abs(np_array) < data_lowest),
764-
data_lowest,
765-
np_array,
766-
)
767-
new_array = np_array.astype(to_type.numpy_type)
768-
new_init = numpy_helper.from_array(new_array, init.name)
769-
770-
return new_init
877+
return self._convert_initializer_data(init, from_type, to_type)
771878

772879
def _replace_tensor_name(
773880
self, consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str

modelopt/onnx/autocast/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import logging
2525
from collections import defaultdict
26+
from collections.abc import Callable
2627

2728
import onnx
2829

@@ -122,6 +123,41 @@ def get_cast_to_type(cast_node: onnx.NodeProto) -> int:
122123
raise ValueError("Cast node does not have 'to' attribute")
123124

124125

126+
def walk_subgraphs_recursive(
127+
graph: onnx.GraphProto,
128+
callback: Callable,
129+
parent_node: onnx.NodeProto = None,
130+
is_subgraph: bool = False,
131+
) -> None:
132+
"""Recursively walk through a graph and all its subgraphs, applying a callback.
133+
134+
This utility function traverses an ONNX graph and all nested subgraphs by examining
135+
graph attributes in nodes. It works with standard control flow operators (Scan, If, Loop)
136+
as well as custom operators that define subgraphs using ONNX graph attributes.
137+
138+
Args:
139+
graph: The graph to walk.
140+
callback: Function to call for each graph. Signature: callback(graph, parent_node, is_subgraph).
141+
parent_node: The parent node containing this subgraph (None for main graph).
142+
is_subgraph: Whether this is a subgraph (True) or the main graph (False).
143+
144+
Note:
145+
Works with any node that has attributes of type AttributeProto.GRAPH or
146+
AttributeProto.GRAPHS, including custom operators.
147+
"""
148+
# Apply callback to current graph
149+
callback(graph, parent_node, is_subgraph)
150+
151+
# Recursively process subgraphs in control flow nodes
152+
for node in graph.node:
153+
for attr in node.attribute:
154+
if attr.type == onnx.AttributeProto.GRAPH:
155+
walk_subgraphs_recursive(attr.g, callback, parent_node=node, is_subgraph=True)
156+
elif attr.type == onnx.AttributeProto.GRAPHS:
157+
for subgraph in attr.graphs:
158+
walk_subgraphs_recursive(subgraph, callback, parent_node=node, is_subgraph=True)
159+
160+
125161
def get_op_types_not_supported_in_low_precision(
126162
model: onnx.ModelProto,
127163
min_opset: int,

0 commit comments

Comments
 (0)