From 8212940b1a91d9610032e0ce603542ab21535231 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 17 Oct 2025 10:31:29 -0400 Subject: [PATCH 01/11] Add support for Conv-Act-Pool fusion Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/quantization/partitioning.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modelopt/onnx/quantization/partitioning.py b/modelopt/onnx/quantization/partitioning.py index 9e8a9163c..a74f8b978 100644 --- a/modelopt/onnx/quantization/partitioning.py +++ b/modelopt/onnx/quantization/partitioning.py @@ -45,8 +45,9 @@ def _build_fusible_partition( Add a nodes to the partition if any of these holds: 1. The node is a unary or binary pointwise operation and fusible by cask - 2. The node is BN and/or Relu and fusible with preceding Conv op - 3. The node is a residual Add and fusible with current partition + 2. The node is BN and/or Relu and fusible with preceding Conv op (Conv-Act fusion) + 3. The node is MaxPool following a Conv-Act pattern (Conv-Act-Pool fusion) + 4. The node is a residual Add and fusible with current partition Args: cur_node: Current candidate node for the partition. @@ -135,7 +136,7 @@ def _is_fusible_mul(mul_node: Node) -> bool: and _is_cask_fusible(consumer_node, partition_node_outputs) ) or ( - consumer_node.op in ["BatchNormalization", "Relu"] + consumer_node.op in ["MaxPool", "BatchNormalization", "Relu"] and get_fusible_backbone(consumer_node, graph) ) or _is_on_non_residual_path(consumer_node) From 3222ba38b70b243d1ea5ab12d2f3e95a0b5d622d Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 17 Oct 2025 11:04:04 -0400 Subject: [PATCH 02/11] nit: Adjusted order to match Conv-Act-Pool sequence Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/quantization/partitioning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/onnx/quantization/partitioning.py b/modelopt/onnx/quantization/partitioning.py index a74f8b978..7a53674d6 100644 --- a/modelopt/onnx/quantization/partitioning.py +++ b/modelopt/onnx/quantization/partitioning.py @@ -136,7 +136,7 @@ def _is_fusible_mul(mul_node: Node) -> bool: and _is_cask_fusible(consumer_node, partition_node_outputs) ) or ( - consumer_node.op in ["MaxPool", "BatchNormalization", "Relu"] + consumer_node.op in ["BatchNormalization", "Relu", "MaxPool"] and get_fusible_backbone(consumer_node, graph) ) or _is_on_non_residual_path(consumer_node) From 49d7c389a18cbddf05f919634da850e1e123be97 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 17 Oct 2025 11:05:21 -0400 Subject: [PATCH 03/11] Added unittest Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- .../onnx_quantization/lib_test_models.py | 125 ++++++++++++++++++ tests/unit/onnx/test_quantize_int8.py | 43 +++++- 2 files changed, 162 insertions(+), 6 deletions(-) diff --git a/tests/_test_utils/onnx_quantization/lib_test_models.py b/tests/_test_utils/onnx_quantization/lib_test_models.py index e6846b9ec..2e3b98582 100644 --- a/tests/_test_utils/onnx_quantization/lib_test_models.py +++ b/tests/_test_utils/onnx_quantization/lib_test_models.py @@ -673,3 +673,128 @@ def build_conv_batchnorm_sig_mul_model(): onnx.checker.check_model(model_inferred) return model_inferred + + +def build_conv_act_pool_model(): + # Define your model inputs and outputs + input_names = ["input_0"] + output_names = ["output_0"] + input_shapes = [(32, 64, 256, 256)] + output_shapes = [(32, 128, 128, 128)] + + inputs = [ + helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape) + for input_name, input_shape in zip(input_names, input_shapes) + ] + outputs = [ + helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape) + for output_name, output_shape in zip(output_names, output_shapes) + ] + + # Create the ONNX graph with the nodes + nodes = [ + helper.make_node( + op_type="Conv", + inputs=["input_0", "weights_1", "bias_1"], + outputs=["conv1_conv/Conv2D:0"], + name="conv1_conv/Conv2D", + dilations=[1, 1], + group=1, + kernel_shape=[3, 3], + pads=[0, 0, 0, 0], + strides=[1, 1], + ), + helper.make_node( + op_type="BatchNormalization", + inputs=["conv1_conv/Conv2D:0", "bn1_scale", "bn1_bias", "bn1_mean", "bn1_var"], + outputs=["bn1_batchnorm/BatchNormalization:0"], + name="bn1_batchnorm/BatchNormalization", + ), + helper.make_node( + op_type="Relu", + inputs=["bn1_batchnorm/BatchNormalization:0"], + outputs=["relu1_relu/Relu:0"], + name="relu1_relu/Relu", + ), + helper.make_node( + op_type="MaxPool", + inputs=["relu1_relu/Relu:0"], + outputs=["maxpool1_maxpool/MaxPool2D:0"], + name="maxpool1_maxpool/MaxPool2D", + ceil_mode=False, + kernel_shape=[3, 3], + pads=[0, 0, 0, 0], + strides=[2, 2], + ), + helper.make_node( + op_type="Conv", + inputs=["maxpool1_maxpool/MaxPool2D:0", "weights_2"], + outputs=["output_0"], + name="conv2_conv/Conv2D", + dilations=[1, 1], + group=1, + kernel_shape=[3, 3], + pads=[0, 0, 0, 0], + strides=[1, 1], + ), + ] + + # Create the ONNX initializers + initializers = [ + helper.make_tensor( + name="weights_1", + data_type=onnx.TensorProto.FLOAT, + dims=(128, 64, 3, 3), + vals=np.random.uniform(low=0.5, high=1.0, size=128 * 64 * 3 * 3), + ), + helper.make_tensor( + name="bias_1", + data_type=onnx.TensorProto.FLOAT, + dims=(128,), + vals=np.random.uniform(low=0.5, high=1.0, size=128), + ), + helper.make_tensor( + name="bn1_scale", + data_type=onnx.TensorProto.FLOAT, + dims=(128,), + vals=np.random.uniform(low=0.5, high=1.0, size=128), + ), + helper.make_tensor( + name="bn1_bias", + data_type=onnx.TensorProto.FLOAT, + dims=(128,), + vals=np.random.uniform(low=0.5, high=1.0, size=128), + ), + helper.make_tensor( + name="bn1_mean", + data_type=onnx.TensorProto.FLOAT, + dims=(128,), + vals=np.random.uniform(low=0.5, high=1.0, size=128), + ), + helper.make_tensor( + name="bn1_var", + data_type=onnx.TensorProto.FLOAT, + dims=(128,), + vals=np.random.uniform(low=0.5, high=1.0, size=128), + ), + helper.make_tensor( + name="weights_2", + data_type=onnx.TensorProto.FLOAT, + dims=(128, 128, 3, 3), + vals=np.random.uniform(low=0.5, high=1.0, size=128 * 128 * 3 * 3), + ), + ] + + # Create the ONNX graph with the nodes and initializers + graph = helper.make_graph(nodes, "conv_act_pool", inputs, outputs, initializer=initializers) + + # Create the ONNX model + model = helper.make_model(graph) + model.opset_import[0].version = 13 + model.ir_version = 10 + + # Check the ONNX model + model_inferred = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model_inferred) + + return model_inferred diff --git a/tests/unit/onnx/test_quantize_int8.py b/tests/unit/onnx/test_quantize_int8.py index b474558f8..ce843153e 100644 --- a/tests/unit/onnx/test_quantize_int8.py +++ b/tests/unit/onnx/test_quantize_int8.py @@ -21,6 +21,7 @@ import torch from _test_utils.onnx_quantization.lib_test_models import ( SimpleMLP, + build_conv_act_pool_model, build_convtranspose_conv_residual_model, export_as_onnx, ) @@ -29,13 +30,18 @@ from modelopt.onnx.utils import save_onnx -def _assert_nodes_are_quantized(nodes): +def _assert_nodes_quantization(nodes, should_be_quantized=True): for node in nodes: for inp_idx, inp in enumerate(node.inputs): if isinstance(inp, gs.Variable): - assert node.i(inp_idx).op == "DequantizeLinear", ( - f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!" - ) + if should_be_quantized: + assert node.i(inp_idx).op == "DequantizeLinear", ( + f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!" + ) + else: + assert node.i(inp_idx).op != "DequantizeLinear", ( + f"Input '{inp.name}' of node '{node.name}' is quantized but should not be!" + ) return True @@ -59,7 +65,7 @@ def test_int8(tmp_path, high_precision_dtype): # Check that all MatMul nodes are quantized mm_nodes = [n for n in graph.nodes if n.op == "MatMul"] - assert _assert_nodes_are_quantized(mm_nodes) + assert _assert_nodes_quantization(mm_nodes) def test_convtranspose_conv_residual_int8(tmp_path): @@ -80,7 +86,7 @@ def test_convtranspose_conv_residual_int8(tmp_path): # Check that Conv and ConvTransposed are quantized conv_nodes = [n for n in graph.nodes if "Conv" in n.op] - assert _assert_nodes_are_quantized(conv_nodes) + assert _assert_nodes_quantization(conv_nodes) # Check that only 1 input of Add is quantized add_nodes = [n for n in graph.nodes if n.op == "Add"] @@ -89,3 +95,28 @@ def test_convtranspose_conv_residual_int8(tmp_path): assert len(quantized_inputs) == 1, ( f"More than one input of {node.name} is being quantized, but only one should be quantized!" ) + + +def test_conv_act_pool_int8(tmp_path): + onnx_model = build_conv_act_pool_model() + onnx_path = os.path.join(tmp_path, "conv_act_pool_model.onnx") + save_onnx(onnx_model, onnx_path) + + moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16") + + # Output model should be produced in the same tmp_path + output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") + + # Check that quantized explicit model is generated + assert os.path.isfile(output_onnx_path) + + # Load the output model and check QDQ node placements + graph = gs.import_onnx(onnx.load(output_onnx_path)) + + # Check that Conv is quantized + conv_nodes = [n for n in graph.nodes if n.op == "Conv"] + assert _assert_nodes_quantization(conv_nodes) + + # Check that MaxPool is not quantized + pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"] + assert _assert_nodes_quantization(pool_nodes, should_be_quantized=False) From a96b10817c28c1ffcf90abe2e0bb80cd118986fb Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 17 Oct 2025 11:05:41 -0400 Subject: [PATCH 04/11] nit: added MaxPool to comment Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/quantization/graph_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index ce252bc8f..a3c49c98a 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -173,7 +173,7 @@ def has_path_type( def get_fusible_backbone(node: Node, graph: Graph) -> Node | None: """Returns the linear backbone node for a given node if it matches the pattern. - TensorRT fuses convolution with BN, Relu etc. when in some specific pattern. + TensorRT fuses convolution with BN, Relu, MaxPool etc. when in some specific pattern. This rule tries to match some of those patterns. Note. BiasAdd and ConstMul are optional in path types. From 2c23352ccdc07eeadd77b70404b9494fef348407 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 17 Oct 2025 12:22:34 -0400 Subject: [PATCH 05/11] [5274346] Skip copy ops in CASK patterns, added unittest Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/op_types.py | 11 ++- modelopt/onnx/quantization/graph_utils.py | 4 +- modelopt/onnx/quantization/partitioning.py | 6 +- .../onnx_quantization/lib_test_models.py | 69 +++++++++++++------ tests/unit/onnx/test_quantize_int8.py | 7 +- 5 files changed, 66 insertions(+), 31 deletions(-) diff --git a/modelopt/onnx/op_types.py b/modelopt/onnx/op_types.py index 04dfa8f65..3e60bc1d5 100644 --- a/modelopt/onnx/op_types.py +++ b/modelopt/onnx/op_types.py @@ -96,9 +96,9 @@ def is_fusible_scaling_op(op_type: str): ] -def is_copy_op(op_type: str): - """Returns whether the given op is a copy operator or not.""" - return op_type in [ +def copy_ops(): + """Returns list of copy operators.""" + return [ "Flatten", "Transpose", "Concat", @@ -118,6 +118,11 @@ def is_copy_op(op_type: str): ] +def is_copy_op(op_type: str): + """Returns whether the given op is a copy operator or not.""" + return op_type in copy_ops() + + def is_linear_op(op_type: str): """Returns whether the given op type is of Linear category or not.""" return op_type in ["Conv", "ConvTranspose", "Gemm", "MatMul"] diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index a3c49c98a..a7ca6db33 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -29,7 +29,7 @@ from onnxruntime.quantization.calibrate import CalibrationDataReader from modelopt.onnx.logging_config import logger -from modelopt.onnx.op_types import is_copy_op, is_linear_op +from modelopt.onnx.op_types import copy_ops, is_copy_op, is_linear_op from modelopt.onnx.quantization.ort_utils import create_inference_session from modelopt.onnx.utils import ( find_lowest_common_ancestor, @@ -207,7 +207,7 @@ def _get_backbone(root: Node): ["Mul", "Sigmoid", "BatchNormalization", conv_type], ] for idx, path_type in enumerate(fusible_linear_path_types): - if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=[]): + if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=copy_ops()): return _get_backbone(node) return None diff --git a/modelopt/onnx/quantization/partitioning.py b/modelopt/onnx/quantization/partitioning.py index 7a53674d6..e238c83ba 100644 --- a/modelopt/onnx/quantization/partitioning.py +++ b/modelopt/onnx/quantization/partitioning.py @@ -44,7 +44,7 @@ def _build_fusible_partition( """Traverses the graph starting from cur_node and updates the fusible_partition list. Add a nodes to the partition if any of these holds: - 1. The node is a unary or binary pointwise operation and fusible by cask + 1. The node is a unary or binary pointwise operation or a copy op and fusible by cask 2. The node is BN and/or Relu and fusible with preceding Conv op (Conv-Act fusion) 3. The node is MaxPool following a Conv-Act pattern (Conv-Act-Pool fusion) 4. The node is a residual Add and fusible with current partition @@ -132,6 +132,10 @@ def _is_fusible_mul(mul_node: Node) -> bool: if ( ( + is_copy_op(consumer_node.op) + and _is_cask_fusible(consumer_node, partition_node_outputs) + ) + or ( is_pointwise_or_elementwise_op(consumer_node.op) and _is_cask_fusible(consumer_node, partition_node_outputs) ) diff --git a/tests/_test_utils/onnx_quantization/lib_test_models.py b/tests/_test_utils/onnx_quantization/lib_test_models.py index 2e3b98582..580cad79f 100644 --- a/tests/_test_utils/onnx_quantization/lib_test_models.py +++ b/tests/_test_utils/onnx_quantization/lib_test_models.py @@ -676,6 +676,7 @@ def build_conv_batchnorm_sig_mul_model(): def build_conv_act_pool_model(): +def build_conv_act_pool_model(include_reshape_node=False): # Define your model inputs and outputs input_names = ["input_0"] output_names = ["output_0"] @@ -701,7 +702,7 @@ def build_conv_act_pool_model(): dilations=[1, 1], group=1, kernel_shape=[3, 3], - pads=[0, 0, 0, 0], + pads=[1, 1, 1, 1], strides=[1, 1], ), helper.make_node( @@ -716,28 +717,43 @@ def build_conv_act_pool_model(): outputs=["relu1_relu/Relu:0"], name="relu1_relu/Relu", ), - helper.make_node( - op_type="MaxPool", - inputs=["relu1_relu/Relu:0"], - outputs=["maxpool1_maxpool/MaxPool2D:0"], - name="maxpool1_maxpool/MaxPool2D", - ceil_mode=False, - kernel_shape=[3, 3], - pads=[0, 0, 0, 0], - strides=[2, 2], - ), - helper.make_node( - op_type="Conv", - inputs=["maxpool1_maxpool/MaxPool2D:0", "weights_2"], - outputs=["output_0"], - name="conv2_conv/Conv2D", - dilations=[1, 1], - group=1, - kernel_shape=[3, 3], - pads=[0, 0, 0, 0], - strides=[1, 1], - ), ] + if include_reshape_node: + nodes.append( + helper.make_node( + op_type="Reshape", + inputs=["relu1_relu/Relu:0", "shape_1"], + outputs=["reshape1_reshape/Reshape:0"], + name="reshape1_reshape/Reshape", + ), + ) + nodes.extend( + [ + helper.make_node( + op_type="MaxPool", + inputs=[ + "reshape1_reshape/Reshape:0" if include_reshape_node else "relu1_relu/Relu:0" + ], + outputs=["maxpool1_maxpool/MaxPool2D:0"], + name="maxpool1_maxpool/MaxPool2D", + ceil_mode=False, + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[2, 2], + ), + helper.make_node( + op_type="Conv", + inputs=["maxpool1_maxpool/MaxPool2D:0", "weights_2"], + outputs=["output_0"], + name="conv2_conv/Conv2D", + dilations=[1, 1], + group=1, + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[1, 1], + ), + ] + ) # Create the ONNX initializers initializers = [ @@ -784,6 +800,15 @@ def build_conv_act_pool_model(): vals=np.random.uniform(low=0.5, high=1.0, size=128 * 128 * 3 * 3), ), ] + if include_reshape_node: + initializers.append( + helper.make_tensor( + name="shape_1", + data_type=onnx.TensorProto.INT64, + dims=(4,), + vals=(32, 128, 256, 256), + ), + ) # Create the ONNX graph with the nodes and initializers graph = helper.make_graph(nodes, "conv_act_pool", inputs, outputs, initializer=initializers) diff --git a/tests/unit/onnx/test_quantize_int8.py b/tests/unit/onnx/test_quantize_int8.py index ce843153e..49d92928e 100644 --- a/tests/unit/onnx/test_quantize_int8.py +++ b/tests/unit/onnx/test_quantize_int8.py @@ -97,9 +97,10 @@ def test_convtranspose_conv_residual_int8(tmp_path): ) -def test_conv_act_pool_int8(tmp_path): - onnx_model = build_conv_act_pool_model() - onnx_path = os.path.join(tmp_path, "conv_act_pool_model.onnx") +@pytest.mark.parametrize("include_reshape_node", [False, True]) +def test_conv_act_pool_int8(tmp_path, include_reshape_node): + onnx_model = build_conv_act_pool_model(include_reshape_node) + onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx") save_onnx(onnx_model, onnx_path) moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16") From cb36cb512d2328718a869982b8ef76d746d320eb Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 17 Oct 2025 13:58:36 -0400 Subject: [PATCH 06/11] Replaced SymbolicShapeInference with infer_shapes Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/quantization/graph_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index a7ca6db33..e11a3a8df 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -1002,7 +1002,6 @@ def find_nodes_from_matmul_to_exclude( logger.debug("No MatMul nodes found in the model") return [] - nodes_to_exclude = [] logger.debug(f"Found {len(matmul_nodes)} MatMul nodes to analyze") if calibration_shapes: From 4ed1713fc4bf7e48d0a6a3538b642ab9a708ae7f Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 17 Oct 2025 14:23:27 -0400 Subject: [PATCH 07/11] nit: copy_ops -> get_copy_ops Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/op_types.py | 4 ++-- modelopt/onnx/quantization/graph_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modelopt/onnx/op_types.py b/modelopt/onnx/op_types.py index 3e60bc1d5..cc94a221f 100644 --- a/modelopt/onnx/op_types.py +++ b/modelopt/onnx/op_types.py @@ -96,7 +96,7 @@ def is_fusible_scaling_op(op_type: str): ] -def copy_ops(): +def get_copy_ops(): """Returns list of copy operators.""" return [ "Flatten", @@ -120,7 +120,7 @@ def copy_ops(): def is_copy_op(op_type: str): """Returns whether the given op is a copy operator or not.""" - return op_type in copy_ops() + return op_type in get_copy_ops() def is_linear_op(op_type: str): diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index e11a3a8df..362f0fc59 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -29,7 +29,7 @@ from onnxruntime.quantization.calibrate import CalibrationDataReader from modelopt.onnx.logging_config import logger -from modelopt.onnx.op_types import copy_ops, is_copy_op, is_linear_op +from modelopt.onnx.op_types import get_copy_ops, is_copy_op, is_linear_op from modelopt.onnx.quantization.ort_utils import create_inference_session from modelopt.onnx.utils import ( find_lowest_common_ancestor, @@ -207,7 +207,7 @@ def _get_backbone(root: Node): ["Mul", "Sigmoid", "BatchNormalization", conv_type], ] for idx, path_type in enumerate(fusible_linear_path_types): - if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=copy_ops()): + if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=get_copy_ops()): return _get_backbone(node) return None From 45aad6dbff5244701c3e5a8463d565cde9894cf7 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 17 Oct 2025 15:12:48 -0400 Subject: [PATCH 08/11] Moved unittest to qdq_rules script Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- tests/unit/onnx/test_qdq_rules_int8.py | 28 ++++++++++++++++++++++++++ tests/unit/onnx/test_quantize_int8.py | 27 ------------------------- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/tests/unit/onnx/test_qdq_rules_int8.py b/tests/unit/onnx/test_qdq_rules_int8.py index 272911d2e..b969abe31 100644 --- a/tests/unit/onnx/test_qdq_rules_int8.py +++ b/tests/unit/onnx/test_qdq_rules_int8.py @@ -18,7 +18,9 @@ import numpy as np import onnx import onnx_graphsurgeon as gs +import pytest from _test_utils.onnx_quantization.lib_test_models import ( + build_conv_act_pool_model, build_conv_batchnorm_sig_mul_model, build_r1a_model, build_resnet_block, @@ -150,3 +152,29 @@ def test_conv_batchnorm_sig_mul_int8(tmp_path): assert len(quantized_inputs) == 1, ( f"More than one input of {node.name} is being quantized, but only one should be quantized!" ) + + +@pytest.mark.parametrize("include_reshape_node", [False, True]) +def test_conv_act_pool_int8(tmp_path, include_reshape_node): + onnx_model = build_conv_act_pool_model(include_reshape_node) + onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx") + save_onnx(onnx_model, onnx_path) + + quantize.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16") + + # Output model should be produced in the same tmp_path + output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") + + # Check that quantized explicit model is generated + assert os.path.isfile(output_onnx_path) + + # Load the output model and check QDQ node placements + graph = gs.import_onnx(onnx.load(output_onnx_path)) + + # Check that Conv is quantized + conv_nodes = [n for n in graph.nodes if n.op == "Conv"] + assert _assert_nodes_are_quantized(conv_nodes) + + # Check that MaxPool is not quantized + pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"] + assert _assert_nodes_are_not_quantized(pool_nodes) diff --git a/tests/unit/onnx/test_quantize_int8.py b/tests/unit/onnx/test_quantize_int8.py index 49d92928e..35c9401f7 100644 --- a/tests/unit/onnx/test_quantize_int8.py +++ b/tests/unit/onnx/test_quantize_int8.py @@ -21,7 +21,6 @@ import torch from _test_utils.onnx_quantization.lib_test_models import ( SimpleMLP, - build_conv_act_pool_model, build_convtranspose_conv_residual_model, export_as_onnx, ) @@ -95,29 +94,3 @@ def test_convtranspose_conv_residual_int8(tmp_path): assert len(quantized_inputs) == 1, ( f"More than one input of {node.name} is being quantized, but only one should be quantized!" ) - - -@pytest.mark.parametrize("include_reshape_node", [False, True]) -def test_conv_act_pool_int8(tmp_path, include_reshape_node): - onnx_model = build_conv_act_pool_model(include_reshape_node) - onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx") - save_onnx(onnx_model, onnx_path) - - moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16") - - # Output model should be produced in the same tmp_path - output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") - - # Check that quantized explicit model is generated - assert os.path.isfile(output_onnx_path) - - # Load the output model and check QDQ node placements - graph = gs.import_onnx(onnx.load(output_onnx_path)) - - # Check that Conv is quantized - conv_nodes = [n for n in graph.nodes if n.op == "Conv"] - assert _assert_nodes_quantization(conv_nodes) - - # Check that MaxPool is not quantized - pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"] - assert _assert_nodes_quantization(pool_nodes, should_be_quantized=False) From fd59bdfe886688dbb78e4e3c06fd9307254f7985 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 17 Oct 2025 15:14:02 -0400 Subject: [PATCH 09/11] Revert modification of nodes_are_quantized function Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- tests/unit/onnx/test_quantize_int8.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/tests/unit/onnx/test_quantize_int8.py b/tests/unit/onnx/test_quantize_int8.py index 35c9401f7..b474558f8 100644 --- a/tests/unit/onnx/test_quantize_int8.py +++ b/tests/unit/onnx/test_quantize_int8.py @@ -29,18 +29,13 @@ from modelopt.onnx.utils import save_onnx -def _assert_nodes_quantization(nodes, should_be_quantized=True): +def _assert_nodes_are_quantized(nodes): for node in nodes: for inp_idx, inp in enumerate(node.inputs): if isinstance(inp, gs.Variable): - if should_be_quantized: - assert node.i(inp_idx).op == "DequantizeLinear", ( - f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!" - ) - else: - assert node.i(inp_idx).op != "DequantizeLinear", ( - f"Input '{inp.name}' of node '{node.name}' is quantized but should not be!" - ) + assert node.i(inp_idx).op == "DequantizeLinear", ( + f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!" + ) return True @@ -64,7 +59,7 @@ def test_int8(tmp_path, high_precision_dtype): # Check that all MatMul nodes are quantized mm_nodes = [n for n in graph.nodes if n.op == "MatMul"] - assert _assert_nodes_quantization(mm_nodes) + assert _assert_nodes_are_quantized(mm_nodes) def test_convtranspose_conv_residual_int8(tmp_path): @@ -85,7 +80,7 @@ def test_convtranspose_conv_residual_int8(tmp_path): # Check that Conv and ConvTransposed are quantized conv_nodes = [n for n in graph.nodes if "Conv" in n.op] - assert _assert_nodes_quantization(conv_nodes) + assert _assert_nodes_are_quantized(conv_nodes) # Check that only 1 input of Add is quantized add_nodes = [n for n in graph.nodes if n.op == "Add"] From 3a1e3fb7f9f9c28642ee8ae7c5cca91b3197ee89 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 24 Oct 2025 16:40:45 -0400 Subject: [PATCH 10/11] Fix: reverted 1 line from rebase + quantize.quantize -> quantize Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- tests/_test_utils/onnx_quantization/lib_test_models.py | 1 - tests/unit/onnx/test_qdq_rules_int8.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/_test_utils/onnx_quantization/lib_test_models.py b/tests/_test_utils/onnx_quantization/lib_test_models.py index 580cad79f..ffd827b62 100644 --- a/tests/_test_utils/onnx_quantization/lib_test_models.py +++ b/tests/_test_utils/onnx_quantization/lib_test_models.py @@ -675,7 +675,6 @@ def build_conv_batchnorm_sig_mul_model(): return model_inferred -def build_conv_act_pool_model(): def build_conv_act_pool_model(include_reshape_node=False): # Define your model inputs and outputs input_names = ["input_0"] diff --git a/tests/unit/onnx/test_qdq_rules_int8.py b/tests/unit/onnx/test_qdq_rules_int8.py index b969abe31..02ef24967 100644 --- a/tests/unit/onnx/test_qdq_rules_int8.py +++ b/tests/unit/onnx/test_qdq_rules_int8.py @@ -160,7 +160,7 @@ def test_conv_act_pool_int8(tmp_path, include_reshape_node): onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx") save_onnx(onnx_model, onnx_path) - quantize.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16") + quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16") # Output model should be produced in the same tmp_path output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") From 041597823f221939f27e782264473f16d7f75848 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Mon, 27 Oct 2025 11:04:43 -0400 Subject: [PATCH 11/11] Move ConvT residual test to qdq_rules Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- tests/unit/onnx/test_qdq_rules_int8.py | 30 +++++++++++++++++++++ tests/unit/onnx/test_quantize_int8.py | 36 +------------------------- 2 files changed, 31 insertions(+), 35 deletions(-) diff --git a/tests/unit/onnx/test_qdq_rules_int8.py b/tests/unit/onnx/test_qdq_rules_int8.py index 02ef24967..e73f82be4 100644 --- a/tests/unit/onnx/test_qdq_rules_int8.py +++ b/tests/unit/onnx/test_qdq_rules_int8.py @@ -22,6 +22,7 @@ from _test_utils.onnx_quantization.lib_test_models import ( build_conv_act_pool_model, build_conv_batchnorm_sig_mul_model, + build_convtranspose_conv_residual_model, build_r1a_model, build_resnet_block, build_resnet_block_with_downsample, @@ -125,6 +126,35 @@ def test_resnet_residual_connection_with_downsample(tmp_path): _check_resnet_residual_connection(onnx_path) +def test_convtranspose_conv_residual_int8(tmp_path): + onnx_model = build_convtranspose_conv_residual_model() + onnx_path = os.path.join(tmp_path, "convtranspose_conv_residual_model.onnx") + save_onnx(onnx_model, onnx_path) + + quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16") + + # Output model should be produced in the same tmp_path + output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") + + # Check that quantized explicit model is generated + assert os.path.isfile(output_onnx_path) + + # Load the output model and check QDQ node placements + graph = gs.import_onnx(onnx.load(output_onnx_path)) + + # Check that Conv and ConvTransposed are quantized + conv_nodes = [n for n in graph.nodes if "Conv" in n.op] + assert _assert_nodes_are_quantized(conv_nodes) + + # Check that only 1 input of Add is quantized + add_nodes = [n for n in graph.nodes if n.op == "Add"] + for node in add_nodes: + quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"] + assert len(quantized_inputs) == 1, ( + f"More than one input of {node.name} is being quantized, but only one should be quantized!" + ) + + def test_conv_batchnorm_sig_mul_int8(tmp_path): onnx_model = build_conv_batchnorm_sig_mul_model() onnx_path = os.path.join(tmp_path, "conv_batchnorm_sig_mul_model.onnx") diff --git a/tests/unit/onnx/test_quantize_int8.py b/tests/unit/onnx/test_quantize_int8.py index b474558f8..2d4d192a0 100644 --- a/tests/unit/onnx/test_quantize_int8.py +++ b/tests/unit/onnx/test_quantize_int8.py @@ -19,14 +19,9 @@ import onnx_graphsurgeon as gs import pytest import torch -from _test_utils.onnx_quantization.lib_test_models import ( - SimpleMLP, - build_convtranspose_conv_residual_model, - export_as_onnx, -) +from _test_utils.onnx_quantization.lib_test_models import SimpleMLP, export_as_onnx import modelopt.onnx.quantization as moq -from modelopt.onnx.utils import save_onnx def _assert_nodes_are_quantized(nodes): @@ -60,32 +55,3 @@ def test_int8(tmp_path, high_precision_dtype): # Check that all MatMul nodes are quantized mm_nodes = [n for n in graph.nodes if n.op == "MatMul"] assert _assert_nodes_are_quantized(mm_nodes) - - -def test_convtranspose_conv_residual_int8(tmp_path): - onnx_model = build_convtranspose_conv_residual_model() - onnx_path = os.path.join(tmp_path, "convtranspose_conv_residual_model.onnx") - save_onnx(onnx_model, onnx_path) - - moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16") - - # Output model should be produced in the same tmp_path - output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") - - # Check that quantized explicit model is generated - assert os.path.isfile(output_onnx_path) - - # Load the output model and check QDQ node placements - graph = gs.import_onnx(onnx.load(output_onnx_path)) - - # Check that Conv and ConvTransposed are quantized - conv_nodes = [n for n in graph.nodes if "Conv" in n.op] - assert _assert_nodes_are_quantized(conv_nodes) - - # Check that only 1 input of Add is quantized - add_nodes = [n for n in graph.nodes if n.op == "Add"] - for node in add_nodes: - quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"] - assert len(quantized_inputs) == 1, ( - f"More than one input of {node.name} is being quantized, but only one should be quantized!" - )