1
1
import inspect
2
- from typing import Any , Dict , List , Optional , Set , Tuple , Union
2
+ from typing import Any , Dict , List , Optional , Tuple , Union
3
3
4
4
import torch
5
5
from compressed_tensors .quantization import disable_quantization
6
- from compressed_tensors .utils import (
7
- align_module_device ,
8
- get_execution_device ,
9
- update_offload_parameter ,
10
- )
6
+ from compressed_tensors .utils import align_module_device , update_offload_parameter
11
7
from loguru import logger
12
8
from pydantic import ConfigDict , PrivateAttr , model_validator
13
9
from torch .nn import Module
14
- from torch .utils .hooks import RemovableHandle
15
10
from tqdm import tqdm
16
11
17
12
from llmcompressor .core import Event , EventType , State
18
13
from llmcompressor .modifiers import Modifier
14
+ from llmcompressor .modifiers .awq .helpers import accumulate_mean
19
15
from llmcompressor .modifiers .quantization .calibration import update_weight_zp_scale
20
16
from llmcompressor .modifiers .quantization .quantization import QuantizationMixin
21
17
from llmcompressor .modifiers .utils .hooks import HooksMixin
18
+ from llmcompressor .pipelines .cache import IntermediatesCache
22
19
from llmcompressor .utils .fsdp .helpers import get_fsdp_parent
23
20
from llmcompressor .utils .helpers import calibration_forward_context
24
21
from llmcompressor .utils .pytorch .module import (
@@ -131,9 +128,11 @@ class AWQModifier(Modifier, QuantizationMixin):
131
128
132
129
# Private vars set during initialization, cleared during finalization
133
130
_resolved_mappings : List [ResolvedMapping ] = PrivateAttr (default_factory = list )
134
- _activations : Dict [str , List [torch .Tensor ]] = PrivateAttr (default_factory = dict )
135
- _activation_hooks : Set [RemovableHandle ] = PrivateAttr (default_factory = set )
136
- _module_kwargs : Dict = PrivateAttr (default_factory = dict )
131
+ _samples : Dict [Module , IntermediatesCache ] = PrivateAttr (
132
+ default_factory = IntermediatesCache
133
+ )
134
+ _sample_means : Dict [Module , float ] = PrivateAttr (default_factory = dict )
135
+ _num_samples : Dict [Module , int ] = PrivateAttr (default_factory = dict )
137
136
138
137
@model_validator (mode = "after" )
139
138
def validate_model_after (model : "AWQModifier" ) -> "AWQModifier" :
@@ -214,8 +213,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:
214
213
215
214
self ._set_resolved_mappings (state .model )
216
215
217
- self ._set_module_kwargs (state .model , state .data .calib )
218
-
219
216
return True
220
217
221
218
def on_start (self , state : State , event : Event , ** kwargs ):
@@ -262,8 +259,7 @@ def on_end(self, state: State, event: Event, **kwargs):
262
259
QuantizationMixin .end_calibration (self , state .model )
263
260
264
261
# remove activation hooks
265
- self .remove_hooks (self ._activation_hooks )
266
- self ._activation_hooks .clear ()
262
+ self .remove_hooks ()
267
263
268
264
def on_finalize (self , state : State , ** kwargs ) -> bool :
269
265
"""
@@ -275,7 +271,9 @@ def on_finalize(self, state: State, **kwargs) -> bool:
275
271
if not self .ended_ :
276
272
self .on_end (state , None )
277
273
278
- self ._activations .clear ()
274
+ self ._samples .clear ()
275
+ self ._sample_means .clear ()
276
+ self ._num_samples .clear ()
279
277
self ._resolved_mappings .clear ()
280
278
281
279
return True
@@ -363,30 +361,24 @@ def _setup_activation_cache_hooks(self) -> None:
363
361
calculate the dynamic range during calibration
364
362
"""
365
363
366
- def create_cache_activation_hook (smooth_layer_name ):
367
- def cache_activation_hook_fn (
368
- _module : torch .nn .Module ,
369
- args : Tuple [torch .Tensor , ...],
370
- _output : torch .Tensor ,
371
- ):
372
- # Assume that first argument is the input
373
- inp = args [0 ].cpu ().detach ()
374
-
375
- if smooth_layer_name in self ._activations :
376
- self ._activations [smooth_layer_name ].append (inp )
377
- else :
378
- self ._activations [smooth_layer_name ] = [inp ]
364
+ def cache_activation_hook_fn (
365
+ _module : torch .nn .Module ,
366
+ args : Tuple [torch .Tensor , ...],
367
+ kwargs : Dict [str , Any ],
368
+ ):
369
+ sample = args [0 ] # assume input is first arg
370
+ values = inspect .signature (_module .forward ).bind (* args , ** kwargs )
379
371
380
- return cache_activation_hook_fn
372
+ self ._samples [_module ].append (values )
373
+ self ._sample_means , self ._num_samples = accumulate_mean (
374
+ sample , self ._sample_means , self ._num_samples
375
+ )
381
376
382
377
for mapping in self ._resolved_mappings :
383
378
# storing inputs to first balance layer is sufficient
384
379
# other balance layers get the same input
385
- layer = mapping .balance_layers [0 ]
386
- hook = self .register_hook (
387
- layer , create_cache_activation_hook (mapping .smooth_name ), "forward"
388
- )
389
- self ._activation_hooks .add (hook )
380
+ for parent in mapping .parent :
381
+ self .register_hook (parent , cache_activation_hook_fn , "forward_pre" )
390
382
391
383
@torch .no_grad ()
392
384
def _apply_smoothing (self , model : Module ) -> None :
@@ -398,18 +390,15 @@ def _apply_smoothing(self, model: Module) -> None:
398
390
:param model: model to apply smoothing to
399
391
"""
400
392
for mapping in tqdm (self ._resolved_mappings , desc = "Smoothing" ):
393
+ smooth_layer = mapping .smooth_layer
394
+ balance_layers = mapping .balance_layers
395
+ parent_layer = mapping .parent
396
+
401
397
# NOTE: When using SequentialPipeline, not all the mappings
402
398
# will have cached activations in the segment being udpated
403
- if mapping . smooth_name not in self ._activations :
399
+ if parent_layer not in self ._num_samples :
404
400
continue
405
401
406
- activations = torch .cat (self ._activations [mapping .smooth_name ], dim = 0 )
407
- del self ._activations [mapping .smooth_name ]
408
-
409
- smooth_layer = mapping .smooth_layer
410
- balance_layers = mapping .balance_layers
411
- module2inspect = mapping .parent
412
-
413
402
# [STEP 1]: Compute per-channel mean of normalised weights
414
403
# All layer weights are concatted together
415
404
weight = torch .cat ([bl .weight for bl in balance_layers ], dim = 0 )
@@ -425,45 +414,18 @@ def _apply_smoothing(self, model: Module) -> None:
425
414
# Gets the average rescaled magnitude for each output channel
426
415
w_mean = w_scale .mean (0 )
427
416
428
- # [STEP 2]: Compute per-channel mean of the input activation with chunking
429
- # move inp to cpu to avoid memory leak
430
- inp = activations .to (weight .device )
431
- inp_flat = activations .cpu ().abs ().view (- 1 , inp .shape [- 1 ])
432
- num_elements = inp_flat .size (0 )
433
- num_channels = inp_flat .size (1 )
434
- element_size_bytes = inp_flat .element_size () * 2 # multiplied by 2 for FP32
435
-
436
- # Calculate chunk size dynamically based on max_chunk_memory
437
- chunk_size = int (
438
- self .max_chunk_memory // (element_size_bytes * num_channels )
439
- )
440
- chunk_size = min (chunk_size , num_elements )
441
-
442
- # Use float32 for sum calculation
443
- x_sum = torch .zeros (num_channels , dtype = torch .float32 , device = inp .device )
444
-
445
- for i in range (0 , num_elements , chunk_size ):
446
- end = min (i + chunk_size , num_elements )
447
- chunk_sum = inp_flat [i :end ].to (torch .float32 ).sum (dim = 0 )
448
- x_sum += chunk_sum .to (inp .device )
449
-
450
- x_mean = (x_sum / num_elements ).to (inp .dtype )
451
-
452
417
with calibration_forward_context (model ), HooksMixin .disable_hooks ():
453
418
# [STEP 3]: Compute output of module
454
- fp16_output = self ._forward_input_with_kwargs (
455
- module = module2inspect ,
456
- inputs = inp ,
457
- input_kwargs = _sanitize_kwargs (self ._module_kwargs , module2inspect ),
458
- )
419
+ # could cache from hook, rather than recomputing here
420
+ fp16_output = self ._run_samples (parent_layer )
459
421
fp16_output = fp16_output .clip (
460
422
torch .finfo (fp16_output .dtype ).min ,
461
423
torch .finfo (fp16_output .dtype ).max ,
462
424
)
463
425
464
426
# [STEP 4]: Compute loss
465
427
best_scales = self ._compute_best_scale (
466
- inp , w_mean , x_mean , module2inspect , balance_layers , fp16_output
428
+ w_mean , parent_layer , balance_layers , fp16_output
467
429
)
468
430
469
431
scales = best_scales
@@ -504,14 +466,26 @@ def smooth(module):
504
466
smooth (layer )
505
467
smooth (smooth_layer )
506
468
469
+ # remove caches needed to smooth this mapping
470
+ del self ._samples [parent_layer ]
471
+ del self ._sample_means [parent_layer ]
472
+ del self ._num_samples [parent_layer ]
473
+
507
474
self ._assert_all_activations_consumed ()
508
475
476
+ def _run_samples (self , module : Module ) -> torch .Tensor :
477
+ with align_module_device (module ):
478
+ return torch .cat (
479
+ [module (** batch ) for batch in self ._samples [module ]],
480
+ dim = 0 ,
481
+ )
482
+
509
483
def _compute_best_scale (
510
484
self ,
511
485
x : torch .Tensor ,
512
486
w_mean : torch .Tensor ,
513
487
x_mean : torch .Tensor ,
514
- module2inspect : torch .nn .Module ,
488
+ parent_layer : torch .nn .Module ,
515
489
linears2scale : List [torch .nn .Linear ],
516
490
fp16_output : torch .Tensor ,
517
491
) -> torch .Tensor :
@@ -530,9 +504,10 @@ def _compute_best_scale(
530
504
best_scales = None
531
505
best_error = float ("inf" )
532
506
533
- org_sd = {k : v .cpu () for k , v in module2inspect .state_dict ().items ()}
507
+ org_sd = {k : v .cpu () for k , v in parent_layer .state_dict ().items ()}
534
508
535
509
device = x .device
510
+ x_mean = self ._sample_means [parent_layer ]
536
511
x_mean = x_mean .view (- 1 ).to (device )
537
512
w_mean = w_mean .view (- 1 ).to (device )
538
513
@@ -571,9 +546,7 @@ def _compute_best_scale(
571
546
)
572
547
573
548
# W * X
574
- int_w_output = self ._forward_input_with_kwargs (
575
- module = module2inspect , inputs = x , input_kwargs = self ._module_kwargs
576
- )
549
+ int_w_output = self ._run_samples (parent_layer )
577
550
int_w_output = int_w_output .clip (
578
551
torch .finfo (int_w_output .dtype ).min ,
579
552
torch .finfo (int_w_output .dtype ).max ,
@@ -587,7 +560,7 @@ def _compute_best_scale(
587
560
best_error = loss
588
561
best_ratio = ratio
589
562
best_scales = scales .clone ()
590
- module2inspect .load_state_dict (org_sd )
563
+ parent_layer .load_state_dict (org_sd )
591
564
592
565
if best_ratio == - 1 :
593
566
logger .debug (history )
@@ -642,123 +615,10 @@ def _assert_all_activations_consumed(self):
642
615
Confirm all activations have been consumed
643
616
If not, something has gone wrong
644
617
"""
645
- if len (self ._activations ) > 0 :
646
- raise RuntimeError ("Some cached activations were not used" )
647
-
648
- def _set_module_kwargs (self , model , dataloader ) -> None :
649
- _ , modules = next (iter (get_layers ("re:.*layers" , model ).items ()))
650
-
651
- samples = [batch ["input_ids" ] for batch in dataloader ]
652
-
653
- samples = torch .cat (samples , dim = 0 )
654
-
655
- inps = []
656
- layer_kwargs = {}
657
-
658
- best_device = "cuda"
659
- modules [0 ] = modules [0 ].to (best_device )
660
-
661
- # get input and kwargs to layer 0
662
- # with_kwargs is only supported in PyTorch 2.0
663
- # use this Catcher hack for now
664
- class Catcher (torch .nn .Module ):
665
- def __init__ (self , module ):
666
- super ().__init__ ()
667
- self .module = module
668
-
669
- def forward (self , * args , ** kwargs ):
670
- # assume first input to forward is hidden states
671
- if len (args ) > 0 :
672
- hidden_states = args [0 ]
673
- del args
674
- else :
675
- first_key = list (kwargs .keys ())[0 ]
676
- hidden_states = kwargs .pop (first_key )
677
-
678
- inps .append (hidden_states )
679
- layer_kwargs .update (kwargs )
680
- raise ValueError # early exit to break later inference
681
-
682
- # patch layer 0 to catch input and kwargs
683
- modules [0 ] = Catcher (modules [0 ])
684
- try :
685
- with calibration_forward_context (model ):
686
- model (samples .to (next (model .parameters ()).device ))
687
- except ValueError : # work with early exit
688
- pass
689
- modules [0 ] = modules [0 ].module # restore
690
-
691
- # Update the layer kwargs with `prepare_inputs_for_generation` method
692
- # that takes care of everything to avoid unexpected errors.
693
- layer_kwargs = model .prepare_inputs_for_generation (samples , ** layer_kwargs )
694
- # Pop the input_ids as they are not needed at all.
695
- layer_kwargs .pop ("input_ids" )
696
-
697
- del samples
698
- inps = inps [0 ]
699
-
700
- if layer_kwargs .get ("attention_mask" ) is not None :
701
- layer_kwargs ["attention_mask" ] = layer_kwargs ["attention_mask" ].to (
702
- best_device
703
- )
704
-
705
- self ._module_kwargs = layer_kwargs
706
-
707
- def _forward_input_with_kwargs (
708
- self ,
709
- module : Module ,
710
- inputs : torch .Tensor ,
711
- input_kwargs : Optional [Dict [str , Any ]] = None ,
712
- ) -> torch .Tensor :
713
- """
714
- Forward pass with input arguments
715
-
716
- :param module: module to run forward pass on
717
- :param inputs: input tensor to pass to the module
718
- :param input_kwargs: additional arguments to pass to the module
719
- :return: the first output tensor from the forward pass
720
- """
721
- kwargs = input_kwargs or self ._module_kwargs
722
- kwargs = _sanitize_kwargs (kwargs , module )
723
-
724
- inputs = inputs .to (get_execution_device (module ))
725
-
726
- return module (inputs , ** kwargs )[0 ]
727
-
728
-
729
- def _sanitize_kwargs (input_kwargs : Dict [str , Any ], module : Module ) -> Dict [str , Any ]:
730
- """
731
- Sanitize input keyword arguments to match the module's forward method signature,
732
- excluding `use_cache` which is not desired to be passed into module.
733
-
734
- Args:
735
- inputs_kwargs (`dict`):
736
- The input dictionary to pass to the model layer
737
- module (`torch.nn.Module`):
738
- Target module to quantize.
739
- """
740
-
741
- params = inspect .signature (module .forward ).parameters
742
-
743
- # Filter out any kwargs not in module.forward signature
744
- sanitized_kwargs = {k : v for k , v in input_kwargs .items () if k in params }
745
-
746
- # Edge Case: forward pass has optional dependencies that don't default to None.
747
- # This is the case for `LlamaAttention.forward` which has input
748
- # `attention_mask: Optional[torch.Tensor],` (with no `= None` default)
749
- # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L246
750
- for k , v in params .items ():
751
- if (
752
- k not in sanitized_kwargs
753
- and v .default is inspect .Parameter .empty
754
- and str (v .annotation ).startswith ("typing.Optional" )
618
+ if not (
619
+ len (self ._samples ) == len (self ._num_samples ) == len (self ._sample_means ) == 0
755
620
):
756
- sanitized_kwargs [k ] = None
757
-
758
- # Exclude `use_cache` entirely
759
- sanitized_kwargs .pop ("use_cache" , None )
760
-
761
- return sanitized_kwargs
621
+ raise RuntimeError ("Some cached activations were not used" )
762
622
763
623
764
624
def _pseudo_quantize_tensor (
0 commit comments