diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index ab6386b04..7ae61f3e2 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -26,6 +26,7 @@ quantize_weight, ) from llmcompressor.modifiers.quantization.quantization import QuantizationMixin +from llmcompressor.sentinel import Sentinel from llmcompressor.utils.metric_logging import CompressionLogger __all__ = ["GPTQModifier"] @@ -109,7 +110,7 @@ class GPTQModifier(Modifier, QuantizationMixin): sequential_targets: Union[str, List[str], None] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 - actorder: Optional[ActivationOrdering] = None + actorder: Optional[Union[ActivationOrdering, Sentinel]] = None offload_hessians: bool = False # private variables @@ -131,23 +132,29 @@ def validate_sequential_update(cls, value: bool) -> bool: def resolve_quantization_config(self) -> QuantizationConfig: config = super().resolve_quantization_config() - # Resolve config with `self.actorder` + def resolve_actorder(existing): + # sentinel default only overrides if existing is None + if self.actorder == Sentinel("static"): + return ActivationOrdering.STATIC if existing is None else existing + + # user-provided value always attempts to override + if self.actorder is not None: + if existing is None or self.actorder == existing: + return self.actorder + raise ValueError( + "Cannot resolve activation ordering when both " + "`GPTQModifier.actorder` and `QuantizationScheme.actorder` " + "are provided and differ. Either set `GPTQModifier.actorder = " + "None` or remove `actorder` from config groups." + ) + + # setting `GPTQModifier.actorder = None` does nothing + return existing + for scheme in config.config_groups.values(): - assert isinstance(scheme, QuantizationScheme) # (1) + assert isinstance(scheme, QuantizationScheme) if scheme.weights is not None: - existing = scheme.weights.actorder - assert isinstance(existing, (ActivationOrdering, type(None))) # (2) - if existing is not None and existing != self.actorder: - raise ValueError( - "Cannot resolve activation ordering when both " - "`GPTQModifier.actorder` and `QuantizationScheme.actorder` " - "both are provided. Either set `GPTQModifier.actorder = None` " - "or remove `actorder` from config groups" - ) - scheme.weights.actorder = self.actorder - - # (1) QuantizationConfig.model_post_init - # (2) QuantizationScheme.validate_actorder + scheme.weights.actorder = resolve_actorder(scheme.weights.actorder) return config diff --git a/src/llmcompressor/sentinel.py b/src/llmcompressor/sentinel.py new file mode 100644 index 000000000..55b829460 --- /dev/null +++ b/src/llmcompressor/sentinel.py @@ -0,0 +1,52 @@ +import inspect + +from pydantic_core import core_schema + +_registry = {} + + +class Sentinel: + """ + Unique sentinel values. Implements https://peps.python.org/pep-0661/ + with dummy pydantic validation + """ + + def __new__(cls, name, module_name=None): + name = str(name) + + if module_name is None: + module_name = inspect.currentframe().f_globals.get("__file__") + if module_name is None: + module_name = __name__ + + registry_key = f"{module_name}-{name}" + + sentinel = _registry.get(registry_key, None) + if sentinel is not None: + return sentinel + + sentinel = super().__new__(cls) + sentinel._name = name + sentinel._module_name = module_name + + return _registry.setdefault(registry_key, sentinel) + + def __repr__(self): + return self._name + + def __reduce__(self): + return ( + self.__class__, + ( + self._name, + self._module_name, + ), + ) + + @classmethod + def __get_pydantic_core_schema__(cls, _source_type, _handler): + return core_schema.no_info_plain_validator_function(cls.validate) + + @classmethod + def validate(cls, value: "Sentinel") -> "Sentinel": + return value diff --git a/tests/llmcompressor/test_sentinel.py b/tests/llmcompressor/test_sentinel.py new file mode 100644 index 000000000..f60a1814b --- /dev/null +++ b/tests/llmcompressor/test_sentinel.py @@ -0,0 +1,6 @@ +from llmcompressor.sentinel import Sentinel + + +def test_sentinel(): + assert Sentinel("MISSING") == Sentinel("MISSING") + assert Sentinel("MISSING", "module_one") != Sentinel("MISSING", "module_two")