Skip to content

Commit d1d9171

Browse files
more cleanup
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent ec234ea commit d1d9171

File tree

2 files changed

+35
-24
lines changed

2 files changed

+35
-24
lines changed

examples/awq/awq_one_shot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# 1) Run the `llm-compressor` implementation of AWQ
1010
# 2) Evaluate the compressed model with the lm_eval framework
1111

12-
MODEL_ID = "meta-llama/Llama-3.2-3B" # "meta-llama/Meta-Llama-3-8B-Instruct"
12+
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
1313
DATASET_ID = "mit-han-lab/pile-val-backup"
1414
DATASET_SPLIT = "validation"
1515
NUM_CALIBRATION_SAMPLES = 256

src/llmcompressor/modifiers/awq/base.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ class AWQModifier(Modifier, QuantizationMixin):
137137

138138
@model_validator(mode="after")
139139
def validate_model_after(model: "AWQModifier") -> "AWQModifier":
140+
"""
141+
Confirm only one configuration for group_size, symmetric, and num_bits,
142+
as AWQ algorithm depends on it
143+
Confirm no activation quantization, as AWQ only works with WNA16
144+
"""
140145
if not model.config_groups and not model.scheme:
141146
raise ValueError("AWQ requires either a config_groups or a scheme")
142147

@@ -178,6 +183,8 @@ def validate_model_after(model: "AWQModifier") -> "AWQModifier":
178183

179184
model._group_size = next(iter(group_size_set))
180185

186+
# TODO confirm no activation quantization
187+
181188
return model
182189

183190
def on_initialize(self, state: State, **kwargs) -> bool:
@@ -195,8 +202,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
195202

196203
self._set_resolved_mappings(state.model)
197204

198-
with calibration_forward_context(state.model):
199-
self._set_module_kwargs(state.model, state.data.calib)
205+
self._set_module_kwargs(state.model, state.data.calib)
200206

201207
return True
202208

@@ -222,13 +228,11 @@ def on_event(self, state: State, event: Event, **kwargs):
222228

223229
elif event.type_ == EventType.SEQUENTIAL_EPOCH_END:
224230
# Run smoothing in case of sequential pipeline
225-
with calibration_forward_context(state.model), HooksMixin.disable_hooks():
226-
self._apply_smoothing(state.model)
231+
self._apply_smoothing(state.model)
227232

228233
elif event.type_ == EventType.CALIBRATION_EPOCH_END:
229234
# Run smoothing in case of basic pipeline
230-
with calibration_forward_context(state.model), HooksMixin.disable_hooks():
231-
self._apply_smoothing(state.model)
235+
self._apply_smoothing(state.model)
232236

233237
if not self.ended_:
234238
self.on_end(state, None)
@@ -385,8 +389,8 @@ def _apply_smoothing(self, model: Module) -> None:
385389
:param model: model to apply smoothing to
386390
"""
387391
for mapping in tqdm(self._resolved_mappings, desc="Smoothing"):
388-
# When using SequentialPipeline, not all the mappings will have
389-
# activations not found in this segment
392+
# NOTE: When using SequentialPipeline, not all the mappings
393+
# will have cached activations in the segment being udpated
390394
if mapping.smooth_name not in self._activations:
391395
continue
392396

@@ -437,14 +441,16 @@ def _apply_smoothing(self, model: Module) -> None:
437441
x_mean = (x_sum / num_elements).to(inp.dtype)
438442

439443
# [STEP 3]: Compute output of module
440-
fp16_output = self._forward_input_with_kwargs(
441-
module=module2inspect,
442-
inputs=inp,
443-
input_kwargs=_sanitize_kwargs(self._module_kwargs, module2inspect),
444-
)
445-
fp16_output = fp16_output.clip(
446-
torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max
447-
)
444+
with calibration_forward_context(model), HooksMixin.disable_hooks():
445+
fp16_output = self._forward_input_with_kwargs(
446+
module=module2inspect,
447+
inputs=inp,
448+
input_kwargs=_sanitize_kwargs(self._module_kwargs, module2inspect),
449+
)
450+
fp16_output = fp16_output.clip(
451+
torch.finfo(fp16_output.dtype).min,
452+
torch.finfo(fp16_output.dtype).max,
453+
)
448454

449455
# [STEP 4]: Compute loss
450456
best_scales = self._compute_best_scale(
@@ -556,12 +562,16 @@ def _compute_best_scale(
556562
)
557563

558564
# W * X
559-
int_w_output = self._forward_input_with_kwargs(
560-
module=module2inspect, inputs=x, input_kwargs=self._module_kwargs
561-
)
562-
int_w_output = int_w_output.clip(
563-
torch.finfo(int_w_output.dtype).min, torch.finfo(int_w_output.dtype).max
564-
)
565+
with calibration_forward_context(
566+
module2inspect
567+
), HooksMixin.disable_hooks():
568+
int_w_output = self._forward_input_with_kwargs(
569+
module=module2inspect, inputs=x, input_kwargs=self._module_kwargs
570+
)
571+
int_w_output = int_w_output.clip(
572+
torch.finfo(int_w_output.dtype).min,
573+
torch.finfo(int_w_output.dtype).max,
574+
)
565575

566576
# compute mean squared error (L2 norm)
567577
loss = self._compute_loss(fp16_output, int_w_output, device)
@@ -666,7 +676,8 @@ def forward(self, *args, **kwargs):
666676
# patch layer 0 to catch input and kwargs
667677
modules[0] = Catcher(modules[0])
668678
try:
669-
model(samples.to(next(model.parameters()).device))
679+
with calibration_forward_context(model):
680+
model(samples.to(next(model.parameters()).device))
670681
except ValueError: # work with early exit
671682
pass
672683
modules[0] = modules[0].module # restore

0 commit comments

Comments
 (0)