Skip to content

Commit 1fb1377

Browse files
AWQ Qwen and Phi mappings (#1440)
SUMMARY: I wanted to create a PR showing users how they can add more mappings to AWQ to account for more models. Turns out qwen has the exact same as Llama, so I added one for Phi as well. I also updated the naming and used the infer pattern employed in SmoothQuant, rather than requiring user to set it TEST PLAN: `examples/awq/llama_example.py` works on this branch for ```python MODEL_ID = "microsoft/Phi-4-mini-reasoning" ``` TODOs: - [x] Merge in after #1451 lands --------- Signed-off-by: Brian Dellabetta <[email protected]>
1 parent d9881b0 commit 1fb1377

File tree

2 files changed

+73
-28
lines changed

2 files changed

+73
-28
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616

1717
from llmcompressor.core import Event, EventType, State
1818
from llmcompressor.modifiers import Modifier
19+
from llmcompressor.modifiers.awq.mappings import (
20+
AWQMapping,
21+
ResolvedMapping,
22+
get_layer_mappings_from_architecture,
23+
)
1924
from llmcompressor.modifiers.quantization.calibration import update_weight_zp_scale
2025
from llmcompressor.modifiers.quantization.quantization import QuantizationMixin
2126
from llmcompressor.modifiers.utils.hooks import HooksMixin
@@ -27,8 +32,6 @@
2732
get_parent_by_name,
2833
)
2934

30-
from .mappings import AWQ_MAPPING_REGISTRY, AWQMapping, ResolvedMapping
31-
3235
__all__ = ["AWQModifier"]
3336

3437

@@ -120,7 +123,7 @@ class AWQModifier(Modifier, QuantizationMixin):
120123

121124
# User-provided vars (in addition to QuantizationMixin args)
122125
sequential_targets: Union[str, List[str], None] = None
123-
mappings: List[AWQMapping] = AWQ_MAPPING_REGISTRY["Llama"]
126+
mappings: Optional[List[AWQMapping]] = None
124127
max_chunk_memory: int = 1024 * 1024 * 1024
125128
duo_scaling: bool = True
126129

@@ -212,6 +215,12 @@ def on_initialize(self, state: State, **kwargs) -> bool:
212215
if QuantizationMixin.has_config(self):
213216
QuantizationMixin.initialize_quantization(self, state.model)
214217

218+
if self.mappings is None:
219+
logger.info("No AWQModifier.mappings provided, inferring from model...")
220+
self.mappings = get_layer_mappings_from_architecture(
221+
architecture=state.model.__class__.__name__
222+
)
223+
215224
self._set_resolved_mappings(state.model)
216225

217226
self._set_module_kwargs(state.model, state.data.calib)
@@ -500,13 +509,9 @@ def smooth(module):
500509
# in this case, default to scaling the last output features
501510
# because the desired smooth layer is v_proj
502511
# https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/scale.py#L123
503-
update_offload_parameter(
504-
module,
505-
"weight",
506-
module.weight[-scales.size(0) :].div_(
507-
scales.view(-1, 1)
508-
),
509-
)
512+
weight = module.weight
513+
weight[-scales.size(0) :].div_(scales.view(-1, 1))
514+
update_offload_parameter(module, "weight", weight)
510515
if hasattr(module, "bias") and module.bias is not None:
511516
update_offload_parameter(
512517
module,

src/llmcompressor/modifiers/awq/mappings.py

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from dataclasses import dataclass
22
from typing import Dict, List, Optional
33

4+
from loguru import logger
45
from torch.nn import Module
56

6-
__all__ = ["AWQMapping", "AWQ_MAPPING_REGISTRY"]
7+
__all__ = ["AWQMapping", "AWQ_MAPPING_REGISTRY", "get_layer_mappings_from_architecture"]
78

89

910
@dataclass
@@ -22,24 +23,48 @@ class AWQMapping:
2223
balance_layers: list[str]
2324

2425

26+
_default_mappings = [
27+
AWQMapping(
28+
"re:.*input_layernorm",
29+
["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
30+
),
31+
AWQMapping("re:.*v_proj", ["re:.*o_proj"]),
32+
AWQMapping(
33+
"re:.*post_attention_layernorm",
34+
["re:.*gate_proj", "re:.*up_proj"],
35+
),
36+
AWQMapping(
37+
"re:.*up_proj",
38+
["re:.*down_proj"],
39+
),
40+
]
41+
42+
# Phi merges
43+
# q, k, and v proj layers into a single qkv_proj layer
44+
# gate and up proj layers into a single gate_up_proj layer
45+
_phi_mappings = [
46+
AWQMapping(
47+
"re:.*input_layernorm",
48+
["re:.*qkv_proj"],
49+
),
50+
AWQMapping("re:.*qkv_proj", ["re:.*o_proj"]),
51+
AWQMapping(
52+
"re:.*post_attention_layernorm",
53+
["re:.*gate_up_proj"],
54+
),
55+
AWQMapping(
56+
"re:.*gate_up_proj",
57+
["re:.*down_proj"],
58+
),
59+
]
60+
2561
AWQ_MAPPING_REGISTRY: Dict[str, list[AWQMapping]] = {
26-
"Llama": [
27-
AWQMapping(
28-
"re:.*input_layernorm",
29-
["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
30-
),
31-
AWQMapping("re:.*v_proj", ["re:.*o_proj"]),
32-
AWQMapping(
33-
"re:.*post_attention_layernorm",
34-
["re:.*gate_proj", "re:.*up_proj"],
35-
),
36-
AWQMapping(
37-
"re:.*up_proj",
38-
["re:.*down_proj"],
39-
),
40-
],
41-
# TODO (Brian INFERENG-529) Add Qwen mappings
42-
# "Qwen": [ ],
62+
"LlamaForCausalLM": _default_mappings,
63+
"Qwen2ForCausalLM": _default_mappings,
64+
"Qwen3ForCausalLM": _default_mappings,
65+
"MistralForCausalLM": _default_mappings,
66+
"Phi3ForCausalLM": _phi_mappings,
67+
"Phi3VForCausalLM": _phi_mappings,
4368
}
4469

4570

@@ -64,3 +89,18 @@ class ResolvedMapping:
6489
balance_names: Optional[List[str]] = None
6590
parent: Optional[Module] = None
6691
parent_name: Optional[str] = None
92+
93+
94+
def get_layer_mappings_from_architecture(architecture: str) -> List[AWQMapping]:
95+
"""
96+
:param architecture: str: The architecture of the model
97+
:return: list: The layer mappings for the given architecture
98+
"""
99+
100+
if architecture not in AWQ_MAPPING_REGISTRY:
101+
logger.info(
102+
f"Architecture {architecture} not found in mappings. "
103+
f"Using default mappings: {_default_mappings}"
104+
)
105+
106+
return AWQ_MAPPING_REGISTRY.get(architecture, _default_mappings)

0 commit comments

Comments
 (0)