Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 123 additions & 3 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames

if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"):
return QUANTIZATION_NVFP4_AWQ
if getattr(layer, "fused_with_layernorm", False):
if getattr(layer, "fused_with_prequant", False):
return QUANTIZATION_NVFP4_AWQ
assert input_quantizer is not None, (
f"input_quantizer is None for {quantizer_attr_names}"
Expand Down Expand Up @@ -923,18 +923,138 @@ def all_items_same(item_list):
return all(x == item_list[0] for x in item_list)


# Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale))
PQS_FUSE_MODULE_MAPPING = [
# Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension
# Mathematical equivalence:
# Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T
# After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T
(["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj")),
# MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension
# Mathematical equivalence:
# Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T
# After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T
(["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")),
]


# TODO: make this more general instead of rule based
def pattern_fuse_prequant(model: torch.nn.Module):
"""Fuse pre_quant_scale to the linear weights.

For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that
the results are mathematically equivalent to the following::

out_proj.input = (attn_weights @ v_proj.output)
out_proj.output = (out_proj.input * pre_quant_scale) * out_proj.weight
= attn_weights @ (v_proj.output * pre_quant_scale) * out_proj.weight

For GQA/MQA models where v_proj output dimension < o_proj input dimension,
the pre_quant_scale is averaged across the repeated head groups and then the
o_proj's pre_quant_scale is updated to maintain mathematical equivalence.

Note:
This is an experimental feature, and it might mess up the quantization errors
of fused linear modules.
"""
for _, module in model.named_modules():
for module_map in PQS_FUSE_MODULE_MAPPING:
target_module_list = module_map[0]
linear_pair = module_map[1]
if any(module_name in type(module).__name__ for module_name in target_module_list):
linear_to = module.get_submodule(linear_pair[0])
linear_from = module.get_submodule(linear_pair[1])
if hasattr(linear_from, "input_quantizer") and hasattr(
linear_from.input_quantizer, "_pre_quant_scale"
):
pre_quant_scale = linear_from.input_quantizer._pre_quant_scale

# for GQA/MQA models, we apply averaging to the pre_quant_scale
if pre_quant_scale.numel() != linear_to.weight.shape[0]:
if "attention" not in type(module).__name__.lower():
continue
else:
config = module.config
num_kv_heads = config.num_key_value_heads
kv_head_dim = linear_to.weight.shape[0] // num_kv_heads
n_rep = pre_quant_scale.numel() // num_kv_heads // kv_head_dim

# Reshape:(num_kv_heads, n_rep, kv_head_dim)
averaged_scale = pre_quant_scale.view(
num_kv_heads, n_rep, kv_head_dim
).mean(dim=1)

# To update o_proj, we need to repeat back to original shape
repeated_scale = (
averaged_scale.unsqueeze(1) # (2, 1, 16)
.expand(num_kv_heads, n_rep, kv_head_dim) # (2, 2, 16)
.reshape(-1) # (64,)
)

def _update_pre_quant_scale(module, new_pre_quant_scale):
old_pre_quant_scale = module.input_quantizer._pre_quant_scale
module.weight = nn.Parameter(
module.weight
* old_pre_quant_scale.to(
dtype=module.weight.dtype, device=module.weight.device
)
/ new_pre_quant_scale.to(
dtype=module.weight.dtype, device=module.weight.device
)
)
module.input_quantizer.pre_quant_scale = new_pre_quant_scale

# Redo weights collection
module.weight_quantizer.reset_amax()
enable_stats_collection(module.weight_quantizer)
module.weight_quantizer(module.weight)
finish_stats_collection(module.weight_quantizer)

# Update o_proj's pre_quant_scale
_update_pre_quant_scale(linear_from, repeated_scale)

# Use averaged scale (flattened) for v_proj fusion
pre_quant_scale = averaged_scale.reshape(-1)

# Fuse the pre_quant_scale to v_proj weight (linear_to)
# v_proj.weight shape: (out_features, in_features) = (32, hidden_size)
# We scale the output dimension (first dimension)
linear_to.weight = torch.nn.Parameter(
linear_to.weight * pre_quant_scale.view(-1, 1)
)
if hasattr(linear_to, "bias") and linear_to.bias is not None:
linear_to.bias = torch.nn.Parameter(linear_to.bias * pre_quant_scale)

delattr(linear_from.input_quantizer, "_pre_quant_scale")
setattr(linear_from, "fused_with_prequant", True)


def fuse_prequant_layernorm(
layernorm_module: torch.nn.Module,
modules: list[torch.Tensor],
):
"""Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted."""
"""Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted.

original:
layernorm_output = (normalization(input) * weight) + bias
layernorm_output_scaled = layernorm_output * pre_quant_scale

fused:
fused_weight = weight * avg_pre_quant_scale
fused_bias = bias * avg_pre_quant_scale
layernorm_output_scaled = (normalization(input) * fused_weight) + fused_bias
"""
layernorm_module.weight = torch.nn.Parameter(
layernorm_module.weight * getattr(modules[0].input_quantizer, "_pre_quant_scale")
)
if hasattr(layernorm_module, "bias"):
layernorm_module.bias = torch.nn.Parameter(
layernorm_module.bias * getattr(modules[0].input_quantizer, "_pre_quant_scale")
)
# Pre_quant_scales of modules must not be exported, since they have been fused with layernorm
for module in modules:
delattr(module.input_quantizer, "_pre_quant_scale")
setattr(module, "fused_with_layernorm", True)
setattr(module, "fused_with_prequant", True)


def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False):
Expand Down
3 changes: 3 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
get_weight_scaling_factor,
get_weight_scaling_factor_2,
maybe_transpose_expert_weight_dimensions,
pattern_fuse_prequant,
postprocess_state_dict,
preprocess_linear_fusion,
to_quantized_weight,
Expand Down Expand Up @@ -173,6 +174,8 @@ def _output_hook(module, input, output):
# Pre quant scale of modules is already updated to avg_pre_quant_scale
fuse_prequant_layernorm(output_to_layernorm[tensor], modules)

pattern_fuse_prequant(model)

# The dummy forward may not be able to activate all the experts.
# Process experts by naming rules like experts.0, experts.1, etc.
for name, modules_fused in fused_linears.items():
Expand Down
99 changes: 99 additions & 0 deletions tests/gpu/torch/export/test_quant_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

pytest.importorskip("transformers")

from transformers import LlamaConfig, LlamaForCausalLM

import modelopt.torch.quantization as mtq
from modelopt.torch.export.quant_utils import pattern_fuse_prequant


def get_tiny_llama(attention_heads=4, key_value_heads=4):
"""Create a tiny Llama model for testing."""
config = LlamaConfig(
hidden_size=64,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=attention_heads,
num_key_value_heads=key_value_heads,
max_position_embeddings=128,
vocab_size=256,
)
return LlamaForCausalLM(config)


@pytest.mark.parametrize(
"quant_config",
[
mtq.INT4_AWQ_CFG,
mtq.NVFP4_AWQ_LITE_CFG,
],
)
@pytest.mark.parametrize(
"attention_kv_heads_pair",
[
(4, 4), # MHA
(4, 2), # GQA
(4, 1), # MQA
],
)
def test_pattern_fuse_prequant(quant_config, attention_kv_heads_pair):
"""Test pattern_fuse_prequant on modules from a tiny Llama model."""
model = get_tiny_llama(attention_kv_heads_pair[0], attention_kv_heads_pair[1]).to("cuda")

# Quantize the model
dummy_input = torch.randint(0, 256, (1, 16), device="cuda")
mtq.quantize(model, quant_config, lambda m: m(dummy_input))

# Run forward pass before fusion
model.eval()
with torch.no_grad():
output_before_fuse = model(dummy_input)

traget_module_name_list = [
"model.layers.0.self_attn.o_proj",
"model.layers.0.mlp.down_proj",
"model.layers.1.self_attn.o_proj",
"model.layers.1.mlp.down_proj",
]

# Apply fusion
pattern_fuse_prequant(model)

# Check if pre_quant_scale and fused_with_prequant flag are removed correctly
for target_module_name in traget_module_name_list:
target_module = model.get_submodule(target_module_name)

# Verify pre_quant_scale was removed
assert not hasattr(target_module.input_quantizer, "_pre_quant_scale"), (
f"{target_module_name}: pre_quant_scale should be removed after fusion"
)

# Verify fused_with_prequant flag was set
assert (
hasattr(target_module, "fused_with_prequant") and target_module.fused_with_prequant
), f"{target_module_name}: fused_with_prequant flag should be set"

# Verify output is close to the original output
with torch.no_grad():
output_after_fuse = model(dummy_input)
# There will be some small difference due to quantization errors after pre_quant_scale fusion to the weights
assert torch.allclose(
output_before_fuse.logits, output_after_fuse.logits, rtol=1e-1, atol=5e-1
), "Output should be the same before and after fusion"