diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 49fc491fc..5ffe53082 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -16,6 +16,11 @@ from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier +from llmcompressor.modifiers.awq.mappings import ( + AWQMapping, + ResolvedMapping, + get_layer_mappings_from_architecture, +) from llmcompressor.modifiers.quantization.calibration import update_weight_zp_scale from llmcompressor.modifiers.quantization.quantization import QuantizationMixin from llmcompressor.modifiers.utils.hooks import HooksMixin @@ -27,8 +32,6 @@ get_parent_by_name, ) -from .mappings import AWQ_MAPPING_REGISTRY, AWQMapping, ResolvedMapping - __all__ = ["AWQModifier"] @@ -120,7 +123,7 @@ class AWQModifier(Modifier, QuantizationMixin): # User-provided vars (in addition to QuantizationMixin args) sequential_targets: Union[str, List[str], None] = None - mappings: List[AWQMapping] = AWQ_MAPPING_REGISTRY["Llama"] + mappings: Optional[List[AWQMapping]] = None max_chunk_memory: int = 1024 * 1024 * 1024 duo_scaling: bool = True @@ -212,6 +215,12 @@ def on_initialize(self, state: State, **kwargs) -> bool: if QuantizationMixin.has_config(self): QuantizationMixin.initialize_quantization(self, state.model) + if self.mappings is None: + logger.info("No AWQModifier.mappings provided, inferring from model...") + self.mappings = get_layer_mappings_from_architecture( + architecture=state.model.__class__.__name__ + ) + self._set_resolved_mappings(state.model) self._set_module_kwargs(state.model, state.data.calib) @@ -500,13 +509,9 @@ def smooth(module): # in this case, default to scaling the last output features # because the desired smooth layer is v_proj # https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/scale.py#L123 - update_offload_parameter( - module, - "weight", - module.weight[-scales.size(0) :].div_( - scales.view(-1, 1) - ), - ) + weight = module.weight + weight[-scales.size(0) :].div_(scales.view(-1, 1)) + update_offload_parameter(module, "weight", weight) if hasattr(module, "bias") and module.bias is not None: update_offload_parameter( module, diff --git a/src/llmcompressor/modifiers/awq/mappings.py b/src/llmcompressor/modifiers/awq/mappings.py index 7a23fcaf4..700525ed8 100644 --- a/src/llmcompressor/modifiers/awq/mappings.py +++ b/src/llmcompressor/modifiers/awq/mappings.py @@ -1,9 +1,10 @@ from dataclasses import dataclass from typing import Dict, List, Optional +from loguru import logger from torch.nn import Module -__all__ = ["AWQMapping", "AWQ_MAPPING_REGISTRY"] +__all__ = ["AWQMapping", "AWQ_MAPPING_REGISTRY", "get_layer_mappings_from_architecture"] @dataclass @@ -22,24 +23,48 @@ class AWQMapping: balance_layers: list[str] +_default_mappings = [ + AWQMapping( + "re:.*input_layernorm", + ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], + ), + AWQMapping("re:.*v_proj", ["re:.*o_proj"]), + AWQMapping( + "re:.*post_attention_layernorm", + ["re:.*gate_proj", "re:.*up_proj"], + ), + AWQMapping( + "re:.*up_proj", + ["re:.*down_proj"], + ), +] + +# Phi merges +# q, k, and v proj layers into a single qkv_proj layer +# gate and up proj layers into a single gate_up_proj layer +_phi_mappings = [ + AWQMapping( + "re:.*input_layernorm", + ["re:.*qkv_proj"], + ), + AWQMapping("re:.*qkv_proj", ["re:.*o_proj"]), + AWQMapping( + "re:.*post_attention_layernorm", + ["re:.*gate_up_proj"], + ), + AWQMapping( + "re:.*gate_up_proj", + ["re:.*down_proj"], + ), +] + AWQ_MAPPING_REGISTRY: Dict[str, list[AWQMapping]] = { - "Llama": [ - AWQMapping( - "re:.*input_layernorm", - ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], - ), - AWQMapping("re:.*v_proj", ["re:.*o_proj"]), - AWQMapping( - "re:.*post_attention_layernorm", - ["re:.*gate_proj", "re:.*up_proj"], - ), - AWQMapping( - "re:.*up_proj", - ["re:.*down_proj"], - ), - ], - # TODO (Brian INFERENG-529) Add Qwen mappings - # "Qwen": [ ], + "LlamaForCausalLM": _default_mappings, + "Qwen2ForCausalLM": _default_mappings, + "Qwen3ForCausalLM": _default_mappings, + "MistralForCausalLM": _default_mappings, + "Phi3ForCausalLM": _phi_mappings, + "Phi3VForCausalLM": _phi_mappings, } @@ -64,3 +89,18 @@ class ResolvedMapping: balance_names: Optional[List[str]] = None parent: Optional[Module] = None parent_name: Optional[str] = None + + +def get_layer_mappings_from_architecture(architecture: str) -> List[AWQMapping]: + """ + :param architecture: str: The architecture of the model + :return: list: The layer mappings for the given architecture + """ + + if architecture not in AWQ_MAPPING_REGISTRY: + logger.info( + f"Architecture {architecture} not found in mappings. " + f"Using default mappings: {_default_mappings}" + ) + + return AWQ_MAPPING_REGISTRY.get(architecture, _default_mappings)