diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 885a12582..ad071cf09 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -478,7 +478,7 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"): return QUANTIZATION_NVFP4_AWQ - if getattr(layer, "fused_with_layernorm", False): + if getattr(layer, "fused_with_prequant", False): return QUANTIZATION_NVFP4_AWQ assert input_quantizer is not None, ( f"input_quantizer is None for {quantizer_attr_names}" @@ -923,18 +923,138 @@ def all_items_same(item_list): return all(x == item_list[0] for x in item_list) +# Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale)) +PQS_FUSE_MODULE_MAPPING = [ + # Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension + # Mathematical equivalence: + # Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T + # After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T + (["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj")), + # MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension + # Mathematical equivalence: + # Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T + # After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T + (["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")), +] + + +# TODO: make this more general instead of rule based +def pattern_fuse_prequant(model: torch.nn.Module): + """Fuse pre_quant_scale to the linear weights. + + For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that + the results are mathematically equivalent to the following:: + + out_proj.input = (attn_weights @ v_proj.output) + out_proj.output = (out_proj.input * pre_quant_scale) * out_proj.weight + = attn_weights @ (v_proj.output * pre_quant_scale) * out_proj.weight + + For GQA/MQA models where v_proj output dimension < o_proj input dimension, + the pre_quant_scale is averaged across the repeated head groups and then the + o_proj's pre_quant_scale is updated to maintain mathematical equivalence. + + Note: + This is an experimental feature, and it might mess up the quantization errors + of fused linear modules. + """ + for _, module in model.named_modules(): + for module_map in PQS_FUSE_MODULE_MAPPING: + target_module_list = module_map[0] + linear_pair = module_map[1] + if any(module_name in type(module).__name__ for module_name in target_module_list): + linear_to = module.get_submodule(linear_pair[0]) + linear_from = module.get_submodule(linear_pair[1]) + if hasattr(linear_from, "input_quantizer") and hasattr( + linear_from.input_quantizer, "_pre_quant_scale" + ): + pre_quant_scale = linear_from.input_quantizer._pre_quant_scale + + # for GQA/MQA models, we apply averaging to the pre_quant_scale + if pre_quant_scale.numel() != linear_to.weight.shape[0]: + if "attention" not in type(module).__name__.lower(): + continue + else: + config = module.config + num_kv_heads = config.num_key_value_heads + kv_head_dim = linear_to.weight.shape[0] // num_kv_heads + n_rep = pre_quant_scale.numel() // num_kv_heads // kv_head_dim + + # Reshape:(num_kv_heads, n_rep, kv_head_dim) + averaged_scale = pre_quant_scale.view( + num_kv_heads, n_rep, kv_head_dim + ).mean(dim=1) + + # To update o_proj, we need to repeat back to original shape + repeated_scale = ( + averaged_scale.unsqueeze(1) # (2, 1, 16) + .expand(num_kv_heads, n_rep, kv_head_dim) # (2, 2, 16) + .reshape(-1) # (64,) + ) + + def _update_pre_quant_scale(module, new_pre_quant_scale): + old_pre_quant_scale = module.input_quantizer._pre_quant_scale + module.weight = nn.Parameter( + module.weight + * old_pre_quant_scale.to( + dtype=module.weight.dtype, device=module.weight.device + ) + / new_pre_quant_scale.to( + dtype=module.weight.dtype, device=module.weight.device + ) + ) + module.input_quantizer.pre_quant_scale = new_pre_quant_scale + + # Redo weights collection + module.weight_quantizer.reset_amax() + enable_stats_collection(module.weight_quantizer) + module.weight_quantizer(module.weight) + finish_stats_collection(module.weight_quantizer) + + # Update o_proj's pre_quant_scale + _update_pre_quant_scale(linear_from, repeated_scale) + + # Use averaged scale (flattened) for v_proj fusion + pre_quant_scale = averaged_scale.reshape(-1) + + # Fuse the pre_quant_scale to v_proj weight (linear_to) + # v_proj.weight shape: (out_features, in_features) = (32, hidden_size) + # We scale the output dimension (first dimension) + linear_to.weight = torch.nn.Parameter( + linear_to.weight * pre_quant_scale.view(-1, 1) + ) + if hasattr(linear_to, "bias") and linear_to.bias is not None: + linear_to.bias = torch.nn.Parameter(linear_to.bias * pre_quant_scale) + + delattr(linear_from.input_quantizer, "_pre_quant_scale") + setattr(linear_from, "fused_with_prequant", True) + + def fuse_prequant_layernorm( layernorm_module: torch.nn.Module, modules: list[torch.Tensor], ): - """Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted.""" + """Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted. + + original: + layernorm_output = (normalization(input) * weight) + bias + layernorm_output_scaled = layernorm_output * pre_quant_scale + + fused: + fused_weight = weight * avg_pre_quant_scale + fused_bias = bias * avg_pre_quant_scale + layernorm_output_scaled = (normalization(input) * fused_weight) + fused_bias + """ layernorm_module.weight = torch.nn.Parameter( layernorm_module.weight * getattr(modules[0].input_quantizer, "_pre_quant_scale") ) + if hasattr(layernorm_module, "bias"): + layernorm_module.bias = torch.nn.Parameter( + layernorm_module.bias * getattr(modules[0].input_quantizer, "_pre_quant_scale") + ) # Pre_quant_scales of modules must not be exported, since they have been fused with layernorm for module in modules: delattr(module.input_quantizer, "_pre_quant_scale") - setattr(module, "fused_with_layernorm", True) + setattr(module, "fused_with_prequant", True) def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False): diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index f966ffac6..2ebb93111 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -64,6 +64,7 @@ get_weight_scaling_factor, get_weight_scaling_factor_2, maybe_transpose_expert_weight_dimensions, + pattern_fuse_prequant, postprocess_state_dict, preprocess_linear_fusion, to_quantized_weight, @@ -173,6 +174,8 @@ def _output_hook(module, input, output): # Pre quant scale of modules is already updated to avg_pre_quant_scale fuse_prequant_layernorm(output_to_layernorm[tensor], modules) + pattern_fuse_prequant(model) + # The dummy forward may not be able to activate all the experts. # Process experts by naming rules like experts.0, experts.1, etc. for name, modules_fused in fused_linears.items(): diff --git a/tests/gpu/torch/export/test_quant_utils.py b/tests/gpu/torch/export/test_quant_utils.py new file mode 100644 index 000000000..81e750227 --- /dev/null +++ b/tests/gpu/torch/export/test_quant_utils.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +pytest.importorskip("transformers") + +from transformers import LlamaConfig, LlamaForCausalLM + +import modelopt.torch.quantization as mtq +from modelopt.torch.export.quant_utils import pattern_fuse_prequant + + +def get_tiny_llama(attention_heads=4, key_value_heads=4): + """Create a tiny Llama model for testing.""" + config = LlamaConfig( + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=attention_heads, + num_key_value_heads=key_value_heads, + max_position_embeddings=128, + vocab_size=256, + ) + return LlamaForCausalLM(config) + + +@pytest.mark.parametrize( + "quant_config", + [ + mtq.INT4_AWQ_CFG, + mtq.NVFP4_AWQ_LITE_CFG, + ], +) +@pytest.mark.parametrize( + "attention_kv_heads_pair", + [ + (4, 4), # MHA + (4, 2), # GQA + (4, 1), # MQA + ], +) +def test_pattern_fuse_prequant(quant_config, attention_kv_heads_pair): + """Test pattern_fuse_prequant on modules from a tiny Llama model.""" + model = get_tiny_llama(attention_kv_heads_pair[0], attention_kv_heads_pair[1]).to("cuda") + + # Quantize the model + dummy_input = torch.randint(0, 256, (1, 16), device="cuda") + mtq.quantize(model, quant_config, lambda m: m(dummy_input)) + + # Run forward pass before fusion + model.eval() + with torch.no_grad(): + output_before_fuse = model(dummy_input) + + traget_module_name_list = [ + "model.layers.0.self_attn.o_proj", + "model.layers.0.mlp.down_proj", + "model.layers.1.self_attn.o_proj", + "model.layers.1.mlp.down_proj", + ] + + # Apply fusion + pattern_fuse_prequant(model) + + # Check if pre_quant_scale and fused_with_prequant flag are removed correctly + for target_module_name in traget_module_name_list: + target_module = model.get_submodule(target_module_name) + + # Verify pre_quant_scale was removed + assert not hasattr(target_module.input_quantizer, "_pre_quant_scale"), ( + f"{target_module_name}: pre_quant_scale should be removed after fusion" + ) + + # Verify fused_with_prequant flag was set + assert ( + hasattr(target_module, "fused_with_prequant") and target_module.fused_with_prequant + ), f"{target_module_name}: fused_with_prequant flag should be set" + + # Verify output is close to the original output + with torch.no_grad(): + output_after_fuse = model(dummy_input) + # There will be some small difference due to quantization errors after pre_quant_scale fusion to the weights + assert torch.allclose( + output_before_fuse.logits, output_after_fuse.logits, rtol=1e-1, atol=5e-1 + ), "Output should be the same before and after fusion"