Skip to content

AWQ Qwen and Phi mappings #1440

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 21, 2025
25 changes: 15 additions & 10 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.awq.mappings import (
AWQMapping,
ResolvedMapping,
get_layer_mappings_from_architecture,
)
from llmcompressor.modifiers.quantization.calibration import update_weight_zp_scale
from llmcompressor.modifiers.quantization.quantization import QuantizationMixin
from llmcompressor.modifiers.utils.hooks import HooksMixin
Expand All @@ -27,8 +32,6 @@
get_parent_by_name,
)

from .mappings import AWQ_MAPPING_REGISTRY, AWQMapping, ResolvedMapping

__all__ = ["AWQModifier"]


Expand Down Expand Up @@ -120,7 +123,7 @@ class AWQModifier(Modifier, QuantizationMixin):

# User-provided vars (in addition to QuantizationMixin args)
sequential_targets: Union[str, List[str], None] = None
mappings: List[AWQMapping] = AWQ_MAPPING_REGISTRY["Llama"]
mappings: Optional[List[AWQMapping]] = None
max_chunk_memory: int = 1024 * 1024 * 1024
duo_scaling: bool = True

Expand Down Expand Up @@ -212,6 +215,12 @@ def on_initialize(self, state: State, **kwargs) -> bool:
if QuantizationMixin.has_config(self):
QuantizationMixin.initialize_quantization(self, state.model)

if self.mappings is None:
logger.info("No AWQModifier.mappings provided, inferring from model...")
self.mappings = get_layer_mappings_from_architecture(
architecture=state.model.__class__.__name__
)

self._set_resolved_mappings(state.model)

self._set_module_kwargs(state.model, state.data.calib)
Expand Down Expand Up @@ -500,13 +509,9 @@ def smooth(module):
# in this case, default to scaling the last output features
# because the desired smooth layer is v_proj
# https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/scale.py#L123
update_offload_parameter(
module,
"weight",
module.weight[-scales.size(0) :].div_(
scales.view(-1, 1)
),
)
weight = module.weight
weight[-scales.size(0) :].div_(scales.view(-1, 1))
update_offload_parameter(module, "weight", weight)
if hasattr(module, "bias") and module.bias is not None:
update_offload_parameter(
module,
Expand Down
76 changes: 58 additions & 18 deletions src/llmcompressor/modifiers/awq/mappings.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass
from typing import Dict, List, Optional

from loguru import logger
from torch.nn import Module

__all__ = ["AWQMapping", "AWQ_MAPPING_REGISTRY"]
__all__ = ["AWQMapping", "AWQ_MAPPING_REGISTRY", "get_layer_mappings_from_architecture"]


@dataclass
Expand All @@ -22,24 +23,48 @@ class AWQMapping:
balance_layers: list[str]


_default_mappings = [
AWQMapping(
"re:.*input_layernorm",
["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
),
AWQMapping("re:.*v_proj", ["re:.*o_proj"]),
AWQMapping(
"re:.*post_attention_layernorm",
["re:.*gate_proj", "re:.*up_proj"],
),
AWQMapping(
"re:.*up_proj",
["re:.*down_proj"],
),
]

# Phi merges
# q, k, and v proj layers into a single qkv_proj layer
# gate and up proj layers into a single gate_up_proj layer
_phi_mappings = [
AWQMapping(
"re:.*input_layernorm",
["re:.*qkv_proj"],
),
AWQMapping("re:.*qkv_proj", ["re:.*o_proj"]),
AWQMapping(
"re:.*post_attention_layernorm",
["re:.*gate_up_proj"],
),
AWQMapping(
"re:.*gate_up_proj",
["re:.*down_proj"],
),
]

AWQ_MAPPING_REGISTRY: Dict[str, list[AWQMapping]] = {
"Llama": [
AWQMapping(
"re:.*input_layernorm",
["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
),
AWQMapping("re:.*v_proj", ["re:.*o_proj"]),
AWQMapping(
"re:.*post_attention_layernorm",
["re:.*gate_proj", "re:.*up_proj"],
),
AWQMapping(
"re:.*up_proj",
["re:.*down_proj"],
),
],
# TODO (Brian INFERENG-529) Add Qwen mappings
# "Qwen": [ ],
"LlamaForCausalLM": _default_mappings,
"Qwen2ForCausalLM": _default_mappings,
"Qwen3ForCausalLM": _default_mappings,
"MistralForCausalLM": _default_mappings,
"Phi3ForCausalLM": _phi_mappings,
"Phi3VForCausalLM": _phi_mappings,
}


Expand All @@ -64,3 +89,18 @@ class ResolvedMapping:
balance_names: Optional[List[str]] = None
parent: Optional[Module] = None
parent_name: Optional[str] = None


def get_layer_mappings_from_architecture(architecture: str) -> List[AWQMapping]:
"""
:param architecture: str: The architecture of the model
:return: list: The layer mappings for the given architecture
"""

if architecture not in AWQ_MAPPING_REGISTRY:
logger.info(
f"Architecture {architecture} not found in mappings. "
f"Using default mappings: {_default_mappings}"
)

return AWQ_MAPPING_REGISTRY.get(architecture, _default_mappings)