Skip to content

Commit 6fa33a7

Browse files
AWQ QuantizationMixin + SequentialPipeline (#1426)
SUMMARY: - Add QuantizationMixin to AWQModifier so we don't have redundant inputs (num_bits, symmetric, group_size) - Move AWQModifier to sequential pipelines, to avoid huge memory requirements of caching all activations at once. Regression test results are acceptable, results are all roughly the same, and within stderr, see test plan below. Resolves #1409 Resolves #1369 Related to #1383 Related to #1406 Related to #1368 Related to #1410 More improvements split into #1435 TEST PLAN: - [x] Rerun tests to validate No regression in tests, comparing against those reported in [original AWQ PR](#1177 (comment)). All gsm8k results are within stderr: | Type | gsm8k | wikitext | ------ | ------ | ----- | Old AWQ+QuantModifier Sym | .1054, .1069 | 9.1931 | New AWQ+QuantMixin Sym | .1077, .1084 | 9.1841 | Old AWQ+QuantModifier Asym | .1274, .1281 | 9.0281 | New AWQ+QuantMixin Asym | .1312, .1350 | 9.0288 --------- Signed-off-by: Brian Dellabetta <[email protected]> Co-authored-by: Kyle Sayers <[email protected]>
1 parent 9a11f52 commit 6fa33a7

File tree

10 files changed

+323
-235
lines changed

10 files changed

+323
-235
lines changed

examples/awq/llama_example.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,8 @@
1-
import lm_eval
2-
from compressed_tensors.quantization import (
3-
QuantizationArgs,
4-
QuantizationScheme,
5-
QuantizationStrategy,
6-
QuantizationType,
7-
)
81
from datasets import load_dataset
9-
from lm_eval.utils import make_table
102
from transformers import AutoModelForCausalLM, AutoTokenizer
113

124
from llmcompressor import oneshot
135
from llmcompressor.modifiers.awq import AWQModifier
14-
from llmcompressor.modifiers.quantization import QuantizationModifier
156

167
# Select model and load it.
178
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
@@ -61,23 +52,7 @@ def tokenize(sample):
6152

6253
# Configure the quantization algorithm to run.
6354
recipe = [
64-
AWQModifier(bits=4, symmetric=False),
65-
QuantizationModifier(
66-
ignore=["lm_head"],
67-
config_groups={
68-
"group_0": QuantizationScheme(
69-
targets=["Linear"],
70-
weights=QuantizationArgs(
71-
num_bits=4,
72-
type=QuantizationType.INT,
73-
dynamic=False,
74-
symmetric=False,
75-
strategy=QuantizationStrategy.GROUP,
76-
group_size=128,
77-
),
78-
)
79-
},
80-
),
55+
AWQModifier(ignore=["lm_head"], scheme="W4A16_ASYM", targets=["Linear"]),
8156
]
8257

8358
# Apply algorithms.
@@ -101,21 +76,3 @@ def tokenize(sample):
10176
SAVE_DIR = MODEL_ID.split("/")[-1] + "-awq-asym"
10277
model.save_pretrained(SAVE_DIR, save_compressed=True)
10378
tokenizer.save_pretrained(SAVE_DIR)
104-
105-
#
106-
# 2) Evaluate model on wikitext perplexity
107-
#
108-
109-
results = lm_eval.simple_evaluate(
110-
model="hf",
111-
model_args={
112-
"pretrained": SAVE_DIR,
113-
"add_bos_token": True,
114-
"dtype": "bfloat16",
115-
"gpu_memory_utilization": 0.5,
116-
},
117-
tasks=["wikitext"],
118-
num_fewshot=5,
119-
batch_size="auto",
120-
)
121-
print(make_table(results))

0 commit comments

Comments
 (0)