-
Notifications
You must be signed in to change notification settings - Fork 183
ONNX 1.19 compatibility fix for INT4 quantization #423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughExplicitly cast packed INT4/UINT4 values to int8/uint8 during ONNX export, add a CuPy-safe Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller as Caller
participant Quant as int4.quantize_rtn
participant Analyzer as Graph Analyzer
participant Inserter as DQ Inserter
participant Exporter as ONNX Export
Caller->>Quant: quantize_rtn(config)
Quant->>Analyzer: find quantizable paths (weights, Gather)
alt DQ-only / DQ path
Analyzer-->>Quant: provide axis, block_size
Quant->>Inserter: insert DQ nodes with {axis, block_size}
else no DQ
Quant-->>Quant: skip DQ insertion
end
Quant->>Exporter: export graph to ONNX
Exporter-->>Quant: model
Quant->>Quant: set model.ir_version = 10
Quant-->>Caller: return model
sequenceDiagram
autonumber
participant Exporter as _export_tensor_proto
participant Packer as 4-bit packer
participant Tensor as ONNX Tensor
Exporter->>Packer: obtain packed 4-bit vals
alt INT4 (signed)
Packer-->>Exporter: packed vals
Exporter-->>Exporter: cast vals -> int8
else UINT4 (unsigned)
Packer-->>Exporter: packed vals
Exporter-->>Exporter: cast vals -> uint8
end
Exporter->>Tensor: create ONNX tensor from vals
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
2abbcad to
f6734ec
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
modelopt/onnx/quantization/int4.py (1)
102-123: Handle ml_dtypes.uint4 too and add light typing.Good helper. To cover UINT4 and improve readability, consider handling ml_dtypes.uint4 and adding minimal type hints.
-def safe_cupy_array(tensor): +from typing import Any + +def safe_cupy_array(tensor: Any): @@ - try: - import ml_dtypes - - if hasattr(tensor, "dtype") and tensor.dtype == ml_dtypes.int4: - return np.asarray(tensor.astype(numpy.int8)) + try: + import ml_dtypes + if hasattr(tensor, "dtype"): + # INT4 path + if tensor.dtype == ml_dtypes.int4: + return np.asarray(numpy.asarray(tensor, dtype=numpy.int8)) + # UINT4 path (if available in ml_dtypes) + if getattr(ml_dtypes, "uint4", None) is not None and tensor.dtype == ml_dtypes.uint4: + return np.asarray(numpy.asarray(tensor, dtype=numpy.uint8)) except ImportError: pass @@ - return np.asarray(tensor) + return np.asarray(tensor)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
modelopt/onnx/quantization/gs_patching.py(1 hunks)modelopt/onnx/quantization/int4.py(3 hunks)tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt/onnx/quantization/gs_patching.py
- tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/int4.py (1)
modelopt/onnx/quantization/qdq_utils.py (1)
insert_dq_nodes(349-419)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (3)
modelopt/onnx/quantization/int4.py (3)
297-305: Passing DQ node attributes is correct and aligns with insert_dq_nodes.Setting axis=0 and block_size for Gemm/MatMul is appropriate. insert_dq_nodes consumes attributes and propagates per-channel adjustments. Based on relevant code in modelopt/onnx/quantization/qdq_utils.py.
306-319: Gather DQ attributes look right.axis=gather_quantize_axis and block_size=gather_block_size ensure correct DQ semantics for per-axis/block gather paths.
333-335: Setting IR version to 10 for RTN export improves ORT compatibility.Matches AWQ paths; helps avoid ORT incompatibility with newer IRs. Please ensure no ops require IR > 10 in RTN graphs.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #423 +/- ##
==========================================
- Coverage 73.39% 73.37% -0.02%
==========================================
Files 180 180
Lines 17976 17986 +10
==========================================
+ Hits 13194 13198 +4
- Misses 4782 4788 +6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
f6734ec to
e453cc6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
modelopt/onnx/quantization/gs_patching.py(1 hunks)modelopt/onnx/quantization/int4.py(3 hunks)tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt/onnx/quantization/gs_patching.py
- tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/int4.py (1)
modelopt/onnx/quantization/qdq_utils.py (1)
insert_dq_nodes(349-419)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (2)
modelopt/onnx/quantization/int4.py (2)
297-318: LGTM! DQ node attributes properly added for RTN quantization.The addition of explicit
axisandblock_sizeattributes for both weight DQ nodes (line 297) and gather DQ nodes (lines 308-311) aligns with the PR objective to handle 2D scales produced by block-wise quantization. This change is consistent with the existing pattern in AWQ code paths (lines 604, 1307).
332-335: LGTM! IR version setting ensures ONNX Runtime compatibility.Setting
ir_version = 10after export makes the RTN quantization path consistent with AWQ paths (lines 629, 1340) and addresses the PR objective to ensure ONNX Runtime compatibility. The explicit version override is appropriate until ORT supports IR version 11.
e453cc6 to
22bdfb8
Compare
22bdfb8 to
f6734ec
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
modelopt/onnx/quantization/int4.py (1)
102-123: Broaden helper to uint4 and wire it into weight-loading sites
- Currently converts only ml_dtypes.int4; CuPy will still fail on ml_dtypes.uint4. Extend to handle both.
- The helper is defined but not used; replace np.asarray(...) on ONNX-loaded weights with safe_cupy_array(...) to avoid CuPy dtype errors. This was raised earlier.
Apply within this function:
def safe_cupy_array(tensor): """Convert ml_dtypes.int4 tensor to numpy.int8 for CuPy compatibility. In ONNX 1.19, int4 tensors use ml_dtypes.int4 which CuPy doesn't support. This function converts them to regular numpy.int8 while preserving values. Args: tensor: numpy array that may have ml_dtypes.int4 dtype Returns: - cupy or numpy array (if cupy is not supported) with numpy.int8 dtype if input was ml_dtypes.int4, + cupy or numpy array (if cupy is not supported) with numpy.int8/uint8 dtype if input was ml_dtypes.int4/uint4, otherwise unchanged """ try: import ml_dtypes - if hasattr(tensor, "dtype") and tensor.dtype == ml_dtypes.int4: - return np.asarray(tensor.astype(numpy.int8)) + if hasattr(tensor, "dtype"): + if tensor.dtype == ml_dtypes.int4: + return np.asarray(tensor.astype(numpy.int8)) + if hasattr(ml_dtypes, "uint4") and tensor.dtype == ml_dtypes.uint4: + return np.asarray(tensor.astype(numpy.uint8)) except ImportError: pass return np.asarray(tensor)To find candidate callsites to switch from np.asarray(...) to safe_cupy_array(...):
#!/bin/bash # Locate places where ONNX tensors are loaded then wrapped with np.asarray rg -n -A6 -B2 'numpy_helper\.to_array' modelopt/onnx/quantization/int4.py | rg -n 'numpy_helper\.to_array|np\.asarray\(' # Also list all np.asarray(...) in this file for manual audit rg -n 'np\.asarray\(' modelopt/onnx/quantization/int4.py
🧹 Nitpick comments (1)
modelopt/onnx/quantization/gs_patching.py (1)
73-77: Minor: remove duplication in INT4/UINT4 packing castSame call repeated; simplify to one pack + conditional astype. Behavior unchanged.
- if signed: - vals = pack_float32_to_4bit_cpp_based(tensor.values, signed=signed).astype(np.int8) - else: - vals = pack_float32_to_4bit_cpp_based(tensor.values, signed=signed).astype(np.uint8) + packed = pack_float32_to_4bit_cpp_based(tensor.values, signed=signed) + vals = packed.astype(np.int8 if signed else np.uint8)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
modelopt/onnx/quantization/gs_patching.py(1 hunks)modelopt/onnx/quantization/int4.py(3 hunks)tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (2)
modelopt/onnx/quantization/int4.py (2)
332-335: IR version pinned to 10Setting model.ir_version = 10 improves ORT compatibility. LGTM.
297-317: Approve explicit DQ attributes for RTN and Gather paths: insert_dq_nodes signature supportsattributesandprecision_info.
bf39fc5 to
29d1dff
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)
54-56: Add assertions to verify value preservation during conversion.The test verifies type/dtype correctness but doesn't confirm that data values are preserved during conversion. Consider adding assertions to check that the converted arrays contain the expected values.
For example:
tensor = np.array([1, 2, 3, 4], dtype=np.int8) result = int4.safe_cupy_array(tensor) assert isinstance(result, np.ndarray) # Should return numpy array + assert np.array_equal(result, [1, 2, 3, 4])Apply similar assertions at lines 79-83 and 89-92 to ensure the conversion logic correctly preserves data values across all code paths.
Also applies to: 79-83, 89-92
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/onnx/quantization/gs_patching.py(1 hunks)tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/onnx/quantization/gs_patching.py
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (3)
tests/_test_utils/import_helper.py (1)
skip_if_no_libcudnn(37-43)modelopt/onnx/quantization/int4.py (1)
safe_cupy_array(102-122)modelopt/onnx/quantization/quant_utils.py (1)
dq_tensor(328-344)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)
148-149: LGTM! Clean integration of safe_cupy_array helper.The changes correctly wrap array conversions with
safe_cupy_array, which handles the ONNX 1.19 ml_dtypes.int4 type that CuPy doesn't support. The conditional application only whenhas_cupyis True is appropriate.Also applies to: 159-160
8bfb51b to
d767c33
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)
42-92: Enhance test with value verification.The test comprehensively covers all code paths (ImportError, ml_dtypes.int4, and normal numpy), but it only validates types without verifying that the actual values are preserved during conversion.
Consider adding assertions to verify the values are preserved:
mock_tensor = MockInt4Tensor(np.array([1, 2, 3, 4], dtype=np.int8)) result = int4.safe_cupy_array(mock_tensor) assert isinstance(result, np.ndarray) assert result.dtype == np.int8 +assert np.array_equal(result, [1, 2, 3, 4])Similarly for Test 1 and Test 3:
tensor = np.array([1, 2, 3, 4], dtype=np.int8) result = int4.safe_cupy_array(tensor) assert isinstance(result, np.ndarray) # Should return numpy array +assert np.array_equal(result, [1, 2, 3, 4])
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/onnx/quantization/gs_patching.py(1 hunks)tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (3)
tests/_test_utils/import_helper.py (1)
skip_if_no_libcudnn(37-43)modelopt/onnx/quantization/int4.py (1)
safe_cupy_array(102-122)modelopt/onnx/quantization/quant_utils.py (1)
dq_tensor(328-344)
modelopt/onnx/quantization/gs_patching.py (2)
modelopt/onnx/quantization/quant_utils.py (1)
pack_float32_to_4bit_cpp_based(69-123)tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)
astype(73-74)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
modelopt/onnx/quantization/gs_patching.py (1)
73-74: LGTM! Explicit dtype casting ensures ONNX 1.19 compatibility.The explicit cast to
int8for signed anduint8for unsigned INT4 types ensures correct dtype handling for ONNX export, regardless of the underlying implementation details ofpack_float32_to_4bit_cpp_based.tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)
147-148: LGTM! Safe CuPy array wrapper appropriately applied.The
safe_cupy_arraycalls correctly wrap tensor access when CuPy is available, ensuring compatibility with ONNX 1.19's ml_dtypes.int4 tensor type.Also applies to: 158-159
Signed-off-by: Hrishith Thadicherla <[email protected]>
Signed-off-by: Hrishith Thadicherla <[email protected]>
Signed-off-by: Hrishith Thadicherla <[email protected]>
Signed-off-by: Hrishith Thadicherla <[email protected]>
Signed-off-by: Hrishith Thadicherla <[email protected]>
…time and added axis, block size attributes to dq node Signed-off-by: Hrishith Thadicherla <[email protected]>
…time and added axis, block size attributes to dq node Signed-off-by: Hrishith Thadicherla <[email protected]>
Signed-off-by: Hrishith Thadicherla <[email protected]>
d767c33 to
1635bab
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
modelopt/onnx/quantization/gs_patching.py(1 hunks)modelopt/onnx/quantization/int4.py(3 hunks)tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/onnx/quantization/int4.py (1)
modelopt/onnx/quantization/qdq_utils.py (1)
insert_dq_nodes(349-419)
modelopt/onnx/quantization/gs_patching.py (1)
modelopt/onnx/quantization/quant_utils.py (1)
pack_float32_to_4bit_cpp_based(69-123)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (4)
tests/_test_utils/import_helper.py (1)
skip_if_no_libcudnn(37-43)modelopt/onnx/quantization/int4.py (1)
safe_cupy_array(102-122)modelopt/onnx/quantization/quant_utils.py (1)
dq_tensor(328-344)tests/_test_utils/onnx_quantization/lib_test_models.py (1)
find_init(235-243)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (5)
modelopt/onnx/quantization/gs_patching.py (1)
73-74: LGTM! Explicit cast ensures CuPy compatibility.The explicit cast to
int8for signed INT4 anduint8for unsigned UINT4 ensures that packed 4-bit values are in a dtype that CuPy supports, avoiding issues withml_dtypes.int4introduced in ONNX 1.19.tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (2)
42-72: Comprehensive test coverage forsafe_cupy_array.The test effectively covers all code paths in the
safe_cupy_arrayfunction:
- Regular NumPy arrays (baseline case)
ml_dtypes.int4conversion whenml_dtypesis available- Fallback behavior when
ml_dtypesimport fails (using monkeypatch)The use of
.get()on line 57 correctly handles the CuPy-to-NumPy conversion for array comparison.
127-128: Correct usage ofsafe_cupy_arrayfor INT4 weights and scales.Wrapping the quantized weights and scales with
safe_cupy_arraywhenint4.has_cupyisTrueensures that anyml_dtypes.int4tensors are converted tonumpy.int8before CuPy operations, maintaining compatibility with ONNX 1.19.Also applies to: 138-139
modelopt/onnx/quantization/int4.py (2)
297-317: DQ node attributes correctly propagated.The explicit addition of
dq_node_attributesandgather_dq_node_attributescontainingaxisandblock_sizeensures that DequantizeLinear nodes are created with the necessary attributes for block-wise quantization, which is required for ONNX 1.19 compatibility and proper dequantization.
332-335: IR version forced to 10 for ONNX Runtime compatibility.Setting
model.ir_version = 10after exporting ensures compatibility with ONNX Runtime, which may not yet support IR version 11. This is a temporary workaround consistent with similar changes in other quantization methods (lines 629, 1340).
| def safe_cupy_array(tensor): | ||
| """Convert ml_dtypes.int4 tensor to numpy.int8 for CuPy compatibility. | ||
| In ONNX 1.19, int4 tensors use ml_dtypes.int4 which CuPy doesn't support. | ||
| This function converts them to regular numpy.int8 while preserving values. | ||
| Args: | ||
| tensor: numpy array that may have ml_dtypes.int4 dtype | ||
| Returns: | ||
| cupy or numpy array (if cupy is not supported) with numpy.int8 dtype if input was ml_dtypes.int4, | ||
| otherwise unchanged | ||
| """ | ||
| try: | ||
| import ml_dtypes | ||
|
|
||
| if hasattr(tensor, "dtype") and tensor.dtype == ml_dtypes.int4: | ||
| return np.asarray(tensor.astype(numpy.int8)) | ||
| except ImportError: | ||
| pass | ||
|
|
||
| return np.asarray(tensor) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Well-implemented compatibility helper for ml_dtypes.int4.
The safe_cupy_array function correctly handles the conversion of ml_dtypes.int4 tensors to numpy.int8 for CuPy compatibility. The dynamic import of ml_dtypes and conditional casting ensure backward compatibility with ONNX versions that don't use ml_dtypes.int4.
However, verify that all weight-loading scenarios are covered. The previous review comment flagged that weights loaded via numpy_helper.to_array (e.g., lines 531, 557, 750, 842, 931, 1162, 1249) followed by np.asarray might fail with CuPy if those weights are ml_dtypes.int4. While the typical quantization flow loads float32 weights, confirm there are no edge cases where already-quantized INT4 weights are loaded and processed within this module.
Run the following script to identify all weight-loading patterns:
🏁 Script executed:
#!/bin/bash
# Description: Find all numpy_helper.to_array usage followed by np.asarray in int4.py
rg -n -A 5 'numpy_helper\.to_array' modelopt/onnx/quantization/int4.py | \
rg -B 2 -A 2 'np\.asarray'Length of output: 1093
Enforce safe_cupy_array on all weight loads
Wrap each raw np.asarray call following numpy_helper.to_array (lines 536, 562, 755, 847, 936, 1167 and line 1253) with safe_cupy_array to convert any ml_dtypes.int4 to numpy.int8 and ensure CuPy compatibility.
8c4ca7a to
07db172
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)
42-74: The monkeypatch test doesn't fully exercise the ImportError path.The test at lines 63-73 attempts to verify behavior when
ml_dtypesimport fails, but it only tests with anumpy.int8tensor. This doesn't exercise the code path where anml_dtypes.int4tensor is encountered but the import fails (which would happen iftensor.dtype == ml_dtypes.int4is checked when ml_dtypes is unavailable).In practice, this edge case is unlikely (you can't have an
ml_dtypes.int4tensor without ml_dtypes), but for complete test coverage, consider either:
- Documenting that this scenario is impossible and the test verifies graceful degradation for non-int4 tensors
- Restructuring the test to use a mock tensor with a dtype that appears to be int4 without requiring ml_dtypes
For example, you could create a mock tensor with a fake dtype:
# Test 3: When ml_dtypes import fails with int4-like tensor class MockInt4Dtype: pass class MockInt4Tensor: def __init__(self, data): self.data = numpy.array(data, dtype=numpy.int8) self.dtype = MockInt4Dtype() def astype(self, dtype): return self.data.astype(dtype) def mock_import(name, *args, **kwargs): if name == "ml_dtypes": raise ImportError("ml_dtypes not available") return builtins.__import__(name, *args, **kwargs) monkeypatch.setattr(builtins, "__import__", mock_import) mock_tensor = MockInt4Tensor([5, 6, 7, 8]) result = int4.safe_cupy_array(mock_tensor) assert isinstance(result, np.ndarray)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (4)
tests/_test_utils/import_helper.py (1)
skip_if_no_libcudnn(37-43)modelopt/onnx/quantization/int4.py (1)
safe_cupy_array(102-122)modelopt/onnx/quantization/quant_utils.py (1)
dq_tensor(328-344)tests/_test_utils/onnx_quantization/lib_test_models.py (1)
find_init(235-243)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (3)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (3)
23-23: LGTM! Import change aligns with ONNX compatibility goals.Replacing the ONNX version guard with the libcudnn guard is correct, as the code now handles both ONNX 1.18 and 1.19 through the
safe_cupy_arrayhelper rather than version-specific logic.
129-130: LGTM! Correct integration of safe_cupy_array.The use of
safe_cupy_arrayproperly handles ONNX 1.19'sml_dtypes.int4tensors by converting them tonumpy.int8before CuPy processing, maintaining compatibility across ONNX versions.
140-141: LGTM! Consistent usage of safe_cupy_array.The integration pattern is consistent with lines 129-130, correctly handling INT4 tensor compatibility for the AWQ clip path.
07db172 to
3cea2a8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (1)
54-54: Use a more direct path to create int4 tensors.Creating an int4 tensor via
numpy.float32 -> ml_dtypes.int4is unconventional and may not properly test the conversion logic.Consider using a more direct approach:
- int4_tensor = numpy.array([1, 2, -3, 4], dtype=numpy.float32).astype(ml_dtypes.int4) + int4_tensor = numpy.array([1, 2, -3, 4], dtype=ml_dtypes.int4)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (3)
tests/_test_utils/import_helper.py (1)
skip_if_no_libcudnn(37-43)modelopt/onnx/quantization/int4.py (1)
safe_cupy_array(102-122)modelopt/onnx/quantization/quant_utils.py (1)
dq_tensor(328-344)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (3)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (3)
23-23: LGTM! Import change aligns with PR objectives.The removal of the ONNX version guard and use of
skip_if_no_libcudnnis consistent with the PR's goal of supporting both ONNX 1.18 and 1.19 through thesafe_cupy_arraycompatibility layer.
131-132: LGTM! Proper use of the compatibility helper.The
safe_cupy_arraywrapper correctly handles both ONNX 1.18 (numpy.int8) and ONNX 1.19 (ml_dtypes.int4) tensor types, maintaining CuPy compatibility as described in the PR objectives.Also applies to: 142-143
63-75: The ImportError test may not exercise the intended code path.If
ml_dtypesis successfully imported in Test 2 (lines 51-61), it remains cached insys.modules. The monkeypatched__import__won't be invoked for cached modules, so theexcept ImportErrorblock at line 119 insafe_cupy_arraymay not be reached.To properly test the ImportError fallback, clear the module cache before applying the monkeypatch:
monkeypatch.setattr(builtins, "__import__", mock_import) + # Clear cached ml_dtypes to ensure monkeypatch is effective + import sys + if 'ml_dtypes' in sys.modules: + monkeypatch.delitem(sys.modules, 'ml_dtypes')Likely an incorrect or invalid review comment.
3cea2a8 to
0c74459
Compare
Signed-off-by: Hrishith Thadicherla <[email protected]>
0c74459 to
58bfca0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (3)
132-135: Call safe_cupy_array unconditionally (simpler, always correct).
safe_cupy_arrayis safe with or without CuPy/ml_dtypes. Drop thehas_cupyguard to consistently normalize dtypes.- if int4.has_cupy: - wq_onnx_awq_lite = int4.safe_cupy_array(wq_onnx_awq_lite) - scale_awq_lite = int4.safe_cupy_array(scale_awq_lite) + wq_onnx_awq_lite = int4.safe_cupy_array(wq_onnx_awq_lite) + scale_awq_lite = int4.safe_cupy_array(scale_awq_lite) @@ - if int4.has_cupy: - wq_onnx_awq_clip = int4.safe_cupy_array(wq_onnx_awq_clip) - scale_awq_clip = int4.safe_cupy_array(scale_awq_clip) + wq_onnx_awq_clip = int4.safe_cupy_array(wq_onnx_awq_clip) + scale_awq_clip = int4.safe_cupy_array(scale_awq_clip)Also applies to: 143-146
48-49: Avoid brittle references to source line numbers in test comments.Comments like “covers lines 117–118” will rot. Make them descriptive instead.
- # Test 1: Regular numpy array (should hit line 122) + # Test 1: Regular numpy array path @@ - # Test 2: With real ml_dtypes.int4 (covers lines 117-118) + # Test 2: With real ml_dtypes.int4 -> expect int8 cast @@ - # Test 3: When ml_dtypes import fails (covers ImportError catch and line 122) + # Test 3: Simulate ml_dtypes ImportError -> fallback pathAlso applies to: 52-54, 65-66
149-150: Normalize to NumPy for robust comparisons across backends.When CuPy is active,
npiscupyandnp.allclosemay coerce torch tensors implicitly. Normalize explicitly to NumPy to avoid backend surprises.- assert np.allclose(wq_torch_awq_lite.detach(), wq_onnx_awq_lite.T, atol=1e-3) - assert np.allclose(wq_torch_awq_clip.detach(), wq_onnx_awq_clip.T, atol=1e-3) + import numpy as _numpy + lhs_lite = wq_torch_awq_lite.detach().cpu().numpy() + rhs_lite = wq_onnx_awq_lite.T + rhs_lite = rhs_lite.get() if int4.has_cupy else rhs_lite + assert _numpy.allclose(lhs_lite, rhs_lite, atol=1e-3) + + lhs_clip = wq_torch_awq_clip.detach().cpu().numpy() + rhs_clip = wq_onnx_awq_clip.T + rhs_clip = rhs_clip.get() if int4.has_cupy else rhs_clip + assert _numpy.allclose(lhs_clip, rhs_clip, atol=1e-3)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py (4)
tests/_test_utils/import_helper.py (1)
skip_if_no_libcudnn(37-43)modelopt/onnx/quantization/int4.py (1)
safe_cupy_array(102-122)modelopt/onnx/quantization/quant_utils.py (1)
dq_tensor(328-344)tests/_test_utils/onnx_quantization/lib_test_models.py (1)
find_init(235-243)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
- GitHub Check: wait-checks / wait
| assert isinstance(result, np.ndarray) and result.dtype == numpy.int8 | ||
| expected = int4_tensor.astype(numpy.int8) | ||
| actual = result.get() if int4.has_cupy else result | ||
| np.testing.assert_array_equal(actual, expected) | ||
| except ImportError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use numpy.testing here to avoid cupy alias mismatch.
Under cupy, np is cupy, so np.testing.assert_array_equal may not accept NumPy arrays (actual/expected). Call numpy’s testing explicitly.
- np.testing.assert_array_equal(actual, expected)
+ numpy.testing.assert_array_equal(actual, expected)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| assert isinstance(result, np.ndarray) and result.dtype == numpy.int8 | |
| expected = int4_tensor.astype(numpy.int8) | |
| actual = result.get() if int4.has_cupy else result | |
| np.testing.assert_array_equal(actual, expected) | |
| except ImportError: | |
| assert isinstance(result, np.ndarray) and result.dtype == numpy.int8 | |
| expected = int4_tensor.astype(numpy.int8) | |
| actual = result.get() if int4.has_cupy else result | |
| numpy.testing.assert_array_equal(actual, expected) | |
| except ImportError: |
🤖 Prompt for AI Agents
In tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py around lines 58 to 62,
the test calls np.testing.assert_array_equal which under CuPy can resolve to
cupy.testing and fail when comparing NumPy arrays; replace that call with
numpy.testing.assert_array_equal so the NumPy testing function is used
explicitly (ensure the existing numpy import is used), i.e. change the assertion
to call numpy.testing.assert_array_equal(actual, expected).
What does this PR do?
Type of change: Bug Fix
Overview:
This PR adds compatibility support for ONNX 1.19 int4 quantization by handling the new
ml_dtypes.int4tensor type that was introduced in ONNX 1.19.Background:
numpy.int8ml_dtypes.int4to represent int4 insteadIssues encountered:
The issues faced in ModelOpt because of this change was that CuPy doesn't support
ml_dtypes.int4, causing failures when loading int4 weight tensors and after packing 2 weights in np.int8 while exporting to onnx, it was casting the tensor to theml_dtypes.int4after which caused the int4 AWQ quantization test to fail.Changes made:
safe_cupy_array()inint4.pyto maintain CuPy compatibilityml_dtypes.int4while packing and instead cast tonumpy.int8ornumpy.uint8Testing
Summary by CodeRabbit
New Features
Bug Fixes
Tests