Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ Model Optimizer Changelog (Linux)

**New Features**

- Add flag ``op_types_to_exclude_fp16`` in ONNX quantization to exclude ops from being converted to FP16/BF16. Alternatively, for custom TensorRT ops, this can also be done by indicating ``'fp32'`` precision in ``trt_plugins_precision``.
- Add LoRA mode support for MCore in a new peft submodule: ``modelopt.torch.peft.update_model(model, LORA_CFG)``.
- Support PTQ and fakequant in vLLM for fast evaluation of arbitrary quantization formats. See ``examples/vllm_serve`` for more details.
- Add support to enable custom emulated quantization backend. See :meth:`register_quant_backend <modelopt.torch.quantization.nn.modules.tensor_quantizer.register_quant_backend>`` for more details. See an example in ``tests/unit/torch/quantization/test_custom_backend.py``.
- Add flag ``op_types_to_exclude_fp16`` in ONNX quantization to exclude ops from being converted to FP16/BF16. Alternatively, for custom TensorRT ops, this can also be done by indicating ``'fp32'`` precision in ``trt_plugins_precision``.
- Add support for ``nemotron-post-training-dataset-v2`` and ``nemotron-post-training-dataset-v1`` in ``examples/llm_ptq``. Default to a mix of ``cnn_dailymail`` and ``nemotron-post-training-dataset-v2`` if no dataset is specified.
- Allow specifying ``calib_seq`` in ``examples/llm_ptq`` to set the maximum sequence length for calibration.

Expand Down
35 changes: 32 additions & 3 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ class QuantizerAttributeConfig(ModeloptBaseConfig):
description="""If True, enables the quantizer. If False, by-pass the quantizer and returns the input tensor.""",
)

num_bits: int | tuple[int, int] = ModeloptField(
num_bits: int | tuple[int, int] | str = ModeloptField(
default=8,
title="An integer or a tuple of two integers specifying the number of quantization bits.",
description="""`num_bits` can be:
Expand All @@ -675,7 +675,9 @@ class QuantizerAttributeConfig(ModeloptBaseConfig):

#. Constant integer tuple (E,M) for floating point quantization emulating
Nvidia's FPx quantization. E is the number of exponent bits and M is the number
of mantissa bits. Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1).""",
of mantissa bits. Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1).

#. String specifying the quantization format. This is current used only for custom backends.""",
)

@model_validator(mode="before")
Expand Down Expand Up @@ -707,10 +709,16 @@ def _validate_recursive(value):
@model_validator(mode="after")
def validate_num_bits(self):
"""Validate `num_bits`."""
if self.backend is not None:
# For custom backends, we don't need to validate num_bits
return self

num_bits = self.num_bits

if isinstance(num_bits, int) and num_bits < 1:
raise ValueError("num_bits must be a positive integer or a tuple of positive integers.")
raise ValueError(
f"num_bits must be a positive integer or a tuple of positive integers. {num_bits}"
)

if not isinstance(num_bits, tuple):
return self
Expand Down Expand Up @@ -952,6 +960,27 @@ def validate_calibrator(cls, v, info: ValidationInfo):
""",
)

backend: str | None = ModeloptField(
default=None,
title="Name of custom quantization functional backend.",
description="""
Selects a non-default quantization functional backend by name. See
:meth:`register_quant_backend <modelopt.torch.nn.modules.tensor_quantizer.register_quant_backend>`
for more details on how to register a custom quantization backend.
""",
)
backend_extra_args: dict | None = ModeloptField(
default=None,
title="Extra arguments for the selected backend.",
description="""The extra arguments will saved on to the quantizer instance - this wont be
passed directly to the backend entrypoint. Can be any serializable dictionary.

Please use `backend_extra_args` to pass arguments that are not already supported by
`QuantizerAttributeConfig`. This will ensure maximum compatibility with the other modelopt
features such as modelopt's calibration algorithms.
""",
)


class QuantizeAlgorithmConfig(ModeloptBaseConfig):
"""Calibration algorithm config base."""
Expand Down
2 changes: 1 addition & 1 deletion modelopt/torch/quantization/model_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def forward_loop(model) -> None:
Returns: A pytorch model which has been quantized and calibrated.
"""
model = apply_mode(model, mode=[("quantize", config)], registry=QuantizeModeRegistry)
return calibrate(model, config["algorithm"], forward_loop=forward_loop)
return calibrate(model, config.get("algorithm"), forward_loop=forward_loop)

Comment on lines +231 to 232
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Preserve default calibration behavior (avoid silent no‑calibration).

Passing config.get("algorithm") forwards None when the key is absent, which skips calibration. Previously, missing "algorithm" implied "max". Recommend defaulting to "max" unless explicitly set to None.

-    return calibrate(model, config.get("algorithm"), forward_loop=forward_loop)
+    return calibrate(model, config.get("algorithm", "max"), forward_loop=forward_loop)

Please confirm whether this PR intends to change the behavior when "algorithm" is omitted. If yes, update docs to highlight that omission now skips calibration.


🏁 Script executed:

#!/bin/bash
# Find the calibrate function definition
echo "=== Searching for calibrate function definition ==="
rg -n "def calibrate" --type=py

echo -e "\n=== Context around current code (lines 231-232) ==="
fd "model_quant.py" --type=py --exec head -250 {} \; 2>/dev/null | tail -30

echo -e "\n=== Check for algorithm default in config ==="
rg -n "algorithm" modelopt/torch/quantization/model_quant.py -B2 -A2

echo -e "\n=== Search for related tests ==="
fd "test.*\.py" --type=py | xargs rg -l "calibrate.*algorithm" 2>/dev/null | head -5

Length of output: 5645



🏁 Script executed:

#!/bin/bash

echo "=== Inspect calibrate function to confirm default parameter ==="
rg -n "def calibrate" modelopt/torch/quantization/model_quant.py -A10

echo -e "\n=== Look for git history of line 231 ==="
git log -p --all -S 'config.get("algorithm")' -- modelopt/torch/quantization/model_quant.py 2>/dev/null | head -80

echo -e "\n=== Search for tests calling quantize() without algorithm ==="
rg -n "quantize\(" --type=py -g '*/test*' -B3 -A3 | grep -A3 -B3 "algorithm" | head -40

echo -e "\n=== Check what happens when None is passed to calibrate() ==="
rg -n "if algorithm" modelopt/torch/quantization/model_quant.py -A5 -B2

echo -e "\n=== Look for examples/docs showing quantize() without algorithm key ==="
rg -n 'config.*=.*{' modelopt/torch/quantization/model_quant.py -A15 | grep -A15 '"quant_cfg"'

Length of output: 956


🏁 Script executed:

#!/bin/bash

echo "=== Inspect calibrate function to confirm default parameter ==="
rg -n "def calibrate" modelopt/torch/quantization/model_quant.py -A10

echo -e "\n=== Look for git history of line 231 ==="
git log -p --all -S 'config.get("algorithm")' -- modelopt/torch/quantization/model_quant.py 2>/dev/null | head -100

echo -e "\n=== Search for tests calling quantize without algorithm key ==="
find . -path ./build -prune -o -name "*test*.py" -type f -print | xargs grep -l "quantize.*config" 2>/dev/null | head -5

echo -e "\n=== Check algorithm handling in calibrate function ==="
sed -n '53,110p' modelopt/torch/quantization/model_quant.py

echo -e "\n=== Search for get_modelike_from_algo_cfg to see None handling ==="
rg -n "def get_modelike_from_algo_cfg" modelopt/torch/quantization/ -A15

Length of output: 4705


🏁 Script executed:

#!/bin/bash

echo "=== Look for tests with quantize() missing algorithm key ==="
rg -n "quantize\(" tests/ -B5 -A5 --type=py | grep -B8 -A8 -E "(config\s*=|\"quant_cfg\")" | head -60

echo -e "\n=== Check if None/missing algorithm is tested ==="
rg -n "algorithm.*None\|None.*algorithm" tests/ --type=py -B2 -A2

echo -e "\n=== Look for weight-only quantization examples ==="
rg -n "weight.*only\|weight-only" modelopt/torch/quantization/ --type=py -B3 -A3 | head -40

echo -e "\n=== Check if calibrate() is called without algorithm in examples ==="
rg -n "calibrate\(" examples/ --type=py -B2 -A2 | head -40

echo -e "\n=== Verify line 73-76 of model_quant.py to confirm None documentation ==="
sed -n '73,76p' modelopt/torch/quantization/model_quant.py

Length of output: 7555


🏁 Script executed:

#!/bin/bash

echo "=== Check if quantize() without algorithm key is valid/tested ==="
rg -n "quantize\(" tests/ --type=py -A10 | grep -B10 "\"quant_cfg\"" | grep -v "algorithm" | head -50

echo -e "\n=== Search for any PR/issue context about algorithm default ==="
git log --all --oneline --grep="algorithm" -- modelopt/torch/quantization/model_quant.py 2>/dev/null | head -20

echo -e "\n=== Check git diff of lines around 231 ==="
git log -p -- modelopt/torch/quantization/model_quant.py 2>/dev/null | grep -B5 -A5 "config.get" | head -40

echo -e "\n=== Verify: Does config schema require algorithm? ==="
rg -n "algorithm" modelopt/torch/quantization/config.py -B3 -A3 | head -40

Length of output: 3733


🏁 Script executed:

#!/bin/bash

echo "=== Search for QuantizeConfig class definition ==="
rg -n "class QuantizeConfig" modelopt/torch/quantization/ -A30

echo -e "\n=== Check if algorithm field is marked as required ==="
rg -n "algorithm.*required\|required.*algorithm" modelopt/torch/quantization/ --type=py

echo -e "\n=== Look for config validation that checks algorithm presence ==="
rg -n "def.*quantize\(" modelopt/torch/quantization/model_quant.py -A50 | grep -A50 "def quantize"

echo -e "\n=== Check actual quantize() implementation around config.get ==="
sed -n '215,235p' modelopt/torch/quantization/model_quant.py

Length of output: 6346


Preserve default calibration behavior when "algorithm" key is omitted.

The QuantizeConfig defines algorithm with default="max", but line 231 bypasses this by passing config.get("algorithm") without a default fallback. When the config dict omits the "algorithm" key, None is passed to calibrate(), which explicitly skips calibration (as documented in its docstring). This contradicts the QuantizeConfig default.

-    return calibrate(model, config.get("algorithm"), forward_loop=forward_loop)
+    return calibrate(model, config.get("algorithm", "max"), forward_loop=forward_loop)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return calibrate(model, config.get("algorithm"), forward_loop=forward_loop)
return calibrate(model, config.get("algorithm", "max"), forward_loop=forward_loop)
🤖 Prompt for AI Agents
In modelopt/torch/quantization/model_quant.py around lines 231 to 232, the call
passes config.get("algorithm") which yields None when the key is omitted and
causes calibrate() to skip calibration; change the call to supply the intended
default (e.g., use config.get("algorithm", "max") or resolve the
QuantizeConfig.default) so that when "algorithm" is missing the default "max" is
forwarded to calibrate().


def auto_quantize(
Expand Down
100 changes: 87 additions & 13 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import contextlib
import math
import warnings
from typing import TYPE_CHECKING, Any
from collections.abc import Callable
from typing import Any

import torch
import torch.distributed as dist
Expand All @@ -36,7 +37,7 @@
import torch.nn.functional as F
from torch import nn

from modelopt.torch.utils import standardize_constructor_args
from modelopt.torch.utils import same_device_as, standardize_constructor_args
from modelopt.torch.utils.distributed import DistributedProcessGroup

from ... import calib
Expand All @@ -56,10 +57,58 @@
from ...utils import is_torch_export_mode
from ..functional import normalized_hadamard_transform

if TYPE_CHECKING:
from collections.abc import Callable
__all__ = [
"SequentialQuantizer",
"TensorQuantizer",
"is_registered_quant_backend",
"register_quant_backend",
"unregister_quant_backend",
]

__all__ = ["SequentialQuantizer", "TensorQuantizer"]

QuantBackendEntrypoint = Callable[[torch.Tensor, "TensorQuantizer"], torch.Tensor]

_QUANT_FUNCTIONAL_BACKENDS: dict[str, QuantBackendEntrypoint] = {}


def register_quant_backend(name: str, entrypoint: QuantBackendEntrypoint) -> None:
"""Register a custom quantization backend.

Args:
name: The name of the backend.
entrypoint: The entrypoint of the backend. The entrypoint should be a callable that takes in
the inputs and the tensor quantizer as arguments and returns the quantized tensor.
See :class:`modelopt.torch.quantization.config.QuantizerAttributeConfig`
for details on choosing from the registered backends via the ``backend`` and
``backend_extra_args`` fields.
"""
if not isinstance(name, str) or not name:
raise ValueError("Backend name must be a non-empty string.")
if not callable(entrypoint):
raise TypeError("Entrypoint must be callable.")
if name in _QUANT_FUNCTIONAL_BACKENDS:
warnings.warn(f"Overwriting existing backend: {name}")
_QUANT_FUNCTIONAL_BACKENDS[name] = entrypoint


def unregister_quant_backend(name: str) -> None:
"""Unregister a custom quantization backend.

Args:
name: The name of the backend to unregister.
"""
if not isinstance(name, str) or not name:
raise ValueError("Backend name must be a non-empty string.")
_QUANT_FUNCTIONAL_BACKENDS.pop(name, None)


def is_registered_quant_backend(name: str) -> bool:
"""Check if a custom quantization backend is registered.

Args:
name: The name of the backend to check.
"""
return name in _QUANT_FUNCTIONAL_BACKENDS


class TensorQuantizer(nn.Module):
Expand Down Expand Up @@ -153,6 +202,8 @@ def _calibrator_setter(val):
"enable": ("_disabled", lambda val: val is False),
"type": ("_dynamic", lambda val: val == "dynamic"),
"calibrator": ("_calibrator", _calibrator_setter),
"backend": ("backend", lambda val: val),
"backend_extra_args": ("backend_extra_args", lambda val: val or {}),
}
Comment on lines +205 to 207
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Defensive defaults for backend fields to avoid AttributeError.

_fake_quantize and extra_repr access self.backend/self.backend_extra_args directly. If QuantizerAttributeConfig ever omits these keys, attribute access can raise. Initialize defaults in __init__ and use getattr in hot paths.

Apply:

@@ class TensorQuantizer(nn.Module):
     def __init__(...
         super().__init__()
         quant_attribute_cfg = (
             quant_attribute_cfg if quant_attribute_cfg is not None else QuantizerAttributeConfig()
         )
+        # Defaults for optional backend fields; safe even if config overrides them.
+        self.backend = None
+        self.backend_extra_args = {}
@@
-        if self.backend is not None:
-            if self.backend not in _QUANT_FUNCTIONAL_BACKENDS:
+        backend = getattr(self, "backend", None)
+        if backend is not None:
+            if backend not in _QUANT_FUNCTIONAL_BACKENDS:
                 raise KeyError(f"Quant backend '{self.backend}' is not registered.")
-            entrypoint = _QUANT_FUNCTIONAL_BACKENDS[self.backend]
+            entrypoint = _QUANT_FUNCTIONAL_BACKENDS[backend]
             return entrypoint(inputs, self)

Also applies to: 675-680

🤖 Prompt for AI Agents
In modelopt/torch/quantization/nn/modules/tensor_quantizer.py around lines
205-207, the mapping for backend fields can be omitted by
QuantizerAttributeConfig which makes later direct accesses to self.backend and
self.backend_extra_args raise AttributeError; initialize safe defaults in
__init__ (e.g., self.backend = None and self.backend_extra_args = {} or
appropriate neutral defaults) and update hot paths (_fake_quantize and
extra_repr) to use getattr(self, "backend", None) and getattr(self,
"backend_extra_args", {}) when reading them; apply the same initialization +
getattr pattern to the code region around lines 675-680 as well.


for attribute, val in attribute_cfg.items():
Expand Down Expand Up @@ -621,6 +672,12 @@ def _real_quantize(self, inputs):

def _fake_quantize(self, inputs):
"""Fake quantization."""
if self.backend is not None:
if self.backend not in _QUANT_FUNCTIONAL_BACKENDS:
raise KeyError(f"Quant backend '{self.backend}' is not registered.")
entrypoint = _QUANT_FUNCTIONAL_BACKENDS[self.backend]
return entrypoint(inputs, self)

amax = None
if not self.is_mx_format:
amax = self._get_amax(inputs)
Expand Down Expand Up @@ -927,7 +984,8 @@ def forward(self, inputs):
if hasattr(inputs, "is_contiguous") and not inputs.is_contiguous():
inputs.data = inputs.data.contiguous()
if self.fake_quant:
outputs = self._fake_quantize(inputs)
with same_device_as(inputs):
outputs = self._fake_quantize(inputs)
elif not self._dequantize:
outputs = self._real_quantize(inputs)
else:
Expand Down Expand Up @@ -961,16 +1019,23 @@ def _short_amax(self, fmt=".4f"):
return "None"
if self._amax.is_meta:
return "meta"
if self._amax.numel() == 1:
return f"{self._amax.item():{fmt}}"
return (
f"[{self._amax.min().item():{fmt}},"
f" {self._amax.max().item():{fmt}}]({self._amax.numel()})"
)
return self._short_tensor(self._amax, fmt)

def _short_tensor(self, tensor: torch.Tensor, fmt=".4f"):
"""Short description of tensor."""
if tensor.numel() == 1:
return f"{tensor.item():{fmt}}"
return f"[{tensor.min().item():{fmt}}, {tensor.max().item():{fmt}}]({tensor.numel()})"

def extra_repr(self):
"""Set the extra information about this module."""
if self._disabled:
s = "disabled"
s += (
f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}"
if self.pre_quant_scale is not None
else ""
)
return "disabled"
Comment on lines +1033 to 1039
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

extra_repr returns constant instead of built string when disabled.

You build s with pre_quant_scale but return the literal "disabled". Return s to expose the extra detail.

Apply:

-            return "disabled"
+            return s
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
s = "disabled"
s += (
f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}"
if self.pre_quant_scale is not None
else ""
)
return "disabled"
s = "disabled"
s += (
f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}"
if self.pre_quant_scale is not None
else ""
)
return s
🤖 Prompt for AI Agents
In modelopt/torch/quantization/nn/modules/tensor_quantizer.py around lines 1033
to 1039, the extra_repr branch for the disabled state builds a string s
including pre_quant_scale but then returns the literal "disabled" instead of s;
change the return to return s so the constructed string (e.g., "disabled
pre_quant_scale=...") is returned, preserving the existing conditional
construction and formatting.

s = f"{'unsigned ' if self._unsigned else ''}{self._num_bits} bit"
s += " narrow" if (self._narrow_range) else ""
Expand All @@ -980,7 +1045,11 @@ def extra_repr(self):
else:
s += f" axis={self._axis}" if self._axis is not None else " per-tensor"
s += f" amax={self._short_amax()}"
s += " pre_quant_scale" if self.pre_quant_scale is not None else ""
s += (
f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}"
if self.pre_quant_scale is not None
else ""
)
s += " rotated" if self._rotate else ""
s += (
f" calibrator={self._calibrator.__class__.__name__}"
Expand All @@ -992,6 +1061,11 @@ def extra_repr(self):

s += " quant" if (self._if_quant) else ""
s += " calib" if (self._if_calib) else ""
s += (
f" backend={self.backend}, extra_args={self.backend_extra_args}"
if self.backend is not None
else ""
)
return s

def _get_properties_for_modelopt_state(self):
Expand Down
8 changes: 8 additions & 0 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ def _setup(self):
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
)

if getattr(self, "gradient_accumulation_fusion", False):
warnings.warn(
"gradient_accumulation_fusion is not supported with ModelOpt quantization. "
"Setting gradient_accumulation_fusion to False."
)
self.gradient_accumulation_fusion = False

Comment on lines +234 to +241
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jenchen13 @ChenhanYu this should fix the occasional errors people run into by unknowingly enabling gradient_accumulation_fusion

super()._setup()

def _process_quantizer_amax(self, k, v, quantizer_state_dict):
Expand Down
65 changes: 28 additions & 37 deletions modelopt/torch/quantization/tensor_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,11 @@ def scaled_e4m3_impl(
if cuda_ext_fp8 is None:
return fp8_eager(inputs, amax)

with torch.cuda.device(
None if inputs.device.index == torch.cuda.current_device() else inputs.device.index
):
if amax.numel() == 1:
outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
elif amax.squeeze().ndim == 1:
axis = amax.shape.index(amax.numel())
outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
if amax.numel() == 1:
outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
elif amax.squeeze().ndim == 1:
axis = amax.shape.index(amax.numel())
outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
return outputs
Comment on lines +82 to 87
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Axis inference bug for per‑channel amax (wrong source tensor).

axis is computed from amax.shape, which always yields 0 for 1D amax. The CUDA kernels expect the axis in inputs. This breaks cases where the channel axis isn’t 0 (e.g., last dim activations, column/row‑parallel weights).

Use inputs.shape to infer axis and disambiguate duplicates; otherwise fail fast. Patch both scaled_e4m3_impl and fake_quant_impl:

@@
-    if amax.numel() == 1:
-        outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
-    elif amax.squeeze().ndim == 1:
-        axis = amax.shape.index(amax.numel())
-        outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
+    if amax.numel() == 1:
+        outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
+    elif amax.squeeze().ndim == 1:
+        n = amax.numel()
+        candidates = [i for i, s in enumerate(inputs.shape) if s == n]
+        if not candidates:
+            raise ValueError(
+                f"Cannot infer per-channel axis: amax.numel()={n} does not match any inputs dim {tuple(inputs.shape)}"
+            )
+        # Prefer last-dim if it matches, otherwise require uniqueness
+        axis = inputs.dim() - 1 if inputs.shape[-1] == n else (candidates[0] if len(candidates) == 1 else None)
+        if axis is None:
+            raise ValueError(
+                f"Ambiguous per-channel axis for amax of length {n} in inputs shape {tuple(inputs.shape)}; "
+                "please provide broadcast-shaped amax."
+            )
+        outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
@@
-    if amax.numel() == 1:
-        outputs = cuda_ext.fake_tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
-    else:
-        axis = amax.shape.index(amax.numel())
-        outputs = cuda_ext.fake_tensor_quant_with_axis(
-            inputs, amax.squeeze(), axis, num_bits, unsigned, narrow_range
-        )
+    if amax.numel() == 1:
+        outputs = cuda_ext.fake_tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
+    else:
+        n = amax.numel()
+        candidates = [i for i, s in enumerate(inputs.shape) if s == n]
+        if not candidates:
+            raise ValueError(
+                f"Cannot infer per-channel axis: amax.numel()={n} does not match any inputs dim {tuple(inputs.shape)}"
+            )
+        axis = inputs.dim() - 1 if inputs.shape[-1] == n else (candidates[0] if len(candidates) == 1 else None)
+        if axis is None:
+            raise ValueError(
+                f"Ambiguous per-channel axis for amax of length {n} in inputs shape {tuple(inputs.shape)}; "
+                "please provide broadcast-shaped amax."
+            )
+        outputs = cuda_ext.fake_tensor_quant_with_axis(
+            inputs, amax.squeeze(), axis, num_bits, unsigned, narrow_range
+        )

Also applies to: 100-107

🤖 Prompt for AI Agents
In modelopt/torch/quantization/tensor_quant.py around lines 82-87 (and similarly
for lines ~100-107 in scaled_e4m3_impl and fake_quant_impl), the code currently
computes axis from amax.shape which yields 0 for 1D amax and is wrong when the
channel axis in inputs isn't 0; instead infer the axis by searching inputs.shape
for dimensions that match amax.numel(), disambiguate duplicates by ensuring
amax.squeeze().ndim == 1 and matching dimension lengths, and if multiple
candidate axes exist or no match is found, raise an explicit error (fail fast).
Replace axis = amax.shape.index(amax.numel()) with logic that derives candidate
axes from inputs.shape, selects the single matching axis or errors on ambiguity,
and then call the CUDA kernels with that inferred axis; apply the same fix to
both scaled_e4m3_impl and fake_quant_impl.



Expand All @@ -100,17 +97,14 @@ def fake_quant_impl(
"""Implementation of fake quantizing input according to number of bits."""
cuda_ext = get_cuda_ext()

with torch.cuda.device(
None if inputs.device.index == torch.cuda.current_device() else inputs.device.index
):
if amax.numel() == 1:
outputs = cuda_ext.fake_tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
else:
axis = amax.shape.index(amax.numel())
outputs = cuda_ext.fake_tensor_quant_with_axis(
inputs, amax.squeeze(), axis, num_bits, unsigned, narrow_range
)
return outputs
if amax.numel() == 1:
outputs = cuda_ext.fake_tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
else:
axis = amax.shape.index(amax.numel())
outputs = cuda_ext.fake_tensor_quant_with_axis(
inputs, amax.squeeze(), axis, num_bits, unsigned, narrow_range
)
return outputs


def _quantize_impl(
Expand Down Expand Up @@ -173,25 +167,22 @@ def _dynamic_block_quantize_impl(
assert amax.is_cuda, "amax must be a CUDA tensor for dynamic block quantization."
if amax.numel() != 1:
amax = amax.amax()
with torch.cuda.device(
None if inputs.device.index == torch.cuda.current_device() else inputs.device.index
if (
num_bits == (2, 1) # type: ignore[comparison-overlap]
and scale_bits == (4, 3)
and triton_kernel.IS_AVAILABLE
and not DISABLE_TRITON_KERNEL
and amax is not None
):
if (
num_bits == (2, 1) # type: ignore[comparison-overlap]
and scale_bits == (4, 3)
and triton_kernel.IS_AVAILABLE
and not DISABLE_TRITON_KERNEL
and amax is not None
):
return triton_kernel.fp4_fake_quant_block(inputs, amax)
cuda_ext_mx = get_cuda_ext_mx(raise_if_failed=True)
return cuda_ext_mx.fused_amax_convert(
inputs,
block_size,
getattr(cuda_ext_mx.Types, mx_format_map[num_bits]),
getattr(cuda_ext_mx.Types, mx_format_map[scale_bits]),
amax,
)
return triton_kernel.fp4_fake_quant_block(inputs, amax)
cuda_ext_mx = get_cuda_ext_mx(raise_if_failed=True)
return cuda_ext_mx.fused_amax_convert(
inputs,
block_size,
getattr(cuda_ext_mx.Types, mx_format_map[num_bits]),
getattr(cuda_ext_mx.Types, mx_format_map[scale_bits]),
amax,
)
else:
raise NotImplementedError(
f"Unsupported num_bits: {num_bits}, scale_bits: {scale_bits} for dynamic block quantization."
Expand Down
12 changes: 12 additions & 0 deletions modelopt/torch/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,31 @@
"""Utility functions for PyTorch tensors."""

from collections import abc
from contextlib import nullcontext

import numpy as np
import torch

__all__ = [
"numpy_to_torch",
"same_device_as",
"to_empty_if_meta_device",
"torch_detach",
"torch_to",
"torch_to_numpy",
]


def same_device_as(inputs: torch.Tensor):
"""Return a context manager that sets the CUDA device to be the same as the input tensor.

Returns a null context if the tensor is on CPU or on the same device as the current CUDA device.
"""
if not inputs.is_cuda or inputs.device.index == torch.cuda.current_device():
return nullcontext()
return torch.cuda.device(inputs.device.index)


def torch_to(data, *args, **kwargs):
"""Try to recursively move the data to the specified args/kwargs."""
if isinstance(data, torch.Tensor):
Expand Down
Loading