@@ -137,6 +137,11 @@ class AWQModifier(Modifier, QuantizationMixin):
137
137
138
138
@model_validator (mode = "after" )
139
139
def validate_model_after (model : "AWQModifier" ) -> "AWQModifier" :
140
+ """
141
+ Confirm only one configuration for group_size, symmetric, and num_bits,
142
+ as AWQ algorithm depends on it
143
+ Confirm no activation quantization, as AWQ only works with WNA16
144
+ """
140
145
if not model .config_groups and not model .scheme :
141
146
raise ValueError ("AWQ requires either a config_groups or a scheme" )
142
147
@@ -178,6 +183,8 @@ def validate_model_after(model: "AWQModifier") -> "AWQModifier":
178
183
179
184
model ._group_size = next (iter (group_size_set ))
180
185
186
+ # TODO confirm no activation quantization
187
+
181
188
return model
182
189
183
190
def on_initialize (self , state : State , ** kwargs ) -> bool :
@@ -195,8 +202,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
195
202
196
203
self ._set_resolved_mappings (state .model )
197
204
198
- with calibration_forward_context (state .model ):
199
- self ._set_module_kwargs (state .model , state .data .calib )
205
+ self ._set_module_kwargs (state .model , state .data .calib )
200
206
201
207
return True
202
208
@@ -222,13 +228,11 @@ def on_event(self, state: State, event: Event, **kwargs):
222
228
223
229
elif event .type_ == EventType .SEQUENTIAL_EPOCH_END :
224
230
# Run smoothing in case of sequential pipeline
225
- with calibration_forward_context (state .model ), HooksMixin .disable_hooks ():
226
- self ._apply_smoothing (state .model )
231
+ self ._apply_smoothing (state .model )
227
232
228
233
elif event .type_ == EventType .CALIBRATION_EPOCH_END :
229
234
# Run smoothing in case of basic pipeline
230
- with calibration_forward_context (state .model ), HooksMixin .disable_hooks ():
231
- self ._apply_smoothing (state .model )
235
+ self ._apply_smoothing (state .model )
232
236
233
237
if not self .ended_ :
234
238
self .on_end (state , None )
@@ -385,8 +389,8 @@ def _apply_smoothing(self, model: Module) -> None:
385
389
:param model: model to apply smoothing to
386
390
"""
387
391
for mapping in tqdm (self ._resolved_mappings , desc = "Smoothing" ):
388
- # When using SequentialPipeline, not all the mappings will have
389
- # activations not found in this segment
392
+ # NOTE: When using SequentialPipeline, not all the mappings
393
+ # will have cached activations in the segment being udpated
390
394
if mapping .smooth_name not in self ._activations :
391
395
continue
392
396
@@ -437,14 +441,16 @@ def _apply_smoothing(self, model: Module) -> None:
437
441
x_mean = (x_sum / num_elements ).to (inp .dtype )
438
442
439
443
# [STEP 3]: Compute output of module
440
- fp16_output = self ._forward_input_with_kwargs (
441
- module = module2inspect ,
442
- inputs = inp ,
443
- input_kwargs = _sanitize_kwargs (self ._module_kwargs , module2inspect ),
444
- )
445
- fp16_output = fp16_output .clip (
446
- torch .finfo (fp16_output .dtype ).min , torch .finfo (fp16_output .dtype ).max
447
- )
444
+ with calibration_forward_context (model ), HooksMixin .disable_hooks ():
445
+ fp16_output = self ._forward_input_with_kwargs (
446
+ module = module2inspect ,
447
+ inputs = inp ,
448
+ input_kwargs = _sanitize_kwargs (self ._module_kwargs , module2inspect ),
449
+ )
450
+ fp16_output = fp16_output .clip (
451
+ torch .finfo (fp16_output .dtype ).min ,
452
+ torch .finfo (fp16_output .dtype ).max ,
453
+ )
448
454
449
455
# [STEP 4]: Compute loss
450
456
best_scales = self ._compute_best_scale (
@@ -556,12 +562,16 @@ def _compute_best_scale(
556
562
)
557
563
558
564
# W * X
559
- int_w_output = self ._forward_input_with_kwargs (
560
- module = module2inspect , inputs = x , input_kwargs = self ._module_kwargs
561
- )
562
- int_w_output = int_w_output .clip (
563
- torch .finfo (int_w_output .dtype ).min , torch .finfo (int_w_output .dtype ).max
564
- )
565
+ with calibration_forward_context (
566
+ module2inspect
567
+ ), HooksMixin .disable_hooks ():
568
+ int_w_output = self ._forward_input_with_kwargs (
569
+ module = module2inspect , inputs = x , input_kwargs = self ._module_kwargs
570
+ )
571
+ int_w_output = int_w_output .clip (
572
+ torch .finfo (int_w_output .dtype ).min ,
573
+ torch .finfo (int_w_output .dtype ).max ,
574
+ )
565
575
566
576
# compute mean squared error (L2 norm)
567
577
loss = self ._compute_loss (fp16_output , int_w_output , device )
@@ -666,7 +676,8 @@ def forward(self, *args, **kwargs):
666
676
# patch layer 0 to catch input and kwargs
667
677
modules [0 ] = Catcher (modules [0 ])
668
678
try :
669
- model (samples .to (next (model .parameters ()).device ))
679
+ with calibration_forward_context (model ):
680
+ model (samples .to (next (model .parameters ()).device ))
670
681
except ValueError : # work with early exit
671
682
pass
672
683
modules [0 ] = modules [0 ].module # restore
0 commit comments