-
Notifications
You must be signed in to change notification settings - Fork 183
ONNX 1.19 compatibility fix for INT4 quantization #423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4d62165
bcc60b8
ef69ab7
39e8922
e36413d
ecea581
1e9631c
1635bab
58bfca0
5e51553
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -20,7 +20,7 @@ | |||||||||||||||||||||
| from functools import partial | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| import torch | ||||||||||||||||||||||
| from _test_utils.import_helper import skip_if_no_libcudnn, skip_if_onnx_version_above_1_18 | ||||||||||||||||||||||
| from _test_utils.import_helper import skip_if_no_libcudnn | ||||||||||||||||||||||
| from _test_utils.onnx_quantization.lib_test_models import SimpleMLP, export_as_onnx, find_init | ||||||||||||||||||||||
| from _test_utils.torch_quantization.quantize_common import get_awq_config | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -39,9 +39,45 @@ | |||||||||||||||||||||
| # test_qdq_utils_fp8.py::test_fused_q[bf16,fp16] fails if this script runs after the int4 test, but not before. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def test_int4_awq(tmp_path): | ||||||||||||||||||||||
| skip_if_onnx_version_above_1_18() | ||||||||||||||||||||||
| def test_safe_cupy_array(monkeypatch): | ||||||||||||||||||||||
| """Comprehensive test for safe_cupy_array covering all code paths.""" | ||||||||||||||||||||||
| import builtins | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| import numpy # Import actual numpy for creating int4 tensors | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Test 1: Regular numpy array (should hit line 122) | ||||||||||||||||||||||
| result = int4.safe_cupy_array(numpy.array([1, 2, 3, 4], dtype=numpy.float32)) | ||||||||||||||||||||||
| assert isinstance(result, np.ndarray) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Test 2: With real ml_dtypes.int4 (covers lines 117-118) | ||||||||||||||||||||||
| try: | ||||||||||||||||||||||
| import ml_dtypes | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| int4_tensor = numpy.array([1, 2, -3, 4], dtype=numpy.float32).astype(ml_dtypes.int4) | ||||||||||||||||||||||
| result = int4.safe_cupy_array(int4_tensor) | ||||||||||||||||||||||
| assert isinstance(result, np.ndarray) and result.dtype == numpy.int8 | ||||||||||||||||||||||
| expected = int4_tensor.astype(numpy.int8) | ||||||||||||||||||||||
| actual = result.get() if int4.has_cupy else result | ||||||||||||||||||||||
| np.testing.assert_array_equal(actual, expected) | ||||||||||||||||||||||
| except ImportError: | ||||||||||||||||||||||
|
Comment on lines
+58
to
+62
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use numpy.testing here to avoid cupy alias mismatch. Under cupy, - np.testing.assert_array_equal(actual, expected)
+ numpy.testing.assert_array_equal(actual, expected)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||
| pass # ml_dtypes not available | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Test 3: When ml_dtypes import fails (covers ImportError catch and line 122) | ||||||||||||||||||||||
| original_import = builtins.__import__ | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def mock_import(name, *args, **kwargs): | ||||||||||||||||||||||
| if name == "ml_dtypes": | ||||||||||||||||||||||
| raise ImportError("ml_dtypes not available") | ||||||||||||||||||||||
| return original_import(name, *args, **kwargs) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| monkeypatch.setattr(builtins, "__import__", mock_import) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Use actual numpy for creating the array | ||||||||||||||||||||||
| result = int4.safe_cupy_array(numpy.array([5, 6, 7, 8], dtype=numpy.int8)) | ||||||||||||||||||||||
| assert isinstance(result, np.ndarray) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def test_int4_awq(tmp_path): | ||||||||||||||||||||||
| def _forward_loop(model, dataloader): | ||||||||||||||||||||||
| """Forward loop for calibration.""" | ||||||||||||||||||||||
| for data in dataloader: | ||||||||||||||||||||||
|
|
@@ -94,20 +130,19 @@ def _forward_loop(model, dataloader): | |||||||||||||||||||||
| scale_awq_lite = find_init(onnx_model_awq_lite, scale_names[i]) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if int4.has_cupy: | ||||||||||||||||||||||
| wq_onnx_awq_lite = np.array(wq_onnx_awq_lite) | ||||||||||||||||||||||
| scale_awq_lite = np.array(scale_awq_lite) | ||||||||||||||||||||||
| wq_onnx_awq_lite = int4.safe_cupy_array(wq_onnx_awq_lite) | ||||||||||||||||||||||
| scale_awq_lite = int4.safe_cupy_array(scale_awq_lite) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| wq_onnx_awq_lite = dq_tensor(wq_onnx_awq_lite, scale_awq_lite, block_size) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| wq_torch_awq_clip = model_torch_copy.net[i * 2].weight_quantizer( | ||||||||||||||||||||||
| model_torch_copy.net[i * 2].weight | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| wq_onnx_awq_clip = find_init(onnx_model_awq_clip, wq_names[i]) | ||||||||||||||||||||||
| scale_awq_clip = find_init(onnx_model_awq_clip, scale_names[i]) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if int4.has_cupy: | ||||||||||||||||||||||
| wq_onnx_awq_clip = np.array(wq_onnx_awq_clip) | ||||||||||||||||||||||
| scale_awq_clip = np.array(scale_awq_clip) | ||||||||||||||||||||||
| wq_onnx_awq_clip = int4.safe_cupy_array(wq_onnx_awq_clip) | ||||||||||||||||||||||
| scale_awq_clip = int4.safe_cupy_array(scale_awq_clip) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| wq_onnx_awq_clip = dq_tensor(wq_onnx_awq_clip, scale_awq_clip, block_size) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Well-implemented compatibility helper for
ml_dtypes.int4.The
safe_cupy_arrayfunction correctly handles the conversion ofml_dtypes.int4tensors tonumpy.int8for CuPy compatibility. The dynamic import ofml_dtypesand conditional casting ensure backward compatibility with ONNX versions that don't useml_dtypes.int4.However, verify that all weight-loading scenarios are covered. The previous review comment flagged that weights loaded via
numpy_helper.to_array(e.g., lines 531, 557, 750, 842, 931, 1162, 1249) followed bynp.asarraymight fail with CuPy if those weights areml_dtypes.int4. While the typical quantization flow loads float32 weights, confirm there are no edge cases where already-quantized INT4 weights are loaded and processed within this module.Run the following script to identify all weight-loading patterns:
🏁 Script executed:
Length of output: 1093
Enforce
safe_cupy_arrayon all weight loadsWrap each raw
np.asarraycall followingnumpy_helper.to_array(lines 536, 562, 755, 847, 936, 1167 and line 1253) withsafe_cupy_arrayto convert anyml_dtypes.int4tonumpy.int8and ensure CuPy compatibility.