-
Notifications
You must be signed in to change notification settings - Fork 183
[5271050, 5274346][ONNX] Add support for Conv-Act-Pool fusion #448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[5271050, 5274346][ONNX] Add support for Conv-Act-Pool fusion #448
Conversation
WalkthroughThe PR adds a get_copy_ops() helper, extends fusion/backbone matching to treat copy ops and MaxPool as fusible in quantization graph utilities and partitioning, updates usages to pass wild_card_types=get_copy_ops(), and adds a Conv→BN→ReLU→(optional Reshape)→MaxPool→Conv test model and unit test (duplicate model builder present). Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test Runner
participant Builder as build_conv_act_pool_model()
participant Save as save_onnx()
participant Quant as Quantizer
participant Partition as partitioning._build_fusible_partition()
participant Graph as graph_utils.get_fusible_backbone()
participant Ops as op_types.get_copy_ops()
Test->>Builder: build model (include_reshape_node?)
Builder-->>Test: ONNX model
Test->>Save: persist model
Test->>Quant: quantize to int8
Quant->>Partition: analyze fusible partitions
Partition->>Graph: request fusible backbone
Graph->>Ops: get_copy_ops() (includes copy ops, MaxPool)
Ops-->>Graph: copy op list
Graph-->>Partition: matched backbone paths (Conv-Act(-Pool))
Partition-->>Quant: fusion decisions
Quant-->>Test: quantized model
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (5)
🚧 Files skipped from review as they are similar to previous changes (1)
🧰 Additional context used🧬 Code graph analysis (2)tests/unit/onnx/test_qdq_rules_int8.py (2)
modelopt/onnx/quantization/partitioning.py (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
🔇 Additional comments (7)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #448 +/- ##
==========================================
+ Coverage 73.40% 73.42% +0.01%
==========================================
Files 180 180
Lines 18077 18078 +1
==========================================
+ Hits 13270 13273 +3
+ Misses 4807 4805 -2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
modelopt/onnx/op_types.py (2)
99-118: Clarify “copy” semantics; consider narrowing set and caching.Including Gather*/Scatter*/OneHot here blurs “copy/layout” with indexing, write, and generation ops. This may over-traverse/fuse across nodes that change values. If intentional, document it; otherwise, consider restricting to pure layout/data-movement ops (e.g., Flatten/Transpose/Concat/Split/Squeeze/Reshape/Tile/Expand/Slice). Also avoid rebuilding the list on every call.
Suggested refactor (keeps API, improves perf):
+from functools import lru_cache + - def copy_ops(): - """Returns list of copy operators.""" - return [ +@lru_cache(None) +def copy_ops(): + """Op types treated as copy/layout for traversal and CASK fusibility. + Note: if indexing/scatter/generation ops are intentionally included, keep them and update this docstring accordingly.""" + return [ "Flatten", "Transpose", "Concat", "Split", "Squeeze", "Expand", "ReverseSequence", "Reshape", "Tile", "Gather", "Slice", "GatherElements", "GatherND", "ScatterElements", "ScatterND", "OneHot", ]Would you like this set narrowed to pure layout ops, or is the broader set required for current fusibility rules?
121-124: Minor: use a set for membership.Membership on a list is O(n). If we keep the function, consider a module-level frozenset for checks.
-def is_copy_op(op_type: str): - """Returns whether the given op is a copy operator or not.""" - return op_type in copy_ops() +def is_copy_op(op_type: str): + """Returns whether the given op is a copy operator or not.""" + return op_type in set(copy_ops())modelopt/onnx/quantization/partitioning.py (1)
47-51: Doc nit and scope clarification.
- Grammar: “Add nodes to the partition…”.
- Confirm intent: bullet 3 explicitly mentions MaxPool; if AveragePool is also supported (now or later), call it out or keep it MaxPool-only by design.
- Add a nodes to the partition if any of these holds: + Add nodes to the partition if any of these holds: - 3. The node is MaxPool following a Conv-Act pattern (Conv-Act-Pool fusion) + 3. The node is MaxPool following a Conv-Act pattern (Conv-Act-Pool fusion) + # Note: AveragePool is intentionally excluded.Is AveragePool intentionally excluded from Conv-Act-Pool fusion?
tests/unit/onnx/test_quantize_int8.py (1)
33-45: Simplify helper: no need to return True or wrap with assert.The helper already asserts per input. Returning True and then asserting the return is redundant.
-def _assert_nodes_quantization(nodes, should_be_quantized=True): +def _assert_nodes_quantization(nodes, should_be_quantized=True): for node in nodes: for inp_idx, inp in enumerate(node.inputs): if isinstance(inp, gs.Variable): if should_be_quantized: assert node.i(inp_idx).op == "DequantizeLinear", ( f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!" ) else: assert node.i(inp_idx).op != "DequantizeLinear", ( f"Input '{inp.name}' of node '{node.name}' is quantized but should not be!" ) - return True + return NoneAnd drop outer asserts at call sites (lines 68-69, 89-90, 119-120, 123):
-assert _assert_nodes_quantization(mm_nodes) +_assert_nodes_quantization(mm_nodes)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
modelopt/onnx/op_types.py(2 hunks)modelopt/onnx/quantization/graph_utils.py(3 hunks)modelopt/onnx/quantization/partitioning.py(2 hunks)tests/_test_utils/onnx_quantization/lib_test_models.py(1 hunks)tests/unit/onnx/test_quantize_int8.py(5 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt/onnx/quantization/graph_utils.py
- tests/_test_utils/onnx_quantization/lib_test_models.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/onnx/quantization/partitioning.py (1)
modelopt/onnx/op_types.py (2)
is_copy_op(121-123)is_pointwise_or_elementwise_op(131-136)
tests/unit/onnx/test_quantize_int8.py (1)
tests/_test_utils/onnx_quantization/lib_test_models.py (1)
build_conv_act_pool_model(559-705)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: wait-checks / wait
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
tests/unit/onnx/test_quantize_int8.py (1)
100-123: Good coverage for Conv-Act-Pool (+optional Reshape).Parametrizing include_reshape_node validates copy-op traversal and MaxPool staying in the quantized domain. LGTM.
modelopt/onnx/quantization/partitioning.py (1)
135-146: Verification confirms fusion rule extension is sound.The script output confirms all three verification points:
- get_fusible_backbone recognizes MaxPool chains via the fusible_linear_path_types including
["MaxPool", "Relu", "BatchNormalization", "BiasAdd", conv_type]- copy_ops() is used as wildcards in path traversal with
has_path_type(..., wild_card_types=copy_ops())- The MaxPool fusion in partitioning.py (lines 143-144) is properly gated by get_fusible_backbone
No issues found. Code changes align with backbone discovery logic and wildcard handling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/onnx/quantization/graph_utils.py (2)
185-207: Two issues: missing guard in _get_backbone() and incomplete Conv‑Act‑Pool patterns.
- Runtime risk: IndexError when a tensor is a graph input (len(inputs)==0).
- Pattern gap: Conv→Act→Pool without BN (common) won’t match; current list requires BN before MaxPool.
Apply this diff:
@@ def get_fusible_backbone(node: Node, graph: Graph) -> Node | None: @@ - def _get_backbone(root: Node): + def _get_backbone(root: Node): if root.op in ["Conv", "ConvTranspose"]: return root - - for tensor in root.inputs: - if not isinstance(tensor, Constant): - parent_node = tensor.inputs[0] - bb = _get_backbone(parent_node) - if bb: - return bb + for tensor in root.inputs: + if isinstance(tensor, Constant): + continue + # Guard against graph inputs (no producers) + if len(tensor.inputs) == 0: + continue + parent_node = tensor.inputs[0] + bb = _get_backbone(parent_node) + if bb: + return bb @@ - for conv_type in ["Conv", "ConvTranspose"]: + for conv_type in ["Conv", "ConvTranspose"]: fusible_linear_path_types += [ ["BiasAdd", "ConstMul", conv_type], ["Relu", "BiasAdd", "ConstMul", conv_type], ["BatchNormalization", "BiasAdd", conv_type], ["Relu", "BatchNormalization", "BiasAdd", conv_type], - ["MaxPool", "Relu", "BatchNormalization", "BiasAdd", conv_type], + # Conv-Act-Pool (with BN) + ["MaxPool", "Relu", "BatchNormalization", "BiasAdd", conv_type], + # Conv-Act-Pool (no BN) + ["MaxPool", "Relu", "BiasAdd", "ConstMul", conv_type], ]Additionally, consider supporting other TRT-fusible activations (e.g., LeakyRelu/Clip) if applicable.
1060-1111: Robustness of MatMul shape inference: include graph outputs and use stricter infer options.
- value_info_map only scans value_info; MatMul outputs that are graph outputs won’t be found → RuntimeError.
- Prefer stricter inference (check types, data propagation) to reduce “unknown” shapes.
Apply this diff:
@@ - model.graph.ClearField("value_info") - model = infer_shapes(model) - value_info_map = {vi.name: vi for vi in model.graph.value_info} + model.graph.ClearField("value_info") + # Enable stricter inference to populate more shapes + model = infer_shapes(model, check_type=True, data_prop=True) + value_info_map = {vi.name: vi for vi in model.graph.value_info} + # Also consider graph outputs (MatMul outputs may be final graph outputs) + value_info_map.update({vi.name: vi for vi in model.graph.output}) @@ - value_info = value_info_map.get(output_name) - if not value_info: - raise RuntimeError(f"Shape inference did not find shape for {output_name}.") + value_info = value_info_map.get(output_name) + if not value_info: + # Skip if shape couldn't be inferred; fallback will be inference-based path upstream + logger.debug(f"Shape inference missing for {output_name}; skipping MatMul '{matmul_node.name}'.") + continue
♻️ Duplicate comments (1)
modelopt/onnx/quantization/graph_utils.py (1)
1090-1090: Minor: pass explicit kwargs to infer_shapes for consistency.Already covered in the diff above by adding check_type=True, data_prop=True.
🧹 Nitpick comments (1)
modelopt/onnx/quantization/graph_utils.py (1)
206-207: Mutable default arg in has_path_type() can leak state across calls.path_nodes defaulting to [] is mutated (append) and reused. Make defaults None and init inside.
Apply this diff:
@@ -def has_path_type( +def has_path_type( node: Node, graph: Graph, path_type: list[str], is_forward: bool, - wild_card_types: list[str] = [], - path_nodes: list[Node] = [], + wild_card_types: list[str] | None = None, + path_nodes: list[Node] | None = None, ) -> bool: @@ - optional_path_types = ["BiasAdd", "ConstMul"] + optional_path_types = ["BiasAdd", "ConstMul"] + if wild_card_types is None: + wild_card_types = [] + if path_nodes is None: + path_nodes = []
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/onnx/quantization/graph_utils.py(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/graph_utils.py (2)
modelopt/onnx/op_types.py (3)
copy_ops(99-118)is_copy_op(121-123)is_linear_op(126-128)modelopt/onnx/utils.py (4)
find_lowest_common_ancestor(572-613)get_child_nodes(625-628)get_parent_nodes(616-622)infer_shapes(723-736)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (3)
modelopt/onnx/quantization/graph_utils.py (3)
32-32: Centralizing wildcard ops via copy_ops() looks good.Helps keep path matching consistent with op_types.
173-176: Docstring update aligns with intent to include MaxPool.Matches PR goal of Conv-Act-Pool fusion mention.
34-41: ****The review incorrectly identifies the bug. The actual code at line 603 of
modelopt/onnx/utils.pyreads:lowest_common_ancestor = common_ancestors.pop()Not
lowest_common_ancestor = common_ancestoras claimed. The variablecommon_ancestors(plural) is a properly defined set from line 601, and.pop()is a valid set method. There is no undefined variable, and no build paths would be broken.Likely an incorrect or invalid review comment.
99766cd to
858c3da
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/onnx/quantization/graph_utils.py (1)
173-205: Fix mutable default arguments inhas_path_typefunction signature (line 84-90).The function has mutable default parameters
wild_card_types: list[str] = [](line 89) andpath_nodes: list[Node] = [](line 90). Sincepath_nodesis modified viaappend()at line 150, these lists persist across function calls, causing state leakage. When callers omitpath_nodes(e.g., line 206), the same default list accumulates nodes from previous invocations.Apply the suggested refactor to the function signature:
def has_path_type( node: Node, graph: Graph, path_type: list[str], is_forward: bool, wild_card_types: list[str] | None = None, path_nodes: list[Node] | None = None, ) -> bool: if wild_card_types is None: wild_card_types = [] if path_nodes is None: path_nodes = [] # ... rest of function
♻️ Duplicate comments (1)
modelopt/onnx/quantization/graph_utils.py (1)
1060-1110: Shape-inference path drops dict calibration_shapes (re-raising prior feedback).Only string specs are parsed; dict inputs are ignored, defaulting dims to 1 and misclassifying MatMuls. Please accept dicts directly.
Apply this diff:
@@ def _exclude_matmuls_by_shape_inference( - input_shapes = ( - parse_shapes_spec(calibration_shapes) - if (calibration_shapes and isinstance(calibration_shapes, str)) - else {} - ) + input_shapes = {} + if calibration_shapes: + if isinstance(calibration_shapes, str): + input_shapes = parse_shapes_spec(calibration_shapes) + elif isinstance(calibration_shapes, dict): + input_shapes = calibration_shapesAdditionally, consider using both value_info and outputs in the lookup to be resilient to where ONNX stores inferred shapes:
- value_info_map = {vi.name: vi for vi in model.graph.value_info} + value_info_map = {vi.name: vi for vi in model.graph.value_info} + value_info_map.update({vi.name: vi for vi in model.graph.output})
🧹 Nitpick comments (2)
modelopt/onnx/op_types.py (1)
99-119: Broaden copy-ops: include Unsqueeze and Identity.These are commonly treated as copy/no-op in traversals; adding them improves wildcard path robustness for fusion.
Apply this diff:
def get_copy_ops(): """Returns list of copy operators.""" return [ "Flatten", "Transpose", "Concat", "Split", "Squeeze", + "Unsqueeze", "Expand", "ReverseSequence", "Reshape", "Tile", "Gather", "Slice", "GatherElements", "GatherND", "ScatterElements", "ScatterND", "OneHot", + "Identity", ]Also applies to: 121-124
tests/unit/onnx/test_quantize_int8.py (1)
33-45: Harden producer lookup to avoid graph-input edge cases.
node.i(inp_idx)can fail when an input is a graph input (no producer). Guard withinp.inputs.Apply this diff:
def _assert_nodes_quantization(nodes, should_be_quantized=True): for node in nodes: for inp_idx, inp in enumerate(node.inputs): if isinstance(inp, gs.Variable): - if should_be_quantized: - assert node.i(inp_idx).op == "DequantizeLinear", ( + producer = node.i(inp_idx) if inp.inputs else None + if should_be_quantized: + assert producer is not None and producer.op == "DequantizeLinear", ( f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!" ) else: - assert node.i(inp_idx).op != "DequantizeLinear", ( + assert producer is None or producer.op != "DequantizeLinear", ( f"Input '{inp.name}' of node '{node.name}' is quantized but should not be!" ) return True
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
modelopt/onnx/op_types.py(2 hunks)modelopt/onnx/quantization/graph_utils.py(6 hunks)modelopt/onnx/quantization/partitioning.py(2 hunks)tests/_test_utils/onnx_quantization/lib_test_models.py(1 hunks)tests/unit/onnx/test_quantize_int8.py(5 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/onnx/quantization/partitioning.py
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/onnx/quantization/graph_utils.py (2)
modelopt/onnx/op_types.py (3)
get_copy_ops(99-118)is_copy_op(121-123)is_linear_op(126-128)modelopt/onnx/utils.py (3)
get_child_nodes(625-628)get_parent_nodes(616-622)infer_shapes(723-736)
tests/_test_utils/onnx_quantization/lib_test_models.py (1)
modelopt/onnx/utils.py (1)
check_model(557-569)
tests/unit/onnx/test_quantize_int8.py (1)
tests/_test_utils/onnx_quantization/lib_test_models.py (1)
build_conv_act_pool_model(560-706)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (3)
modelopt/onnx/quantization/graph_utils.py (1)
206-207: Good: wildcarding through copy-ops enables Conv→Act→(copy)→Pool patterns.Using
wild_card_types=get_copy_ops()here is the right call to allow optional reshape/transpose between backbone ops.tests/unit/onnx/test_quantize_int8.py (1)
100-123: Nice coverage: Conv is quantized; MaxPool remains unquantized across optional Reshape.This directly validates the Conv‑Act‑(copy)‑Pool fusion behavior the PR targets.
tests/_test_utils/onnx_quantization/lib_test_models.py (1)
560-706: No issues found. The optional refactor suggestion is valid and safe to apply.The verification confirms:
- Single definition and call site (no shadowing)
- Test function contains no hardcoded shape assertions or dimension-dependent logic
- Proposed tensor shapes are mathematically consistent with the MaxPool stride=2 operation
- Refactor reduces model size and inference cost without affecting test correctness
The suggestion to reduce from (32, 64, 256, 256) to (1, 64, 16, 16) input shapes is sound and will improve CI performance without breaking the quantization tests.
60dc169 to
3db4dc8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
modelopt/onnx/quantization/graph_utils.py (1)
1060-1076: Dict calibration_shapes are silently ignored.The existing issue flagged in past reviews remains: when
calibration_shapesis a dict, it's not parsed (only string specs are handled at lines 1072-1076), causing dimensions to default to 1 and potentially leading to false positive GEMV exclusions.Consider applying the fix suggested in the past review:
- input_shapes = ( - parse_shapes_spec(calibration_shapes) - if (calibration_shapes and isinstance(calibration_shapes, str)) - else {} - ) + input_shapes = {} + if calibration_shapes: + if isinstance(calibration_shapes, str): + input_shapes = parse_shapes_spec(calibration_shapes) + elif isinstance(calibration_shapes, dict): + input_shapes = calibration_shapes
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
modelopt/onnx/op_types.py(2 hunks)modelopt/onnx/quantization/graph_utils.py(6 hunks)modelopt/onnx/quantization/partitioning.py(2 hunks)tests/_test_utils/onnx_quantization/lib_test_models.py(1 hunks)tests/unit/onnx/test_qdq_rules_int8.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
modelopt/onnx/quantization/partitioning.py (1)
modelopt/onnx/op_types.py (2)
is_copy_op(121-123)is_pointwise_or_elementwise_op(131-136)
tests/_test_utils/onnx_quantization/lib_test_models.py (1)
modelopt/onnx/utils.py (1)
check_model(557-569)
modelopt/onnx/quantization/graph_utils.py (2)
modelopt/onnx/op_types.py (1)
get_copy_ops(99-118)modelopt/onnx/utils.py (1)
infer_shapes(723-736)
tests/unit/onnx/test_qdq_rules_int8.py (3)
tests/_test_utils/onnx_quantization/lib_test_models.py (1)
build_conv_act_pool_model(560-706)modelopt/onnx/quantization/int8.py (1)
quantize(113-291)modelopt/onnx/quantization/quantize.py (1)
quantize(209-564)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (5)
modelopt/onnx/op_types.py (1)
99-123: LGTM! Clean refactoring that improves maintainability.The extraction of copy operators into a dedicated
get_copy_ops()helper is a solid refactoring that centralizes the definition and enables reuse across the codebase (e.g., ingraph_utils.pyandpartitioning.py).modelopt/onnx/quantization/partitioning.py (1)
133-146: LGTM! Conv-Act-Pool fusion support correctly implemented.The changes appropriately extend the fusible criteria to include:
- Copy ops as fusible consumers (lines 135-137)
- MaxPool as a fusible backbone operation (line 143)
This enables the Conv-Act-Pool pattern to be correctly fused during quantization, which aligns with the PR objectives.
modelopt/onnx/quantization/graph_utils.py (2)
32-32: LGTM! Import changes support the Conv-Act-Pool fusion feature.The addition of
get_copy_opsimport andinfer_shapesutility, along with their usage at line 206, correctly enables MaxPool and copy operators to be treated as wildcards in fusible backbone matching.Also applies to: 38-38, 206-206
1060-1060: Function rename is consistent.The rename from
_exclude_matmuls_by_symbolic_inferenceto_exclude_matmuls_by_shape_inferenceis applied consistently at the definition (line 1060) and call site (line 968), and aligns with the refactored shape inference approach.Also applies to: 968-968
tests/_test_utils/onnx_quantization/lib_test_models.py (1)
560-706: LGTM! Well-structured test model builder.The
build_conv_act_pool_modelfunction correctly constructs a Conv-Act-Pool graph with:
- Proper node sequencing (Conv → BatchNormalization → Relu → optional Reshape → MaxPool → Conv)
- Correctly sized initializers matching the declared dimensions
- Conditional Reshape node and its initializer based on the
include_reshape_nodeparameter- Shape inference and validation before returning
This provides solid test coverage for the Conv-Act-Pool fusion feature.
3db4dc8 to
7eae896
Compare
Signed-off-by: gcunhase <[email protected]>
Signed-off-by: gcunhase <[email protected]>
Signed-off-by: gcunhase <[email protected]>
Signed-off-by: gcunhase <[email protected]>
Signed-off-by: gcunhase <[email protected]>
Signed-off-by: gcunhase <[email protected]>
Signed-off-by: gcunhase <[email protected]>
Signed-off-by: gcunhase <[email protected]>
Signed-off-by: gcunhase <[email protected]>
7eae896 to
fd59bdf
Compare
Signed-off-by: gcunhase <[email protected]>
What does this PR do?
Type of change: Bug fix
Overview: QDQ node placement was breaking Conv-Act-Pool fusions. Fixed by adding this pattern to the list of fusible partitions.
Usage
$ python -m modelopt.onnx.quantization --onnx_path=$MODEL_NAME.onnxTesting
Added unittest.
Before your PR is "Ready for review"
Summary by CodeRabbit
New Features
Refactor
Tests