-
Notifications
You must be signed in to change notification settings - Fork 183
Cleanup mixed precision and gather node layer info mapping #434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cleanup mixed precision and gather node layer info mapping #434
Conversation
Signed-off-by: unknown <[email protected]>
WalkthroughRefactors ONNX quantization to use a per-weight layer_info mapping ({precision, block_size, axis}). Renames get_precision_info → get_layer_info, adds block_size/quantize_axis params, propagates layer_info through int4, qdq, and quant utilities, adds gather-specific defaults, and sets model.ir_version = 10 on export. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor U as Caller
participant G as graph_utils.get_layer_info
participant Q as int4.quantize_* (RTN/AWQ/AWQ-lite)
participant QU as quant_utils
participant QDQ as qdq_utils.insert_qdq_nodes
participant M as ONNX Model
U->>G: request layer_info (block_size, quantize_axis)\n(+ optional gather configs)
G-->>U: layer_info map or None
U->>Q: quantize(model, layer_info, ...)
Q->>QU: get_num_bits(name, layer_info)
QU-->>Q: precision per weight
Q->>QU: get_layer_block_size/name axis from layer_info
QU-->>Q: block_size/axis per weight
Q->>QU: reshape_scales_for_per_channel_nodes(layer_info)
QU-->>Q: adjusted scales for per-channel
Q->>QDQ: insert_qdq_nodes(model, layer_info)
QDQ->>QU: get_num_bits(name, layer_info)
QU-->>QDQ: precision
QDQ-->>Q: Q/DQ nodes inserted
Q->>M: set ir_version = 10
Q-->>U: quantized model
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #434 +/- ##
=======================================
Coverage 73.40% 73.40%
=======================================
Files 180 180
Lines 18047 18077 +30
=======================================
+ Hits 13247 13270 +23
- Misses 4800 4807 +7 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (5)
modelopt/onnx/quantization/graph_utils.py (2)
726-737: Docstring is stale: return type and params don’t match implementationget_layer_precision_mapping now returns a dict per-weight with keys precision, block_size, axis, and accepts block_size/quantize_axis, but the docstring still describes the old “precision only” mapping and omits new params. Please update for clarity.
Apply this diff to refresh the doc:
- """Generate a mapping of layer names to their quantization precision (4 bits or 8 bits) for an ONNX model. + """Generate a mapping of weight (initializer) names to their quantization configuration. @@ - precision_pattern_8bit (str, optional): Comma-separated string of layer patterns to quantize to 8 bits. - If None, a default set of patterns is used to select layers for 8 bits quantization. - nodes_to_exclude (list[str], optional): List of node name patterns to exclude from quantization. - Defaults to [r"/lm_head"]. + precision_pattern_8bit (str | None): Comma-separated list of layer patterns to quantize to 8 bits. + If None, a heuristic/default set of patterns is used. + nodes_to_exclude (list[str] | None): Node name patterns to exclude from quantization. + block_size (int): Default block size to use for 4-bit layers. + quantize_axis (int): Default quantization axis to use for 4-bit layers. @@ - dict: A mapping from layer names to their quantization precision (e.g., {"layer_name": "8"}). + dict[str, dict[str, int]]: Mapping from weight name to a dict with: + - precision (int): 4 or 8 + - block_size (int): -1 for per-channel, or positive block size + - axis (int): quantization axis
721-725: Avoid default mutable arguments for nodes_to_excludeUsing a list literal as a default can lead to subtle bugs. Prefer None and set the default inside.
-def get_layer_precision_mapping( - onnx_model: onnx.ModelProto, - precision_pattern_8bit: str | None = None, - nodes_to_exclude: list[str] | None = [r"/lm_head"], +def get_layer_precision_mapping( + onnx_model: onnx.ModelProto, + precision_pattern_8bit: str | None = None, + nodes_to_exclude: list[str] | None = None, block_size: int = 128, quantize_axis: int = 0, ): @@ - nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) + nodes_to_exclude = nodes_to_exclude or [r"/lm_head"] + nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)-def get_layer_info( - onnx_model: onnx.ModelProto, - nodes_to_exclude: list[str] | None = [r"/lm_head"], +def get_layer_info( + onnx_model: onnx.ModelProto, + nodes_to_exclude: list[str] | None = None, block_size: int = 128, quantize_axis: int = 0, **kwargs: Any, ): @@ - if enable_mixed_quant: + nodes_to_exclude = nodes_to_exclude or [r"/lm_head"] + if enable_mixed_quant: layer_info = get_layer_precision_mapping( onnx_model, layers_8bit, nodes_to_exclude, block_size, quantize_axis, )Also applies to: 837-841
modelopt/onnx/quantization/qdq_utils.py (1)
328-338: Hardcoded axis=1 for per‑channel might be too rigidFor INT8/per‑channel you force axis=1 and drop block_size. That’s fine for GEMM/MatMul, but could be wrong for other ops (e.g., Conv). Consider honoring an existing axis in attrs (if provided) or deriving it per op.
modelopt/onnx/quantization/quant_utils.py (2)
169-180: update_block_size docstring out of sync with signatureDoc still mentions num_bits and doesn’t describe new behavior. Please update.
- Args: - num_bits (int): Number of bits for quantization. - layer_info (dict[str, dict] | None): Optional dictionary mapping tensor names - to layer configuration dict. - name (str | None): Name of the tensor. - block_size (int): Current block size. - quantize_axis (int): Axis along which to quantize. - w (np.ndarray): Weight tensor to be quantized. + Args: + block_size (int): Base block size. + layer_info (dict[str, dict] | None): Optional mapping from tensor name to config dict. + name (str | None): Tensor name (used to look up layer_info). + quantize_axis (int): Axis along which to quantize. + w (np.ndarray): Weight tensor (required when block_size == -1 to expand per‑channel). @@ - int: Updated block size. + int: Updated block size (if -1, expanded to size along quantize_axis).
288-296: reshape_scales_for_per_channel_nodes doc: add block_size descriptionThe function takes block_size but the doc omits it. Minor clarity fix.
- Args: - scales_map (dict[str, np.ndarray]): Dictionary mapping weight names to scale arrays. - layer_info (dict[str, dict] | None): Optional dictionary mapping tensor names - to layer configuration dict. + Args: + scales_map (dict[str, np.ndarray]): Map of weight name -> scale array. + block_size (int): Default block size used for 4‑bit layers. + layer_info (dict[str, dict] | None): Optional per‑layer configuration dict.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
modelopt/onnx/quantization/graph_utils.py(4 hunks)modelopt/onnx/quantization/int4.py(20 hunks)modelopt/onnx/quantization/qdq_utils.py(5 hunks)modelopt/onnx/quantization/quant_utils.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/onnx/quantization/qdq_utils.py (1)
modelopt/onnx/quantization/quant_utils.py (1)
get_num_bits(189-204)
modelopt/onnx/quantization/quant_utils.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
num_bits(183-185)num_bits(188-190)
modelopt/onnx/quantization/int4.py (3)
modelopt/onnx/quantization/graph_utils.py (1)
get_layer_info(835-889)modelopt/onnx/quantization/quant_utils.py (4)
get_num_bits(189-204)update_block_size(160-186)reshape_scales_for_per_channel_nodes(282-303)rtn(356-389)modelopt/onnx/quantization/qdq_utils.py (2)
insert_dq_nodes(349-421)insert_qdq_nodes(424-480)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
🔇 Additional comments (1)
modelopt/onnx/quantization/graph_utils.py (1)
815-833: Layer info mapping looks goodPer-weight dict with precision, block_size and axis is consistent with downstream usage.
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: ynankani-nv <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/onnx/quantization/int4.py (1)
123-181: Critical:block_sizeundefined at line 181.Line 181 references
block_size, but this variable is only defined inside the loop at line 126. After the loop completes,block_sizewould hold the value from the last iteration, which is incorrect when callingreshape_scales_for_per_channel_nodeswith a mixture of gather nodes that may have different block sizes.Solution 1 (recommended): Pass
gather_block_sizeas a parameter to_quantize_gather_nodesand use it as the default fallback:def _quantize_gather_nodes( graph: onnx.GraphProto, nodes_to_exclude: list[str], use_zero_point: bool, dq_only: bool, + gather_block_size: int, layer_info: dict[str, dict] | None, ):Then at line 181:
- scales_map = reshape_scales_for_per_channel_nodes(scales_map, block_size, layer_info) + scales_map = reshape_scales_for_per_channel_nodes(scales_map, gather_block_size, layer_info)Solution 2 (alternative): Use the default constant directly:
- scales_map = reshape_scales_for_per_channel_nodes(scales_map, block_size, layer_info) + scales_map = reshape_scales_for_per_channel_nodes(scales_map, DEFAULT_GATHER_BLOCK_SIZE, layer_info)Note: You'll also need to update all callers of
_quantize_gather_nodes(lines 253, 571, 1274) to pass thegather_block_sizeparameter if using Solution 1.
🧹 Nitpick comments (1)
modelopt/onnx/quantization/int4.py (1)
312-312: Consider adding explanatory comment for consistency.For consistency with lines 605 and 1316, consider adding a comment explaining why
ir_version = 10is set:+ # Set ir_version to 10, remove it once ORT supports ir_version 11 model.ir_version = 10This helps maintainers understand the temporary nature of this workaround.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/onnx/quantization/int4.py(20 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/int4.py (3)
modelopt/onnx/quantization/graph_utils.py (1)
get_layer_info(835-889)modelopt/onnx/quantization/quant_utils.py (3)
get_num_bits(189-204)update_block_size(160-186)reshape_scales_for_per_channel_nodes(282-303)modelopt/onnx/quantization/qdq_utils.py (2)
insert_dq_nodes(349-421)insert_qdq_nodes(424-480)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (5)
modelopt/onnx/quantization/int4.py (5)
43-43: LGTM: Import updated correctly.The import of
get_layer_infoaligns with the refactoring to use per-layer configuration.
229-313: LGTM: layer_info propagated correctly through quantize_rtn.The function properly retrieves
layer_infoearly and passes it through the quantization pipeline. The setting ofmodel.ir_version = 10at lines 312-313 ensures compatibility.
492-597: LGTM: layer_info propagated correctly through _quantize_awq_clip.The function properly retrieves
layer_infoearly at line 492 and passes it consistently through all quantization steps, including AWQ clip search, weight quantization, and DQ node insertion.
694-738: LGTM: layer_info used correctly in scale search.The function properly accepts
layer_infoas a parameter and uses it to determine per-layer configuration for bit-width and block size.
985-1305: LGTM: Past review addressed, layer_info propagated correctly.Line 985 now correctly passes
block_sizetoget_layer_info, addressing the previous review comment. The function consistently propagateslayer_infothrough all quantization paths including AWQ-lite scale search, weight quantization, and DQ node insertion.Based on learnings
| onnx_model: onnx.ModelProto, | ||
| precision_pattern_8bit: str | None = None, | ||
| nodes_to_exclude: list[str] | None = [r"/lm_head"], | ||
| block_size: int = 128, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use default 32 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default block_size is set to 128 in int4.py if block_size is passed as none.
…_nodes_quantization Signed-off-by: ynankani-nv <[email protected]>
What does this PR do?
Cleanup mixed precision and gather node layer info mapping
Type of change: ? Cleanup
Overview: ?
Cleanup mixed precision and gather node layer info mapping
Usage
Testing
Tested on combination of few model, quantization algo and gather node quantize or not
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Refactor
Chores