-
Couldn't load subscription status.
- Fork 187
[1/2] Registry interface for custom quantization functional backend #449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,7 +18,8 @@ | |||||||||||||||||||||||||||||
| import contextlib | ||||||||||||||||||||||||||||||
| import math | ||||||||||||||||||||||||||||||
| import warnings | ||||||||||||||||||||||||||||||
| from typing import TYPE_CHECKING, Any | ||||||||||||||||||||||||||||||
| from collections.abc import Callable | ||||||||||||||||||||||||||||||
| from typing import Any | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||
| import torch.distributed as dist | ||||||||||||||||||||||||||||||
|
|
@@ -36,7 +37,7 @@ | |||||||||||||||||||||||||||||
| import torch.nn.functional as F | ||||||||||||||||||||||||||||||
| from torch import nn | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| from modelopt.torch.utils import standardize_constructor_args | ||||||||||||||||||||||||||||||
| from modelopt.torch.utils import same_device_as, standardize_constructor_args | ||||||||||||||||||||||||||||||
| from modelopt.torch.utils.distributed import DistributedProcessGroup | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| from ... import calib | ||||||||||||||||||||||||||||||
|
|
@@ -56,10 +57,58 @@ | |||||||||||||||||||||||||||||
| from ...utils import is_torch_export_mode | ||||||||||||||||||||||||||||||
| from ..functional import normalized_hadamard_transform | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if TYPE_CHECKING: | ||||||||||||||||||||||||||||||
| from collections.abc import Callable | ||||||||||||||||||||||||||||||
| __all__ = [ | ||||||||||||||||||||||||||||||
| "SequentialQuantizer", | ||||||||||||||||||||||||||||||
| "TensorQuantizer", | ||||||||||||||||||||||||||||||
| "is_registered_quant_backend", | ||||||||||||||||||||||||||||||
| "register_quant_backend", | ||||||||||||||||||||||||||||||
| "unregister_quant_backend", | ||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| __all__ = ["SequentialQuantizer", "TensorQuantizer"] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| QuantBackendEntrypoint = Callable[[torch.Tensor, "TensorQuantizer"], torch.Tensor] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| _QUANT_FUNCTIONAL_BACKENDS: dict[str, QuantBackendEntrypoint] = {} | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def register_quant_backend(name: str, entrypoint: QuantBackendEntrypoint) -> None: | ||||||||||||||||||||||||||||||
| """Register a custom quantization backend. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||
| name: The name of the backend. | ||||||||||||||||||||||||||||||
| entrypoint: The entrypoint of the backend. The entrypoint should be a callable that takes in | ||||||||||||||||||||||||||||||
| the inputs and the tensor quantizer as arguments and returns the quantized tensor. | ||||||||||||||||||||||||||||||
| See :class:`modelopt.torch.quantization.config.QuantizerAttributeConfig` | ||||||||||||||||||||||||||||||
| for details on choosing from the registered backends via the ``backend`` and | ||||||||||||||||||||||||||||||
| ``backend_extra_args`` fields. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| if not isinstance(name, str) or not name: | ||||||||||||||||||||||||||||||
| raise ValueError("Backend name must be a non-empty string.") | ||||||||||||||||||||||||||||||
| if not callable(entrypoint): | ||||||||||||||||||||||||||||||
| raise TypeError("Entrypoint must be callable.") | ||||||||||||||||||||||||||||||
| if name in _QUANT_FUNCTIONAL_BACKENDS: | ||||||||||||||||||||||||||||||
| warnings.warn(f"Overwriting existing backend: {name}") | ||||||||||||||||||||||||||||||
| _QUANT_FUNCTIONAL_BACKENDS[name] = entrypoint | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def unregister_quant_backend(name: str) -> None: | ||||||||||||||||||||||||||||||
| """Unregister a custom quantization backend. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||
| name: The name of the backend to unregister. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| if not isinstance(name, str) or not name: | ||||||||||||||||||||||||||||||
| raise ValueError("Backend name must be a non-empty string.") | ||||||||||||||||||||||||||||||
| _QUANT_FUNCTIONAL_BACKENDS.pop(name, None) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def is_registered_quant_backend(name: str) -> bool: | ||||||||||||||||||||||||||||||
| """Check if a custom quantization backend is registered. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||
| name: The name of the backend to check. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| return name in _QUANT_FUNCTIONAL_BACKENDS | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| class TensorQuantizer(nn.Module): | ||||||||||||||||||||||||||||||
|
|
@@ -153,6 +202,8 @@ def _calibrator_setter(val): | |||||||||||||||||||||||||||||
| "enable": ("_disabled", lambda val: val is False), | ||||||||||||||||||||||||||||||
| "type": ("_dynamic", lambda val: val == "dynamic"), | ||||||||||||||||||||||||||||||
| "calibrator": ("_calibrator", _calibrator_setter), | ||||||||||||||||||||||||||||||
| "backend": ("backend", lambda val: val), | ||||||||||||||||||||||||||||||
| "backend_extra_args": ("backend_extra_args", lambda val: val or {}), | ||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||
|
Comment on lines
+205
to
207
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Defensive defaults for backend fields to avoid AttributeError.
Apply: @@ class TensorQuantizer(nn.Module):
def __init__(...
super().__init__()
quant_attribute_cfg = (
quant_attribute_cfg if quant_attribute_cfg is not None else QuantizerAttributeConfig()
)
+ # Defaults for optional backend fields; safe even if config overrides them.
+ self.backend = None
+ self.backend_extra_args = {}
@@
- if self.backend is not None:
- if self.backend not in _QUANT_FUNCTIONAL_BACKENDS:
+ backend = getattr(self, "backend", None)
+ if backend is not None:
+ if backend not in _QUANT_FUNCTIONAL_BACKENDS:
raise KeyError(f"Quant backend '{self.backend}' is not registered.")
- entrypoint = _QUANT_FUNCTIONAL_BACKENDS[self.backend]
+ entrypoint = _QUANT_FUNCTIONAL_BACKENDS[backend]
return entrypoint(inputs, self)Also applies to: 675-680 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| for attribute, val in attribute_cfg.items(): | ||||||||||||||||||||||||||||||
|
|
@@ -621,6 +672,12 @@ def _real_quantize(self, inputs): | |||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _fake_quantize(self, inputs): | ||||||||||||||||||||||||||||||
| """Fake quantization.""" | ||||||||||||||||||||||||||||||
| if self.backend is not None: | ||||||||||||||||||||||||||||||
| if self.backend not in _QUANT_FUNCTIONAL_BACKENDS: | ||||||||||||||||||||||||||||||
| raise KeyError(f"Quant backend '{self.backend}' is not registered.") | ||||||||||||||||||||||||||||||
| entrypoint = _QUANT_FUNCTIONAL_BACKENDS[self.backend] | ||||||||||||||||||||||||||||||
| return entrypoint(inputs, self) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| amax = None | ||||||||||||||||||||||||||||||
| if not self.is_mx_format: | ||||||||||||||||||||||||||||||
| amax = self._get_amax(inputs) | ||||||||||||||||||||||||||||||
|
|
@@ -927,7 +984,8 @@ def forward(self, inputs): | |||||||||||||||||||||||||||||
| if hasattr(inputs, "is_contiguous") and not inputs.is_contiguous(): | ||||||||||||||||||||||||||||||
| inputs.data = inputs.data.contiguous() | ||||||||||||||||||||||||||||||
| if self.fake_quant: | ||||||||||||||||||||||||||||||
| outputs = self._fake_quantize(inputs) | ||||||||||||||||||||||||||||||
| with same_device_as(inputs): | ||||||||||||||||||||||||||||||
| outputs = self._fake_quantize(inputs) | ||||||||||||||||||||||||||||||
| elif not self._dequantize: | ||||||||||||||||||||||||||||||
| outputs = self._real_quantize(inputs) | ||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
|
|
@@ -961,16 +1019,23 @@ def _short_amax(self, fmt=".4f"): | |||||||||||||||||||||||||||||
| return "None" | ||||||||||||||||||||||||||||||
| if self._amax.is_meta: | ||||||||||||||||||||||||||||||
| return "meta" | ||||||||||||||||||||||||||||||
| if self._amax.numel() == 1: | ||||||||||||||||||||||||||||||
| return f"{self._amax.item():{fmt}}" | ||||||||||||||||||||||||||||||
| return ( | ||||||||||||||||||||||||||||||
| f"[{self._amax.min().item():{fmt}}," | ||||||||||||||||||||||||||||||
| f" {self._amax.max().item():{fmt}}]({self._amax.numel()})" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| return self._short_tensor(self._amax, fmt) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _short_tensor(self, tensor: torch.Tensor, fmt=".4f"): | ||||||||||||||||||||||||||||||
| """Short description of tensor.""" | ||||||||||||||||||||||||||||||
| if tensor.numel() == 1: | ||||||||||||||||||||||||||||||
| return f"{tensor.item():{fmt}}" | ||||||||||||||||||||||||||||||
| return f"[{tensor.min().item():{fmt}}, {tensor.max().item():{fmt}}]({tensor.numel()})" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def extra_repr(self): | ||||||||||||||||||||||||||||||
| """Set the extra information about this module.""" | ||||||||||||||||||||||||||||||
| if self._disabled: | ||||||||||||||||||||||||||||||
| s = "disabled" | ||||||||||||||||||||||||||||||
| s += ( | ||||||||||||||||||||||||||||||
| f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}" | ||||||||||||||||||||||||||||||
| if self.pre_quant_scale is not None | ||||||||||||||||||||||||||||||
| else "" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| return "disabled" | ||||||||||||||||||||||||||||||
|
Comment on lines
+1033
to
1039
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. extra_repr returns constant instead of built string when disabled. You build Apply: - return "disabled"
+ return s📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||
| s = f"{'unsigned ' if self._unsigned else ''}{self._num_bits} bit" | ||||||||||||||||||||||||||||||
| s += " narrow" if (self._narrow_range) else "" | ||||||||||||||||||||||||||||||
|
|
@@ -980,7 +1045,11 @@ def extra_repr(self): | |||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
| s += f" axis={self._axis}" if self._axis is not None else " per-tensor" | ||||||||||||||||||||||||||||||
| s += f" amax={self._short_amax()}" | ||||||||||||||||||||||||||||||
| s += " pre_quant_scale" if self.pre_quant_scale is not None else "" | ||||||||||||||||||||||||||||||
| s += ( | ||||||||||||||||||||||||||||||
| f" pre_quant_scale={self._short_tensor(self.pre_quant_scale)}" | ||||||||||||||||||||||||||||||
| if self.pre_quant_scale is not None | ||||||||||||||||||||||||||||||
| else "" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| s += " rotated" if self._rotate else "" | ||||||||||||||||||||||||||||||
| s += ( | ||||||||||||||||||||||||||||||
| f" calibrator={self._calibrator.__class__.__name__}" | ||||||||||||||||||||||||||||||
|
|
@@ -992,6 +1061,11 @@ def extra_repr(self): | |||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| s += " quant" if (self._if_quant) else "" | ||||||||||||||||||||||||||||||
| s += " calib" if (self._if_calib) else "" | ||||||||||||||||||||||||||||||
| s += ( | ||||||||||||||||||||||||||||||
| f" backend={self.backend}, extra_args={self.backend_extra_args}" | ||||||||||||||||||||||||||||||
| if self.backend is not None | ||||||||||||||||||||||||||||||
| else "" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| return s | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _get_properties_for_modelopt_state(self): | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -231,6 +231,14 @@ def _setup(self): | |
| data_parallel_group, | ||
| mcore_parallel.get_tensor_model_parallel_group(), | ||
| ) | ||
|
|
||
| if getattr(self, "gradient_accumulation_fusion", False): | ||
| warnings.warn( | ||
| "gradient_accumulation_fusion is not supported with ModelOpt quantization. " | ||
| "Setting gradient_accumulation_fusion to False." | ||
| ) | ||
| self.gradient_accumulation_fusion = False | ||
|
|
||
|
Comment on lines
+234
to
+241
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jenchen13 @ChenhanYu this should fix the occasional errors people run into by unknowingly enabling |
||
| super()._setup() | ||
|
|
||
| def _process_quantizer_amax(self, k, v, quantizer_state_dict): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -79,14 +79,11 @@ def scaled_e4m3_impl( | |
| if cuda_ext_fp8 is None: | ||
| return fp8_eager(inputs, amax) | ||
|
|
||
| with torch.cuda.device( | ||
| None if inputs.device.index == torch.cuda.current_device() else inputs.device.index | ||
| ): | ||
| if amax.numel() == 1: | ||
| outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax) | ||
| elif amax.squeeze().ndim == 1: | ||
| axis = amax.shape.index(amax.numel()) | ||
| outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis) | ||
| if amax.numel() == 1: | ||
| outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax) | ||
| elif amax.squeeze().ndim == 1: | ||
| axis = amax.shape.index(amax.numel()) | ||
| outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis) | ||
| return outputs | ||
|
Comment on lines
+82
to
87
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Axis inference bug for per‑channel amax (wrong source tensor). axis is computed from amax.shape, which always yields 0 for 1D amax. The CUDA kernels expect the axis in inputs. This breaks cases where the channel axis isn’t 0 (e.g., last dim activations, column/row‑parallel weights). Use inputs.shape to infer axis and disambiguate duplicates; otherwise fail fast. Patch both scaled_e4m3_impl and fake_quant_impl: @@
- if amax.numel() == 1:
- outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
- elif amax.squeeze().ndim == 1:
- axis = amax.shape.index(amax.numel())
- outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
+ if amax.numel() == 1:
+ outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
+ elif amax.squeeze().ndim == 1:
+ n = amax.numel()
+ candidates = [i for i, s in enumerate(inputs.shape) if s == n]
+ if not candidates:
+ raise ValueError(
+ f"Cannot infer per-channel axis: amax.numel()={n} does not match any inputs dim {tuple(inputs.shape)}"
+ )
+ # Prefer last-dim if it matches, otherwise require uniqueness
+ axis = inputs.dim() - 1 if inputs.shape[-1] == n else (candidates[0] if len(candidates) == 1 else None)
+ if axis is None:
+ raise ValueError(
+ f"Ambiguous per-channel axis for amax of length {n} in inputs shape {tuple(inputs.shape)}; "
+ "please provide broadcast-shaped amax."
+ )
+ outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
@@
- if amax.numel() == 1:
- outputs = cuda_ext.fake_tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
- else:
- axis = amax.shape.index(amax.numel())
- outputs = cuda_ext.fake_tensor_quant_with_axis(
- inputs, amax.squeeze(), axis, num_bits, unsigned, narrow_range
- )
+ if amax.numel() == 1:
+ outputs = cuda_ext.fake_tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
+ else:
+ n = amax.numel()
+ candidates = [i for i, s in enumerate(inputs.shape) if s == n]
+ if not candidates:
+ raise ValueError(
+ f"Cannot infer per-channel axis: amax.numel()={n} does not match any inputs dim {tuple(inputs.shape)}"
+ )
+ axis = inputs.dim() - 1 if inputs.shape[-1] == n else (candidates[0] if len(candidates) == 1 else None)
+ if axis is None:
+ raise ValueError(
+ f"Ambiguous per-channel axis for amax of length {n} in inputs shape {tuple(inputs.shape)}; "
+ "please provide broadcast-shaped amax."
+ )
+ outputs = cuda_ext.fake_tensor_quant_with_axis(
+ inputs, amax.squeeze(), axis, num_bits, unsigned, narrow_range
+ )Also applies to: 100-107 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
|
|
@@ -100,17 +97,14 @@ def fake_quant_impl( | |
| """Implementation of fake quantizing input according to number of bits.""" | ||
| cuda_ext = get_cuda_ext() | ||
|
|
||
| with torch.cuda.device( | ||
| None if inputs.device.index == torch.cuda.current_device() else inputs.device.index | ||
| ): | ||
| if amax.numel() == 1: | ||
| outputs = cuda_ext.fake_tensor_quant(inputs, amax, num_bits, unsigned, narrow_range) | ||
| else: | ||
| axis = amax.shape.index(amax.numel()) | ||
| outputs = cuda_ext.fake_tensor_quant_with_axis( | ||
| inputs, amax.squeeze(), axis, num_bits, unsigned, narrow_range | ||
| ) | ||
| return outputs | ||
| if amax.numel() == 1: | ||
| outputs = cuda_ext.fake_tensor_quant(inputs, amax, num_bits, unsigned, narrow_range) | ||
| else: | ||
| axis = amax.shape.index(amax.numel()) | ||
| outputs = cuda_ext.fake_tensor_quant_with_axis( | ||
| inputs, amax.squeeze(), axis, num_bits, unsigned, narrow_range | ||
| ) | ||
| return outputs | ||
|
|
||
|
|
||
| def _quantize_impl( | ||
|
|
@@ -173,25 +167,22 @@ def _dynamic_block_quantize_impl( | |
| assert amax.is_cuda, "amax must be a CUDA tensor for dynamic block quantization." | ||
| if amax.numel() != 1: | ||
| amax = amax.amax() | ||
| with torch.cuda.device( | ||
| None if inputs.device.index == torch.cuda.current_device() else inputs.device.index | ||
| if ( | ||
| num_bits == (2, 1) # type: ignore[comparison-overlap] | ||
| and scale_bits == (4, 3) | ||
| and triton_kernel.IS_AVAILABLE | ||
| and not DISABLE_TRITON_KERNEL | ||
| and amax is not None | ||
| ): | ||
| if ( | ||
| num_bits == (2, 1) # type: ignore[comparison-overlap] | ||
| and scale_bits == (4, 3) | ||
| and triton_kernel.IS_AVAILABLE | ||
| and not DISABLE_TRITON_KERNEL | ||
| and amax is not None | ||
| ): | ||
| return triton_kernel.fp4_fake_quant_block(inputs, amax) | ||
| cuda_ext_mx = get_cuda_ext_mx(raise_if_failed=True) | ||
| return cuda_ext_mx.fused_amax_convert( | ||
| inputs, | ||
| block_size, | ||
| getattr(cuda_ext_mx.Types, mx_format_map[num_bits]), | ||
| getattr(cuda_ext_mx.Types, mx_format_map[scale_bits]), | ||
| amax, | ||
| ) | ||
| return triton_kernel.fp4_fake_quant_block(inputs, amax) | ||
| cuda_ext_mx = get_cuda_ext_mx(raise_if_failed=True) | ||
| return cuda_ext_mx.fused_amax_convert( | ||
| inputs, | ||
| block_size, | ||
| getattr(cuda_ext_mx.Types, mx_format_map[num_bits]), | ||
| getattr(cuda_ext_mx.Types, mx_format_map[scale_bits]), | ||
| amax, | ||
| ) | ||
| else: | ||
| raise NotImplementedError( | ||
| f"Unsupported num_bits: {num_bits}, scale_bits: {scale_bits} for dynamic block quantization." | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Preserve default calibration behavior (avoid silent no‑calibration).
Passing config.get("algorithm") forwards None when the key is absent, which skips calibration. Previously, missing "algorithm" implied "max". Recommend defaulting to "max" unless explicitly set to None.
Please confirm whether this PR intends to change the behavior when "algorithm" is omitted. If yes, update docs to highlight that omission now skips calibration.
🏁 Script executed:
Length of output: 5645
🏁 Script executed:
Length of output: 956
🏁 Script executed:
Length of output: 4705
🏁 Script executed:
Length of output: 7555
🏁 Script executed:
Length of output: 3733
🏁 Script executed:
Length of output: 6346
Preserve default calibration behavior when "algorithm" key is omitted.
The QuantizeConfig defines
algorithmwithdefault="max", but line 231 bypasses this by passingconfig.get("algorithm")without a default fallback. When the config dict omits the "algorithm" key,Noneis passed tocalibrate(), which explicitly skips calibration (as documented in its docstring). This contradicts the QuantizeConfig default.📝 Committable suggestion
🤖 Prompt for AI Agents