-
Notifications
You must be signed in to change notification settings - Fork 183
[1/2] Registry interface for custom quantization functional backend #449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: realAsma <[email protected]> minor minor moved external changes to this PR addressed PR comments; clean ups minor test fix rebasing external changes from [2/2] minor unrelated fix Signed-off-by: realAsma <[email protected]>
Signed-off-by: realAsma <[email protected]>
WalkthroughAdds a pluggable custom quantization backend system, extends quantizer configuration to include backend name/args and flexible num_bits, refactors device handling for CUDA/Triton paths, updates TensorQuantizer to delegate fake-quant to registered backends, and adds tests and warnings for related plugins. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Registry as Backend Registry
participant ModelOpt as ModelOpt Quantize
participant Quantizer as TensorQuantizer
participant Backend as Custom Backend
User->>Registry: register_quant_backend("name", impl)
Registry-->>User: ack
User->>ModelOpt: quantize(model, config{backend:"name", backend_extra_args:{...}})
ModelOpt->>Quantizer: set_from_attribute_config(...)
Quantizer->>Quantizer: store backend name & args
ModelOpt->>Quantizer: forward(tensor)
alt backend configured
Quantizer->>Registry: lookup("name")
Registry-->>Quantizer: callable
Quantizer->>Quantizer: same_device_as(tensor) ctx
Quantizer->>Backend: callable(tensor, quantizer)
Backend-->>Quantizer: quantized tensor
else no backend
Quantizer->>Quantizer: builtin fake-quant path
end
User->>Registry: unregister_quant_backend("name")
Registry-->>User: ack
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
💤 Files with no reviewable changes (1)
🧰 Additional context used🧬 Code graph analysis (1)tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py (2)
🔇 Additional comments (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
|
||
| if getattr(self, "gradient_accumulation_fusion", False): | ||
| warnings.warn( | ||
| "gradient_accumulation_fusion is not supported with ModelOpt quantization. " | ||
| "Setting gradient_accumulation_fusion to False." | ||
| ) | ||
| self.gradient_accumulation_fusion = False | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jenchen13 @ChenhanYu this should fix the occasional errors people run into by unknowingly enabling gradient_accumulation_fusion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/tensor_quant.py (1)
75-87: Add device context guards to CUDA extension calls in tensor_quant.py to prevent multi-GPU errors.The review correctly identifies a multi-GPU correctness issue. CUDA extension calls execute on the current device (via
getCurrentCUDAStream()), not necessarily where the input tensor resides. This causes failures when current_device ≠ inputs.device.However, line number references are offset. The actual functions are:
scaled_e4m3_impl: lines 62–87 (extension calls at 80, 84)fake_quant_impl: lines 90–108 (extension calls at 101, 106)_dynamic_block_quantize_impl: lines 152–181 (extension call at 176)The proposed fix is valid: import
same_device_asfrommodelopt.torch.utils.tensorand wrap each extension call. This pattern already exists intensor_quantizer.py:987and matches the helper's design.Apply the import and wrapping changes as suggested in the diff, adjusting line numbers to match the actual file locations.
🧹 Nitpick comments (7)
modelopt/torch/utils/tensor.py (1)
34-41: Small enhancement: prefer device_of and clarify CUDA-only scope.Consider using torch.cuda.device_of(inputs) for brevity and to future‑proof against potential device index edge cases. Also clarify in the docstring that this helper is CUDA‑only (returns a no‑op for CPU/other backends).
Apply this diff if you prefer:
-from contextlib import nullcontext +from contextlib import nullcontext +from typing import ContextManager -def same_device_as(inputs: torch.Tensor): - """Return a context manager that sets the CUDA device to be the same as the input tensor. +def same_device_as(inputs: torch.Tensor) -> ContextManager[None]: + """CUDA-only: return a context manager that sets the current CUDA device to the tensor's device. @@ - if not inputs.is_cuda or inputs.device.index == torch.cuda.current_device(): + if not inputs.is_cuda or inputs.device.index == torch.cuda.current_device(): return nullcontext() - return torch.cuda.device(inputs.device.index) + return torch.cuda.device_of(inputs)modelopt/torch/quantization/config.py (1)
668-681: Doc/errmsg polish and typo fix.Minor text fixes to improve clarity and the error message string formatting.
@@ - num_bits: int | tuple[int, int] | str = ModeloptField( + num_bits: int | tuple[int, int] | str = ModeloptField( @@ - of mantissa bits. Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1). + of mantissa bits. Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6 (E3M2, E2M3), FP4 (E2M1). @@ - #. String specifying the quantization format. This is current used only for custom backends.""", + #. String specifying the quantization format. This is currently used only for custom backends.""", @@ - backend_extra_args: dict | None = ModeloptField( + backend_extra_args: dict | None = ModeloptField( @@ - description="""The extra arguments will saved on to the quantizer instance - this wont be - passed directly to the backend entrypoint. Can be any serializable dictionary. + description="""The extra arguments will be saved on the quantizer instance — this won't be + passed directly to the backend entrypoint. Can be any serializable dictionary. @@ - raise ValueError( - "Only blockwise dynamic quantization is supported with quantization " - "formats E{num_bis[0]}M{num_bits[1]}." - ) + raise ValueError( + f"Only blockwise dynamic quantization is supported with quantization " + f"formats E{num_bits[0]}M{num_bits[1]}." + )Also applies to: 972-982, 743-749
modelopt/torch/quantization/plugins/megatron.py (1)
235-241: Use logger and reduce warning spam.warnings.warn here can flood logs per layer. Prefer logger.warning for consistency with this module and to honor logging config.
- if getattr(self, "gradient_accumulation_fusion", False): - warnings.warn( - "gradient_accumulation_fusion is not supported with ModelOpt quantization. " - "Setting gradient_accumulation_fusion to False." - ) - self.gradient_accumulation_fusion = False + if getattr(self, "gradient_accumulation_fusion", False): + logger.warning( + "gradient_accumulation_fusion is not supported with ModelOpt quantization; " + "forcing gradient_accumulation_fusion = False." + ) + self.gradient_accumulation_fusion = FalseIf many layers set this flag, consider gating the message (warn once per process) to avoid repetition.
CHANGELOG.rst (1)
13-14: Fix Sphinx role and tighten wording.
- Remove the extra trailing backtick to avoid a docs build error.
- Prefer “Add support for a custom emulated quantization backend.”
Apply:
- - Add support to enable custom emulated quantization backend. See :meth:`register_quant_backend <modelopt.torch.quantization.nn.modules.tensor_quantizer.register_quant_backend>`` for more details. See an example in ``tests/unit/torch/quantization/test_custom_backend.py``. + - Add support for a custom emulated quantization backend. See :meth:`register_quant_backend <modelopt.torch.quantization.nn.modules.tensor_quantizer.register_quant_backend>` for details. Example: ``tests/unit/torch/quantization/test_custom_backend.py``.tests/unit/torch/quantization/test_custom_backend.py (2)
6-9: Clarify quantizer being exercised.The comment says “output quantizer,” but the config targets “*weight_quantizer”. Update to avoid confusion.
24-53: Ensure backend cleanup even on assertion failure.Wrap registration/unregistration in try/finally to avoid leaking the backend into other tests. Optionally assert registry state.
Apply:
- register_quant_backend("dummy_backend", dummy_backend) + register_quant_backend("dummy_backend", dummy_backend) + try: ... - # Unregister the backend to avoid impacting other tests - unregister_quant_backend("dummy_backend") + finally: + # Unregister the backend to avoid impacting other tests + unregister_quant_backend("dummy_backend")Optionally:
- from modelopt.torch.quantization.nn import register_quant_backend, unregister_quant_backend + from modelopt.torch.quantization.nn import ( + register_quant_backend, + unregister_quant_backend, + is_registered_quant_backend, + ) @@ - register_quant_backend("dummy_backend", dummy_backend) + register_quant_backend("dummy_backend", dummy_backend) + assert is_registered_quant_backend("dummy_backend") @@ - unregister_quant_backend("dummy_backend") + unregister_quant_backend("dummy_backend") + assert not is_registered_quant_backend("dummy_backend")modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
1048-1052: Guard optional backend fields in extra_repr.If
backendis set butbackend_extra_argsisn’t, attribute access can fail. Usegetattr.Apply:
- s += ( - f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}" + s += ( + f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}" if self.pre_quant_scale is not None else "" ) @@ - s += ( - f" backend={self.backend}, extra_args={self.backend_extra_args}" - if self.backend is not None - else "" - ) + if getattr(self, "backend", None) is not None: + extra = getattr(self, "backend_extra_args", {}) + s += f" backend={self.backend}, extra_args={extra}"Also applies to: 1065-1068
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
CHANGELOG.rst(1 hunks)modelopt/torch/quantization/config.py(4 hunks)modelopt/torch/quantization/model_quant.py(1 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py(9 hunks)modelopt/torch/quantization/plugins/megatron.py(1 hunks)modelopt/torch/quantization/tensor_quant.py(3 hunks)modelopt/torch/utils/tensor.py(1 hunks)tests/unit/torch/quantization/test_custom_backend.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
modelopt/torch/quantization/tensor_quant.py (5)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
amax(287-292)amax(295-306)axis(334-336)axis(339-341)modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
fake_e4m3fy(87-101)fake_e4m3fy(87-87)fake_e4m3fy_with_axis(66-85)fake_e4m3fy_with_axis(66-66)modelopt/torch/quantization/triton/fp4_kernel.py (1)
fp4_fake_quant_block(122-164)modelopt/torch/quantization/extensions.py (1)
get_cuda_ext_mx(55-71)modelopt/torch/quantization/src/tensor_quant_mx.cu (2)
fused_amax_convert(355-387)fused_amax_convert(355-356)
tests/unit/torch/quantization/test_custom_backend.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
register_quant_backend(74-91)unregister_quant_backend(94-102)modelopt/torch/quantization/model_quant.py (2)
forward_loop(94-95)quantize(136-231)
modelopt/torch/quantization/model_quant.py (3)
tests/unit/torch/quantization/test_custom_backend.py (1)
forward_loop(44-45)modelopt/torch/quantization/algorithms.py (1)
forward_loop(567-572)modelopt/torch/quantization/plugins/transformers_trainer.py (1)
forward_loop(200-205)
modelopt/torch/quantization/config.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
num_bits(234-236)num_bits(239-241)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
modelopt/torch/utils/tensor.py (1)
same_device_as(34-41)modelopt/torch/utils/network.py (1)
standardize_constructor_args(471-484)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (7)
modelopt/torch/quantization/config.py (1)
712-715: Bypass of num_bits validation for backends looks good.Short‑circuiting validation when backend is set aligns with pluggable backends and avoids over‑constraining custom formats.
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (6)
60-66: Public API surface looks good.Exporting backend management functions alongside quantizer classes is appropriate.
69-72: Type alias and registry are clear and well-scoped.Good use of
Callableand a private dict for backends.
74-92: Backend registration API: sound validation and ergonomics.Input validation, overwrite warning, and idempotent unregister are appropriate.
Also applies to: 94-103, 105-112
987-989: Good device-safety around backend path.Wrapping fake-quant in
same_device_as(inputs)avoids cross-device CUDA context issues.
1024-1029: Nice helper for concise tensor summaries.
_short_tensorimprovesextra_reprreadability without overhead.
40-40: Import re-exports are confirmed and working as designed.The symbols
same_device_asandstandardize_constructor_argsare correctly re-exported frommodelopt.torch.utilsvia wildcard imports in the package's__init__.py. Bothfrom .tensor import *andfrom .network import *are used to re-export these modules' public APIs, making the current import statement valid and idiomatic.The suggested refactor to explicit imports would contradict the module's intentional design of using wildcard re-exports to provide a unified public API. No changes are needed.
Likely an incorrect or invalid review comment.
| return calibrate(model, config.get("algorithm"), forward_loop=forward_loop) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Preserve default calibration behavior (avoid silent no‑calibration).
Passing config.get("algorithm") forwards None when the key is absent, which skips calibration. Previously, missing "algorithm" implied "max". Recommend defaulting to "max" unless explicitly set to None.
- return calibrate(model, config.get("algorithm"), forward_loop=forward_loop)
+ return calibrate(model, config.get("algorithm", "max"), forward_loop=forward_loop)Please confirm whether this PR intends to change the behavior when "algorithm" is omitted. If yes, update docs to highlight that omission now skips calibration.
🏁 Script executed:
#!/bin/bash
# Find the calibrate function definition
echo "=== Searching for calibrate function definition ==="
rg -n "def calibrate" --type=py
echo -e "\n=== Context around current code (lines 231-232) ==="
fd "model_quant.py" --type=py --exec head -250 {} \; 2>/dev/null | tail -30
echo -e "\n=== Check for algorithm default in config ==="
rg -n "algorithm" modelopt/torch/quantization/model_quant.py -B2 -A2
echo -e "\n=== Search for related tests ==="
fd "test.*\.py" --type=py | xargs rg -l "calibrate.*algorithm" 2>/dev/null | head -5Length of output: 5645
🏁 Script executed:
#!/bin/bash
echo "=== Inspect calibrate function to confirm default parameter ==="
rg -n "def calibrate" modelopt/torch/quantization/model_quant.py -A10
echo -e "\n=== Look for git history of line 231 ==="
git log -p --all -S 'config.get("algorithm")' -- modelopt/torch/quantization/model_quant.py 2>/dev/null | head -80
echo -e "\n=== Search for tests calling quantize() without algorithm ==="
rg -n "quantize\(" --type=py -g '*/test*' -B3 -A3 | grep -A3 -B3 "algorithm" | head -40
echo -e "\n=== Check what happens when None is passed to calibrate() ==="
rg -n "if algorithm" modelopt/torch/quantization/model_quant.py -A5 -B2
echo -e "\n=== Look for examples/docs showing quantize() without algorithm key ==="
rg -n 'config.*=.*{' modelopt/torch/quantization/model_quant.py -A15 | grep -A15 '"quant_cfg"'Length of output: 956
🏁 Script executed:
#!/bin/bash
echo "=== Inspect calibrate function to confirm default parameter ==="
rg -n "def calibrate" modelopt/torch/quantization/model_quant.py -A10
echo -e "\n=== Look for git history of line 231 ==="
git log -p --all -S 'config.get("algorithm")' -- modelopt/torch/quantization/model_quant.py 2>/dev/null | head -100
echo -e "\n=== Search for tests calling quantize without algorithm key ==="
find . -path ./build -prune -o -name "*test*.py" -type f -print | xargs grep -l "quantize.*config" 2>/dev/null | head -5
echo -e "\n=== Check algorithm handling in calibrate function ==="
sed -n '53,110p' modelopt/torch/quantization/model_quant.py
echo -e "\n=== Search for get_modelike_from_algo_cfg to see None handling ==="
rg -n "def get_modelike_from_algo_cfg" modelopt/torch/quantization/ -A15Length of output: 4705
🏁 Script executed:
#!/bin/bash
echo "=== Look for tests with quantize() missing algorithm key ==="
rg -n "quantize\(" tests/ -B5 -A5 --type=py | grep -B8 -A8 -E "(config\s*=|\"quant_cfg\")" | head -60
echo -e "\n=== Check if None/missing algorithm is tested ==="
rg -n "algorithm.*None\|None.*algorithm" tests/ --type=py -B2 -A2
echo -e "\n=== Look for weight-only quantization examples ==="
rg -n "weight.*only\|weight-only" modelopt/torch/quantization/ --type=py -B3 -A3 | head -40
echo -e "\n=== Check if calibrate() is called without algorithm in examples ==="
rg -n "calibrate\(" examples/ --type=py -B2 -A2 | head -40
echo -e "\n=== Verify line 73-76 of model_quant.py to confirm None documentation ==="
sed -n '73,76p' modelopt/torch/quantization/model_quant.pyLength of output: 7555
🏁 Script executed:
#!/bin/bash
echo "=== Check if quantize() without algorithm key is valid/tested ==="
rg -n "quantize\(" tests/ --type=py -A10 | grep -B10 "\"quant_cfg\"" | grep -v "algorithm" | head -50
echo -e "\n=== Search for any PR/issue context about algorithm default ==="
git log --all --oneline --grep="algorithm" -- modelopt/torch/quantization/model_quant.py 2>/dev/null | head -20
echo -e "\n=== Check git diff of lines around 231 ==="
git log -p -- modelopt/torch/quantization/model_quant.py 2>/dev/null | grep -B5 -A5 "config.get" | head -40
echo -e "\n=== Verify: Does config schema require algorithm? ==="
rg -n "algorithm" modelopt/torch/quantization/config.py -B3 -A3 | head -40Length of output: 3733
🏁 Script executed:
#!/bin/bash
echo "=== Search for QuantizeConfig class definition ==="
rg -n "class QuantizeConfig" modelopt/torch/quantization/ -A30
echo -e "\n=== Check if algorithm field is marked as required ==="
rg -n "algorithm.*required\|required.*algorithm" modelopt/torch/quantization/ --type=py
echo -e "\n=== Look for config validation that checks algorithm presence ==="
rg -n "def.*quantize\(" modelopt/torch/quantization/model_quant.py -A50 | grep -A50 "def quantize"
echo -e "\n=== Check actual quantize() implementation around config.get ==="
sed -n '215,235p' modelopt/torch/quantization/model_quant.pyLength of output: 6346
Preserve default calibration behavior when "algorithm" key is omitted.
The QuantizeConfig defines algorithm with default="max", but line 231 bypasses this by passing config.get("algorithm") without a default fallback. When the config dict omits the "algorithm" key, None is passed to calibrate(), which explicitly skips calibration (as documented in its docstring). This contradicts the QuantizeConfig default.
- return calibrate(model, config.get("algorithm"), forward_loop=forward_loop)
+ return calibrate(model, config.get("algorithm", "max"), forward_loop=forward_loop)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| return calibrate(model, config.get("algorithm"), forward_loop=forward_loop) | |
| return calibrate(model, config.get("algorithm", "max"), forward_loop=forward_loop) |
🤖 Prompt for AI Agents
In modelopt/torch/quantization/model_quant.py around lines 231 to 232, the call
passes config.get("algorithm") which yields None when the key is omitted and
causes calibrate() to skip calibration; change the call to supply the intended
default (e.g., use config.get("algorithm", "max") or resolve the
QuantizeConfig.default) so that when "algorithm" is missing the default "max" is
forwarded to calibrate().
| "backend": ("backend", lambda val: val), | ||
| "backend_extra_args": ("backend_extra_args", lambda val: val or {}), | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defensive defaults for backend fields to avoid AttributeError.
_fake_quantize and extra_repr access self.backend/self.backend_extra_args directly. If QuantizerAttributeConfig ever omits these keys, attribute access can raise. Initialize defaults in __init__ and use getattr in hot paths.
Apply:
@@ class TensorQuantizer(nn.Module):
def __init__(...
super().__init__()
quant_attribute_cfg = (
quant_attribute_cfg if quant_attribute_cfg is not None else QuantizerAttributeConfig()
)
+ # Defaults for optional backend fields; safe even if config overrides them.
+ self.backend = None
+ self.backend_extra_args = {}
@@
- if self.backend is not None:
- if self.backend not in _QUANT_FUNCTIONAL_BACKENDS:
+ backend = getattr(self, "backend", None)
+ if backend is not None:
+ if backend not in _QUANT_FUNCTIONAL_BACKENDS:
raise KeyError(f"Quant backend '{self.backend}' is not registered.")
- entrypoint = _QUANT_FUNCTIONAL_BACKENDS[self.backend]
+ entrypoint = _QUANT_FUNCTIONAL_BACKENDS[backend]
return entrypoint(inputs, self)Also applies to: 675-680
🤖 Prompt for AI Agents
In modelopt/torch/quantization/nn/modules/tensor_quantizer.py around lines
205-207, the mapping for backend fields can be omitted by
QuantizerAttributeConfig which makes later direct accesses to self.backend and
self.backend_extra_args raise AttributeError; initialize safe defaults in
__init__ (e.g., self.backend = None and self.backend_extra_args = {} or
appropriate neutral defaults) and update hot paths (_fake_quantize and
extra_repr) to use getattr(self, "backend", None) and getattr(self,
"backend_extra_args", {}) when reading them; apply the same initialization +
getattr pattern to the code region around lines 675-680 as well.
| s = "disabled" | ||
| s += ( | ||
| f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}" | ||
| if self.pre_quant_scale is not None | ||
| else "" | ||
| ) | ||
| return "disabled" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extra_repr returns constant instead of built string when disabled.
You build s with pre_quant_scale but return the literal "disabled". Return s to expose the extra detail.
Apply:
- return "disabled"
+ return s📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| s = "disabled" | |
| s += ( | |
| f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}" | |
| if self.pre_quant_scale is not None | |
| else "" | |
| ) | |
| return "disabled" | |
| s = "disabled" | |
| s += ( | |
| f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}" | |
| if self.pre_quant_scale is not None | |
| else "" | |
| ) | |
| return s |
🤖 Prompt for AI Agents
In modelopt/torch/quantization/nn/modules/tensor_quantizer.py around lines 1033
to 1039, the extra_repr branch for the disabled state builds a string s
including pre_quant_scale but then returns the literal "disabled" instead of s;
change the return to return s so the constructed string (e.g., "disabled
pre_quant_scale=...") is returned, preserving the existing conditional
construction and formatting.
| if amax.numel() == 1: | ||
| outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax) | ||
| elif amax.squeeze().ndim == 1: | ||
| axis = amax.shape.index(amax.numel()) | ||
| outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis) | ||
| return outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Axis inference bug for per‑channel amax (wrong source tensor).
axis is computed from amax.shape, which always yields 0 for 1D amax. The CUDA kernels expect the axis in inputs. This breaks cases where the channel axis isn’t 0 (e.g., last dim activations, column/row‑parallel weights).
Use inputs.shape to infer axis and disambiguate duplicates; otherwise fail fast. Patch both scaled_e4m3_impl and fake_quant_impl:
@@
- if amax.numel() == 1:
- outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
- elif amax.squeeze().ndim == 1:
- axis = amax.shape.index(amax.numel())
- outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
+ if amax.numel() == 1:
+ outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
+ elif amax.squeeze().ndim == 1:
+ n = amax.numel()
+ candidates = [i for i, s in enumerate(inputs.shape) if s == n]
+ if not candidates:
+ raise ValueError(
+ f"Cannot infer per-channel axis: amax.numel()={n} does not match any inputs dim {tuple(inputs.shape)}"
+ )
+ # Prefer last-dim if it matches, otherwise require uniqueness
+ axis = inputs.dim() - 1 if inputs.shape[-1] == n else (candidates[0] if len(candidates) == 1 else None)
+ if axis is None:
+ raise ValueError(
+ f"Ambiguous per-channel axis for amax of length {n} in inputs shape {tuple(inputs.shape)}; "
+ "please provide broadcast-shaped amax."
+ )
+ outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
@@
- if amax.numel() == 1:
- outputs = cuda_ext.fake_tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
- else:
- axis = amax.shape.index(amax.numel())
- outputs = cuda_ext.fake_tensor_quant_with_axis(
- inputs, amax.squeeze(), axis, num_bits, unsigned, narrow_range
- )
+ if amax.numel() == 1:
+ outputs = cuda_ext.fake_tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
+ else:
+ n = amax.numel()
+ candidates = [i for i, s in enumerate(inputs.shape) if s == n]
+ if not candidates:
+ raise ValueError(
+ f"Cannot infer per-channel axis: amax.numel()={n} does not match any inputs dim {tuple(inputs.shape)}"
+ )
+ axis = inputs.dim() - 1 if inputs.shape[-1] == n else (candidates[0] if len(candidates) == 1 else None)
+ if axis is None:
+ raise ValueError(
+ f"Ambiguous per-channel axis for amax of length {n} in inputs shape {tuple(inputs.shape)}; "
+ "please provide broadcast-shaped amax."
+ )
+ outputs = cuda_ext.fake_tensor_quant_with_axis(
+ inputs, amax.squeeze(), axis, num_bits, unsigned, narrow_range
+ )Also applies to: 100-107
🤖 Prompt for AI Agents
In modelopt/torch/quantization/tensor_quant.py around lines 82-87 (and similarly
for lines ~100-107 in scaled_e4m3_impl and fake_quant_impl), the code currently
computes axis from amax.shape which yields 0 for 1D amax and is wrong when the
channel axis in inputs isn't 0; instead infer the axis by searching inputs.shape
for dimensions that match amax.numel(), disambiguate duplicates by ensuring
amax.squeeze().ndim == 1 and matching dimension lengths, and if multiple
candidate axes exist or no match is found, raise an explicit error (fail fast).
Replace axis = amax.shape.index(amax.numel()) with logic that derives candidate
axes from inputs.shape, selects the single matching axis or errors on ambiguity,
and then call the CUDA kernels with that inferred axis; apply the same fix to
both scaled_e4m3_impl and fake_quant_impl.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #449 +/- ##
==========================================
+ Coverage 73.38% 73.42% +0.04%
==========================================
Files 180 180
Lines 17934 17968 +34
==========================================
+ Hits 13160 13193 +33
- Misses 4774 4775 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: realAsma <[email protected]>
What does this PR do?
Type of change: ? new feature
Overview:
Add support to enable custom emulated quantization backend. See :meth:`register_quant_backend <modelopt.torch.quantization.nn.modules.tensor_quantizer.register_quant_backend>`` for more details.
Usage
See an example in
tests/unit/torch/quantization/test_custom_backend.py.Testing
See the unites
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Improvements
Tests