Skip to content

Conversation

@ynankani
Copy link
Contributor

@ynankani ynankani commented Oct 14, 2025

What does this PR do?

Cleanup mixed precision and gather node layer info mapping

Type of change: ? Cleanup

Overview: ?
Cleanup mixed precision and gather node layer info mapping

Usage

python quantize.py --model_name=meta-llama/Llama-3.1-8B-Instruct --onnx_path=G:\llama-3-8b-instruct\llama3.1-8B-genai-cuda-fp16\model.onnx --output_path=G:\llama3.1-8B-genai-cuda-rtn_dq_gather__before_refactor_testing\model.onnx  --algo=rtn_dq --gather_quantize_axis=1 --gather_block_size=32

Testing

Tested on combination of few model, quantization algo and gather node quantize or not

  1. LLama-3.1-8b-instruct => awq_lite, rtn_dq, gather_node Quantize

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: NA
  • Did you add or update any necessary documentation?: NA
  • Did you update Changelog?: NA

Additional Information

Summary by CodeRabbit

  • New Features

    • Per-layer quantization configuration now exposes precision, block size, and axis, with explicit defaults for 4-bit, 8-bit, and Gather quantization.
  • Refactor

    • Unified layer-level configuration replaces the legacy precision mapping and is propagated across mixed-precision and quantization flows for consistent per-weight behavior.
  • Chores

    • Quantized ONNX models are now exported with an updated IR version.

@ynankani ynankani requested a review from a team as a code owner October 14, 2025 16:40
@ynankani ynankani requested a review from gcunhase October 14, 2025 16:40
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 14, 2025

Walkthrough

Refactors ONNX quantization to use a per-weight layer_info mapping ({precision, block_size, axis}). Renames get_precision_info → get_layer_info, adds block_size/quantize_axis params, propagates layer_info through int4, qdq, and quant utilities, adds gather-specific defaults, and sets model.ir_version = 10 on export.

Changes

Cohort / File(s) Summary
Graph utils: per-layer config & API
modelopt/onnx/quantization/graph_utils.py
Adds layer_info mapping {precision, block_size, axis} and gather defaults. Renames get_precision_infoget_layer_info, updates get_layer_precision_mapping signature to accept block_size and quantize_axis, and returns richer per-weight configs.
Quantization core: int4 pipeline updates
modelopt/onnx/quantization/int4.py
Replaces precision_info with layer_info across RTN, AWQ, AWQ-lite flows. Per-weight num_bits, axis, and block_size are read from layer_info; gather quantization derives settings per input; DQ insertions and scale reshaping use layer_info. Sets model.ir_version = 10 before return.
QDQ utilities: API and usage switch
modelopt/onnx/quantization/qdq_utils.py
insert_dq_nodes / insert_qdq_nodes updated to accept layer_info. Internal logic uses get_num_bits(layer_info, name) and related layer_info-driven behavior. Docstrings updated.
Quant utilities: helpers & per-layer behavior
modelopt/onnx/quantization/quant_utils.py
get_num_bits now reads layer_info[name]["precision"]; adds get_layer_block_size and get_layer_axis. update_block_size and reshape_scales_for_per_channel_nodes derive block_size/axis from layer_info (block_size == -1 treated as per-channel sentinel). Signatures and docs updated.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor U as Caller
  participant G as graph_utils.get_layer_info
  participant Q as int4.quantize_* (RTN/AWQ/AWQ-lite)
  participant QU as quant_utils
  participant QDQ as qdq_utils.insert_qdq_nodes
  participant M as ONNX Model

  U->>G: request layer_info (block_size, quantize_axis)\n(+ optional gather configs)
  G-->>U: layer_info map or None

  U->>Q: quantize(model, layer_info, ...)
  Q->>QU: get_num_bits(name, layer_info)
  QU-->>Q: precision per weight
  Q->>QU: get_layer_block_size/name axis from layer_info
  QU-->>Q: block_size/axis per weight
  Q->>QU: reshape_scales_for_per_channel_nodes(layer_info)
  QU-->>Q: adjusted scales for per-channel

  Q->>QDQ: insert_qdq_nodes(model, layer_info)
  QDQ->>QU: get_num_bits(name, layer_info)
  QU-->>QDQ: precision
  QDQ-->>Q: Q/DQ nodes inserted

  Q->>M: set ir_version = 10
  Q-->>U: quantized model
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

Thump-thump, my paws map bits anew,
Each weight a secret—axis, block, and hue.
Gather hops find 4-bit cheer,
DQ blossoms—models steer.
A carrot-coded tweak, and off we chew! 🥕

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title “Cleanup mixed precision and gather node layer info mapping” accurately summarizes the main refactoring effort in the pull request, which focuses on cleaning up mixed‐precision handling and the mapping of layer information for gather nodes. It is concise and specific, clearly reflecting the scope without introducing extraneous details or vague terms. This phrasing will help reviewers quickly understand the primary change in the quantization utilities.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ynankani/refactor_mixed_precision_and_gather_nodes_quantization

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Oct 14, 2025

Codecov Report

❌ Patch coverage is 78.26087% with 15 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.40%. Comparing base (99c76ff) to head (99919d3).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/onnx/quantization/quant_utils.py 68.42% 6 Missing ⚠️
modelopt/onnx/quantization/graph_utils.py 73.68% 5 Missing ⚠️
modelopt/onnx/quantization/int4.py 86.66% 4 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main     #434   +/-   ##
=======================================
  Coverage   73.40%   73.40%           
=======================================
  Files         180      180           
  Lines       18047    18077   +30     
=======================================
+ Hits        13247    13270   +23     
- Misses       4800     4807    +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (5)
modelopt/onnx/quantization/graph_utils.py (2)

726-737: Docstring is stale: return type and params don’t match implementation

get_layer_precision_mapping now returns a dict per-weight with keys precision, block_size, axis, and accepts block_size/quantize_axis, but the docstring still describes the old “precision only” mapping and omits new params. Please update for clarity.

Apply this diff to refresh the doc:

-    """Generate a mapping of layer names to their quantization precision (4 bits or 8 bits) for an ONNX model.
+    """Generate a mapping of weight (initializer) names to their quantization configuration.
@@
-        precision_pattern_8bit (str, optional): Comma-separated string of layer patterns to quantize to 8 bits.
-            If None, a default set of patterns is used to select layers for 8 bits quantization.
-        nodes_to_exclude (list[str], optional): List of node name patterns to exclude from quantization.
-            Defaults to [r"/lm_head"].
+        precision_pattern_8bit (str | None): Comma-separated list of layer patterns to quantize to 8 bits.
+            If None, a heuristic/default set of patterns is used.
+        nodes_to_exclude (list[str] | None): Node name patterns to exclude from quantization.
+        block_size (int): Default block size to use for 4-bit layers.
+        quantize_axis (int): Default quantization axis to use for 4-bit layers.
@@
-        dict: A mapping from layer names to their quantization precision (e.g., {"layer_name": "8"}).
+        dict[str, dict[str, int]]: Mapping from weight name to a dict with:
+            - precision (int): 4 or 8
+            - block_size (int): -1 for per-channel, or positive block size
+            - axis (int): quantization axis

721-725: Avoid default mutable arguments for nodes_to_exclude

Using a list literal as a default can lead to subtle bugs. Prefer None and set the default inside.

-def get_layer_precision_mapping(
-    onnx_model: onnx.ModelProto,
-    precision_pattern_8bit: str | None = None,
-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+def get_layer_precision_mapping(
+    onnx_model: onnx.ModelProto,
+    precision_pattern_8bit: str | None = None,
+    nodes_to_exclude: list[str] | None = None,
     block_size: int = 128,
     quantize_axis: int = 0,
 ):
@@
-    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+    nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
-def get_layer_info(
-    onnx_model: onnx.ModelProto,
-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+def get_layer_info(
+    onnx_model: onnx.ModelProto,
+    nodes_to_exclude: list[str] | None = None,
     block_size: int = 128,
     quantize_axis: int = 0,
     **kwargs: Any,
 ):
@@
-    if enable_mixed_quant:
+    nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+    if enable_mixed_quant:
         layer_info = get_layer_precision_mapping(
             onnx_model,
             layers_8bit,
             nodes_to_exclude,
             block_size,
             quantize_axis,
         )

Also applies to: 837-841

modelopt/onnx/quantization/qdq_utils.py (1)

328-338: Hardcoded axis=1 for per‑channel might be too rigid

For INT8/per‑channel you force axis=1 and drop block_size. That’s fine for GEMM/MatMul, but could be wrong for other ops (e.g., Conv). Consider honoring an existing axis in attrs (if provided) or deriving it per op.

modelopt/onnx/quantization/quant_utils.py (2)

169-180: update_block_size docstring out of sync with signature

Doc still mentions num_bits and doesn’t describe new behavior. Please update.

-    Args:
-        num_bits (int): Number of bits for quantization.
-        layer_info (dict[str, dict] | None): Optional dictionary mapping tensor names
-            to layer configuration dict.
-        name (str | None): Name of the tensor.
-        block_size (int): Current block size.
-        quantize_axis (int): Axis along which to quantize.
-        w (np.ndarray): Weight tensor to be quantized.
+    Args:
+        block_size (int): Base block size.
+        layer_info (dict[str, dict] | None): Optional mapping from tensor name to config dict.
+        name (str | None): Tensor name (used to look up layer_info).
+        quantize_axis (int): Axis along which to quantize.
+        w (np.ndarray): Weight tensor (required when block_size == -1 to expand per‑channel).
@@
-        int: Updated block size.
+        int: Updated block size (if -1, expanded to size along quantize_axis).

288-296: reshape_scales_for_per_channel_nodes doc: add block_size description

The function takes block_size but the doc omits it. Minor clarity fix.

-    Args:
-        scales_map (dict[str, np.ndarray]): Dictionary mapping weight names to scale arrays.
-        layer_info (dict[str, dict] | None): Optional dictionary mapping tensor names
-            to layer configuration dict.
+    Args:
+        scales_map (dict[str, np.ndarray]): Map of weight name -> scale array.
+        block_size (int): Default block size used for 4‑bit layers.
+        layer_info (dict[str, dict] | None): Optional per‑layer configuration dict.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 35f90d0 and a5d2dbe.

📒 Files selected for processing (4)
  • modelopt/onnx/quantization/graph_utils.py (4 hunks)
  • modelopt/onnx/quantization/int4.py (20 hunks)
  • modelopt/onnx/quantization/qdq_utils.py (5 hunks)
  • modelopt/onnx/quantization/quant_utils.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/onnx/quantization/qdq_utils.py (1)
modelopt/onnx/quantization/quant_utils.py (1)
  • get_num_bits (189-204)
modelopt/onnx/quantization/quant_utils.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • num_bits (183-185)
  • num_bits (188-190)
modelopt/onnx/quantization/int4.py (3)
modelopt/onnx/quantization/graph_utils.py (1)
  • get_layer_info (835-889)
modelopt/onnx/quantization/quant_utils.py (4)
  • get_num_bits (189-204)
  • update_block_size (160-186)
  • reshape_scales_for_per_channel_nodes (282-303)
  • rtn (356-389)
modelopt/onnx/quantization/qdq_utils.py (2)
  • insert_dq_nodes (349-421)
  • insert_qdq_nodes (424-480)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
🔇 Additional comments (1)
modelopt/onnx/quantization/graph_utils.py (1)

815-833: Layer info mapping looks good

Per-weight dict with precision, block_size and axis is consistent with downstream usage.

@gcunhase gcunhase requested a review from ajrasane October 14, 2025 16:56
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: ynankani-nv <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/onnx/quantization/int4.py (1)

123-181: Critical: block_size undefined at line 181.

Line 181 references block_size, but this variable is only defined inside the loop at line 126. After the loop completes, block_size would hold the value from the last iteration, which is incorrect when calling reshape_scales_for_per_channel_nodes with a mixture of gather nodes that may have different block sizes.

Solution 1 (recommended): Pass gather_block_size as a parameter to _quantize_gather_nodes and use it as the default fallback:

 def _quantize_gather_nodes(
     graph: onnx.GraphProto,
     nodes_to_exclude: list[str],
     use_zero_point: bool,
     dq_only: bool,
+    gather_block_size: int,
     layer_info: dict[str, dict] | None,
 ):

Then at line 181:

-    scales_map = reshape_scales_for_per_channel_nodes(scales_map, block_size, layer_info)
+    scales_map = reshape_scales_for_per_channel_nodes(scales_map, gather_block_size, layer_info)

Solution 2 (alternative): Use the default constant directly:

-    scales_map = reshape_scales_for_per_channel_nodes(scales_map, block_size, layer_info)
+    scales_map = reshape_scales_for_per_channel_nodes(scales_map, DEFAULT_GATHER_BLOCK_SIZE, layer_info)

Note: You'll also need to update all callers of _quantize_gather_nodes (lines 253, 571, 1274) to pass the gather_block_size parameter if using Solution 1.

🧹 Nitpick comments (1)
modelopt/onnx/quantization/int4.py (1)

312-312: Consider adding explanatory comment for consistency.

For consistency with lines 605 and 1316, consider adding a comment explaining why ir_version = 10 is set:

+    # Set ir_version to 10, remove it once ORT supports ir_version 11
     model.ir_version = 10

This helps maintainers understand the temporary nature of this workaround.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a5d2dbe and 1fd7973.

📒 Files selected for processing (1)
  • modelopt/onnx/quantization/int4.py (20 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/int4.py (3)
modelopt/onnx/quantization/graph_utils.py (1)
  • get_layer_info (835-889)
modelopt/onnx/quantization/quant_utils.py (3)
  • get_num_bits (189-204)
  • update_block_size (160-186)
  • reshape_scales_for_per_channel_nodes (282-303)
modelopt/onnx/quantization/qdq_utils.py (2)
  • insert_dq_nodes (349-421)
  • insert_qdq_nodes (424-480)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (5)
modelopt/onnx/quantization/int4.py (5)

43-43: LGTM: Import updated correctly.

The import of get_layer_info aligns with the refactoring to use per-layer configuration.


229-313: LGTM: layer_info propagated correctly through quantize_rtn.

The function properly retrieves layer_info early and passes it through the quantization pipeline. The setting of model.ir_version = 10 at lines 312-313 ensures compatibility.


492-597: LGTM: layer_info propagated correctly through _quantize_awq_clip.

The function properly retrieves layer_info early at line 492 and passes it consistently through all quantization steps, including AWQ clip search, weight quantization, and DQ node insertion.


694-738: LGTM: layer_info used correctly in scale search.

The function properly accepts layer_info as a parameter and uses it to determine per-layer configuration for bit-width and block size.


985-1305: LGTM: Past review addressed, layer_info propagated correctly.

Line 985 now correctly passes block_size to get_layer_info, addressing the previous review comment. The function consistently propagates layer_info through all quantization paths including AWQ-lite scale search, weight quantization, and DQ node insertion.

Based on learnings

onnx_model: onnx.ModelProto,
precision_pattern_8bit: str | None = None,
nodes_to_exclude: list[str] | None = [r"/lm_head"],
block_size: int = 128,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use default 32 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default block_size is set to 128 in int4.py if block_size is passed as none.

@ynankani ynankani merged commit f5c209d into main Oct 24, 2025
26 checks passed
@ynankani ynankani deleted the ynankani/refactor_mixed_precision_and_gather_nodes_quantization branch October 24, 2025 06:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants