diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index ca8fa68a..e688e500 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -129,6 +129,7 @@ def apply_quantization_config( target_to_scheme = OrderedDict() config = process_quantization_config(config) names_to_scheme = OrderedDict() + for scheme in config.config_groups.values(): for target in scheme.targets: target_to_scheme[target] = scheme @@ -152,7 +153,6 @@ def apply_quantization_config( continue # layer matches ignore list, continue targets = find_name_or_class_matches(name, submodule, target_to_scheme) - if targets: # mark modules to be quantized by adding # quant scheme to the matching layers diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 6886423a..8ec01dab 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -80,7 +80,6 @@ def initialize_module_for_quantization( _initialize_attn_scales(module) else: - if scheme.input_activations is not None: _initialize_scale_zero_point( module, @@ -109,8 +108,14 @@ def initialize_module_for_quantization( if scheme.output_activations is not None: if not is_kv_cache_quant_scheme(scheme): + weight_shape = None + if isinstance(module, torch.nn.Linear) and hasattr(module, "weight"): + weight_shape = module.weight.shape _initialize_scale_zero_point( - module, "output", scheme.output_activations + module, + "output", + scheme.output_activations, + weight_shape=weight_shape, ) module.quantization_scheme = scheme @@ -153,13 +158,18 @@ def _initialize_scale_zero_point( expected_shape = 1 if base_name == "weight" and weight_shape is not None: + if quantization_args.strategy == QuantizationStrategy.CHANNEL: - # (output_channels, 1) expected_shape = (weight_shape[0], 1) + elif quantization_args.strategy == QuantizationStrategy.GROUP: num_groups = weight_shape[1] // quantization_args.group_size expected_shape = (weight_shape[0], max(num_groups, 1)) + if base_name == "output" and weight_shape is not None: + if quantization_args.strategy == QuantizationStrategy.CHANNEL: + expected_shape = weight_shape[0] + scale_dtype = module.weight.dtype if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]: scale_dtype = torch.float16 @@ -190,10 +200,13 @@ def _initialize_scale_zero_point( register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) -def _initialize_attn_scales(module: Module) -> None: +def _initialize_attn_scales( + module: Module, +) -> None: """Initlaize k_scale, v_scale for self_attn""" - expected_shape = 1 # per tensor + # per token for each layer + expected_shape = 1 param = next(module.parameters()) scale_dtype = param.dtype