Skip to content

[Quantization] Channel wise output activation quantization for QKV Attention layers #270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def apply_quantization_config(
target_to_scheme = OrderedDict()
config = process_quantization_config(config)
names_to_scheme = OrderedDict()

for scheme in config.config_groups.values():
for target in scheme.targets:
target_to_scheme[target] = scheme
Expand All @@ -152,7 +153,6 @@ def apply_quantization_config(
continue # layer matches ignore list, continue

targets = find_name_or_class_matches(name, submodule, target_to_scheme)

if targets:
# mark modules to be quantized by adding
# quant scheme to the matching layers
Expand Down
23 changes: 18 additions & 5 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def initialize_module_for_quantization(
_initialize_attn_scales(module)

else:

if scheme.input_activations is not None:
_initialize_scale_zero_point(
module,
Expand Down Expand Up @@ -109,8 +108,14 @@ def initialize_module_for_quantization(

if scheme.output_activations is not None:
if not is_kv_cache_quant_scheme(scheme):
weight_shape = None
if isinstance(module, torch.nn.Linear) and hasattr(module, "weight"):
weight_shape = module.weight.shape
_initialize_scale_zero_point(
module, "output", scheme.output_activations
module,
"output",
scheme.output_activations,
weight_shape=weight_shape,
)

module.quantization_scheme = scheme
Expand Down Expand Up @@ -153,13 +158,18 @@ def _initialize_scale_zero_point(
expected_shape = 1

if base_name == "weight" and weight_shape is not None:

if quantization_args.strategy == QuantizationStrategy.CHANNEL:
# (output_channels, 1)
expected_shape = (weight_shape[0], 1)

elif quantization_args.strategy == QuantizationStrategy.GROUP:
num_groups = weight_shape[1] // quantization_args.group_size
expected_shape = (weight_shape[0], max(num_groups, 1))

if base_name == "output" and weight_shape is not None:
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
expected_shape = weight_shape[0]

scale_dtype = module.weight.dtype
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
scale_dtype = torch.float16
Expand Down Expand Up @@ -190,10 +200,13 @@ def _initialize_scale_zero_point(
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)


def _initialize_attn_scales(module: Module) -> None:
def _initialize_attn_scales(
module: Module,
) -> None:
"""Initlaize k_scale, v_scale for self_attn"""

expected_shape = 1 # per tensor
# per token for each layer
expected_shape = 1

param = next(module.parameters())
scale_dtype = param.dtype
Expand Down