Skip to content

Conversation

@realAsma
Copy link
Contributor

@realAsma realAsma commented Oct 17, 2025

What does this PR do?

Type of change: ? new feature

Overview:

Add support to enable custom emulated quantization backend. See :meth:`register_quant_backend <modelopt.torch.quantization.nn.modules.tensor_quantizer.register_quant_backend>`` for more details.

Usage

See an example in tests/unit/torch/quantization/test_custom_backend.py.

Testing

See the unites

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Support for pluggable custom quantization backends with configurable backend args.
    • Quantization config accepts string-based bit specifications.
  • Improvements

    • Algorithm parameter made optional in quantization calls.
    • Device-aware CUDA handling to ensure correct device placement during quantization.
    • Optional Triton-accelerated FP4 path for improved performance where available.
    • Warning and automatic disabling when gradient_accumulation_fusion is incompatible.
  • Tests

    • Added unit test covering custom backend usage; updated GPU tests for non-current device scenarios.

realAsma added 2 commits October 17, 2025 18:06
Signed-off-by: realAsma <[email protected]>

minor

minor

moved external changes to this PR

addressed PR comments; clean ups

minor test fix

rebasing external changes from [2/2]

minor unrelated fix

Signed-off-by: realAsma <[email protected]>
Signed-off-by: realAsma <[email protected]>
@realAsma realAsma requested review from a team as code owners October 17, 2025 18:13
@realAsma realAsma requested a review from cjluo-nv October 17, 2025 18:13
@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 17, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 17, 2025

Walkthrough

Adds a pluggable custom quantization backend system, extends quantizer configuration to include backend name/args and flexible num_bits, refactors device handling for CUDA/Triton paths, updates TensorQuantizer to delegate fake-quant to registered backends, and adds tests and warnings for related plugins.

Changes

Cohort / File(s) Summary
Changelog
CHANGELOG.rst
Added feature entry for custom emulated quantization backend and reordered an ONNX quantization flag entry.
Quantizer configuration
modelopt/torch/quantization/config.py
QuantizerAttributeConfig.num_bits now accepts `str
Backend registry & TensorQuantizer
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Added registry _QUANT_FUNCTIONAL_BACKENDS, type alias QuantBackendEntrypoint, and public APIs: register_quant_backend, unregister_quant_backend, is_registered_quant_backend; TensorQuantizer now supports backend/backend_extra_args, delegates fake-quant to registered backends, and updates extra_repr and setters.
Quantization entrypoint
modelopt/torch/quantization/model_quant.py
quantize() now passes config.get("algorithm") to calibrate() (algorithm becomes optional).
Device utility
modelopt/torch/utils/tensor.py
Added same_device_as(torch.Tensor) context manager and exported it.
Tensor quantization paths
modelopt/torch/quantization/tensor_quant.py
Removed explicit torch.cuda.device context blocks; introduced a Triton FP4 fast path with conditional routing and simplified device handling for CUDA extension calls.
Megatron plugin
modelopt/torch/quantization/plugins/megatron.py
_MegatronParallelLinear._setup() now warns and disables gradient_accumulation_fusion when incompatible with ModelOpt quantization.
Unit tests
tests/unit/torch/quantization/test_custom_backend.py
New test registering a dummy backend, running quantize with backend-specific config and backend_extra_args, asserting expected outputs, and unregistering the backend.
GPU tests
tests/gpu/torch/quantization/test_tensor_quant_cuda.py
tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py
Removed two non-current-GPU tests; added test_non_current_gpu to TestTensorQuantizerE4M3 validating cross-device consistency for TensorQuantizer on CUDA.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    actor User
    participant Registry as Backend Registry
    participant ModelOpt as ModelOpt Quantize
    participant Quantizer as TensorQuantizer
    participant Backend as Custom Backend

    User->>Registry: register_quant_backend("name", impl)
    Registry-->>User: ack

    User->>ModelOpt: quantize(model, config{backend:"name", backend_extra_args:{...}})
    ModelOpt->>Quantizer: set_from_attribute_config(...) 
    Quantizer->>Quantizer: store backend name & args

    ModelOpt->>Quantizer: forward(tensor)
    alt backend configured
        Quantizer->>Registry: lookup("name")
        Registry-->>Quantizer: callable
        Quantizer->>Quantizer: same_device_as(tensor) ctx
        Quantizer->>Backend: callable(tensor, quantizer)
        Backend-->>Quantizer: quantized tensor
    else no backend
        Quantizer->>Quantizer: builtin fake-quant path
    end

    User->>Registry: unregister_quant_backend("name")
    Registry-->>User: ack
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐇 I hopped to register a backend so spry,
Tiny offsets, tensors soaring high.
I guard the device, I mind each stride,
We quantize together — paw and pride.
Hooray for backends, small and grand!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title "[1/2] Registry interface for custom quantization functional backend" is clearly and specifically related to the main feature introduced in this changeset. The title directly identifies the registry interface (register_quant_backend, unregister_quant_backend, is_registered_quant_backend) as the primary change, which aligns with the PR objectives describing the addition of "a registry interface" for custom backends. The title is concise, avoids vague terms, and provides enough specificity that a developer reviewing the commit history would understand the core contribution. While the changeset includes supporting modifications to configurations and backend handling, the registry interface represents the primary architectural addition, which the title appropriately highlights.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch asma/quant_backends_initial_support

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1b2a57d and b119fd1.

📒 Files selected for processing (2)
  • tests/gpu/torch/quantization/test_tensor_quant_cuda.py (0 hunks)
  • tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py (1 hunks)
💤 Files with no reviewable changes (1)
  • tests/gpu/torch/quantization/test_tensor_quant_cuda.py
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py (2)
tests/gpu/torch/conftest.py (1)
  • need_2_gpus (32-34)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
  • to (115-123)
🔇 Additional comments (1)
tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py (1)

58-64: LGTM! Cross-device quantization test is well-structured.

The test correctly validates that the quantizer can process inputs on different CUDA devices and produce consistent results. This aligns with the PR's goal of improving device handling through the new backend system.

Optional suggestion: Consider adding a brief docstring or comment explaining that this test validates cross-device quantization behavior, especially in the context of the new same_device_as utility.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Comment on lines +234 to +241

if getattr(self, "gradient_accumulation_fusion", False):
warnings.warn(
"gradient_accumulation_fusion is not supported with ModelOpt quantization. "
"Setting gradient_accumulation_fusion to False."
)
self.gradient_accumulation_fusion = False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jenchen13 @ChenhanYu this should fix the occasional errors people run into by unknowingly enabling gradient_accumulation_fusion

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/tensor_quant.py (1)

75-87: Add device context guards to CUDA extension calls in tensor_quant.py to prevent multi-GPU errors.

The review correctly identifies a multi-GPU correctness issue. CUDA extension calls execute on the current device (via getCurrentCUDAStream()), not necessarily where the input tensor resides. This causes failures when current_device ≠ inputs.device.

However, line number references are offset. The actual functions are:

  • scaled_e4m3_impl: lines 62–87 (extension calls at 80, 84)
  • fake_quant_impl: lines 90–108 (extension calls at 101, 106)
  • _dynamic_block_quantize_impl: lines 152–181 (extension call at 176)

The proposed fix is valid: import same_device_as from modelopt.torch.utils.tensor and wrap each extension call. This pattern already exists in tensor_quantizer.py:987 and matches the helper's design.

Apply the import and wrapping changes as suggested in the diff, adjusting line numbers to match the actual file locations.

🧹 Nitpick comments (7)
modelopt/torch/utils/tensor.py (1)

34-41: Small enhancement: prefer device_of and clarify CUDA-only scope.

Consider using torch.cuda.device_of(inputs) for brevity and to future‑proof against potential device index edge cases. Also clarify in the docstring that this helper is CUDA‑only (returns a no‑op for CPU/other backends).

Apply this diff if you prefer:

-from contextlib import nullcontext
+from contextlib import nullcontext
+from typing import ContextManager

-def same_device_as(inputs: torch.Tensor):
-    """Return a context manager that sets the CUDA device to be the same as the input tensor.
+def same_device_as(inputs: torch.Tensor) -> ContextManager[None]:
+    """CUDA-only: return a context manager that sets the current CUDA device to the tensor's device.
@@
-    if not inputs.is_cuda or inputs.device.index == torch.cuda.current_device():
+    if not inputs.is_cuda or inputs.device.index == torch.cuda.current_device():
         return nullcontext()
-    return torch.cuda.device(inputs.device.index)
+    return torch.cuda.device_of(inputs)
modelopt/torch/quantization/config.py (1)

668-681: Doc/errmsg polish and typo fix.

Minor text fixes to improve clarity and the error message string formatting.

@@
-    num_bits: int | tuple[int, int] | str = ModeloptField(
+    num_bits: int | tuple[int, int] | str = ModeloptField(
@@
-            of mantissa bits. Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1).
+            of mantissa bits. Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6 (E3M2, E2M3), FP4 (E2M1).
@@
-        #. String specifying the quantization format. This is current used only for custom backends.""",
+        #. String specifying the quantization format. This is currently used only for custom backends.""",
@@
-    backend_extra_args: dict | None = ModeloptField(
+    backend_extra_args: dict | None = ModeloptField(
@@
-        description="""The extra arguments will saved on to the quantizer instance - this wont be
-        passed directly to the backend entrypoint. Can be any serializable dictionary.
+        description="""The extra arguments will be saved on the quantizer instance — this won't be
+        passed directly to the backend entrypoint. Can be any serializable dictionary.
@@
-            raise ValueError(
-                "Only blockwise dynamic quantization is supported with quantization "
-                "formats E{num_bis[0]}M{num_bits[1]}."
-            )
+            raise ValueError(
+                f"Only blockwise dynamic quantization is supported with quantization "
+                f"formats E{num_bits[0]}M{num_bits[1]}."
+            )

Also applies to: 972-982, 743-749

modelopt/torch/quantization/plugins/megatron.py (1)

235-241: Use logger and reduce warning spam.

warnings.warn here can flood logs per layer. Prefer logger.warning for consistency with this module and to honor logging config.

-        if getattr(self, "gradient_accumulation_fusion", False):
-            warnings.warn(
-                "gradient_accumulation_fusion is not supported with ModelOpt quantization. "
-                "Setting gradient_accumulation_fusion to False."
-            )
-            self.gradient_accumulation_fusion = False
+        if getattr(self, "gradient_accumulation_fusion", False):
+            logger.warning(
+                "gradient_accumulation_fusion is not supported with ModelOpt quantization; "
+                "forcing gradient_accumulation_fusion = False."
+            )
+            self.gradient_accumulation_fusion = False

If many layers set this flag, consider gating the message (warn once per process) to avoid repetition.

CHANGELOG.rst (1)

13-14: Fix Sphinx role and tighten wording.

  • Remove the extra trailing backtick to avoid a docs build error.
  • Prefer “Add support for a custom emulated quantization backend.”

Apply:

- - Add support to enable custom emulated quantization backend. See :meth:`register_quant_backend <modelopt.torch.quantization.nn.modules.tensor_quantizer.register_quant_backend>`` for more details. See an example in ``tests/unit/torch/quantization/test_custom_backend.py``.
+ - Add support for a custom emulated quantization backend. See :meth:`register_quant_backend <modelopt.torch.quantization.nn.modules.tensor_quantizer.register_quant_backend>` for details. Example: ``tests/unit/torch/quantization/test_custom_backend.py``.
tests/unit/torch/quantization/test_custom_backend.py (2)

6-9: Clarify quantizer being exercised.

The comment says “output quantizer,” but the config targets “*weight_quantizer”. Update to avoid confusion.


24-53: Ensure backend cleanup even on assertion failure.

Wrap registration/unregistration in try/finally to avoid leaking the backend into other tests. Optionally assert registry state.

Apply:

-    register_quant_backend("dummy_backend", dummy_backend)
+    register_quant_backend("dummy_backend", dummy_backend)
+    try:
         ...
-    # Unregister the backend to avoid impacting other tests
-    unregister_quant_backend("dummy_backend")
+    finally:
+        # Unregister the backend to avoid impacting other tests
+        unregister_quant_backend("dummy_backend")

Optionally:

- from modelopt.torch.quantization.nn import register_quant_backend, unregister_quant_backend
+ from modelopt.torch.quantization.nn import (
+     register_quant_backend,
+     unregister_quant_backend,
+     is_registered_quant_backend,
+ )
@@
-    register_quant_backend("dummy_backend", dummy_backend)
+    register_quant_backend("dummy_backend", dummy_backend)
+    assert is_registered_quant_backend("dummy_backend")
@@
-    unregister_quant_backend("dummy_backend")
+    unregister_quant_backend("dummy_backend")
+    assert not is_registered_quant_backend("dummy_backend")
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

1048-1052: Guard optional backend fields in extra_repr.

If backend is set but backend_extra_args isn’t, attribute access can fail. Use getattr.

Apply:

-        s += (
-            f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}"
+        s += (
+            f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}"
             if self.pre_quant_scale is not None
             else ""
         )
@@
-        s += (
-            f" backend={self.backend}, extra_args={self.backend_extra_args}"
-            if self.backend is not None
-            else ""
-        )
+        if getattr(self, "backend", None) is not None:
+            extra = getattr(self, "backend_extra_args", {})
+            s += f" backend={self.backend}, extra_args={extra}"

Also applies to: 1065-1068

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f8a9353 and 1b2a57d.

📒 Files selected for processing (8)
  • CHANGELOG.rst (1 hunks)
  • modelopt/torch/quantization/config.py (4 hunks)
  • modelopt/torch/quantization/model_quant.py (1 hunks)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (9 hunks)
  • modelopt/torch/quantization/plugins/megatron.py (1 hunks)
  • modelopt/torch/quantization/tensor_quant.py (3 hunks)
  • modelopt/torch/utils/tensor.py (1 hunks)
  • tests/unit/torch/quantization/test_custom_backend.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
modelopt/torch/quantization/tensor_quant.py (5)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • amax (287-292)
  • amax (295-306)
  • axis (334-336)
  • axis (339-341)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
  • fake_e4m3fy (87-101)
  • fake_e4m3fy (87-87)
  • fake_e4m3fy_with_axis (66-85)
  • fake_e4m3fy_with_axis (66-66)
modelopt/torch/quantization/triton/fp4_kernel.py (1)
  • fp4_fake_quant_block (122-164)
modelopt/torch/quantization/extensions.py (1)
  • get_cuda_ext_mx (55-71)
modelopt/torch/quantization/src/tensor_quant_mx.cu (2)
  • fused_amax_convert (355-387)
  • fused_amax_convert (355-356)
tests/unit/torch/quantization/test_custom_backend.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • register_quant_backend (74-91)
  • unregister_quant_backend (94-102)
modelopt/torch/quantization/model_quant.py (2)
  • forward_loop (94-95)
  • quantize (136-231)
modelopt/torch/quantization/model_quant.py (3)
tests/unit/torch/quantization/test_custom_backend.py (1)
  • forward_loop (44-45)
modelopt/torch/quantization/algorithms.py (1)
  • forward_loop (567-572)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
  • forward_loop (200-205)
modelopt/torch/quantization/config.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • num_bits (234-236)
  • num_bits (239-241)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
modelopt/torch/utils/tensor.py (1)
  • same_device_as (34-41)
modelopt/torch/utils/network.py (1)
  • standardize_constructor_args (471-484)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (7)
modelopt/torch/quantization/config.py (1)

712-715: Bypass of num_bits validation for backends looks good.

Short‑circuiting validation when backend is set aligns with pluggable backends and avoids over‑constraining custom formats.

modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)

60-66: Public API surface looks good.

Exporting backend management functions alongside quantizer classes is appropriate.


69-72: Type alias and registry are clear and well-scoped.

Good use of Callable and a private dict for backends.


74-92: Backend registration API: sound validation and ergonomics.

Input validation, overwrite warning, and idempotent unregister are appropriate.

Also applies to: 94-103, 105-112


987-989: Good device-safety around backend path.

Wrapping fake-quant in same_device_as(inputs) avoids cross-device CUDA context issues.


1024-1029: Nice helper for concise tensor summaries.

_short_tensor improves extra_repr readability without overhead.


40-40: Import re-exports are confirmed and working as designed.

The symbols same_device_as and standardize_constructor_args are correctly re-exported from modelopt.torch.utils via wildcard imports in the package's __init__.py. Both from .tensor import * and from .network import * are used to re-export these modules' public APIs, making the current import statement valid and idiomatic.

The suggested refactor to explicit imports would contradict the module's intentional design of using wildcard re-exports to provide a unified public API. No changes are needed.

Likely an incorrect or invalid review comment.

Comment on lines +231 to 232
return calibrate(model, config.get("algorithm"), forward_loop=forward_loop)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Preserve default calibration behavior (avoid silent no‑calibration).

Passing config.get("algorithm") forwards None when the key is absent, which skips calibration. Previously, missing "algorithm" implied "max". Recommend defaulting to "max" unless explicitly set to None.

-    return calibrate(model, config.get("algorithm"), forward_loop=forward_loop)
+    return calibrate(model, config.get("algorithm", "max"), forward_loop=forward_loop)

Please confirm whether this PR intends to change the behavior when "algorithm" is omitted. If yes, update docs to highlight that omission now skips calibration.


🏁 Script executed:

#!/bin/bash
# Find the calibrate function definition
echo "=== Searching for calibrate function definition ==="
rg -n "def calibrate" --type=py

echo -e "\n=== Context around current code (lines 231-232) ==="
fd "model_quant.py" --type=py --exec head -250 {} \; 2>/dev/null | tail -30

echo -e "\n=== Check for algorithm default in config ==="
rg -n "algorithm" modelopt/torch/quantization/model_quant.py -B2 -A2

echo -e "\n=== Search for related tests ==="
fd "test.*\.py" --type=py | xargs rg -l "calibrate.*algorithm" 2>/dev/null | head -5

Length of output: 5645



🏁 Script executed:

#!/bin/bash

echo "=== Inspect calibrate function to confirm default parameter ==="
rg -n "def calibrate" modelopt/torch/quantization/model_quant.py -A10

echo -e "\n=== Look for git history of line 231 ==="
git log -p --all -S 'config.get("algorithm")' -- modelopt/torch/quantization/model_quant.py 2>/dev/null | head -80

echo -e "\n=== Search for tests calling quantize() without algorithm ==="
rg -n "quantize\(" --type=py -g '*/test*' -B3 -A3 | grep -A3 -B3 "algorithm" | head -40

echo -e "\n=== Check what happens when None is passed to calibrate() ==="
rg -n "if algorithm" modelopt/torch/quantization/model_quant.py -A5 -B2

echo -e "\n=== Look for examples/docs showing quantize() without algorithm key ==="
rg -n 'config.*=.*{' modelopt/torch/quantization/model_quant.py -A15 | grep -A15 '"quant_cfg"'

Length of output: 956


🏁 Script executed:

#!/bin/bash

echo "=== Inspect calibrate function to confirm default parameter ==="
rg -n "def calibrate" modelopt/torch/quantization/model_quant.py -A10

echo -e "\n=== Look for git history of line 231 ==="
git log -p --all -S 'config.get("algorithm")' -- modelopt/torch/quantization/model_quant.py 2>/dev/null | head -100

echo -e "\n=== Search for tests calling quantize without algorithm key ==="
find . -path ./build -prune -o -name "*test*.py" -type f -print | xargs grep -l "quantize.*config" 2>/dev/null | head -5

echo -e "\n=== Check algorithm handling in calibrate function ==="
sed -n '53,110p' modelopt/torch/quantization/model_quant.py

echo -e "\n=== Search for get_modelike_from_algo_cfg to see None handling ==="
rg -n "def get_modelike_from_algo_cfg" modelopt/torch/quantization/ -A15

Length of output: 4705


🏁 Script executed:

#!/bin/bash

echo "=== Look for tests with quantize() missing algorithm key ==="
rg -n "quantize\(" tests/ -B5 -A5 --type=py | grep -B8 -A8 -E "(config\s*=|\"quant_cfg\")" | head -60

echo -e "\n=== Check if None/missing algorithm is tested ==="
rg -n "algorithm.*None\|None.*algorithm" tests/ --type=py -B2 -A2

echo -e "\n=== Look for weight-only quantization examples ==="
rg -n "weight.*only\|weight-only" modelopt/torch/quantization/ --type=py -B3 -A3 | head -40

echo -e "\n=== Check if calibrate() is called without algorithm in examples ==="
rg -n "calibrate\(" examples/ --type=py -B2 -A2 | head -40

echo -e "\n=== Verify line 73-76 of model_quant.py to confirm None documentation ==="
sed -n '73,76p' modelopt/torch/quantization/model_quant.py

Length of output: 7555


🏁 Script executed:

#!/bin/bash

echo "=== Check if quantize() without algorithm key is valid/tested ==="
rg -n "quantize\(" tests/ --type=py -A10 | grep -B10 "\"quant_cfg\"" | grep -v "algorithm" | head -50

echo -e "\n=== Search for any PR/issue context about algorithm default ==="
git log --all --oneline --grep="algorithm" -- modelopt/torch/quantization/model_quant.py 2>/dev/null | head -20

echo -e "\n=== Check git diff of lines around 231 ==="
git log -p -- modelopt/torch/quantization/model_quant.py 2>/dev/null | grep -B5 -A5 "config.get" | head -40

echo -e "\n=== Verify: Does config schema require algorithm? ==="
rg -n "algorithm" modelopt/torch/quantization/config.py -B3 -A3 | head -40

Length of output: 3733


🏁 Script executed:

#!/bin/bash

echo "=== Search for QuantizeConfig class definition ==="
rg -n "class QuantizeConfig" modelopt/torch/quantization/ -A30

echo -e "\n=== Check if algorithm field is marked as required ==="
rg -n "algorithm.*required\|required.*algorithm" modelopt/torch/quantization/ --type=py

echo -e "\n=== Look for config validation that checks algorithm presence ==="
rg -n "def.*quantize\(" modelopt/torch/quantization/model_quant.py -A50 | grep -A50 "def quantize"

echo -e "\n=== Check actual quantize() implementation around config.get ==="
sed -n '215,235p' modelopt/torch/quantization/model_quant.py

Length of output: 6346


Preserve default calibration behavior when "algorithm" key is omitted.

The QuantizeConfig defines algorithm with default="max", but line 231 bypasses this by passing config.get("algorithm") without a default fallback. When the config dict omits the "algorithm" key, None is passed to calibrate(), which explicitly skips calibration (as documented in its docstring). This contradicts the QuantizeConfig default.

-    return calibrate(model, config.get("algorithm"), forward_loop=forward_loop)
+    return calibrate(model, config.get("algorithm", "max"), forward_loop=forward_loop)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return calibrate(model, config.get("algorithm"), forward_loop=forward_loop)
return calibrate(model, config.get("algorithm", "max"), forward_loop=forward_loop)
🤖 Prompt for AI Agents
In modelopt/torch/quantization/model_quant.py around lines 231 to 232, the call
passes config.get("algorithm") which yields None when the key is omitted and
causes calibrate() to skip calibration; change the call to supply the intended
default (e.g., use config.get("algorithm", "max") or resolve the
QuantizeConfig.default) so that when "algorithm" is missing the default "max" is
forwarded to calibrate().

Comment on lines +205 to 207
"backend": ("backend", lambda val: val),
"backend_extra_args": ("backend_extra_args", lambda val: val or {}),
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Defensive defaults for backend fields to avoid AttributeError.

_fake_quantize and extra_repr access self.backend/self.backend_extra_args directly. If QuantizerAttributeConfig ever omits these keys, attribute access can raise. Initialize defaults in __init__ and use getattr in hot paths.

Apply:

@@ class TensorQuantizer(nn.Module):
     def __init__(...
         super().__init__()
         quant_attribute_cfg = (
             quant_attribute_cfg if quant_attribute_cfg is not None else QuantizerAttributeConfig()
         )
+        # Defaults for optional backend fields; safe even if config overrides them.
+        self.backend = None
+        self.backend_extra_args = {}
@@
-        if self.backend is not None:
-            if self.backend not in _QUANT_FUNCTIONAL_BACKENDS:
+        backend = getattr(self, "backend", None)
+        if backend is not None:
+            if backend not in _QUANT_FUNCTIONAL_BACKENDS:
                 raise KeyError(f"Quant backend '{self.backend}' is not registered.")
-            entrypoint = _QUANT_FUNCTIONAL_BACKENDS[self.backend]
+            entrypoint = _QUANT_FUNCTIONAL_BACKENDS[backend]
             return entrypoint(inputs, self)

Also applies to: 675-680

🤖 Prompt for AI Agents
In modelopt/torch/quantization/nn/modules/tensor_quantizer.py around lines
205-207, the mapping for backend fields can be omitted by
QuantizerAttributeConfig which makes later direct accesses to self.backend and
self.backend_extra_args raise AttributeError; initialize safe defaults in
__init__ (e.g., self.backend = None and self.backend_extra_args = {} or
appropriate neutral defaults) and update hot paths (_fake_quantize and
extra_repr) to use getattr(self, "backend", None) and getattr(self,
"backend_extra_args", {}) when reading them; apply the same initialization +
getattr pattern to the code region around lines 675-680 as well.

Comment on lines +1033 to 1039
s = "disabled"
s += (
f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}"
if self.pre_quant_scale is not None
else ""
)
return "disabled"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

extra_repr returns constant instead of built string when disabled.

You build s with pre_quant_scale but return the literal "disabled". Return s to expose the extra detail.

Apply:

-            return "disabled"
+            return s
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
s = "disabled"
s += (
f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}"
if self.pre_quant_scale is not None
else ""
)
return "disabled"
s = "disabled"
s += (
f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}"
if self.pre_quant_scale is not None
else ""
)
return s
🤖 Prompt for AI Agents
In modelopt/torch/quantization/nn/modules/tensor_quantizer.py around lines 1033
to 1039, the extra_repr branch for the disabled state builds a string s
including pre_quant_scale but then returns the literal "disabled" instead of s;
change the return to return s so the constructed string (e.g., "disabled
pre_quant_scale=...") is returned, preserving the existing conditional
construction and formatting.

Comment on lines +82 to 87
if amax.numel() == 1:
outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
elif amax.squeeze().ndim == 1:
axis = amax.shape.index(amax.numel())
outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
return outputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Axis inference bug for per‑channel amax (wrong source tensor).

axis is computed from amax.shape, which always yields 0 for 1D amax. The CUDA kernels expect the axis in inputs. This breaks cases where the channel axis isn’t 0 (e.g., last dim activations, column/row‑parallel weights).

Use inputs.shape to infer axis and disambiguate duplicates; otherwise fail fast. Patch both scaled_e4m3_impl and fake_quant_impl:

@@
-    if amax.numel() == 1:
-        outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
-    elif amax.squeeze().ndim == 1:
-        axis = amax.shape.index(amax.numel())
-        outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
+    if amax.numel() == 1:
+        outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
+    elif amax.squeeze().ndim == 1:
+        n = amax.numel()
+        candidates = [i for i, s in enumerate(inputs.shape) if s == n]
+        if not candidates:
+            raise ValueError(
+                f"Cannot infer per-channel axis: amax.numel()={n} does not match any inputs dim {tuple(inputs.shape)}"
+            )
+        # Prefer last-dim if it matches, otherwise require uniqueness
+        axis = inputs.dim() - 1 if inputs.shape[-1] == n else (candidates[0] if len(candidates) == 1 else None)
+        if axis is None:
+            raise ValueError(
+                f"Ambiguous per-channel axis for amax of length {n} in inputs shape {tuple(inputs.shape)}; "
+                "please provide broadcast-shaped amax."
+            )
+        outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
@@
-    if amax.numel() == 1:
-        outputs = cuda_ext.fake_tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
-    else:
-        axis = amax.shape.index(amax.numel())
-        outputs = cuda_ext.fake_tensor_quant_with_axis(
-            inputs, amax.squeeze(), axis, num_bits, unsigned, narrow_range
-        )
+    if amax.numel() == 1:
+        outputs = cuda_ext.fake_tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
+    else:
+        n = amax.numel()
+        candidates = [i for i, s in enumerate(inputs.shape) if s == n]
+        if not candidates:
+            raise ValueError(
+                f"Cannot infer per-channel axis: amax.numel()={n} does not match any inputs dim {tuple(inputs.shape)}"
+            )
+        axis = inputs.dim() - 1 if inputs.shape[-1] == n else (candidates[0] if len(candidates) == 1 else None)
+        if axis is None:
+            raise ValueError(
+                f"Ambiguous per-channel axis for amax of length {n} in inputs shape {tuple(inputs.shape)}; "
+                "please provide broadcast-shaped amax."
+            )
+        outputs = cuda_ext.fake_tensor_quant_with_axis(
+            inputs, amax.squeeze(), axis, num_bits, unsigned, narrow_range
+        )

Also applies to: 100-107

🤖 Prompt for AI Agents
In modelopt/torch/quantization/tensor_quant.py around lines 82-87 (and similarly
for lines ~100-107 in scaled_e4m3_impl and fake_quant_impl), the code currently
computes axis from amax.shape which yields 0 for 1D amax and is wrong when the
channel axis in inputs isn't 0; instead infer the axis by searching inputs.shape
for dimensions that match amax.numel(), disambiguate duplicates by ensuring
amax.squeeze().ndim == 1 and matching dimension lengths, and if multiple
candidate axes exist or no match is found, raise an explicit error (fail fast).
Replace axis = amax.shape.index(amax.numel()) with logic that derives candidate
axes from inputs.shape, selects the single matching axis or errors on ambiguity,
and then call the CUDA kernels with that inferred axis; apply the same fix to
both scaled_e4m3_impl and fake_quant_impl.

@codecov
Copy link

codecov bot commented Oct 17, 2025

Codecov Report

❌ Patch coverage is 68.96552% with 18 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.42%. Comparing base (4df4091) to head (1b2a57d).
⚠️ Report is 5 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/tensor_quant.py 0.00% 14 Missing ⚠️
.../torch/quantization/nn/modules/tensor_quantizer.py 90.90% 3 Missing ⚠️
modelopt/torch/utils/tensor.py 80.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #449      +/-   ##
==========================================
+ Coverage   73.38%   73.42%   +0.04%     
==========================================
  Files         180      180              
  Lines       17934    17968      +34     
==========================================
+ Hits        13160    13193      +33     
- Misses       4774     4775       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: realAsma <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant