diff --git a/modelopt/onnx/op_types.py b/modelopt/onnx/op_types.py index 04dfa8f65..cc94a221f 100644 --- a/modelopt/onnx/op_types.py +++ b/modelopt/onnx/op_types.py @@ -96,9 +96,9 @@ def is_fusible_scaling_op(op_type: str): ] -def is_copy_op(op_type: str): - """Returns whether the given op is a copy operator or not.""" - return op_type in [ +def get_copy_ops(): + """Returns list of copy operators.""" + return [ "Flatten", "Transpose", "Concat", @@ -118,6 +118,11 @@ def is_copy_op(op_type: str): ] +def is_copy_op(op_type: str): + """Returns whether the given op is a copy operator or not.""" + return op_type in get_copy_ops() + + def is_linear_op(op_type: str): """Returns whether the given op type is of Linear category or not.""" return op_type in ["Conv", "ConvTranspose", "Gemm", "MatMul"] diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index 31ea42764..0ba23fe07 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -29,7 +29,7 @@ from onnxruntime.quantization.calibrate import CalibrationDataReader from modelopt.onnx.logging_config import logger -from modelopt.onnx.op_types import is_copy_op, is_linear_op +from modelopt.onnx.op_types import get_copy_ops, is_copy_op, is_linear_op from modelopt.onnx.quantization.ort_utils import create_inference_session from modelopt.onnx.utils import ( find_lowest_common_ancestor, @@ -173,7 +173,7 @@ def has_path_type( def get_fusible_backbone(node: Node, graph: Graph) -> Node | None: """Returns the linear backbone node for a given node if it matches the pattern. - TensorRT fuses convolution with BN, Relu etc. when in some specific pattern. + TensorRT fuses convolution with BN, Relu, MaxPool etc. when in some specific pattern. This rule tries to match some of those patterns. Note. BiasAdd and ConstMul are optional in path types. @@ -190,7 +190,7 @@ def _get_backbone(root: Node): return root for tensor in root.inputs: - if not isinstance(tensor, Constant): + if not isinstance(tensor, Constant) and tensor.inputs: parent_node = tensor.inputs[0] bb = _get_backbone(parent_node) if bb: @@ -207,7 +207,7 @@ def _get_backbone(root: Node): ["Mul", "Sigmoid", "BatchNormalization", conv_type], ] for idx, path_type in enumerate(fusible_linear_path_types): - if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=[]): + if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=get_copy_ops()): return _get_backbone(node) return None @@ -1002,7 +1002,6 @@ def find_nodes_from_matmul_to_exclude( logger.debug("No MatMul nodes found in the model") return [] - nodes_to_exclude = [] logger.debug(f"Found {len(matmul_nodes)} MatMul nodes to analyze") if calibration_shapes: diff --git a/modelopt/onnx/quantization/partitioning.py b/modelopt/onnx/quantization/partitioning.py index 9e8a9163c..e238c83ba 100644 --- a/modelopt/onnx/quantization/partitioning.py +++ b/modelopt/onnx/quantization/partitioning.py @@ -44,9 +44,10 @@ def _build_fusible_partition( """Traverses the graph starting from cur_node and updates the fusible_partition list. Add a nodes to the partition if any of these holds: - 1. The node is a unary or binary pointwise operation and fusible by cask - 2. The node is BN and/or Relu and fusible with preceding Conv op - 3. The node is a residual Add and fusible with current partition + 1. The node is a unary or binary pointwise operation or a copy op and fusible by cask + 2. The node is BN and/or Relu and fusible with preceding Conv op (Conv-Act fusion) + 3. The node is MaxPool following a Conv-Act pattern (Conv-Act-Pool fusion) + 4. The node is a residual Add and fusible with current partition Args: cur_node: Current candidate node for the partition. @@ -131,11 +132,15 @@ def _is_fusible_mul(mul_node: Node) -> bool: if ( ( + is_copy_op(consumer_node.op) + and _is_cask_fusible(consumer_node, partition_node_outputs) + ) + or ( is_pointwise_or_elementwise_op(consumer_node.op) and _is_cask_fusible(consumer_node, partition_node_outputs) ) or ( - consumer_node.op in ["BatchNormalization", "Relu"] + consumer_node.op in ["BatchNormalization", "Relu", "MaxPool"] and get_fusible_backbone(consumer_node, graph) ) or _is_on_non_residual_path(consumer_node) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 4650e99b2..283b68ea6 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -165,9 +165,7 @@ def get_dynamic_graph_inputs(onnx_model: onnx.ModelProto): List of dynamic inputs. """ graph = gs.import_onnx(onnx_model) - return [ - inp for inp in graph.inputs if -1 in inp.shape or any(isinstance(s, str) for s in inp.shape) - ] + return [inp for inp in graph.inputs if any(isinstance(s, str) or s <= 0 for s in inp.shape)] def _get_all_shapes(container: Any) -> dict[str, list[int]]: diff --git a/tests/_test_utils/onnx/quantization/lib_test_models.py b/tests/_test_utils/onnx/quantization/lib_test_models.py index e6846b9ec..ffd827b62 100644 --- a/tests/_test_utils/onnx/quantization/lib_test_models.py +++ b/tests/_test_utils/onnx/quantization/lib_test_models.py @@ -673,3 +673,152 @@ def build_conv_batchnorm_sig_mul_model(): onnx.checker.check_model(model_inferred) return model_inferred + + +def build_conv_act_pool_model(include_reshape_node=False): + # Define your model inputs and outputs + input_names = ["input_0"] + output_names = ["output_0"] + input_shapes = [(32, 64, 256, 256)] + output_shapes = [(32, 128, 128, 128)] + + inputs = [ + helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape) + for input_name, input_shape in zip(input_names, input_shapes) + ] + outputs = [ + helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape) + for output_name, output_shape in zip(output_names, output_shapes) + ] + + # Create the ONNX graph with the nodes + nodes = [ + helper.make_node( + op_type="Conv", + inputs=["input_0", "weights_1", "bias_1"], + outputs=["conv1_conv/Conv2D:0"], + name="conv1_conv/Conv2D", + dilations=[1, 1], + group=1, + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[1, 1], + ), + helper.make_node( + op_type="BatchNormalization", + inputs=["conv1_conv/Conv2D:0", "bn1_scale", "bn1_bias", "bn1_mean", "bn1_var"], + outputs=["bn1_batchnorm/BatchNormalization:0"], + name="bn1_batchnorm/BatchNormalization", + ), + helper.make_node( + op_type="Relu", + inputs=["bn1_batchnorm/BatchNormalization:0"], + outputs=["relu1_relu/Relu:0"], + name="relu1_relu/Relu", + ), + ] + if include_reshape_node: + nodes.append( + helper.make_node( + op_type="Reshape", + inputs=["relu1_relu/Relu:0", "shape_1"], + outputs=["reshape1_reshape/Reshape:0"], + name="reshape1_reshape/Reshape", + ), + ) + nodes.extend( + [ + helper.make_node( + op_type="MaxPool", + inputs=[ + "reshape1_reshape/Reshape:0" if include_reshape_node else "relu1_relu/Relu:0" + ], + outputs=["maxpool1_maxpool/MaxPool2D:0"], + name="maxpool1_maxpool/MaxPool2D", + ceil_mode=False, + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[2, 2], + ), + helper.make_node( + op_type="Conv", + inputs=["maxpool1_maxpool/MaxPool2D:0", "weights_2"], + outputs=["output_0"], + name="conv2_conv/Conv2D", + dilations=[1, 1], + group=1, + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[1, 1], + ), + ] + ) + + # Create the ONNX initializers + initializers = [ + helper.make_tensor( + name="weights_1", + data_type=onnx.TensorProto.FLOAT, + dims=(128, 64, 3, 3), + vals=np.random.uniform(low=0.5, high=1.0, size=128 * 64 * 3 * 3), + ), + helper.make_tensor( + name="bias_1", + data_type=onnx.TensorProto.FLOAT, + dims=(128,), + vals=np.random.uniform(low=0.5, high=1.0, size=128), + ), + helper.make_tensor( + name="bn1_scale", + data_type=onnx.TensorProto.FLOAT, + dims=(128,), + vals=np.random.uniform(low=0.5, high=1.0, size=128), + ), + helper.make_tensor( + name="bn1_bias", + data_type=onnx.TensorProto.FLOAT, + dims=(128,), + vals=np.random.uniform(low=0.5, high=1.0, size=128), + ), + helper.make_tensor( + name="bn1_mean", + data_type=onnx.TensorProto.FLOAT, + dims=(128,), + vals=np.random.uniform(low=0.5, high=1.0, size=128), + ), + helper.make_tensor( + name="bn1_var", + data_type=onnx.TensorProto.FLOAT, + dims=(128,), + vals=np.random.uniform(low=0.5, high=1.0, size=128), + ), + helper.make_tensor( + name="weights_2", + data_type=onnx.TensorProto.FLOAT, + dims=(128, 128, 3, 3), + vals=np.random.uniform(low=0.5, high=1.0, size=128 * 128 * 3 * 3), + ), + ] + if include_reshape_node: + initializers.append( + helper.make_tensor( + name="shape_1", + data_type=onnx.TensorProto.INT64, + dims=(4,), + vals=(32, 128, 256, 256), + ), + ) + + # Create the ONNX graph with the nodes and initializers + graph = helper.make_graph(nodes, "conv_act_pool", inputs, outputs, initializer=initializers) + + # Create the ONNX model + model = helper.make_model(graph) + model.opset_import[0].version = 13 + model.ir_version = 10 + + # Check the ONNX model + model_inferred = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model_inferred) + + return model_inferred diff --git a/tests/unit/onnx/test_qdq_rules_int8.py b/tests/unit/onnx/test_qdq_rules_int8.py index b11c9f496..3f6104427 100644 --- a/tests/unit/onnx/test_qdq_rules_int8.py +++ b/tests/unit/onnx/test_qdq_rules_int8.py @@ -18,8 +18,11 @@ import numpy as np import onnx import onnx_graphsurgeon as gs +import pytest from _test_utils.onnx.quantization.lib_test_models import ( + build_conv_act_pool_model, build_conv_batchnorm_sig_mul_model, + build_convtranspose_conv_residual_model, build_r1a_model, build_resnet_block, build_resnet_block_with_downsample, @@ -40,7 +43,7 @@ def assert_nodes_are_quantized(nodes): return True -def _assert_nodes_are_not_quantized(nodes): +def assert_nodes_are_not_quantized(nodes): for node in nodes: for inp_idx, inp in enumerate(node.inputs): if isinstance(inp, gs.Variable) and inp.inputs: @@ -76,7 +79,7 @@ def test_bias_add_rule(tmp_path): other_nodes = [ n for n in graph.nodes if n.op not in ["Conv", "QuantizeLinear", "DequantizeLinear"] ] - assert _assert_nodes_are_not_quantized(other_nodes) + assert assert_nodes_are_not_quantized(other_nodes) def _check_resnet_residual_connection(onnx_path): @@ -106,7 +109,7 @@ def _check_resnet_residual_connection(onnx_path): other_nodes = [ n for n in graph.nodes if n.op not in ["Conv", "Add", "QuantizeLinear", "DequantizeLinear"] ] - assert _assert_nodes_are_not_quantized(other_nodes) + assert assert_nodes_are_not_quantized(other_nodes) def test_resnet_residual_connections(tmp_path): @@ -123,6 +126,35 @@ def test_resnet_residual_connection_with_downsample(tmp_path): _check_resnet_residual_connection(onnx_path) +def test_convtranspose_conv_residual_int8(tmp_path): + onnx_model = build_convtranspose_conv_residual_model() + onnx_path = os.path.join(tmp_path, "convtranspose_conv_residual_model.onnx") + save_onnx(onnx_model, onnx_path) + + quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16") + + # Output model should be produced in the same tmp_path + output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") + + # Check that quantized explicit model is generated + assert os.path.isfile(output_onnx_path) + + # Load the output model and check QDQ node placements + graph = gs.import_onnx(onnx.load(output_onnx_path)) + + # Check that Conv and ConvTransposed are quantized + conv_nodes = [n for n in graph.nodes if "Conv" in n.op] + assert assert_nodes_are_quantized(conv_nodes) + + # Check that only 1 input of Add is quantized + add_nodes = [n for n in graph.nodes if n.op == "Add"] + for node in add_nodes: + quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"] + assert len(quantized_inputs) == 1, ( + f"More than one input of {node.name} is being quantized, but only one should be quantized!" + ) + + def test_conv_batchnorm_sig_mul_int8(tmp_path): onnx_model = build_conv_batchnorm_sig_mul_model() onnx_path = os.path.join(tmp_path, "conv_batchnorm_sig_mul_model.onnx") @@ -150,3 +182,29 @@ def test_conv_batchnorm_sig_mul_int8(tmp_path): assert len(quantized_inputs) == 1, ( f"More than one input of {node.name} is being quantized, but only one should be quantized!" ) + + +@pytest.mark.parametrize("include_reshape_node", [False, True]) +def test_conv_act_pool_int8(tmp_path, include_reshape_node): + onnx_model = build_conv_act_pool_model(include_reshape_node) + onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx") + save_onnx(onnx_model, onnx_path) + + quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16") + + # Output model should be produced in the same tmp_path + output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") + + # Check that quantized explicit model is generated + assert os.path.isfile(output_onnx_path) + + # Load the output model and check QDQ node placements + graph = gs.import_onnx(onnx.load(output_onnx_path)) + + # Check that Conv is quantized + conv_nodes = [n for n in graph.nodes if n.op == "Conv"] + assert assert_nodes_are_quantized(conv_nodes) + + # Check that MaxPool is not quantized + pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"] + assert assert_nodes_are_not_quantized(pool_nodes) diff --git a/tests/unit/onnx/test_quantize_int8.py b/tests/unit/onnx/test_quantize_int8.py index 703360aaa..31c84eff1 100644 --- a/tests/unit/onnx/test_quantize_int8.py +++ b/tests/unit/onnx/test_quantize_int8.py @@ -19,14 +19,9 @@ import onnx_graphsurgeon as gs import pytest import torch -from _test_utils.onnx.quantization.lib_test_models import ( - SimpleMLP, - build_convtranspose_conv_residual_model, - export_as_onnx, -) +from _test_utils.onnx.quantization.lib_test_models import SimpleMLP, export_as_onnx import modelopt.onnx.quantization as moq -from modelopt.onnx.utils import save_onnx def assert_nodes_are_quantized(nodes): @@ -60,32 +55,3 @@ def test_int8(tmp_path, high_precision_dtype): # Check that all MatMul nodes are quantized mm_nodes = [n for n in graph.nodes if n.op == "MatMul"] assert assert_nodes_are_quantized(mm_nodes) - - -def test_convtranspose_conv_residual_int8(tmp_path): - onnx_model = build_convtranspose_conv_residual_model() - onnx_path = os.path.join(tmp_path, "convtranspose_conv_residual_model.onnx") - save_onnx(onnx_model, onnx_path) - - moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16") - - # Output model should be produced in the same tmp_path - output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") - - # Check that quantized explicit model is generated - assert os.path.isfile(output_onnx_path) - - # Load the output model and check QDQ node placements - graph = gs.import_onnx(onnx.load(output_onnx_path)) - - # Check that Conv and ConvTransposed are quantized - conv_nodes = [n for n in graph.nodes if "Conv" in n.op] - assert assert_nodes_are_quantized(conv_nodes) - - # Check that only 1 input of Add is quantized - add_nodes = [n for n in graph.nodes if n.op == "Add"] - for node in add_nodes: - quantized_inputs = [inp for inp in node.inputs if inp.inputs[0].op == "DequantizeLinear"] - assert len(quantized_inputs) == 1, ( - f"More than one input of {node.name} is being quantized, but only one should be quantized!" - )