Skip to content
11 changes: 8 additions & 3 deletions modelopt/onnx/op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def is_fusible_scaling_op(op_type: str):
]


def is_copy_op(op_type: str):
"""Returns whether the given op is a copy operator or not."""
return op_type in [
def get_copy_ops():
"""Returns list of copy operators."""
return [
"Flatten",
"Transpose",
"Concat",
Expand All @@ -118,6 +118,11 @@ def is_copy_op(op_type: str):
]


def is_copy_op(op_type: str):
"""Returns whether the given op is a copy operator or not."""
return op_type in get_copy_ops()


def is_linear_op(op_type: str):
"""Returns whether the given op type is of Linear category or not."""
return op_type in ["Conv", "ConvTranspose", "Gemm", "MatMul"]
Expand Down
7 changes: 3 additions & 4 deletions modelopt/onnx/quantization/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from onnxruntime.quantization.calibrate import CalibrationDataReader

from modelopt.onnx.logging_config import logger
from modelopt.onnx.op_types import is_copy_op, is_linear_op
from modelopt.onnx.op_types import get_copy_ops, is_copy_op, is_linear_op
from modelopt.onnx.quantization.ort_utils import create_inference_session
from modelopt.onnx.utils import (
find_lowest_common_ancestor,
Expand Down Expand Up @@ -173,7 +173,7 @@ def has_path_type(
def get_fusible_backbone(node: Node, graph: Graph) -> Node | None:
"""Returns the linear backbone node for a given node if it matches the pattern.

TensorRT fuses convolution with BN, Relu etc. when in some specific pattern.
TensorRT fuses convolution with BN, Relu, MaxPool etc. when in some specific pattern.
This rule tries to match some of those patterns.
Note. BiasAdd and ConstMul are optional in path types.

Expand Down Expand Up @@ -207,7 +207,7 @@ def _get_backbone(root: Node):
["Mul", "Sigmoid", "BatchNormalization", conv_type],
]
for idx, path_type in enumerate(fusible_linear_path_types):
if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=[]):
if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=get_copy_ops()):
return _get_backbone(node)

return None
Expand Down Expand Up @@ -1002,7 +1002,6 @@ def find_nodes_from_matmul_to_exclude(
logger.debug("No MatMul nodes found in the model")
return []

nodes_to_exclude = []
logger.debug(f"Found {len(matmul_nodes)} MatMul nodes to analyze")

if calibration_shapes:
Expand Down
13 changes: 9 additions & 4 deletions modelopt/onnx/quantization/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ def _build_fusible_partition(
"""Traverses the graph starting from cur_node and updates the fusible_partition list.

Add a nodes to the partition if any of these holds:
1. The node is a unary or binary pointwise operation and fusible by cask
2. The node is BN and/or Relu and fusible with preceding Conv op
3. The node is a residual Add and fusible with current partition
1. The node is a unary or binary pointwise operation or a copy op and fusible by cask
2. The node is BN and/or Relu and fusible with preceding Conv op (Conv-Act fusion)
3. The node is MaxPool following a Conv-Act pattern (Conv-Act-Pool fusion)
4. The node is a residual Add and fusible with current partition

Args:
cur_node: Current candidate node for the partition.
Expand Down Expand Up @@ -131,11 +132,15 @@ def _is_fusible_mul(mul_node: Node) -> bool:

if (
(
is_copy_op(consumer_node.op)
and _is_cask_fusible(consumer_node, partition_node_outputs)
)
or (
is_pointwise_or_elementwise_op(consumer_node.op)
and _is_cask_fusible(consumer_node, partition_node_outputs)
)
or (
consumer_node.op in ["BatchNormalization", "Relu"]
consumer_node.op in ["BatchNormalization", "Relu", "MaxPool"]
and get_fusible_backbone(consumer_node, graph)
)
or _is_on_non_residual_path(consumer_node)
Expand Down
149 changes: 149 additions & 0 deletions tests/_test_utils/onnx_quantization/lib_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,3 +673,152 @@ def build_conv_batchnorm_sig_mul_model():
onnx.checker.check_model(model_inferred)

return model_inferred


def build_conv_act_pool_model(include_reshape_node=False):
# Define your model inputs and outputs
input_names = ["input_0"]
output_names = ["output_0"]
input_shapes = [(32, 64, 256, 256)]
output_shapes = [(32, 128, 128, 128)]

inputs = [
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
for input_name, input_shape in zip(input_names, input_shapes)
]
outputs = [
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
for output_name, output_shape in zip(output_names, output_shapes)
]

# Create the ONNX graph with the nodes
nodes = [
helper.make_node(
op_type="Conv",
inputs=["input_0", "weights_1", "bias_1"],
outputs=["conv1_conv/Conv2D:0"],
name="conv1_conv/Conv2D",
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
helper.make_node(
op_type="BatchNormalization",
inputs=["conv1_conv/Conv2D:0", "bn1_scale", "bn1_bias", "bn1_mean", "bn1_var"],
outputs=["bn1_batchnorm/BatchNormalization:0"],
name="bn1_batchnorm/BatchNormalization",
),
helper.make_node(
op_type="Relu",
inputs=["bn1_batchnorm/BatchNormalization:0"],
outputs=["relu1_relu/Relu:0"],
name="relu1_relu/Relu",
),
]
if include_reshape_node:
nodes.append(
helper.make_node(
op_type="Reshape",
inputs=["relu1_relu/Relu:0", "shape_1"],
outputs=["reshape1_reshape/Reshape:0"],
name="reshape1_reshape/Reshape",
),
)
nodes.extend(
[
helper.make_node(
op_type="MaxPool",
inputs=[
"reshape1_reshape/Reshape:0" if include_reshape_node else "relu1_relu/Relu:0"
],
outputs=["maxpool1_maxpool/MaxPool2D:0"],
name="maxpool1_maxpool/MaxPool2D",
ceil_mode=False,
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[2, 2],
),
helper.make_node(
op_type="Conv",
inputs=["maxpool1_maxpool/MaxPool2D:0", "weights_2"],
outputs=["output_0"],
name="conv2_conv/Conv2D",
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
]
)

# Create the ONNX initializers
initializers = [
helper.make_tensor(
name="weights_1",
data_type=onnx.TensorProto.FLOAT,
dims=(128, 64, 3, 3),
vals=np.random.uniform(low=0.5, high=1.0, size=128 * 64 * 3 * 3),
),
helper.make_tensor(
name="bias_1",
data_type=onnx.TensorProto.FLOAT,
dims=(128,),
vals=np.random.uniform(low=0.5, high=1.0, size=128),
),
helper.make_tensor(
name="bn1_scale",
data_type=onnx.TensorProto.FLOAT,
dims=(128,),
vals=np.random.uniform(low=0.5, high=1.0, size=128),
),
helper.make_tensor(
name="bn1_bias",
data_type=onnx.TensorProto.FLOAT,
dims=(128,),
vals=np.random.uniform(low=0.5, high=1.0, size=128),
),
helper.make_tensor(
name="bn1_mean",
data_type=onnx.TensorProto.FLOAT,
dims=(128,),
vals=np.random.uniform(low=0.5, high=1.0, size=128),
),
helper.make_tensor(
name="bn1_var",
data_type=onnx.TensorProto.FLOAT,
dims=(128,),
vals=np.random.uniform(low=0.5, high=1.0, size=128),
),
helper.make_tensor(
name="weights_2",
data_type=onnx.TensorProto.FLOAT,
dims=(128, 128, 3, 3),
vals=np.random.uniform(low=0.5, high=1.0, size=128 * 128 * 3 * 3),
),
]
if include_reshape_node:
initializers.append(
helper.make_tensor(
name="shape_1",
data_type=onnx.TensorProto.INT64,
dims=(4,),
vals=(32, 128, 256, 256),
),
)

# Create the ONNX graph with the nodes and initializers
graph = helper.make_graph(nodes, "conv_act_pool", inputs, outputs, initializer=initializers)

# Create the ONNX model
model = helper.make_model(graph)
model.opset_import[0].version = 13
model.ir_version = 10

# Check the ONNX model
model_inferred = onnx.shape_inference.infer_shapes(model)
onnx.checker.check_model(model_inferred)

return model_inferred
28 changes: 28 additions & 0 deletions tests/unit/onnx/test_qdq_rules_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import numpy as np
import onnx
import onnx_graphsurgeon as gs
import pytest
from _test_utils.onnx_quantization.lib_test_models import (
build_conv_act_pool_model,
build_conv_batchnorm_sig_mul_model,
build_r1a_model,
build_resnet_block,
Expand Down Expand Up @@ -150,3 +152,29 @@ def test_conv_batchnorm_sig_mul_int8(tmp_path):
assert len(quantized_inputs) == 1, (
f"More than one input of {node.name} is being quantized, but only one should be quantized!"
)


@pytest.mark.parametrize("include_reshape_node", [False, True])
def test_conv_act_pool_int8(tmp_path, include_reshape_node):
onnx_model = build_conv_act_pool_model(include_reshape_node)
onnx_path = os.path.join(tmp_path, f"conv_act_pool_model_{include_reshape_node}.onnx")
save_onnx(onnx_model, onnx_path)

quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16")

# Output model should be produced in the same tmp_path
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")

# Check that quantized explicit model is generated
assert os.path.isfile(output_onnx_path)

# Load the output model and check QDQ node placements
graph = gs.import_onnx(onnx.load(output_onnx_path))

# Check that Conv is quantized
conv_nodes = [n for n in graph.nodes if n.op == "Conv"]
assert _assert_nodes_are_quantized(conv_nodes)

# Check that MaxPool is not quantized
pool_nodes = [n for n in graph.nodes if n.op == "MaxPool"]
assert _assert_nodes_are_not_quantized(pool_nodes)