Skip to content

Commit 7273b1c

Browse files
committed
accumulation
Signed-off-by: Kyle Sayers <[email protected]>
1 parent c4cd97c commit 7273b1c

File tree

3 files changed

+97
-215
lines changed

3 files changed

+97
-215
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 54 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
11
import inspect
2-
from typing import Any, Dict, List, Optional, Set, Tuple, Union
2+
from typing import Any, Dict, List, Optional, Tuple, Union
33

44
import torch
55
from compressed_tensors.quantization import disable_quantization
6-
from compressed_tensors.utils import (
7-
align_module_device,
8-
get_execution_device,
9-
update_offload_parameter,
10-
)
6+
from compressed_tensors.utils import align_module_device, update_offload_parameter
117
from loguru import logger
128
from pydantic import ConfigDict, PrivateAttr, model_validator
139
from torch.nn import Module
14-
from torch.utils.hooks import RemovableHandle
1510
from tqdm import tqdm
1611

1712
from llmcompressor.core import Event, EventType, State
1813
from llmcompressor.modifiers import Modifier
14+
from llmcompressor.modifiers.awq.helpers import accumulate_mean
1915
from llmcompressor.modifiers.quantization.calibration import update_weight_zp_scale
2016
from llmcompressor.modifiers.quantization.quantization import QuantizationMixin
2117
from llmcompressor.modifiers.utils.hooks import HooksMixin
18+
from llmcompressor.pipelines.cache import IntermediatesCache
2219
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
2320
from llmcompressor.utils.helpers import calibration_forward_context
2421
from llmcompressor.utils.pytorch.module import (
@@ -131,9 +128,11 @@ class AWQModifier(Modifier, QuantizationMixin):
131128

132129
# Private vars set during initialization, cleared during finalization
133130
_resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list)
134-
_activations: Dict[str, List[torch.Tensor]] = PrivateAttr(default_factory=dict)
135-
_activation_hooks: Set[RemovableHandle] = PrivateAttr(default_factory=set)
136-
_module_kwargs: Dict = PrivateAttr(default_factory=dict)
131+
_samples: Dict[Module, IntermediatesCache] = PrivateAttr(
132+
default_factory=IntermediatesCache
133+
)
134+
_sample_means: Dict[Module, float] = PrivateAttr(default_factory=dict)
135+
_num_samples: Dict[Module, int] = PrivateAttr(default_factory=dict)
137136

138137
@model_validator(mode="after")
139138
def validate_model_after(model: "AWQModifier") -> "AWQModifier":
@@ -214,8 +213,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:
214213

215214
self._set_resolved_mappings(state.model)
216215

217-
self._set_module_kwargs(state.model, state.data.calib)
218-
219216
return True
220217

221218
def on_start(self, state: State, event: Event, **kwargs):
@@ -262,8 +259,7 @@ def on_end(self, state: State, event: Event, **kwargs):
262259
QuantizationMixin.end_calibration(self, state.model)
263260

264261
# remove activation hooks
265-
self.remove_hooks(self._activation_hooks)
266-
self._activation_hooks.clear()
262+
self.remove_hooks()
267263

268264
def on_finalize(self, state: State, **kwargs) -> bool:
269265
"""
@@ -275,7 +271,9 @@ def on_finalize(self, state: State, **kwargs) -> bool:
275271
if not self.ended_:
276272
self.on_end(state, None)
277273

278-
self._activations.clear()
274+
self._samples.clear()
275+
self._sample_means.clear()
276+
self._num_samples.clear()
279277
self._resolved_mappings.clear()
280278

281279
return True
@@ -363,30 +361,24 @@ def _setup_activation_cache_hooks(self) -> None:
363361
calculate the dynamic range during calibration
364362
"""
365363

366-
def create_cache_activation_hook(smooth_layer_name):
367-
def cache_activation_hook_fn(
368-
_module: torch.nn.Module,
369-
args: Tuple[torch.Tensor, ...],
370-
_output: torch.Tensor,
371-
):
372-
# Assume that first argument is the input
373-
inp = args[0].cpu().detach()
374-
375-
if smooth_layer_name in self._activations:
376-
self._activations[smooth_layer_name].append(inp)
377-
else:
378-
self._activations[smooth_layer_name] = [inp]
364+
def cache_activation_hook_fn(
365+
_module: torch.nn.Module,
366+
args: Tuple[torch.Tensor, ...],
367+
kwargs: Dict[str, Any],
368+
):
369+
sample = args[0] # assume input is first arg
370+
values = inspect.signature(_module.forward).bind(*args, **kwargs)
379371

380-
return cache_activation_hook_fn
372+
self._samples[_module].append(values)
373+
self._sample_means, self._num_samples = accumulate_mean(
374+
sample, self._sample_means, self._num_samples
375+
)
381376

382377
for mapping in self._resolved_mappings:
383378
# storing inputs to first balance layer is sufficient
384379
# other balance layers get the same input
385-
layer = mapping.balance_layers[0]
386-
hook = self.register_hook(
387-
layer, create_cache_activation_hook(mapping.smooth_name), "forward"
388-
)
389-
self._activation_hooks.add(hook)
380+
for parent in mapping.parent:
381+
self.register_hook(parent, cache_activation_hook_fn, "forward_pre")
390382

391383
@torch.no_grad()
392384
def _apply_smoothing(self, model: Module) -> None:
@@ -398,18 +390,15 @@ def _apply_smoothing(self, model: Module) -> None:
398390
:param model: model to apply smoothing to
399391
"""
400392
for mapping in tqdm(self._resolved_mappings, desc="Smoothing"):
393+
smooth_layer = mapping.smooth_layer
394+
balance_layers = mapping.balance_layers
395+
parent_layer = mapping.parent
396+
401397
# NOTE: When using SequentialPipeline, not all the mappings
402398
# will have cached activations in the segment being udpated
403-
if mapping.smooth_name not in self._activations:
399+
if parent_layer not in self._num_samples:
404400
continue
405401

406-
activations = torch.cat(self._activations[mapping.smooth_name], dim=0)
407-
del self._activations[mapping.smooth_name]
408-
409-
smooth_layer = mapping.smooth_layer
410-
balance_layers = mapping.balance_layers
411-
module2inspect = mapping.parent
412-
413402
# [STEP 1]: Compute per-channel mean of normalised weights
414403
# All layer weights are concatted together
415404
weight = torch.cat([bl.weight for bl in balance_layers], dim=0)
@@ -425,45 +414,18 @@ def _apply_smoothing(self, model: Module) -> None:
425414
# Gets the average rescaled magnitude for each output channel
426415
w_mean = w_scale.mean(0)
427416

428-
# [STEP 2]: Compute per-channel mean of the input activation with chunking
429-
# move inp to cpu to avoid memory leak
430-
inp = activations.to(weight.device)
431-
inp_flat = activations.cpu().abs().view(-1, inp.shape[-1])
432-
num_elements = inp_flat.size(0)
433-
num_channels = inp_flat.size(1)
434-
element_size_bytes = inp_flat.element_size() * 2 # multiplied by 2 for FP32
435-
436-
# Calculate chunk size dynamically based on max_chunk_memory
437-
chunk_size = int(
438-
self.max_chunk_memory // (element_size_bytes * num_channels)
439-
)
440-
chunk_size = min(chunk_size, num_elements)
441-
442-
# Use float32 for sum calculation
443-
x_sum = torch.zeros(num_channels, dtype=torch.float32, device=inp.device)
444-
445-
for i in range(0, num_elements, chunk_size):
446-
end = min(i + chunk_size, num_elements)
447-
chunk_sum = inp_flat[i:end].to(torch.float32).sum(dim=0)
448-
x_sum += chunk_sum.to(inp.device)
449-
450-
x_mean = (x_sum / num_elements).to(inp.dtype)
451-
452417
with calibration_forward_context(model), HooksMixin.disable_hooks():
453418
# [STEP 3]: Compute output of module
454-
fp16_output = self._forward_input_with_kwargs(
455-
module=module2inspect,
456-
inputs=inp,
457-
input_kwargs=_sanitize_kwargs(self._module_kwargs, module2inspect),
458-
)
419+
# could cache from hook, rather than recomputing here
420+
fp16_output = self._run_samples(parent_layer)
459421
fp16_output = fp16_output.clip(
460422
torch.finfo(fp16_output.dtype).min,
461423
torch.finfo(fp16_output.dtype).max,
462424
)
463425

464426
# [STEP 4]: Compute loss
465427
best_scales = self._compute_best_scale(
466-
inp, w_mean, x_mean, module2inspect, balance_layers, fp16_output
428+
w_mean, parent_layer, balance_layers, fp16_output
467429
)
468430

469431
scales = best_scales
@@ -504,14 +466,26 @@ def smooth(module):
504466
smooth(layer)
505467
smooth(smooth_layer)
506468

469+
# remove caches needed to smooth this mapping
470+
del self._samples[parent_layer]
471+
del self._sample_means[parent_layer]
472+
del self._num_samples[parent_layer]
473+
507474
self._assert_all_activations_consumed()
508475

476+
def _run_samples(self, module: Module) -> torch.Tensor:
477+
with align_module_device(module):
478+
return torch.cat(
479+
[module(**batch) for batch in self._samples[module]],
480+
dim=0,
481+
)
482+
509483
def _compute_best_scale(
510484
self,
511485
x: torch.Tensor,
512486
w_mean: torch.Tensor,
513487
x_mean: torch.Tensor,
514-
module2inspect: torch.nn.Module,
488+
parent_layer: torch.nn.Module,
515489
linears2scale: List[torch.nn.Linear],
516490
fp16_output: torch.Tensor,
517491
) -> torch.Tensor:
@@ -530,9 +504,10 @@ def _compute_best_scale(
530504
best_scales = None
531505
best_error = float("inf")
532506

533-
org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
507+
org_sd = {k: v.cpu() for k, v in parent_layer.state_dict().items()}
534508

535509
device = x.device
510+
x_mean = self._sample_means[parent_layer]
536511
x_mean = x_mean.view(-1).to(device)
537512
w_mean = w_mean.view(-1).to(device)
538513

@@ -571,9 +546,7 @@ def _compute_best_scale(
571546
)
572547

573548
# W * X
574-
int_w_output = self._forward_input_with_kwargs(
575-
module=module2inspect, inputs=x, input_kwargs=self._module_kwargs
576-
)
549+
int_w_output = self._run_samples(parent_layer)
577550
int_w_output = int_w_output.clip(
578551
torch.finfo(int_w_output.dtype).min,
579552
torch.finfo(int_w_output.dtype).max,
@@ -587,7 +560,7 @@ def _compute_best_scale(
587560
best_error = loss
588561
best_ratio = ratio
589562
best_scales = scales.clone()
590-
module2inspect.load_state_dict(org_sd)
563+
parent_layer.load_state_dict(org_sd)
591564

592565
if best_ratio == -1:
593566
logger.debug(history)
@@ -642,123 +615,10 @@ def _assert_all_activations_consumed(self):
642615
Confirm all activations have been consumed
643616
If not, something has gone wrong
644617
"""
645-
if len(self._activations) > 0:
646-
raise RuntimeError("Some cached activations were not used")
647-
648-
def _set_module_kwargs(self, model, dataloader) -> None:
649-
_, modules = next(iter(get_layers("re:.*layers", model).items()))
650-
651-
samples = [batch["input_ids"] for batch in dataloader]
652-
653-
samples = torch.cat(samples, dim=0)
654-
655-
inps = []
656-
layer_kwargs = {}
657-
658-
best_device = "cuda"
659-
modules[0] = modules[0].to(best_device)
660-
661-
# get input and kwargs to layer 0
662-
# with_kwargs is only supported in PyTorch 2.0
663-
# use this Catcher hack for now
664-
class Catcher(torch.nn.Module):
665-
def __init__(self, module):
666-
super().__init__()
667-
self.module = module
668-
669-
def forward(self, *args, **kwargs):
670-
# assume first input to forward is hidden states
671-
if len(args) > 0:
672-
hidden_states = args[0]
673-
del args
674-
else:
675-
first_key = list(kwargs.keys())[0]
676-
hidden_states = kwargs.pop(first_key)
677-
678-
inps.append(hidden_states)
679-
layer_kwargs.update(kwargs)
680-
raise ValueError # early exit to break later inference
681-
682-
# patch layer 0 to catch input and kwargs
683-
modules[0] = Catcher(modules[0])
684-
try:
685-
with calibration_forward_context(model):
686-
model(samples.to(next(model.parameters()).device))
687-
except ValueError: # work with early exit
688-
pass
689-
modules[0] = modules[0].module # restore
690-
691-
# Update the layer kwargs with `prepare_inputs_for_generation` method
692-
# that takes care of everything to avoid unexpected errors.
693-
layer_kwargs = model.prepare_inputs_for_generation(samples, **layer_kwargs)
694-
# Pop the input_ids as they are not needed at all.
695-
layer_kwargs.pop("input_ids")
696-
697-
del samples
698-
inps = inps[0]
699-
700-
if layer_kwargs.get("attention_mask") is not None:
701-
layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(
702-
best_device
703-
)
704-
705-
self._module_kwargs = layer_kwargs
706-
707-
def _forward_input_with_kwargs(
708-
self,
709-
module: Module,
710-
inputs: torch.Tensor,
711-
input_kwargs: Optional[Dict[str, Any]] = None,
712-
) -> torch.Tensor:
713-
"""
714-
Forward pass with input arguments
715-
716-
:param module: module to run forward pass on
717-
:param inputs: input tensor to pass to the module
718-
:param input_kwargs: additional arguments to pass to the module
719-
:return: the first output tensor from the forward pass
720-
"""
721-
kwargs = input_kwargs or self._module_kwargs
722-
kwargs = _sanitize_kwargs(kwargs, module)
723-
724-
inputs = inputs.to(get_execution_device(module))
725-
726-
return module(inputs, **kwargs)[0]
727-
728-
729-
def _sanitize_kwargs(input_kwargs: Dict[str, Any], module: Module) -> Dict[str, Any]:
730-
"""
731-
Sanitize input keyword arguments to match the module's forward method signature,
732-
excluding `use_cache` which is not desired to be passed into module.
733-
734-
Args:
735-
inputs_kwargs (`dict`):
736-
The input dictionary to pass to the model layer
737-
module (`torch.nn.Module`):
738-
Target module to quantize.
739-
"""
740-
741-
params = inspect.signature(module.forward).parameters
742-
743-
# Filter out any kwargs not in module.forward signature
744-
sanitized_kwargs = {k: v for k, v in input_kwargs.items() if k in params}
745-
746-
# Edge Case: forward pass has optional dependencies that don't default to None.
747-
# This is the case for `LlamaAttention.forward` which has input
748-
# `attention_mask: Optional[torch.Tensor],` (with no `= None` default)
749-
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L246
750-
for k, v in params.items():
751-
if (
752-
k not in sanitized_kwargs
753-
and v.default is inspect.Parameter.empty
754-
and str(v.annotation).startswith("typing.Optional")
618+
if not (
619+
len(self._samples) == len(self._num_samples) == len(self._sample_means) == 0
755620
):
756-
sanitized_kwargs[k] = None
757-
758-
# Exclude `use_cache` entirely
759-
sanitized_kwargs.pop("use_cache", None)
760-
761-
return sanitized_kwargs
621+
raise RuntimeError("Some cached activations were not used")
762622

763623

764624
def _pseudo_quantize_tensor(
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Tuple
2+
3+
import torch
4+
5+
AWQ_PRECISION = torch.float32
6+
7+
8+
def accumulate_mean(
9+
inp: torch.Tensor, prev_mean: float, num_samples: int
10+
) -> Tuple[float, int]:
11+
num_added = inp.size(0)
12+
input_sum = inp.to(AWQ_PRECISION).sum()
13+
14+
return ((prev_mean * num_samples) + input_sum) / (num_samples + num_added)

0 commit comments

Comments
 (0)