From eb2c5b66c905f01546c1cb307b139ea611221200 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Thu, 9 Oct 2025 22:27:41 +0000 Subject: [PATCH 01/11] pattern-based fusion Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 65 +++++++++++++++++++++- modelopt/torch/export/unified_export_hf.py | 3 + 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 3e99a0e0a..1c4fd4160 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -488,7 +488,7 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"): return QUANTIZATION_NVFP4_AWQ - if getattr(layer, "fused_with_layernorm", False): + if getattr(layer, "fused_with_prequant", False): return QUANTIZATION_NVFP4_AWQ assert input_quantizer is not None, ( f"input_quantizer is None for {quantizer_attr_names}" @@ -935,18 +935,77 @@ def all_items_same(item_list): return all(x == item_list[0] for x in item_list) +PQS_FUSE_MODULE_MAPPING = [ + # format: (list of target modules, tuple of (linear_pqs_fuse_to, linear_pqs_from), dim to fuse) + (["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj"), "input"), + (["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj"), "output"), +] + + +# TODO: make this more general instead of rule based +def pattern_fuse_prequant(model: torch.nn.Module): + """Fuse pre_quant_scale to the linear weights. + + For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that + The results are mathematically equivalent to the following: + + out_proj.input = (attn_weights @ v_proj.output) + out_proj.output = (out_proj.input * pre_quant_scale) * out_proj.weight + = attn_weights @ (v_proj.output * pre_quant_scale) * out_proj.weight + + Note: This is an experimental feature, and it might mess up the quantization errors of fused linear modules. + """ + for _, module in model.named_modules(): + for module_map in PQS_FUSE_MODULE_MAPPING: + target_module_list = module_map[0] + linear_pair = module_map[1] + dim_to_fuse = module_map[2] + if any(module_name in type(module).__name__ for module_name in target_module_list): + linear_to = module.get_submodule(linear_pair[0]) + linear_from = module.get_submodule(linear_pair[1]) + if hasattr(linear_from, "input_quantizer") and hasattr( + linear_from.input_quantizer, "_pre_quant_scale" + ): + pre_quant_scale = linear_from.input_quantizer._pre_quant_scale + # check if we need to apply to the last dimension or the first dimension + pre_quant_scale = ( + pre_quant_scale.view(-1, 1) + if dim_to_fuse == "output" + else pre_quant_scale.view(1, -1) + ) + linear_to.weight = torch.nn.Parameter(linear_to.weight * pre_quant_scale) + if hasattr(linear_to, "bias") and linear_to.bias is not None: + linear_to.bias = torch.nn.Parameter(linear_to.bias * pre_quant_scale) + delattr(linear_from.input_quantizer, "_pre_quant_scale") + setattr(linear_from, "fused_with_prequant", True) + + def fuse_prequant_layernorm( layernorm_module: torch.nn.Module, modules: list[torch.Tensor], ): - """Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted.""" + """Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted. + + original: + layernorm_output = (normalization(input) * weight) + bias + layernorm_output_scaled = layernorm_output * pre_quant_scale + + fused: + fused_weight = weight * avg_pre_quant_scale + fused_bias = bias * avg_pre_quant_scale + layernorm_output_scaled = (normalization(input) * fused_weight) + fused_bias + """ layernorm_module.weight = torch.nn.Parameter( layernorm_module.weight * getattr(modules[0].input_quantizer, "_pre_quant_scale") ) + if hasattr(layernorm_module, "bias"): + layernorm_module.bias = torch.nn.Parameter( + layernorm_module.bias * getattr(modules[0].input_quantizer, "_pre_quant_scale") + ) # Pre_quant_scales of modules must not be exported, since they have been fused with layernorm for module in modules: delattr(module.input_quantizer, "_pre_quant_scale") - setattr(module, "fused_with_layernorm", True) + setattr(module, "fused_with_prequant", True) def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False): diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 7b102f4e0..c14c73013 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -67,6 +67,7 @@ get_weight_scaling_factor, get_weight_scaling_factor_2, maybe_transpose_expert_weight_dimensions, + pattern_fuse_prequant, postprocess_state_dict, preprocess_linear_fusion, to_quantized_weight, @@ -198,6 +199,8 @@ def _output_hook(module, input, output): with fsdp2_aware_weight_update(model, output_to_layernorm[tensor]): fuse_prequant_layernorm(output_to_layernorm[tensor], modules) + pattern_fuse_prequant(model) + # The dummy forward may not be able to activate all the experts. # Process experts by naming rules like experts.0, experts.1, etc. for name, modules_fused in fused_linears.items(): From 6c83a7bb01f2d362cce2e9712d5f50ae527ddd0a Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Tue, 14 Oct 2025 04:26:19 +0000 Subject: [PATCH 02/11] fix GQA Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 76 ++++++++++++++++++++++++---- 1 file changed, 66 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 1c4fd4160..3bb7844bf 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -935,10 +935,19 @@ def all_items_same(item_list): return all(x == item_list[0] for x in item_list) +# Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale)) PQS_FUSE_MODULE_MAPPING = [ - # format: (list of target modules, tuple of (linear_pqs_fuse_to, linear_pqs_from), dim to fuse) - (["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj"), "input"), - (["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj"), "output"), + # Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension + # Mathematical equivalence: + # Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T + # After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T + # note: for GQA models, TODO: + (["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj")), + # MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension + # Mathematical equivalence: + # Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T + # After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T + (["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")), ] @@ -959,7 +968,6 @@ def pattern_fuse_prequant(model: torch.nn.Module): for module_map in PQS_FUSE_MODULE_MAPPING: target_module_list = module_map[0] linear_pair = module_map[1] - dim_to_fuse = module_map[2] if any(module_name in type(module).__name__ for module_name in target_module_list): linear_to = module.get_submodule(linear_pair[0]) linear_from = module.get_submodule(linear_pair[1]) @@ -967,15 +975,63 @@ def pattern_fuse_prequant(model: torch.nn.Module): linear_from.input_quantizer, "_pre_quant_scale" ): pre_quant_scale = linear_from.input_quantizer._pre_quant_scale - # check if we need to apply to the last dimension or the first dimension - pre_quant_scale = ( - pre_quant_scale.view(-1, 1) - if dim_to_fuse == "output" - else pre_quant_scale.view(1, -1) + + # for GQA/MQA models, we apply averaging to the pre_quant_scale + if pre_quant_scale.numel() != linear_to.weight.shape[0]: + if "attention" not in type(module).__name__.lower(): + continue + else: + config = module.config + num_kv_heads = config.num_key_value_heads + kv_head_dim = linear_to.weight.shape[0] // num_kv_heads + n_rep = pre_quant_scale.numel() // num_kv_heads // kv_head_dim + + # Reshape:(num_kv_heads, n_rep, kv_head_dim) + averaged_scale = pre_quant_scale.view( + num_kv_heads, n_rep, kv_head_dim + ).mean(dim=1) + + # To update o_proj, we need to repeat back to original shape + repeated_scale = ( + averaged_scale.unsqueeze(1) # (2, 1, 16) + .expand(num_kv_heads, n_rep, kv_head_dim) # (2, 2, 16) + .reshape(-1) # (64,) + ) + + def _update_pre_quant_scale(module, new_pre_quant_scale): + old_pre_quant_scale = module.input_quantizer._pre_quant_scale + module.weight = nn.Parameter( + module.weight + * old_pre_quant_scale.to( + dtype=module.weight.dtype, device=module.weight.device + ) + / new_pre_quant_scale.to( + dtype=module.weight.dtype, device=module.weight.device + ) + ) + module.input_quantizer.pre_quant_scale = new_pre_quant_scale + + # Redo weights collection + module.weight_quantizer.reset_amax() + enable_stats_collection(module.weight_quantizer) + module.weight_quantizer(module.weight) + finish_stats_collection(module.weight_quantizer) + + # Update o_proj's pre_quant_scale + _update_pre_quant_scale(linear_from, repeated_scale) + + # Use averaged scale (flattened) for v_proj fusion + pre_quant_scale = averaged_scale.reshape(-1) + + # Fuse the pre_quant_scale to v_proj weight (linear_to) + # v_proj.weight shape: (out_features, in_features) = (32, hidden_size) + # We scale the output dimension (first dimension) + linear_to.weight = torch.nn.Parameter( + linear_to.weight * pre_quant_scale.view(-1, 1) ) - linear_to.weight = torch.nn.Parameter(linear_to.weight * pre_quant_scale) if hasattr(linear_to, "bias") and linear_to.bias is not None: linear_to.bias = torch.nn.Parameter(linear_to.bias * pre_quant_scale) + delattr(linear_from.input_quantizer, "_pre_quant_scale") setattr(linear_from, "fused_with_prequant", True) From 2b320c1b91fc86369115da33e432f386adf4636c Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Tue, 14 Oct 2025 04:27:03 +0000 Subject: [PATCH 03/11] minor Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 3bb7844bf..5fdd0cc37 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -941,7 +941,6 @@ def all_items_same(item_list): # Mathematical equivalence: # Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T # After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T - # note: for GQA models, TODO: (["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj")), # MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension # Mathematical equivalence: From da951c3a815dd0a2b8bce8d16db98f7e6df2ad3d Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Tue, 14 Oct 2025 17:10:22 +0000 Subject: [PATCH 04/11] unit test Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- examples/vllm_serve/vllm_serve_fakequant.py | 4 +- tests/gpu/torch/export/test_quant_utils.py | 99 +++++++++++++++++++++ 2 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 tests/gpu/torch/export/test_quant_utils.py diff --git a/examples/vllm_serve/vllm_serve_fakequant.py b/examples/vllm_serve/vllm_serve_fakequant.py index e96f2d3dc..680b97353 100644 --- a/examples/vllm_serve/vllm_serve_fakequant.py +++ b/examples/vllm_serve/vllm_serve_fakequant.py @@ -97,7 +97,8 @@ def disable_compilation(model): quant_config: dict[str, Any] = { "quant_dataset": "cnn_dailymail", "quant_num_samples": 512, - "quant_format": "NVFP4_DEFAULT_CFG", + # "quant_format": "NVFP4_DEFAULT_CFG", + "quant_format": "NVFP4_AWQ_LITE_CFG", "amax_file_path": None, # Optional: path to pre-computed amax values (e.g., "/path/to/amax.pt") } @@ -176,6 +177,7 @@ def calibrate_loop(model: Any = None) -> None: quant_cfg = getattr(mtq, quant_config["quant_format"]) + print(f"Quantizing model with {quant_config['quant_format']} format") with disable_compilation(self.model): mtq.quantize(self.model, quant_cfg, forward_loop=calibrate_loop) diff --git a/tests/gpu/torch/export/test_quant_utils.py b/tests/gpu/torch/export/test_quant_utils.py new file mode 100644 index 000000000..81e750227 --- /dev/null +++ b/tests/gpu/torch/export/test_quant_utils.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +pytest.importorskip("transformers") + +from transformers import LlamaConfig, LlamaForCausalLM + +import modelopt.torch.quantization as mtq +from modelopt.torch.export.quant_utils import pattern_fuse_prequant + + +def get_tiny_llama(attention_heads=4, key_value_heads=4): + """Create a tiny Llama model for testing.""" + config = LlamaConfig( + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=attention_heads, + num_key_value_heads=key_value_heads, + max_position_embeddings=128, + vocab_size=256, + ) + return LlamaForCausalLM(config) + + +@pytest.mark.parametrize( + "quant_config", + [ + mtq.INT4_AWQ_CFG, + mtq.NVFP4_AWQ_LITE_CFG, + ], +) +@pytest.mark.parametrize( + "attention_kv_heads_pair", + [ + (4, 4), # MHA + (4, 2), # GQA + (4, 1), # MQA + ], +) +def test_pattern_fuse_prequant(quant_config, attention_kv_heads_pair): + """Test pattern_fuse_prequant on modules from a tiny Llama model.""" + model = get_tiny_llama(attention_kv_heads_pair[0], attention_kv_heads_pair[1]).to("cuda") + + # Quantize the model + dummy_input = torch.randint(0, 256, (1, 16), device="cuda") + mtq.quantize(model, quant_config, lambda m: m(dummy_input)) + + # Run forward pass before fusion + model.eval() + with torch.no_grad(): + output_before_fuse = model(dummy_input) + + traget_module_name_list = [ + "model.layers.0.self_attn.o_proj", + "model.layers.0.mlp.down_proj", + "model.layers.1.self_attn.o_proj", + "model.layers.1.mlp.down_proj", + ] + + # Apply fusion + pattern_fuse_prequant(model) + + # Check if pre_quant_scale and fused_with_prequant flag are removed correctly + for target_module_name in traget_module_name_list: + target_module = model.get_submodule(target_module_name) + + # Verify pre_quant_scale was removed + assert not hasattr(target_module.input_quantizer, "_pre_quant_scale"), ( + f"{target_module_name}: pre_quant_scale should be removed after fusion" + ) + + # Verify fused_with_prequant flag was set + assert ( + hasattr(target_module, "fused_with_prequant") and target_module.fused_with_prequant + ), f"{target_module_name}: fused_with_prequant flag should be set" + + # Verify output is close to the original output + with torch.no_grad(): + output_after_fuse = model(dummy_input) + # There will be some small difference due to quantization errors after pre_quant_scale fusion to the weights + assert torch.allclose( + output_before_fuse.logits, output_after_fuse.logits, rtol=1e-1, atol=5e-1 + ), "Output should be the same before and after fusion" From 1ad352cdecd377e5e95c1994704e887f27bd47e9 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Tue, 14 Oct 2025 17:15:08 +0000 Subject: [PATCH 05/11] fix doc Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 5fdd0cc37..8a0380f05 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -955,13 +955,19 @@ def pattern_fuse_prequant(model: torch.nn.Module): """Fuse pre_quant_scale to the linear weights. For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that - The results are mathematically equivalent to the following: + the results are mathematically equivalent to the following:: - out_proj.input = (attn_weights @ v_proj.output) - out_proj.output = (out_proj.input * pre_quant_scale) * out_proj.weight - = attn_weights @ (v_proj.output * pre_quant_scale) * out_proj.weight + out_proj.input = (attn_weights @ v_proj.output) + out_proj.output = (out_proj.input * pre_quant_scale) * out_proj.weight + = attn_weights @ (v_proj.output * pre_quant_scale) * out_proj.weight - Note: This is an experimental feature, and it might mess up the quantization errors of fused linear modules. + For GQA/MQA models where v_proj output dimension < o_proj input dimension, + the pre_quant_scale is averaged across the repeated head groups and then the + o_proj's pre_quant_scale is updated to maintain mathematical equivalence. + + Note: + This is an experimental feature, and it might mess up the quantization errors + of fused linear modules. """ for _, module in model.named_modules(): for module_map in PQS_FUSE_MODULE_MAPPING: From 0d013a06d7c4a2bb596ebfac8043f3345f096caa Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Tue, 14 Oct 2025 17:26:56 +0000 Subject: [PATCH 06/11] revert unintended change Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- examples/vllm_serve/vllm_serve_fakequant.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/vllm_serve/vllm_serve_fakequant.py b/examples/vllm_serve/vllm_serve_fakequant.py index 680b97353..e96f2d3dc 100644 --- a/examples/vllm_serve/vllm_serve_fakequant.py +++ b/examples/vllm_serve/vllm_serve_fakequant.py @@ -97,8 +97,7 @@ def disable_compilation(model): quant_config: dict[str, Any] = { "quant_dataset": "cnn_dailymail", "quant_num_samples": 512, - # "quant_format": "NVFP4_DEFAULT_CFG", - "quant_format": "NVFP4_AWQ_LITE_CFG", + "quant_format": "NVFP4_DEFAULT_CFG", "amax_file_path": None, # Optional: path to pre-computed amax values (e.g., "/path/to/amax.pt") } @@ -177,7 +176,6 @@ def calibrate_loop(model: Any = None) -> None: quant_cfg = getattr(mtq, quant_config["quant_format"]) - print(f"Quantizing model with {quant_config['quant_format']} format") with disable_compilation(self.model): mtq.quantize(self.model, quant_cfg, forward_loop=calibrate_loop) From be77c21bdcf06499dfdee6266bb57757a8f1aa70 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Fri, 17 Oct 2025 21:49:07 +0000 Subject: [PATCH 07/11] minor Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 42 +++++++++++++--------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 8a0380f05..da94f4648 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -488,8 +488,6 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"): return QUANTIZATION_NVFP4_AWQ - if getattr(layer, "fused_with_prequant", False): - return QUANTIZATION_NVFP4_AWQ assert input_quantizer is not None, ( f"input_quantizer is None for {quantizer_attr_names}" ) @@ -974,21 +972,21 @@ def pattern_fuse_prequant(model: torch.nn.Module): target_module_list = module_map[0] linear_pair = module_map[1] if any(module_name in type(module).__name__ for module_name in target_module_list): - linear_to = module.get_submodule(linear_pair[0]) - linear_from = module.get_submodule(linear_pair[1]) - if hasattr(linear_from, "input_quantizer") and hasattr( - linear_from.input_quantizer, "_pre_quant_scale" + linear_fuse_into = module.get_submodule(linear_pair[0]) + linear_pqs_from = module.get_submodule(linear_pair[1]) + if hasattr(linear_pqs_from, "input_quantizer") and hasattr( + linear_pqs_from.input_quantizer, "_pre_quant_scale" ): - pre_quant_scale = linear_from.input_quantizer._pre_quant_scale + pre_quant_scale = linear_pqs_from.input_quantizer._pre_quant_scale # for GQA/MQA models, we apply averaging to the pre_quant_scale - if pre_quant_scale.numel() != linear_to.weight.shape[0]: + if pre_quant_scale.numel() != linear_fuse_into.weight.shape[0]: if "attention" not in type(module).__name__.lower(): continue else: config = module.config num_kv_heads = config.num_key_value_heads - kv_head_dim = linear_to.weight.shape[0] // num_kv_heads + kv_head_dim = linear_fuse_into.weight.shape[0] // num_kv_heads n_rep = pre_quant_scale.numel() // num_kv_heads // kv_head_dim # Reshape:(num_kv_heads, n_rep, kv_head_dim) @@ -998,9 +996,9 @@ def pattern_fuse_prequant(model: torch.nn.Module): # To update o_proj, we need to repeat back to original shape repeated_scale = ( - averaged_scale.unsqueeze(1) # (2, 1, 16) - .expand(num_kv_heads, n_rep, kv_head_dim) # (2, 2, 16) - .reshape(-1) # (64,) + averaged_scale.unsqueeze(1) + .expand(num_kv_heads, n_rep, kv_head_dim) + .reshape(-1) ) def _update_pre_quant_scale(module, new_pre_quant_scale): @@ -1023,22 +1021,22 @@ def _update_pre_quant_scale(module, new_pre_quant_scale): finish_stats_collection(module.weight_quantizer) # Update o_proj's pre_quant_scale - _update_pre_quant_scale(linear_from, repeated_scale) + _update_pre_quant_scale(linear_pqs_from, repeated_scale) # Use averaged scale (flattened) for v_proj fusion pre_quant_scale = averaged_scale.reshape(-1) - # Fuse the pre_quant_scale to v_proj weight (linear_to) - # v_proj.weight shape: (out_features, in_features) = (32, hidden_size) - # We scale the output dimension (first dimension) - linear_to.weight = torch.nn.Parameter( - linear_to.weight * pre_quant_scale.view(-1, 1) + # Fuse the pre_quant_scale to v_proj weight + linear_fuse_into.weight = torch.nn.Parameter( + linear_fuse_into.weight * pre_quant_scale.view(-1, 1) ) - if hasattr(linear_to, "bias") and linear_to.bias is not None: - linear_to.bias = torch.nn.Parameter(linear_to.bias * pre_quant_scale) + if hasattr(linear_fuse_into, "bias") and linear_fuse_into.bias is not None: + linear_fuse_into.bias = torch.nn.Parameter( + linear_fuse_into.bias * pre_quant_scale + ) - delattr(linear_from.input_quantizer, "_pre_quant_scale") - setattr(linear_from, "fused_with_prequant", True) + delattr(linear_pqs_from.input_quantizer, "_pre_quant_scale") + setattr(linear_pqs_from, "fused_with_prequant", True) def fuse_prequant_layernorm( From d3f695e85f7df1bb66f15a01f863ec25601136c5 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Mon, 27 Oct 2025 19:28:41 +0000 Subject: [PATCH 08/11] resmooth Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 109 +++++++++++------- tests/gpu/torch/export/test_quant_utils.py | 126 ++++++++++++++++++++- 2 files changed, 193 insertions(+), 42 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index da94f4648..947304084 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -488,6 +488,8 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"): return QUANTIZATION_NVFP4_AWQ + if getattr(layer, "fused_with_prequant", False): + return QUANTIZATION_NVFP4_AWQ assert input_quantizer is not None, ( f"input_quantizer is None for {quantizer_attr_names}" ) @@ -949,7 +951,7 @@ def all_items_same(item_list): # TODO: make this more general instead of rule based -def pattern_fuse_prequant(model: torch.nn.Module): +def pattern_fuse_prequant(model: torch.nn.Module, fuse_mismatch_dim=False): """Fuse pre_quant_scale to the linear weights. For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that @@ -963,10 +965,29 @@ def pattern_fuse_prequant(model: torch.nn.Module): the pre_quant_scale is averaged across the repeated head groups and then the o_proj's pre_quant_scale is updated to maintain mathematical equivalence. + Args: + model: The model to fuse pre_quant_scale to. + fuse_mismatch_dim: If True, fuse the pre_quant_scale even if dimension between pre_quant_scale + and linear weights is not the same. This is useful for GQA/MQA models but may lead to accuracy + drop. + Note: This is an experimental feature, and it might mess up the quantization errors of fused linear modules. """ + # For MoE models, let's first resmooth the w1 and w3 in experts to get the average pre_quant_scale + for _, module in model.named_modules(): + if ( + hasattr(module, "experts") + and "Qwen3MoeSparseMoeBlock".lower() in type(module).__name__.lower() + ): + linear_list = [] + linear_list.extend([getattr(expert, "up_proj") for expert in module.experts]) + linear_list.extend([getattr(expert, "gate_proj") for expert in module.experts]) + preprocess_linear_fusion(linear_list, resmooth_only=True) + + # import pdb; pdb.set_trace() + # Fuse pre_quant_scale to the linear weights for _, module in model.named_modules(): for module_map in PQS_FUSE_MODULE_MAPPING: target_module_list = module_map[0] @@ -979,52 +1000,58 @@ def pattern_fuse_prequant(model: torch.nn.Module): ): pre_quant_scale = linear_pqs_from.input_quantizer._pre_quant_scale - # for GQA/MQA models, we apply averaging to the pre_quant_scale - if pre_quant_scale.numel() != linear_fuse_into.weight.shape[0]: - if "attention" not in type(module).__name__.lower(): - continue - else: - config = module.config - num_kv_heads = config.num_key_value_heads - kv_head_dim = linear_fuse_into.weight.shape[0] // num_kv_heads - n_rep = pre_quant_scale.numel() // num_kv_heads // kv_head_dim - - # Reshape:(num_kv_heads, n_rep, kv_head_dim) - averaged_scale = pre_quant_scale.view( - num_kv_heads, n_rep, kv_head_dim - ).mean(dim=1) - - # To update o_proj, we need to repeat back to original shape - repeated_scale = ( - averaged_scale.unsqueeze(1) - .expand(num_kv_heads, n_rep, kv_head_dim) - .reshape(-1) + # for GQA/MQA models, we apply averaging to the pre_quant_scale for shared head groups + if pre_quant_scale.numel() != linear_fuse_into.weight.shape[-2]: + if ( + not fuse_mismatch_dim + or "attention" not in type(module).__name__.lower() + ): + warn( + f"Skipping pattern fuse prequant for {type(module).__name__}" + f"pqs dim {pre_quant_scale.numel()} != out_ch dim {linear_fuse_into.weight.shape[-2]}" ) + continue + config = module.config + num_kv_heads = config.num_key_value_heads + kv_head_dim = linear_fuse_into.weight.shape[0] // num_kv_heads + n_rep = pre_quant_scale.numel() // num_kv_heads // kv_head_dim + + # Reshape:(num_kv_heads, n_rep, kv_head_dim) + averaged_scale = pre_quant_scale.view( + num_kv_heads, n_rep, kv_head_dim + ).mean(dim=1) + + # To update o_proj, we need to repeat back to original shape + repeated_scale = ( + averaged_scale.unsqueeze(1) + .expand(num_kv_heads, n_rep, kv_head_dim) + .reshape(-1) + ) - def _update_pre_quant_scale(module, new_pre_quant_scale): - old_pre_quant_scale = module.input_quantizer._pre_quant_scale - module.weight = nn.Parameter( - module.weight - * old_pre_quant_scale.to( - dtype=module.weight.dtype, device=module.weight.device - ) - / new_pre_quant_scale.to( - dtype=module.weight.dtype, device=module.weight.device - ) + def _update_pre_quant_scale(module, new_pre_quant_scale): + old_pre_quant_scale = module.input_quantizer._pre_quant_scale + module.weight = nn.Parameter( + module.weight + * old_pre_quant_scale.to( + dtype=module.weight.dtype, device=module.weight.device + ) + / new_pre_quant_scale.to( + dtype=module.weight.dtype, device=module.weight.device ) - module.input_quantizer.pre_quant_scale = new_pre_quant_scale + ) + module.input_quantizer.pre_quant_scale = new_pre_quant_scale - # Redo weights collection - module.weight_quantizer.reset_amax() - enable_stats_collection(module.weight_quantizer) - module.weight_quantizer(module.weight) - finish_stats_collection(module.weight_quantizer) + # Redo weights collection + module.weight_quantizer.reset_amax() + enable_stats_collection(module.weight_quantizer) + module.weight_quantizer(module.weight) + finish_stats_collection(module.weight_quantizer) - # Update o_proj's pre_quant_scale - _update_pre_quant_scale(linear_pqs_from, repeated_scale) + # Update o_proj's pre_quant_scale + _update_pre_quant_scale(linear_pqs_from, repeated_scale) - # Use averaged scale (flattened) for v_proj fusion - pre_quant_scale = averaged_scale.reshape(-1) + # Use averaged scale (flattened) for v_proj fusion + pre_quant_scale = averaged_scale.reshape(-1) # Fuse the pre_quant_scale to v_proj weight linear_fuse_into.weight = torch.nn.Parameter( diff --git a/tests/gpu/torch/export/test_quant_utils.py b/tests/gpu/torch/export/test_quant_utils.py index 81e750227..1bd213fd3 100644 --- a/tests/gpu/torch/export/test_quant_utils.py +++ b/tests/gpu/torch/export/test_quant_utils.py @@ -74,7 +74,7 @@ def test_pattern_fuse_prequant(quant_config, attention_kv_heads_pair): ] # Apply fusion - pattern_fuse_prequant(model) + pattern_fuse_prequant(model, fuse_mismatch_dim=True) # Check if pre_quant_scale and fused_with_prequant flag are removed correctly for target_module_name in traget_module_name_list: @@ -97,3 +97,127 @@ def test_pattern_fuse_prequant(quant_config, attention_kv_heads_pair): assert torch.allclose( output_before_fuse.logits, output_after_fuse.logits, rtol=1e-1, atol=5e-1 ), "Output should be the same before and after fusion" + + +# TODO: add test for Qwen3MoeSparseMoeBlock MLP fusion + + +@pytest.mark.parametrize( + "quant_config", + [ + mtq.INT4_AWQ_CFG, + mtq.NVFP4_AWQ_LITE_CFG, + ], +) +def test_pattern_fuse_prequant_moe(quant_config): + """Test pattern_fuse_prequant on Qwen3 MoE sparse MLP.""" + pytest.importorskip("transformers", minversion="4.46.0") + from transformers import Qwen3MoeConfig, Qwen3MoeForCausalLM + + # Create a tiny Qwen3MoE model for testing + config = Qwen3MoeConfig( + hidden_size=128, + intermediate_size=256, + moe_intermediate_size=256, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + num_experts=4, + num_experts_per_tok=2, + max_position_embeddings=128, + vocab_size=256, + shared_expert_intermediate_size=256, + ) + model = Qwen3MoeForCausalLM(config).to("cuda") + + # Quantize the model + dummy_input = torch.randint(0, 256, (1, 16), device="cuda") + mtq.quantize(model, quant_config, lambda m: m(dummy_input)) + + # Collect MoE expert modules to verify (down_proj should be fused) + moe_down_proj_modules = [] + moe_gate_proj_modules = [] + moe_up_proj_modules = [] + for name, module in model.named_modules(): + if "mlp" in name and "experts" in name: + if "gate_proj" in name and not any(x in name for x in ["weight", "quantizer"]): + moe_gate_proj_modules.append((name, module)) + elif "down_proj" in name and not any(x in name for x in ["weight", "quantizer"]): + moe_down_proj_modules.append((name, module)) + elif "up_proj" in name and not any(x in name for x in ["weight", "quantizer"]): + moe_up_proj_modules.append((name, module)) + + # Verify experts have pre_quant_scale before fusion + for name, module in moe_gate_proj_modules: + if hasattr(module, "input_quantizer"): + assert hasattr(module.input_quantizer, "_pre_quant_scale"), ( + f"{name}: gate_proj should have pre_quant_scale before fusion" + ) + + for name, module in moe_up_proj_modules: + if hasattr(module, "input_quantizer"): + assert hasattr(module.input_quantizer, "_pre_quant_scale"), ( + f"{name}: up_proj should have pre_quant_scale before fusion" + ) + + for name, module in moe_down_proj_modules: + if hasattr(module, "input_quantizer"): + assert hasattr(module.input_quantizer, "_pre_quant_scale"), ( + f"{name}: down_proj should have pre_quant_scale before fusion" + ) + + # Run forward pass before fusion + model.eval() + with torch.no_grad(): + output_before_fuse = model(dummy_input) + + # Apply fusion (fuse_mismatch_dim only needed for GQA/MQA attention, not for MLP) + pattern_fuse_prequant(model) + + # Check if down_proj's pre_quant_scale was removed and fused into up_proj + for name, module in moe_down_proj_modules: + if hasattr(module, "input_quantizer"): + # Verify pre_quant_scale was removed from down_proj + assert not hasattr(module.input_quantizer, "_pre_quant_scale"), ( + f"{name}: down_proj pre_quant_scale should be removed after fusion" + ) + # Verify fused_with_prequant flag was set + assert hasattr(module, "fused_with_prequant") and module.fused_with_prequant, ( + f"{name}: down_proj should have fused_with_prequant flag set" + ) + + # Verify that gate_proj and up_proj still have pre_quant_scale and are resmoothed + for name, module in model.named_modules(): + if "Qwen3MoeSparseMoeBlock".lower() in type(module).__name__.lower(): + first_gate_scale = getattr( + getattr(module, "experts")[0], "gate_proj" + ).input_quantizer._pre_quant_scale + first_up_scale = getattr( + getattr(module, "experts")[0], "up_proj" + ).input_quantizer._pre_quant_scale + + # gate_proj and up_proj should have the same scale after resmoothing + assert torch.allclose(first_gate_scale, first_up_scale), ( + "gate_proj and up_proj should have the same pre_quant_scale after resmoothing" + ) + + # All experts should have the same gate_proj and up_proj scales + for i, expert in enumerate(getattr(module, "experts")): + gate_scale = getattr(expert, "gate_proj").input_quantizer._pre_quant_scale + up_scale = getattr(expert, "up_proj").input_quantizer._pre_quant_scale + + assert torch.allclose(gate_scale, first_gate_scale), ( + f"Expert {i} gate_proj scale should match expert 0" + ) + assert torch.allclose(up_scale, first_up_scale), ( + f"Expert {i} up_proj scale should match expert 0" + ) + + # Verify output is close to the original output + with torch.no_grad(): + output_after_fuse = model(dummy_input) + + # There will be some difference due to quantization errors after pre_quant_scale fusion + assert torch.allclose( + output_before_fuse.logits, output_after_fuse.logits, rtol=1e-1, atol=5e-1 + ), "Output should be similar before and after Qwen3 MoE fusion" From d4ffb096dc847a92b6f155404d2a79e5c3067876 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Mon, 3 Nov 2025 20:56:54 +0000 Subject: [PATCH 09/11] fix moe fusion Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 28 ++++++---------------- modelopt/torch/export/unified_export_hf.py | 7 +++--- tests/gpu/torch/export/test_quant_utils.py | 6 ++--- 3 files changed, 14 insertions(+), 27 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 947304084..7404e1eef 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -950,9 +950,8 @@ def all_items_same(item_list): ] -# TODO: make this more general instead of rule based -def pattern_fuse_prequant(model: torch.nn.Module, fuse_mismatch_dim=False): - """Fuse pre_quant_scale to the linear weights. +def fuse_prequant_to_linear(model: torch.nn.Module, fuse_grouped_heads=False): + """Fuse pre_quant_scale to the linear weights if possible. For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that the results are mathematically equivalent to the following:: @@ -967,26 +966,13 @@ def pattern_fuse_prequant(model: torch.nn.Module, fuse_mismatch_dim=False): Args: model: The model to fuse pre_quant_scale to. - fuse_mismatch_dim: If True, fuse the pre_quant_scale even if dimension between pre_quant_scale + fuse_grouped_heads: If True, fuse the pre_quant_scale even if dimension between pre_quant_scale and linear weights is not the same. This is useful for GQA/MQA models but may lead to accuracy drop. Note: - This is an experimental feature, and it might mess up the quantization errors - of fused linear modules. + Fuse_grouped_heads is useful for GQA/MQA models but may lead to accuracy drop. """ - # For MoE models, let's first resmooth the w1 and w3 in experts to get the average pre_quant_scale - for _, module in model.named_modules(): - if ( - hasattr(module, "experts") - and "Qwen3MoeSparseMoeBlock".lower() in type(module).__name__.lower() - ): - linear_list = [] - linear_list.extend([getattr(expert, "up_proj") for expert in module.experts]) - linear_list.extend([getattr(expert, "gate_proj") for expert in module.experts]) - preprocess_linear_fusion(linear_list, resmooth_only=True) - - # import pdb; pdb.set_trace() # Fuse pre_quant_scale to the linear weights for _, module in model.named_modules(): for module_map in PQS_FUSE_MODULE_MAPPING: @@ -1000,10 +986,10 @@ def pattern_fuse_prequant(model: torch.nn.Module, fuse_mismatch_dim=False): ): pre_quant_scale = linear_pqs_from.input_quantizer._pre_quant_scale - # for GQA/MQA models, we apply averaging to the pre_quant_scale for shared head groups + # for GQA/MQA models, we can apply averaging to the pre_quant_scale for shared head groups if pre_quant_scale.numel() != linear_fuse_into.weight.shape[-2]: if ( - not fuse_mismatch_dim + not fuse_grouped_heads or "attention" not in type(module).__name__.lower() ): warn( @@ -1053,7 +1039,7 @@ def _update_pre_quant_scale(module, new_pre_quant_scale): # Use averaged scale (flattened) for v_proj fusion pre_quant_scale = averaged_scale.reshape(-1) - # Fuse the pre_quant_scale to v_proj weight + # Fuse the pre_quant_scale to weight linear_fuse_into.weight = torch.nn.Parameter( linear_fuse_into.weight * pre_quant_scale.view(-1, 1) ) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index c14c73013..5d1de5b81 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -60,6 +60,7 @@ from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only from .quant_utils import ( fuse_prequant_layernorm, + fuse_prequant_to_linear, get_activation_scaling_factor, get_quant_config, get_quantization_format, @@ -67,7 +68,6 @@ get_weight_scaling_factor, get_weight_scaling_factor_2, maybe_transpose_expert_weight_dimensions, - pattern_fuse_prequant, postprocess_state_dict, preprocess_linear_fusion, to_quantized_weight, @@ -108,6 +108,9 @@ def _output_hook(module, input, output): fused_linears = {} module_names = set() + # Fuse pre_quant_scale to the linear weights if possible + fuse_prequant_to_linear(model) + for name, module in model.named_modules(): module_names.add(name) @@ -199,8 +202,6 @@ def _output_hook(module, input, output): with fsdp2_aware_weight_update(model, output_to_layernorm[tensor]): fuse_prequant_layernorm(output_to_layernorm[tensor], modules) - pattern_fuse_prequant(model) - # The dummy forward may not be able to activate all the experts. # Process experts by naming rules like experts.0, experts.1, etc. for name, modules_fused in fused_linears.items(): diff --git a/tests/gpu/torch/export/test_quant_utils.py b/tests/gpu/torch/export/test_quant_utils.py index 1bd213fd3..82a7f6e11 100644 --- a/tests/gpu/torch/export/test_quant_utils.py +++ b/tests/gpu/torch/export/test_quant_utils.py @@ -21,7 +21,7 @@ from transformers import LlamaConfig, LlamaForCausalLM import modelopt.torch.quantization as mtq -from modelopt.torch.export.quant_utils import pattern_fuse_prequant +from modelopt.torch.export.quant_utils import fuse_prequant_to_linear def get_tiny_llama(attention_heads=4, key_value_heads=4): @@ -74,7 +74,7 @@ def test_pattern_fuse_prequant(quant_config, attention_kv_heads_pair): ] # Apply fusion - pattern_fuse_prequant(model, fuse_mismatch_dim=True) + fuse_prequant_to_linear(model, fuse_grouped_heads=True) # Check if pre_quant_scale and fused_with_prequant flag are removed correctly for target_module_name in traget_module_name_list: @@ -172,7 +172,7 @@ def test_pattern_fuse_prequant_moe(quant_config): output_before_fuse = model(dummy_input) # Apply fusion (fuse_mismatch_dim only needed for GQA/MQA attention, not for MLP) - pattern_fuse_prequant(model) + fuse_prequant_to_linear(model) # Check if down_proj's pre_quant_scale was removed and fused into up_proj for name, module in moe_down_proj_modules: From 8b0bc54f5d5eaec63f891b273d44a6513b56c8d6 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Mon, 3 Nov 2025 23:46:16 +0000 Subject: [PATCH 10/11] fix test Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- tests/gpu/torch/export/test_quant_utils.py | 32 +--------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/tests/gpu/torch/export/test_quant_utils.py b/tests/gpu/torch/export/test_quant_utils.py index 82a7f6e11..16b4f524c 100644 --- a/tests/gpu/torch/export/test_quant_utils.py +++ b/tests/gpu/torch/export/test_quant_utils.py @@ -99,9 +99,6 @@ def test_pattern_fuse_prequant(quant_config, attention_kv_heads_pair): ), "Output should be the same before and after fusion" -# TODO: add test for Qwen3MoeSparseMoeBlock MLP fusion - - @pytest.mark.parametrize( "quant_config", [ @@ -111,7 +108,7 @@ def test_pattern_fuse_prequant(quant_config, attention_kv_heads_pair): ) def test_pattern_fuse_prequant_moe(quant_config): """Test pattern_fuse_prequant on Qwen3 MoE sparse MLP.""" - pytest.importorskip("transformers", minversion="4.46.0") + pytest.importorskip("transformers") from transformers import Qwen3MoeConfig, Qwen3MoeForCausalLM # Create a tiny Qwen3MoE model for testing @@ -186,33 +183,6 @@ def test_pattern_fuse_prequant_moe(quant_config): f"{name}: down_proj should have fused_with_prequant flag set" ) - # Verify that gate_proj and up_proj still have pre_quant_scale and are resmoothed - for name, module in model.named_modules(): - if "Qwen3MoeSparseMoeBlock".lower() in type(module).__name__.lower(): - first_gate_scale = getattr( - getattr(module, "experts")[0], "gate_proj" - ).input_quantizer._pre_quant_scale - first_up_scale = getattr( - getattr(module, "experts")[0], "up_proj" - ).input_quantizer._pre_quant_scale - - # gate_proj and up_proj should have the same scale after resmoothing - assert torch.allclose(first_gate_scale, first_up_scale), ( - "gate_proj and up_proj should have the same pre_quant_scale after resmoothing" - ) - - # All experts should have the same gate_proj and up_proj scales - for i, expert in enumerate(getattr(module, "experts")): - gate_scale = getattr(expert, "gate_proj").input_quantizer._pre_quant_scale - up_scale = getattr(expert, "up_proj").input_quantizer._pre_quant_scale - - assert torch.allclose(gate_scale, first_gate_scale), ( - f"Expert {i} gate_proj scale should match expert 0" - ) - assert torch.allclose(up_scale, first_up_scale), ( - f"Expert {i} up_proj scale should match expert 0" - ) - # Verify output is close to the original output with torch.no_grad(): output_after_fuse = model(dummy_input) From 234b7c2d83ece9a42057c447fda7d95998040943 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Tue, 4 Nov 2025 00:09:43 +0000 Subject: [PATCH 11/11] only fuse for nvfp4 awq Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/export/unified_export_hf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 5d1de5b81..09b19b595 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -109,7 +109,8 @@ def _output_hook(module, input, output): module_names = set() # Fuse pre_quant_scale to the linear weights if possible - fuse_prequant_to_linear(model) + if "NVFP4_AWQ" in quantization_format: + fuse_prequant_to_linear(model) for name, module in model.named_modules(): module_names.add(name)