Skip to content

Commit 79ed200

Browse files
committed
rename, change num samples help text, add readme
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 7fc2686 commit 79ed200

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

examples/awq/README.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Quantizing Models with Activation-Aware Quantization (AWQ) #
2+
3+
Activation Aware Quantization (AWQ) is a state-of-the-art technique to quantize the weights of large language models which involves using a small calibration dataset to calibrate the model. The AWQ algorithm utilizes calibration data to derive scaling factors which reduce the dynamic range of weights while minimizing accuracy loss to the most salient weight values.
4+
5+
The AWQ implementation found in LLM Compressor is derived from the pioneering work of [AutoAWQ](https://github.com/casper-hansen/AutoAWQ) and with assistance from its original maintainer, [casper-hansen](https://github.com/casper-hansen).
6+
7+
## AWQ Recipe ##
8+
9+
The AWQ recipe has been inferfaced as follows, where the `AWQModifier` adjusts model scales ahead of efficient weight quantization by the `QuantizationModifier`
10+
11+
```python
12+
recipe = [
13+
AWQModifier(bits=4, symmetric=False),
14+
QuantizationModifier(
15+
ignore=["lm_head"],
16+
config_groups={
17+
"group_0": QuantizationScheme(
18+
targets=["Linear"],
19+
weights=QuantizationArgs(
20+
num_bits=4,
21+
type=QuantizationType.INT,
22+
dynamic=False,
23+
symmetric=False,
24+
strategy=QuantizationStrategy.GROUP,
25+
group_size=128,
26+
),
27+
)
28+
},
29+
),
30+
]
31+
```
32+
33+
## Compressing Your Own Model ##
34+
To use your own model, start with an existing example change the `model_id` to match your own model stub.
35+
```python
36+
model_id = "path/to/your/model"
37+
model = AutoModelForCausalLM.from_pretrained(
38+
model_id,
39+
device_map="auto",
40+
torch_dtype="auto",
41+
)
42+
```
43+
44+
## Adding Mappings ##
45+
In order to target weight and activation scaling locations within the model, the `AWQModifier` must be provided an AWQ mapping. For example, the AWQ mapping for the Llama family of models looks like this:
46+
47+
```python
48+
[
49+
AWQMapping(
50+
"re:.*input_layernorm",
51+
["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
52+
),
53+
AWQMapping("re:.*v_proj", ["re:.*o_proj"]),
54+
AWQMapping(
55+
"re:.*post_attention_layernorm",
56+
["re:.*gate_proj", "re:.*up_proj"],
57+
),
58+
AWQMapping(
59+
"re:.*up_proj",
60+
["re:.*down_proj"],
61+
),
62+
]
63+
```
64+
65+
To support other model families, you can add supply your own mappings via the `mappings` argument with instantiating the `AWQModifier`, or you can add them to the registry [here](src/llmcompressor/modifiers/awq/mappings.py) (contributions are welcome!)

examples/awq/awq_one_shot.py renamed to examples/awq/llama_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
# Select number of samples. 512 samples is a good place to start.
3030
# Increasing the number of samples can improve accuracy.
31-
NUM_CALIBRATION_SAMPLES = 256
31+
NUM_CALIBRATION_SAMPLES = 512
3232
MAX_SEQUENCE_LENGTH = 512
3333

3434
# Load dataset and preprocess.

0 commit comments

Comments
 (0)