From 135e637aacc5a9563464269b3e03372d84e900b8 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 20 Oct 2025 21:53:25 +0800 Subject: [PATCH 01/10] add awq for llama/qwen dense --- .../layer_weights/meta_weights/base_weight.py | 5 + .../meta_weights/mm_weight/colmm_weight.py | 52 +++++++- .../meta_weights/mm_weight/mm_weight.py | 124 ++++++++++++++++++ .../meta_weights/mm_weight/rowmm_weight.py | 85 ++++++++++-- lightllm/common/quantization/__init__.py | 11 +- lightllm/common/quantization/awq_quant.py | 58 ++++++++ .../common/quantization/deepgemm_quant.py | 3 + .../common/quantization/quantize_method.py | 3 + 8 files changed, 327 insertions(+), 14 deletions(-) create mode 100644 lightllm/common/quantization/awq_quant.py diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py index b67fc1b43..30a736b4f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py @@ -44,9 +44,14 @@ def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: # load quantization scale pass + def _load_zero_points(self, weights: Dict[str, torch.Tensor]) -> None: + # load quantization zero points + pass + def load_hf_weights(self, weights): self._load_weights(weights) self._load_scales(weights) + self._load_zero_points(weights) return def verify_load(self): diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py index d6d064cf4..59920a055 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py @@ -3,6 +3,7 @@ MMWeight, MMWeightTpl, generate_scale_name, + AWQMMWeightTpl, ) from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id @@ -15,8 +16,7 @@ class COLMMWeight(MMWeight): def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): if quant_method is None or not quantized_weight: return UnquantizedCOLMMWeight - else: - return W8A8B128COLMMWeight + return COLBMM_WEIGHT_CLS_MAP[quant_method.get_name()] class UnquantizedCOLMMWeight(MMWeightTpl): @@ -97,3 +97,51 @@ def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: None, ] return + + +class AWQCOLMMWeight(AWQMMWeightTpl): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + bias_name: Optional[str] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__(data_type, quant_method, tp_rank, tp_world_size) + self.weight_name = weight_name.replace("weight", quant_method.weight_suffix) + self.weight_scale_name = weight_name.replace("weight", quant_method.weight_scale_suffix) + self.weight_zero_point_name = weight_name.replace("weight", quant_method.weight_zero_point_suffix) + self.bias_name = bias_name + self.weight_scale: Optional[torch.Tensor] = None + self.quantized_weight = True + self.weight = [None, None, None] + + def _slice_weight(self, weight: torch.Tensor): + assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}" + tp_size = weight.shape[0] // self.tp_world_size_ + return weight[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1), :] + + def _slice_bias(self, bias): + assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" + tp_size = bias.shape[0] // self.tp_world_size_ + return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1), :] + + def _slice_weight_scale(self, weight_scale: torch.Tensor): + tp_size = weight_scale.shape[0] // self.tp_world_size_ + scale_start = tp_size * self.tp_rank_ + scale_end = tp_size * (self.tp_rank_ + 1) + return weight_scale[scale_start:scale_end, :].to(torch.half) + + def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor): + tp_size = weight_zero_point.shape[0] // self.tp_world_size_ + zero_point_start = tp_size * self.tp_rank_ + zero_point_end = tp_size * (self.tp_rank_ + 1) + return weight_zero_point[zero_point_start:zero_point_end, :] + + +COLBMM_WEIGHT_CLS_MAP = { + "fp8w8a8b128": W8A8B128COLMMWeight, + "awq": AWQCOLMMWeight, +} diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 79787bb18..4cd1224df 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -163,6 +163,129 @@ def _process_weight(self, weight) -> None: self.weight = weight.cuda(get_current_device_id()) +class AWQMMWeightTpl(MMWeightTpl): + def __init__( + self, + data_type: torch.dtype, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__(data_type, quant_method, tp_rank, tp_world_size) + self.weight = [None, None, None] + + def verify_load(self) -> bool: + load_ok = True + # Verify weight. The weight must be not None. + weight_ok = all(w is not None for w in self.weight) + load_ok = load_ok and weight_ok + # Verify bias. If bias_name is set, it must be not None. + if self.has_bias: + load_ok = load_ok and self.bias is not None + return load_ok + + def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None: + if self.weight_name is not None and self.weight_name in weights: + weight = weights[self.weight_name] + weight = self._slice_weight(weight) + self.weight[0] = weight.cuda(get_current_device_id()) + if self.bias_name is not None and self.bias_name in weights: + bias = weights[self.bias_name] + bias = self._slice_bias(bias) + self.bias = bias.cuda(get_current_device_id()) + + def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: + if self.weight_scale_name is not None and self.weight_scale_name in weights: + weight_scale = weights[self.weight_scale_name] + weight_scale = self._slice_weight_scale(weight_scale) + self.weight[1] = weight_scale.cuda(get_current_device_id()) + + def _load_zero_points(self, weights: Dict[str, torch.Tensor]) -> None: + if self.weight_zero_point_name is not None and self.weight_zero_point_name in weights: + weight_zero_point = weights[self.weight_zero_point_name] + weight_zero_point = self._slice_weight_zero_point(weight_zero_point) + self.weight[2] = weight_zero_point.cuda(get_current_device_id()) + + +class AWQMultiMMWeightTpl(AWQMMWeightTpl): + def __init__( + self, + weight_names: List[str], + data_type: torch.dtype, + bias_names: Optional[List[str]] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__(data_type, quant_method, tp_rank, tp_world_size) + self.weight_names = [weight.replace("weight", quant_method.weight_suffix) for weight in weight_names] + self.weight_scale_names = [ + weight.replace("weight", quant_method.weight_scale_suffix) for weight in weight_names + ] + self.weight_zero_point_names = [ + weight.replace("weight", quant_method.weight_zero_point_suffix) for weight in weight_names + ] + self.bias_names = bias_names + self.weights = [None] * len(self.weight_names) + self.weight_scales = [None] * len(self.weight_names) + self.weight_zero_points = [None] * len(self.weight_names) + if self.bias_names is not None: + self.biases = [None] * len(self.bias_names) + self.has_bias = all(b is not None for b in self.bias_names) and len(bias_names) > 0 + else: + self.biases = None + self.has_bias = False + + def _fuse(self) -> None: + if self.weight[0] is None and (None not in self.weights): + weight = torch.cat(self.weights, dim=1) + self.weight[0] = weight.cuda(get_current_device_id()) + delattr(self, "weights") + + if self.weight[1] is None and (None not in self.weight_scales): + # awq 保存的量化参数,weight shape 是 in x out。所以这里的cat dim 是 1 + weight_scale = torch.cat(self.weight_scales, dim=1).cuda(get_current_device_id()) + self.weight[1] = weight_scale.cuda(get_current_device_id()) + delattr(self, "weight_scales") + + if self.weight[2] is None and (None not in self.weight_zero_points): + weight_zero_point = torch.cat(self.weight_zero_points, dim=1) + self.weight[2] = weight_zero_point.cuda(get_current_device_id()) + print("weight_zero_point", self.weight[2].dtype) + delattr(self, "weight_zero_points") + + if self.has_bias and self.bias is None and (None not in self.biases): + self.bias = torch.cat(self.biases, dim=0).cuda(get_current_device_id()) + delattr(self, "biases") + return self + + def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None: + for i in range(len(self.weight_names)): + if self.weight_names[i] is not None and self.weight_names[i] in weights: + weight = weights[self.weight_names[i]] + weight = self._slice_weight(weight) + self.weights[i] = weight.cuda(get_current_device_id()) + + def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: + for i in range(len(self.weight_names)): + if self.weight_scale_names[i] is not None and self.weight_scale_names[i] in weights: + weight_scale = weights[self.weight_scale_names[i]] + weight_scale = self._slice_weight_scale(weight_scale) + self.weight_scales[i] = weight_scale.cuda(get_current_device_id()) + + def _load_zero_points(self, weights: Dict[str, torch.Tensor]) -> None: + for i in range(len(self.weight_names)): + if self.weight_zero_point_names[i] is not None and self.weight_zero_point_names[i] in weights: + weight_zero_point = weights[self.weight_zero_point_names[i]] + weight_zero_point = self._slice_weight_zero_point(weight_zero_point) + self.weight_zero_points[i] = weight_zero_point.cuda(get_current_device_id()) + + def load_hf_weights(self, weights): + super().load_hf_weights(weights) + self._fuse() + return + + class MMWeight: def __new__(cls, **kwargs): quant_cfg = kwargs.pop("quant_cfg", None) @@ -178,6 +301,7 @@ def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> Q if quant_cfg is None: return None, False quant_method = quant_cfg.get_quant_method(layer_num_, name) + quant_method.hf_quantization_method = quant_cfg.hf_quantization_method quantized_weight = quant_cfg.quantized_weight return quant_method, quantized_weight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index c90d7c1a3..3f0720a18 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -4,6 +4,8 @@ MMWeightTpl, BMMWeightTpl, MultiMMWeightTpl, + AWQMMWeightTpl, + AWQMultiMMWeightTpl, generate_scale_name, ) from lightllm.common.quantization import Quantcfg @@ -17,10 +19,8 @@ class ROWMMWeight(MMWeight): def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): if quant_method is None or not quantized_weight: return UnquantizedROWMMWeight - else: - return W8A8B128ROWMMWeight - # TODO: Implement more quantization weight - return None + + return ROWBMM_WEIGHT_CLS_MAP[quant_method.get_name()] class MultiROWMMWeight(MMWeight): @@ -28,10 +28,8 @@ class MultiROWMMWeight(MMWeight): def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): if quant_method is None or not quantized_weight: return UnquantizedMultiROWMMWeight - else: - return W8A8B128MultiROWMMWeight - # TODO: Implement more quantization weight - return None + + return MULTI_ROWBMM_WEIGHT_CLS_MAP[quant_method.get_name()] class ROWBMMWeight(MMWeight): @@ -256,3 +254,74 @@ def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: self.weight_scale, None, ] + + +class AWQROWMMWeight(AWQMMWeightTpl): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + bias_name: Optional[str] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__(data_type, quant_method, tp_rank, tp_world_size) + self.weight_name = weight_name.replace("weight", quant_method.weight_suffix) + self.weight_scale_name = weight_name.replace("weight", quant_method.weight_scale_suffix) + self.weight_zero_point_name = weight_name.replace("weight", quant_method.weight_zero_point_suffix) + self.bias_name = bias_name + self.weight_scale: Optional[torch.Tensor] = None + self.quantized_weight = True + self.weight = [None, None, None] + + def _slice_weight(self, weight: torch.Tensor): + assert weight.shape[1] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[1]} % {self.tp_world_size_}" + tp_size = weight.shape[1] // self.tp_world_size_ + return weight[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + + def _slice_bias(self, bias): + assert bias.shape[1] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[1]} % {self.tp_world_size_}" + tp_size = bias.shape[1] // self.tp_world_size_ + return bias[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + + def _slice_weight_scale(self, weight_scale: torch.Tensor): + tp_size = weight_scale.shape[1] // self.tp_world_size_ + scale_start = tp_size * self.tp_rank_ + scale_end = tp_size * (self.tp_rank_ + 1) + return weight_scale[:, scale_start:scale_end].to(torch.half) + + def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor): + tp_size = weight_zero_point.shape[1] // self.tp_world_size_ + zero_point_start = tp_size * self.tp_rank_ + zero_point_end = tp_size * (self.tp_rank_ + 1) + return weight_zero_point[:, zero_point_start:zero_point_end] + + +class AWQMultiROWMMWeight(AWQMultiMMWeightTpl): + _slice_weight = AWQROWMMWeight._slice_weight + _slice_bias = AWQROWMMWeight._slice_bias + _slice_weight_scale = AWQROWMMWeight._slice_weight_scale + _slice_weight_zero_point = AWQROWMMWeight._slice_weight_zero_point + + def __init__( + self, + weight_names: List[str], + data_type: torch.dtype, + bias_names: Optional[List[str]] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__(weight_names, data_type, bias_names, quant_method, tp_rank, tp_world_size) + + +ROWBMM_WEIGHT_CLS_MAP = { + "fp8w8a8b128": W8A8B128ROWMMWeight, + "awq": AWQROWMMWeight, +} + +MULTI_ROWBMM_WEIGHT_CLS_MAP = { + "fp8w8a8b128": W8A8B128MultiROWMMWeight, + "awq": AWQMultiROWMMWeight, +} diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index 611a3407a..4f05e96d2 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -5,6 +5,7 @@ from .w8a8_quant import * from .triton_quant.triton_quant import * from .deepgemm_quant import * +from .awq_quant import * from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -43,10 +44,12 @@ def _mapping_quant_method(self): else: self.quant_type = "vllm-fp8w8a8-b128" logger.info(f"select fp8w8a8-b128 quant way: {self.quant_type}") - - else: - # TODO: more quant method - pass + elif self.hf_quantization_method == "awq": + self.quant_type = "awq" + logger.info(f"select awq quant way: {self.quant_type}") + else: + # TODO: more quant method + pass def _parse_custom_cfg(self, custom_cfg_path): self.quant_cfg = collections.defaultdict(dict) diff --git a/lightllm/common/quantization/awq_quant.py b/lightllm/common/quantization/awq_quant.py new file mode 100644 index 000000000..b17ebef7c --- /dev/null +++ b/lightllm/common/quantization/awq_quant.py @@ -0,0 +1,58 @@ +import os +import torch +from .quantize_method import QuantizationMethod +from .registry import QUANTMETHODS +import torch.nn.functional as F +from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm +from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops + +if HAS_VLLM: + awq_dequantize = vllm_ops.awq_dequantize + awq_gemm = vllm_ops.awq_gemm + + +class AWQBaseQuantizationMethod(QuantizationMethod): + def __init__(self): + super().__init__() + assert HAS_VLLM, "vllm are not installed, you can't use quant api of them." + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + self.cache_manager = g_cache_manager + + def quantize(self, weight: torch.Tensor): + """ """ + pass + + def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): + """ """ + pass + + +@QUANTMETHODS.register("awq") +class AWQW4A16QuantizationMethod(AWQBaseQuantizationMethod): + def __init__(self): + super().__init__() + self.pack_factor = 8 + self.weight_scale_suffix = "scales" + self.weight_zero_point_suffix = "qzeros" + self.weight_suffix = "qweight" + + def get_name(self): + return "awq" + + def quantize(self, weight: torch.Tensor): + raise NotImplementedError("AWQ online quantization is not supported yet.") + + def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): + qweight, weight_scale, qzeros = weights + + NEED_DEQUANT_WEIGHT = input_tensor.shape[:-1].numel() >= 256 + if NEED_DEQUANT_WEIGHT: + fpweight = awq_dequantize(qweight, weight_scale, qzeros, 0, 0, 0) + out = torch.matmul(input_tensor, fpweight) + else: + out = awq_gemm(input_tensor, qweight, weight_scale, qzeros, self.pack_factor) + + if bias is not None: + out.add_(bias) + return out diff --git a/lightllm/common/quantization/deepgemm_quant.py b/lightllm/common/quantization/deepgemm_quant.py index 964db67cf..6b94002d0 100644 --- a/lightllm/common/quantization/deepgemm_quant.py +++ b/lightllm/common/quantization/deepgemm_quant.py @@ -40,6 +40,9 @@ def __init__(self): self.weight_scale_suffix = "weight_scale_inv" self.act_scale_suffix = None # no support for static input tensor scale for ds model. + def get_name(self): + return "fp8w8a8b128" + def quantize(self, weight: torch.Tensor): from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index b7b4c3705..60326d92d 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -17,3 +17,6 @@ def quantize(self, weights: torch.Tensor): @abstractmethod def apply(self, input_tensor, weight, bias=None, out=None, use_custom_tensor_mananger=True): pass + + def get_name(self): + return self.__class__.__name__ From 20fdf1af827e2f9a800318ad4abfabf1a04bf2e0 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 20 Oct 2025 21:57:46 +0800 Subject: [PATCH 02/10] fix loadworker > 1 for awq --- .../common/basemodel/layer_weights/meta_weights/__init__.py | 1 + .../layer_weights/meta_weights/mm_weight/__init__.py | 1 + .../basemodel/layer_weights/transformer_layer_weight.py | 4 ++-- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index 9db819dfd..ec9d707f8 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -6,6 +6,7 @@ COLMMWeight, MultiROWMMWeight, ROWBMMWeight, + AWQMultiMMWeightTpl, ) from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight from .fused_moe_weight_tp import FusedMoeWeightTP diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py index 263112435..c273ee456 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py @@ -1,6 +1,7 @@ from .mm_weight import ( MMWeightTpl, MultiMMWeightTpl, + AWQMultiMMWeightTpl, ) from .rowmm_weight import ( ROWMMWeight, diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 48167a067..2e557e0e2 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -2,7 +2,7 @@ # from lightllm.common.layers.mm import MM from .base_layer_weight import BaseLayerWeight -from .meta_weights import BaseWeight, MultiMMWeightTpl +from .meta_weights import BaseWeight, MultiMMWeightTpl, AWQMultiMMWeightTpl from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -36,7 +36,7 @@ def load_hf_weights(self, weights): """ for attr_name in dir(self): attr = getattr(self, attr_name, None) - if isinstance(attr, MultiMMWeightTpl): + if isinstance(attr, MultiMMWeightTpl) or isinstance(attr, AWQMultiMMWeightTpl): with self.lock: attr.load_hf_weights(weights) elif isinstance(attr, BaseWeight): From 18081a682ff609e8cdee98dcff793c3ebe077a6d Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 20 Oct 2025 22:07:29 +0800 Subject: [PATCH 03/10] remove unused print --- .../basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 4cd1224df..f243df419 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -251,7 +251,6 @@ def _fuse(self) -> None: if self.weight[2] is None and (None not in self.weight_zero_points): weight_zero_point = torch.cat(self.weight_zero_points, dim=1) self.weight[2] = weight_zero_point.cuda(get_current_device_id()) - print("weight_zero_point", self.weight[2].dtype) delattr(self, "weight_zero_points") if self.has_bias and self.bias is None and (None not in self.biases): From 358821e73829ff51975df9045194c39a9fd1c573 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 21 Oct 2025 14:42:09 +0800 Subject: [PATCH 04/10] add awq marlin --- .../meta_weights/mm_weight/colmm_weight.py | 25 ++++ .../meta_weights/mm_weight/mm_weight.py | 23 +++- .../meta_weights/mm_weight/rowmm_weight.py | 50 +++++++ lightllm/common/quantization/__init__.py | 2 + lightllm/common/quantization/awq_quant.py | 124 ++++++++++++++++++ 5 files changed, 217 insertions(+), 7 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py index 59920a055..550636332 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py @@ -141,7 +141,32 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor): return weight_zero_point[zero_point_start:zero_point_end, :] +class AWQMARLINCOLMMWeight(AWQCOLMMWeight): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + bias_name: Optional[str] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__(weight_name, data_type, bias_name, quant_method, tp_rank, tp_world_size) + + def _process_weight(self, weight: torch.Tensor) -> torch.Tensor: + return self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id())) + + def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: + return self.quant_method._process_weight_scale_after_loading(weight_scale.cuda(get_current_device_id())) + + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + return self.quant_method._process_weight_zero_point_after_loading( + weight_zero_point.cuda(get_current_device_id()) + ) + + COLBMM_WEIGHT_CLS_MAP = { "fp8w8a8b128": W8A8B128COLMMWeight, "awq": AWQCOLMMWeight, + "awq_marlin": AWQMARLINCOLMMWeight, } diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index f243df419..b0ccd8064 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -184,11 +184,20 @@ def verify_load(self) -> bool: load_ok = load_ok and self.bias is not None return load_ok + def _process_weight(self, weight: torch.Tensor) -> torch.Tensor: + return weight.cuda(get_current_device_id()) + + def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: + return weight_scale.cuda(get_current_device_id()) + + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + return weight_zero_point.cuda(get_current_device_id()) + def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None: if self.weight_name is not None and self.weight_name in weights: weight = weights[self.weight_name] weight = self._slice_weight(weight) - self.weight[0] = weight.cuda(get_current_device_id()) + self.weight[0] = self._process_weight(weight) if self.bias_name is not None and self.bias_name in weights: bias = weights[self.bias_name] bias = self._slice_bias(bias) @@ -198,13 +207,13 @@ def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: if self.weight_scale_name is not None and self.weight_scale_name in weights: weight_scale = weights[self.weight_scale_name] weight_scale = self._slice_weight_scale(weight_scale) - self.weight[1] = weight_scale.cuda(get_current_device_id()) + self.weight[1] = self._process_weight_scale(weight_scale) def _load_zero_points(self, weights: Dict[str, torch.Tensor]) -> None: if self.weight_zero_point_name is not None and self.weight_zero_point_name in weights: weight_zero_point = weights[self.weight_zero_point_name] weight_zero_point = self._slice_weight_zero_point(weight_zero_point) - self.weight[2] = weight_zero_point.cuda(get_current_device_id()) + self.weight[2] = self._process_weight_zero_point(weight_zero_point) class AWQMultiMMWeightTpl(AWQMMWeightTpl): @@ -239,18 +248,18 @@ def __init__( def _fuse(self) -> None: if self.weight[0] is None and (None not in self.weights): weight = torch.cat(self.weights, dim=1) - self.weight[0] = weight.cuda(get_current_device_id()) + self.weight[0] = self._process_weight(weight) delattr(self, "weights") if self.weight[1] is None and (None not in self.weight_scales): # awq 保存的量化参数,weight shape 是 in x out。所以这里的cat dim 是 1 weight_scale = torch.cat(self.weight_scales, dim=1).cuda(get_current_device_id()) - self.weight[1] = weight_scale.cuda(get_current_device_id()) + self.weight[1] = self._process_weight_scale(weight_scale) delattr(self, "weight_scales") if self.weight[2] is None and (None not in self.weight_zero_points): weight_zero_point = torch.cat(self.weight_zero_points, dim=1) - self.weight[2] = weight_zero_point.cuda(get_current_device_id()) + self.weight[2] = self._process_weight_zero_point(weight_zero_point) delattr(self, "weight_zero_points") if self.has_bias and self.bias is None and (None not in self.biases): @@ -300,7 +309,7 @@ def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> Q if quant_cfg is None: return None, False quant_method = quant_cfg.get_quant_method(layer_num_, name) - quant_method.hf_quantization_method = quant_cfg.hf_quantization_method + quant_method.hf_quantization_config = quant_cfg.hf_quantization_config quantized_weight = quant_cfg.quantized_weight return quant_method, quantized_weight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index 3f0720a18..83d2ae9d7 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -316,12 +316,62 @@ def __init__( super().__init__(weight_names, data_type, bias_names, quant_method, tp_rank, tp_world_size) +class AWQMARLINROWMMWeight(AWQROWMMWeight): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + bias_name: Optional[str] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__(weight_name, data_type, bias_name, quant_method, tp_rank, tp_world_size) + + def _process_weight(self, weight: torch.Tensor) -> torch.Tensor: + return self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id())) + + def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: + return self.quant_method._process_weight_scale_after_loading(weight_scale.cuda(get_current_device_id())) + + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + return self.quant_method._process_weight_zero_point_after_loading( + weight_zero_point.cuda(get_current_device_id()) + ) + + +class AWQMARLINMultiROWMMWeight(AWQMultiROWMMWeight): + def __init__( + self, + weight_names: List[str], + data_type: torch.dtype, + bias_names: Optional[List[str]] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__(weight_names, data_type, bias_names, quant_method, tp_rank, tp_world_size) + + def _process_weight(self, weight: torch.Tensor) -> torch.Tensor: + return self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id())) + + def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: + return self.quant_method._process_weight_scale_after_loading(weight_scale.cuda(get_current_device_id())) + + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + return self.quant_method._process_weight_zero_point_after_loading( + weight_zero_point.cuda(get_current_device_id()) + ) + + ROWBMM_WEIGHT_CLS_MAP = { "fp8w8a8b128": W8A8B128ROWMMWeight, "awq": AWQROWMMWeight, + "awq_marlin": AWQMARLINROWMMWeight, } MULTI_ROWBMM_WEIGHT_CLS_MAP = { "fp8w8a8b128": W8A8B128MultiROWMMWeight, "awq": AWQMultiROWMMWeight, + "awq_marlin": AWQMARLINMultiROWMMWeight, } diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index 4f05e96d2..7e4a08218 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -46,6 +46,8 @@ def _mapping_quant_method(self): logger.info(f"select fp8w8a8-b128 quant way: {self.quant_type}") elif self.hf_quantization_method == "awq": self.quant_type = "awq" + if is_awq_marlin_compatible(self.hf_quantization_config): + self.quant_type = "awq_marlin" logger.info(f"select awq quant way: {self.quant_type}") else: # TODO: more quant method diff --git a/lightllm/common/quantization/awq_quant.py b/lightllm/common/quantization/awq_quant.py index b17ebef7c..7221a454e 100644 --- a/lightllm/common/quantization/awq_quant.py +++ b/lightllm/common/quantization/awq_quant.py @@ -5,10 +5,25 @@ import torch.nn.functional as F from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops +from typing import Any if HAS_VLLM: awq_dequantize = vllm_ops.awq_dequantize awq_gemm = vllm_ops.awq_gemm + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, + marlin_permute_scales, + awq_to_marlin_zero_points, + should_use_atomic_add_reduce, + marlin_make_empty_g_idx, + marlin_make_workspace_new, + ) + from vllm.scalar_type import scalar_types + + TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } class AWQBaseQuantizationMethod(QuantizationMethod): @@ -56,3 +71,112 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ if bias is not None: out.add_(bias) return out + + +@QUANTMETHODS.register("awq_marlin") +class AWQMARLINW4A16QuantizationMethod(AWQBaseQuantizationMethod): + def __init__(self): + super().__init__() + self.pack_factor = 8 + self.weight_scale_suffix = "scales" + self.weight_zero_point_suffix = "qzeros" + self.weight_suffix = "qweight" + self.g_idx = marlin_make_empty_g_idx(torch.device("cuda")) + self.g_idx_sort_indices = marlin_make_empty_g_idx(torch.device("cuda")) + self.workspace = marlin_make_workspace_new(torch.device("cuda")) + + def get_name(self): + return "awq_marlin" + + def quantize(self, weight: torch.Tensor): + raise NotImplementedError("AWQ online quantization is not supported yet.") + + def _process_weight_after_loading(self, weight: torch.Tensor) -> torch.Tensor: + assert self.hf_quantization_config is not None, "hf_quantization_config is not set" + self.k = weight.shape[0] + self.n = weight.shape[1] * self.pack_factor + return vllm_ops.awq_marlin_repack( + weight, + size_k=weight.shape[0], + size_n=weight.shape[1] * self.pack_factor, + num_bits=self.hf_quantization_config["bits"], + ) + + def _process_weight_scale_after_loading(self, weight_scale: torch.Tensor) -> torch.Tensor: + assert self.hf_quantization_config is not None, "hf_quantization_config is not set" + group_size = self.hf_quantization_config["group_size"] + return marlin_permute_scales( + weight_scale, + size_k=weight_scale.shape[0] * group_size, + size_n=weight_scale.shape[1], + group_size=self.hf_quantization_config["group_size"], + ) + + def _process_weight_zero_point_after_loading(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + return awq_to_marlin_zero_points( + weight_zero_point, + size_k=weight_zero_point.shape[0], + size_n=weight_zero_point.shape[1] * self.pack_factor, + num_bits=self.hf_quantization_config["bits"], + ) + + def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): + qweight, weight_scale, qzeros = weights + reshaped_x = input_tensor.reshape(-1, input_tensor.shape[-1]) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=self.n, + k=self.k, + device=input_tensor.device, + dtype=input_tensor.dtype, + ) + + out = vllm_ops.gptq_marlin_gemm( + reshaped_x, + None, + qweight, + bias, + weight_scale, + None, + qzeros, + self.g_idx, + self.g_idx_sort_indices, + self.workspace, + TYPE_MAP[self.hf_quantization_config["bits"]], + size_m=reshaped_x.shape[0], + size_n=self.n, + size_k=self.k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=True, + is_zp_float=False, + ) + + if bias is not None: + out.add_(bias) + return out + + +# adapted from +# https://github.com/vllm-project/vllm/blob/aef368aa08572505b820db01da82e2fbb3d43a72/vllm/model_executor/layers/quantization/awq_marlin.py#L211-L212 +def is_awq_marlin_compatible(quantization_config: dict[str, Any]): + # Extract data from quant config. + quant_method = quantization_config.get("quant_method", "").lower() + num_bits = quantization_config.get("bits") + group_size = quantization_config.get("group_size") + zero_point = quantization_config.get("zero_point") + + if not torch.cuda.is_available(): + return False + + if quant_method != "awq": + return False + + # If we cannot find the info needed in the config, cannot convert. + if num_bits is None or group_size is None or zero_point is None: + return False + + if num_bits not in TYPE_MAP: + return False + + return check_marlin_supported(quant_type=TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point) From 4cc4174680967bc196d5f2fa93f4678c67ab5efa Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 21 Oct 2025 20:22:08 +0800 Subject: [PATCH 05/10] add awq for qwen3 moe --- .../meta_weights/fused_moe_weight_tp.py | 386 +++++++++++++++++- .../meta_weights/mm_weight/colmm_weight.py | 4 +- .../meta_weights/mm_weight/mm_weight.py | 6 +- .../meta_weights/mm_weight/rowmm_weight.py | 8 +- lightllm/common/quantization/__init__.py | 1 - lightllm/common/quantization/awq_quant.py | 4 +- 6 files changed, 400 insertions(+), 9 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py index 3e61178f3..6d3f7855a 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py @@ -7,7 +7,18 @@ from lightllm.common.quantization import Quantcfg -class FusedMoeWeightTP(BaseWeight): +class FusedMoeWeightTP: + def __new__(cls, **kwargs): + quant_cfg = kwargs.get("quant_cfg", None) + layer_num = kwargs.get("layer_num", None) + quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") + if quant_method.get_name() == "awq_marlin": + return FusedAWQMARLINMoeWeightTP(**kwargs) + else: + return FusedBaseMoeWeightTP(**kwargs) + + +class FusedBaseMoeWeightTP(BaseWeight): def __init__( self, gate_proj_name: str, @@ -39,7 +50,7 @@ def __init__( self.n_routed_experts = n_routed_experts + num_fused_shared_experts self.num_fused_shared_experts = num_fused_shared_experts self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) - self.split_inter_size = split_inter_size + self.split_inter_size = split_inter_size // self.pack_factor self.data_type_ = data_type self.tp_rank_ = get_current_rank_in_dp() self.experts_up_projs = [None] * self.n_routed_experts @@ -245,3 +256,374 @@ def _cuda(self, cpu_tensor): def verify_load(self): return self.w1 is not None and self.w2 is not None + + +class FusedAWQMARLINMoeWeightTP(BaseWeight): + def __init__( + self, + gate_proj_name: str, + down_proj_name: str, + up_proj_name: str, + e_score_correction_bias_name: str, + weight_prefix: str, + n_routed_experts: int, + num_fused_shared_experts: int, + split_inter_size: int, + data_type: torch.dtype, + network_config: Dict[str, Any], + layer_num: int, + quant_cfg: Quantcfg = None, + ) -> None: + super().__init__() + self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") + self.quantized_weight = quant_cfg.quantized_weight + if self.quant_method is not None: + self.weight_scale_suffix = self.quant_method.weight_scale_suffix + self.weight_zero_point_suffix = self.quant_method.weight_zero_point_suffix + self.quant_method.is_moe = True + hf_quantization_config = network_config.get("quantization_config", None) + self.num_bits = hf_quantization_config.get("bits", 4) + self.group_size = hf_quantization_config.get("group_size", 128) + self.pack_factor = 32 // self.num_bits + self.has_processed_weight = False + assert self.quant_method.get_name() == "awq_marlin" + + self.w1_weight_name = gate_proj_name + self.w2_weight_name = down_proj_name + self.w3_weight_name = up_proj_name + + self.e_score_correction_bias_name = e_score_correction_bias_name + self.weight_prefix = weight_prefix + assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." + self.n_routed_experts = n_routed_experts + num_fused_shared_experts + self.num_fused_shared_experts = num_fused_shared_experts + self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) + self.split_inter_size = split_inter_size + self.data_type_ = data_type + self.tp_rank_ = get_current_rank_in_dp() + self.experts_up_projs = [None] * self.n_routed_experts + self.experts_gate_projs = [None] * self.n_routed_experts + self.experts_up_proj_scales = [None] * self.n_routed_experts + self.experts_up_proj_zero_points = [None] * self.n_routed_experts + self.experts_gate_proj_scales = [None] * self.n_routed_experts + self.experts_gate_proj_zero_points = [None] * self.n_routed_experts + self.e_score_correction_bias = None + self.w2_list = [None] * self.n_routed_experts + self.w2_scale_list = [None] * self.n_routed_experts + self.w2_zero_point_list = [None] * self.n_routed_experts + self.scoring_func = network_config.get("scoring_func", "softmax") + self.w1 = [None, None, None] # weight, weight_scale, zero_point + self.w2 = [None, None, None] # weight, weight_scale, zero_point + self.lock = threading.Lock() + + def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + from lightllm.common.fused_moe.topk_select import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=input_tensor, + router_logits=router_logits, + correction_bias=self.e_score_correction_bias, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=self.scoring_func, + ) + topk_weights.mul_(self.routed_scaling_factor) + if self.num_fused_shared_experts > 0: + pad_topk_ids = ( + torch.arange( + start=self.n_routed_experts - self.num_fused_shared_experts, + end=self.n_routed_experts, + step=1, + dtype=topk_ids.dtype, + device="cuda", + ) + .view(1, self.num_fused_shared_experts) + .repeat(topk_ids.shape[0], 1) + ) + pad_topk_weights = torch.full( + (topk_weights.shape[0], self.num_fused_shared_experts), + fill_value=1.0, + device="cuda", + dtype=topk_weights.dtype, + ) + + topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) + topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) + + w1, w1_scale, w1_zero_point = self.w1 + w2, w2_scale, w2_zero_point = self.w2 + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe + + fused_marlin_moe( + input_tensor, + w1, + w2, + None, + None, + w1_scale, + w2_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=self.quant_method.vllm_quant_type.id, + apply_router_weight_on_input=False, + global_num_experts=-1, + expert_map=None, + w1_zeros=w1_zero_point, + w2_zeros=w2_zero_point, + workspace=self.workspace, + inplace=True, + ) + + return + + def _fuse(self): + self._fuse_weight() + self._fuse_weight_scale() + self._fuse_weight_zero_point() + + def _fuse_weight(self): + with self.lock: + if ( + hasattr(self, "experts_up_projs") + and None not in self.experts_up_projs + and None not in self.experts_gate_projs + and None not in self.w2_list + ): + gate_in_dim, gate_out_dim = self.experts_gate_projs[0].shape + up_in_dim, up_out_dim = self.experts_up_projs[0].shape + assert gate_in_dim == up_in_dim + total_expert_num = self.n_routed_experts + + w1 = torch.empty( + (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=torch.int32, device="cpu" + ) + + for i_experts in range(self.n_routed_experts): + w1[i_experts, :, 0:gate_out_dim] = self.experts_gate_projs[i_experts] + w1[i_experts, :, gate_out_dim:] = self.experts_up_projs[i_experts] + + inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] + w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) + self.w1[0] = self._cuda(w1) + self.w2[0] = self._cuda(w2) + delattr(self, "w2_list") + delattr(self, "experts_up_projs") + delattr(self, "experts_gate_projs") + + def _fuse_weight_scale(self): + with self.lock: + if ( + hasattr(self, "experts_up_proj_scales") + and None not in self.experts_up_proj_scales + and None not in self.experts_gate_proj_scales + and None not in self.w2_scale_list + ): + gate_in_dim, gate_out_dim = self.experts_gate_proj_scales[0].shape + up_in_dim, up_out_dim = self.experts_up_proj_scales[0].shape + dtype = self.experts_gate_proj_scales[0].dtype + assert gate_in_dim == up_in_dim + total_expert_num = self.n_routed_experts + w1_scale = torch.empty( + (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=dtype, device="cpu" + ) + for i_experts in range(self.n_routed_experts): + w1_scale[i_experts, :, 0:gate_out_dim] = self.experts_gate_proj_scales[i_experts] + w1_scale[i_experts, :, gate_out_dim:] = self.experts_up_proj_scales[i_experts] + inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] + w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( + len(self.w2_scale_list), inter_shape, hidden_size + ) + self.w1[1] = self._cuda(w1_scale).to(self.data_type_) + self.w2[1] = self._cuda(w2_scale).to(self.data_type_) + delattr(self, "w2_scale_list") + delattr(self, "experts_up_proj_scales") + delattr(self, "experts_gate_proj_scales") + + def _fuse_weight_zero_point(self): + with self.lock: + if ( + hasattr(self, "experts_up_proj_zero_points") + and None not in self.experts_up_proj_zero_points + and None not in self.experts_gate_proj_zero_points + and None not in self.w2_zero_point_list + ): + gate_in_dim, gate_out_dim = self.experts_gate_proj_zero_points[0].shape + up_in_dim, up_out_dim = self.experts_up_proj_zero_points[0].shape + assert gate_in_dim == up_in_dim + total_expert_num = self.n_routed_experts + w1_zero_point = torch.empty( + (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=torch.int32, device="cpu" + ) + for i_experts in range(self.n_routed_experts): + w1_zero_point[i_experts, :, 0:gate_out_dim] = self.experts_gate_proj_zero_points[i_experts] + w1_zero_point[i_experts, :, gate_out_dim:] = self.experts_up_proj_zero_points[i_experts] + inter_shape, hidden_size = self.w2_zero_point_list[0].shape[0], self.w2_zero_point_list[0].shape[1] + w2_zero_point = torch._utils._flatten_dense_tensors(self.w2_zero_point_list).view( + len(self.w2_zero_point_list), inter_shape, hidden_size + ) + self.w1[2] = self._cuda(w1_zero_point) + self.w2[2] = self._cuda(w2_zero_point) + delattr(self, "w2_zero_point_list") + delattr(self, "experts_up_proj_zero_points") + delattr(self, "experts_gate_proj_zero_points") + + def load_hf_weights(self, weights): + self._load_weight(weights) + self._load_weight_scale(weights) + self._load_weight_zero_point(weights) + self._fuse() + self._process_weight_after_loading() + + def _load_weight(self, weights: Dict[str, torch.Tensor]) -> None: + # awq quantization weight shape: in x out + if self.e_score_correction_bias_name in weights: + self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name]) + for i_experts in range(self.n_routed_experts): + w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.qweight" + w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.qweight" + w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.qweight" + + if w1_weight in weights: + self.experts_gate_projs[i_experts] = weights[w1_weight][ + :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) + ] + if w3_weight in weights: + self.experts_up_projs[i_experts] = weights[w3_weight][ + :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) + ] + + if w2_weight in weights: + self.w2_list[i_experts] = weights[w2_weight][ + self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : + ] + + def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: + for i_experts in range(self.n_routed_experts): + w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}" + w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}" + w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}" + split_inter_size = self.split_inter_size * self.pack_factor + if w1_scale in weights: + self.experts_gate_proj_scales[i_experts] = weights[w1_scale][ + :, + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), + ] + if w3_scale in weights: + self.experts_up_proj_scales[i_experts] = weights[w3_scale][ + :, + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), + ] + + if w2_scale in weights: + self.w2_scale_list[i_experts] = weights[w2_scale][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), + :, + ] + + def _load_weight_zero_point(self, weights: Dict[str, torch.Tensor]) -> None: + for i_experts in range(self.n_routed_experts): + w1_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_zero_point_suffix}" + w2_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_zero_point_suffix}" + w3_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_zero_point_suffix}" + if w1_zero_point in weights: + self.experts_gate_proj_zero_points[i_experts] = weights[w1_zero_point][ + :, + self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), + ] + if w3_zero_point in weights: + self.experts_up_proj_zero_points[i_experts] = weights[w3_zero_point][ + :, + self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), + ] + if w2_zero_point in weights: + self.w2_zero_point_list[i_experts] = weights[w2_zero_point][ + self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), + :, + ] + + def _process_weight_after_loading(self): + with self.lock: + if None in self.w1 or None in self.w2 or self.has_processed_weight: + return + self.has_processed_weight = True + from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops + + assert HAS_VLLM, "moe awq marlin quantization requires kernels of vllm" + + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_moe_permute_scales, + moe_awq_to_marlin_zero_points, + marlin_make_workspace_new, + ) + + num_experts = self.n_routed_experts + device = self.w1[0].device + + self.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + self.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + self.w1[0] = vllm_ops.awq_marlin_moe_repack( + self.w1[0], + self.w13_g_idx_sort_indices, + size_k=self.w1[0].shape[1], + size_n=self.w1[0].shape[2] * self.pack_factor, + num_bits=self.num_bits, + ) + + self.w2[0] = vllm_ops.awq_marlin_moe_repack( + self.w2[0], + self.w2_g_idx_sort_indices, + size_k=self.w2[0].shape[1], + size_n=self.w2[0].shape[2] * self.pack_factor, + num_bits=self.num_bits, + ) + + # Why does this take the intermediate size for size_k? + self.w1[1] = marlin_moe_permute_scales( + s=self.w1[1], + size_k=self.split_inter_size * self.pack_factor, + size_n=self.w1[1].shape[2], + group_size=self.group_size, + ) + + self.w2[1] = marlin_moe_permute_scales( + s=self.w2[1], + size_k=self.split_inter_size * self.pack_factor, + size_n=self.w2[1].shape[2], + group_size=self.group_size, + ) + + self.w1[2] = moe_awq_to_marlin_zero_points( + self.w1[2], + size_k=self.w1[2].shape[1], + size_n=self.w1[2].shape[2] * self.pack_factor, + num_bits=self.num_bits, + ) + + self.w2[2] = moe_awq_to_marlin_zero_points( + self.w2[2], + size_k=self.w2[2].shape[1], + size_n=self.w2[2].shape[2] * self.pack_factor, + num_bits=self.num_bits, + ) + + self.workspace = marlin_make_workspace_new(device, 4) + + def _cuda(self, cpu_tensor): + device_id = get_current_device_id() + if self.quantized_weight: + return cpu_tensor.cuda(device_id) + return cpu_tensor.cuda(device_id) + + def verify_load(self): + return self.w1 is not None and self.w2 is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py index 550636332..a5a27eefb 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py @@ -157,7 +157,9 @@ def _process_weight(self, weight: torch.Tensor) -> torch.Tensor: return self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id())) def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: - return self.quant_method._process_weight_scale_after_loading(weight_scale.cuda(get_current_device_id())) + return self.quant_method._process_weight_scale_after_loading( + weight_scale.cuda(get_current_device_id()).to(self.data_type_) + ) def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: return self.quant_method._process_weight_zero_point_after_loading( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index b0ccd8064..47f7b8949 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -188,7 +188,7 @@ def _process_weight(self, weight: torch.Tensor) -> torch.Tensor: return weight.cuda(get_current_device_id()) def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: - return weight_scale.cuda(get_current_device_id()) + return weight_scale.cuda(get_current_device_id()).to(self.data_type_) def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: return weight_zero_point.cuda(get_current_device_id()) @@ -279,7 +279,7 @@ def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: if self.weight_scale_names[i] is not None and self.weight_scale_names[i] in weights: weight_scale = weights[self.weight_scale_names[i]] weight_scale = self._slice_weight_scale(weight_scale) - self.weight_scales[i] = weight_scale.cuda(get_current_device_id()) + self.weight_scales[i] = weight_scale.cuda(get_current_device_id()).to(self.data_type_) def _load_zero_points(self, weights: Dict[str, torch.Tensor]) -> None: for i in range(len(self.weight_names)): @@ -309,6 +309,8 @@ def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> Q if quant_cfg is None: return None, False quant_method = quant_cfg.get_quant_method(layer_num_, name) + if quant_method is None: + return None, False quant_method.hf_quantization_config = quant_cfg.hf_quantization_config quantized_weight = quant_cfg.quantized_weight return quant_method, quantized_weight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index 83d2ae9d7..bd57f1e7a 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -332,7 +332,9 @@ def _process_weight(self, weight: torch.Tensor) -> torch.Tensor: return self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id())) def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: - return self.quant_method._process_weight_scale_after_loading(weight_scale.cuda(get_current_device_id())) + return self.quant_method._process_weight_scale_after_loading( + weight_scale.cuda(get_current_device_id()).to(self.data_type_) + ) def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: return self.quant_method._process_weight_zero_point_after_loading( @@ -356,7 +358,9 @@ def _process_weight(self, weight: torch.Tensor) -> torch.Tensor: return self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id())) def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: - return self.quant_method._process_weight_scale_after_loading(weight_scale.cuda(get_current_device_id())) + return self.quant_method._process_weight_scale_after_loading( + weight_scale.cuda(get_current_device_id()).to(self.data_type_) + ) def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: return self.quant_method._process_weight_zero_point_after_loading( diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index 7e4a08218..26f59258c 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -63,7 +63,6 @@ def _parse_custom_cfg(self, custom_cfg_path): self.quant_type = data["quant_type"] for layer_quant_cfg in data.get("mix_bits", []): - print(layer_quant_cfg) name = layer_quant_cfg["name"] layer_nums = layer_quant_cfg.get("layer_nums", range(self.layer_num)) layer_quant_type = layer_quant_cfg["quant_type"] diff --git a/lightllm/common/quantization/awq_quant.py b/lightllm/common/quantization/awq_quant.py index 7221a454e..2827481a4 100644 --- a/lightllm/common/quantization/awq_quant.py +++ b/lightllm/common/quantization/awq_quant.py @@ -78,12 +78,14 @@ class AWQMARLINW4A16QuantizationMethod(AWQBaseQuantizationMethod): def __init__(self): super().__init__() self.pack_factor = 8 + self.nbits = 4 self.weight_scale_suffix = "scales" self.weight_zero_point_suffix = "qzeros" self.weight_suffix = "qweight" self.g_idx = marlin_make_empty_g_idx(torch.device("cuda")) self.g_idx_sort_indices = marlin_make_empty_g_idx(torch.device("cuda")) self.workspace = marlin_make_workspace_new(torch.device("cuda")) + self.vllm_quant_type = TYPE_MAP[self.nbits] def get_name(self): return "awq_marlin" @@ -143,7 +145,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ self.g_idx, self.g_idx_sort_indices, self.workspace, - TYPE_MAP[self.hf_quantization_config["bits"]], + self.vllm_quant_type, size_m=reshaped_x.shape[0], size_n=self.n, size_k=self.k, From 362149d2ef81405ec60d0b4f4b959d539e1f59a9 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 5 Nov 2025 18:51:48 +0800 Subject: [PATCH 06/10] fix --- .../meta_weights/fused_moe_weight_tp.py | 52 ++++++++++++++++--- .../common/quantization/quantize_method.py | 3 +- lightllm/common/quantization/registry.py | 6 ++- 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py index 6d3f7855a..fb6f693e5 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py @@ -8,14 +8,54 @@ class FusedMoeWeightTP: - def __new__(cls, **kwargs): - quant_cfg = kwargs.get("quant_cfg", None) - layer_num = kwargs.get("layer_num", None) + def __new__( + cls, + gate_proj_name: str, + down_proj_name: str, + up_proj_name: str, + e_score_correction_bias_name: str, + weight_prefix: str, + n_routed_experts: int, + num_fused_shared_experts: int, + split_inter_size: int, + data_type: torch.dtype, + network_config: Dict[str, Any], + layer_num: int, + quant_cfg: Quantcfg = None, + ): quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") + if quant_method is None: + return FusedBaseMoeWeightTP( + gate_proj_name=gate_proj_name, + down_proj_name=down_proj_name, + up_proj_name=up_proj_name, + e_score_correction_bias_name=e_score_correction_bias_name, + weight_prefix=weight_prefix, + n_routed_experts=n_routed_experts, + num_fused_shared_experts=num_fused_shared_experts, + split_inter_size=split_inter_size, + data_type=data_type, + network_config=network_config, + layer_num=layer_num, + quant_cfg=quant_cfg, + ) if quant_method.get_name() == "awq_marlin": - return FusedAWQMARLINMoeWeightTP(**kwargs) + return FusedAWQMARLINMoeWeightTP( + gate_proj_name=gate_proj_name, + down_proj_name=down_proj_name, + up_proj_name=up_proj_name, + e_score_correction_bias_name=e_score_correction_bias_name, + weight_prefix=weight_prefix, + n_routed_experts=n_routed_experts, + num_fused_shared_experts=num_fused_shared_experts, + split_inter_size=split_inter_size, + data_type=data_type, + network_config=network_config, + layer_num=layer_num, + quant_cfg=quant_cfg, + ) else: - return FusedBaseMoeWeightTP(**kwargs) + raise ValueError(f"Unsupported quant method: {quant_method.get_name()}") class FusedBaseMoeWeightTP(BaseWeight): @@ -50,7 +90,7 @@ def __init__( self.n_routed_experts = n_routed_experts + num_fused_shared_experts self.num_fused_shared_experts = num_fused_shared_experts self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) - self.split_inter_size = split_inter_size // self.pack_factor + self.split_inter_size = split_inter_size self.data_type_ = data_type self.tp_rank_ = get_current_rank_in_dp() self.experts_up_projs = [None] * self.n_routed_experts diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index 60326d92d..8dc54a871 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -18,5 +18,6 @@ def quantize(self, weights: torch.Tensor): def apply(self, input_tensor, weight, bias=None, out=None, use_custom_tensor_mananger=True): pass + @abstractmethod def get_name(self): - return self.__class__.__name__ + pass diff --git a/lightllm/common/quantization/registry.py b/lightllm/common/quantization/registry.py index 350f7fd1c..674a22b60 100644 --- a/lightllm/common/quantization/registry.py +++ b/lightllm/common/quantization/registry.py @@ -1,3 +1,7 @@ +from .quantize_method import QuantizationMethod +from typing import Type + + class QuantMethodFactory: def __init__(self): self._quant_methods = {} @@ -13,7 +17,7 @@ def decorator(cls): return decorator - def get(self, key, *args, **kwargs): + def get(self, key, *args, **kwargs) -> Type[QuantizationMethod]: if key == "none": return None quant_method_class = self._quant_methods.get(key) From 684f4c09743a69fdb92fdacdb191446c1a0312d7 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 5 Nov 2025 19:38:09 +0800 Subject: [PATCH 07/10] add quant_method property method_name --- .../meta_weights/fused_moe_weight_tp.py | 12 +++---- .../meta_weights/mm_weight/colmm_weight.py | 4 +-- .../meta_weights/mm_weight/rowmm_weight.py | 8 ++--- lightllm/common/quantization/awq_quant.py | 10 ++++-- .../common/quantization/deepgemm_quant.py | 9 +++-- .../common/quantization/quantize_method.py | 3 +- lightllm/common/quantization/torchao_quant.py | 36 +++++++++++++++++++ lightllm/common/quantization/w8a8_quant.py | 16 +++++++++ 8 files changed, 80 insertions(+), 18 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py index fb6f693e5..ece9de8b8 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py @@ -24,8 +24,8 @@ def __new__( quant_cfg: Quantcfg = None, ): quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - if quant_method is None: - return FusedBaseMoeWeightTP( + if quant_method is not None and quant_method.method_name == "awq_marlin": + return FusedAWQMARLINMoeWeightTP( gate_proj_name=gate_proj_name, down_proj_name=down_proj_name, up_proj_name=up_proj_name, @@ -39,8 +39,8 @@ def __new__( layer_num=layer_num, quant_cfg=quant_cfg, ) - if quant_method.get_name() == "awq_marlin": - return FusedAWQMARLINMoeWeightTP( + else: + return FusedBaseMoeWeightTP( gate_proj_name=gate_proj_name, down_proj_name=down_proj_name, up_proj_name=up_proj_name, @@ -54,8 +54,6 @@ def __new__( layer_num=layer_num, quant_cfg=quant_cfg, ) - else: - raise ValueError(f"Unsupported quant method: {quant_method.get_name()}") class FusedBaseMoeWeightTP(BaseWeight): @@ -326,7 +324,7 @@ def __init__( self.group_size = hf_quantization_config.get("group_size", 128) self.pack_factor = 32 // self.num_bits self.has_processed_weight = False - assert self.quant_method.get_name() == "awq_marlin" + assert self.quant_method.method_name == "awq_marlin" self.w1_weight_name = gate_proj_name self.w2_weight_name = down_proj_name diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py index a5a27eefb..093b89285 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py @@ -16,7 +16,7 @@ class COLMMWeight(MMWeight): def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): if quant_method is None or not quantized_weight: return UnquantizedCOLMMWeight - return COLBMM_WEIGHT_CLS_MAP[quant_method.get_name()] + return COLBMM_WEIGHT_CLS_MAP[quant_method.method_name] class UnquantizedCOLMMWeight(MMWeightTpl): @@ -168,7 +168,7 @@ def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.T COLBMM_WEIGHT_CLS_MAP = { - "fp8w8a8b128": W8A8B128COLMMWeight, + "deepgemm-fp8w8a8-b128": W8A8B128COLMMWeight, "awq": AWQCOLMMWeight, "awq_marlin": AWQMARLINCOLMMWeight, } diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index bd57f1e7a..5cfdcfd5b 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -20,7 +20,7 @@ def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): if quant_method is None or not quantized_weight: return UnquantizedROWMMWeight - return ROWBMM_WEIGHT_CLS_MAP[quant_method.get_name()] + return ROWBMM_WEIGHT_CLS_MAP[quant_method.method_name] class MultiROWMMWeight(MMWeight): @@ -29,7 +29,7 @@ def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): if quant_method is None or not quantized_weight: return UnquantizedMultiROWMMWeight - return MULTI_ROWBMM_WEIGHT_CLS_MAP[quant_method.get_name()] + return MULTI_ROWBMM_WEIGHT_CLS_MAP[quant_method.method_name] class ROWBMMWeight(MMWeight): @@ -369,13 +369,13 @@ def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.T ROWBMM_WEIGHT_CLS_MAP = { - "fp8w8a8b128": W8A8B128ROWMMWeight, + "deepgemm-fp8w8a8-b128": W8A8B128ROWMMWeight, "awq": AWQROWMMWeight, "awq_marlin": AWQMARLINROWMMWeight, } MULTI_ROWBMM_WEIGHT_CLS_MAP = { - "fp8w8a8b128": W8A8B128MultiROWMMWeight, + "deepgemm-fp8w8a8-b128": W8A8B128MultiROWMMWeight, "awq": AWQMultiROWMMWeight, "awq_marlin": AWQMARLINMultiROWMMWeight, } diff --git a/lightllm/common/quantization/awq_quant.py b/lightllm/common/quantization/awq_quant.py index 2827481a4..c758c545b 100644 --- a/lightllm/common/quantization/awq_quant.py +++ b/lightllm/common/quantization/awq_quant.py @@ -42,6 +42,10 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): """ """ pass + @property + def method_name(self): + return "awq-base" + @QUANTMETHODS.register("awq") class AWQW4A16QuantizationMethod(AWQBaseQuantizationMethod): @@ -52,7 +56,8 @@ def __init__(self): self.weight_zero_point_suffix = "qzeros" self.weight_suffix = "qweight" - def get_name(self): + @property + def method_name(self): return "awq" def quantize(self, weight: torch.Tensor): @@ -87,7 +92,8 @@ def __init__(self): self.workspace = marlin_make_workspace_new(torch.device("cuda")) self.vllm_quant_type = TYPE_MAP[self.nbits] - def get_name(self): + @property + def method_name(self): return "awq_marlin" def quantize(self, weight: torch.Tensor): diff --git a/lightllm/common/quantization/deepgemm_quant.py b/lightllm/common/quantization/deepgemm_quant.py index 6b94002d0..6816c8f51 100644 --- a/lightllm/common/quantization/deepgemm_quant.py +++ b/lightllm/common/quantization/deepgemm_quant.py @@ -31,6 +31,10 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): """ """ pass + @property + def method_name(self): + return "deepgemm-base" + @QUANTMETHODS.register(["deepgemm-fp8w8a8-b128"]) class DeepGEMMFP8w8a8B128QuantizationMethod(DeepGEMMBaseQuantizationMethod): @@ -40,8 +44,9 @@ def __init__(self): self.weight_scale_suffix = "weight_scale_inv" self.act_scale_suffix = None # no support for static input tensor scale for ds model. - def get_name(self): - return "fp8w8a8b128" + @property + def method_name(self): + return "deepgemm-fp8w8a8-b128" def quantize(self, weight: torch.Tensor): from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index 8dc54a871..80dad1fe2 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -18,6 +18,7 @@ def quantize(self, weights: torch.Tensor): def apply(self, input_tensor, weight, bias=None, out=None, use_custom_tensor_mananger=True): pass + @property @abstractmethod - def get_name(self): + def method_name(self): pass diff --git a/lightllm/common/quantization/torchao_quant.py b/lightllm/common/quantization/torchao_quant.py index 67677b50c..1a75492b1 100644 --- a/lightllm/common/quantization/torchao_quant.py +++ b/lightllm/common/quantization/torchao_quant.py @@ -43,6 +43,10 @@ def quantize(self, weight: torch.Tensor): def apply(self, input_tensor, weights, bias=None, out=None, use_custom_tensor_mananger=True): return F.linear(input_tensor, weights, bias) + @property + def method_name(self): + return "ao-base" + @QUANTMETHODS.register(["ao-w4a16-256"]) class AOW4A16QuantizationMethodGroup256(AOBaseQuantizationMethod): @@ -51,6 +55,10 @@ def __init__(self): self.group_size = 256 self.quant_func = int4_weight_only(group_size=self.group_size) + @property + def method_name(self): + return "ao-w4a16-256" + @QUANTMETHODS.register(["ao-w4a16-128"]) class AOW4A16QuantizationMethodGroup128(AOBaseQuantizationMethod): @@ -59,6 +67,10 @@ def __init__(self): self.group_size = 128 self.quant_func = int4_weight_only(group_size=self.group_size) + @property + def method_name(self): + return "ao-w4a16-128" + @QUANTMETHODS.register(["ao-w4a16-64"]) class AOW4A16QuantizationMethodGroup64(AOBaseQuantizationMethod): @@ -67,6 +79,10 @@ def __init__(self): self.group_size = 64 self.quant_func = int4_weight_only(group_size=self.group_size) + @property + def method_name(self): + return "ao-w4a16-64" + @QUANTMETHODS.register(["ao-w4a16-32"]) class AOW4A16QuantizationMethodGroup32(AOBaseQuantizationMethod): @@ -75,6 +91,10 @@ def __init__(self): self.group_size = 32 self.quant_func = int4_weight_only(group_size=self.group_size) + @property + def method_name(self): + return "ao-w4a16-32" + @QUANTMETHODS.register("ao-w8a8") class AOW8A8QuantizationMethod(AOBaseQuantizationMethod): @@ -82,6 +102,10 @@ def __init__(self): super().__init__() self.quant_func = int8_dynamic_activation_int8_weight() + @property + def method_name(self): + return "ao-w8a8" + @QUANTMETHODS.register("ao-w8a16") class AOW8A16QuantizationMethod(AOBaseQuantizationMethod): @@ -89,6 +113,10 @@ def __init__(self): super().__init__() self.quant_func = int8_weight_only() + @property + def method_name(self): + return "ao-w8a16" + @QUANTMETHODS.register("ao-fp8w8a16") class AOFP8W8A16QuantizationMethod(AOBaseQuantizationMethod): @@ -98,6 +126,10 @@ def __init__(self): assert is_cuda_8_9, "FP8 requires GPU with compute capability >= 8.9" self.quant_func = float8_weight_only() + @property + def method_name(self): + return "ao-fp8w8a16" + @QUANTMETHODS.register("ao-fp6w6a16") class AOFP6W6A16QuantizationMethod(AOBaseQuantizationMethod): @@ -105,3 +137,7 @@ def __init__(self): super().__init__() assert TORCH_VERSION_AT_LEAST_2_5, "torchao fp6 requires torch >=2.5" self.quant_func = fpx_weight_only(3, 2) + + @property + def method_name(self): + return "ao-fp6w6a16" diff --git a/lightllm/common/quantization/w8a8_quant.py b/lightllm/common/quantization/w8a8_quant.py index 1c38b625f..c07f6b208 100644 --- a/lightllm/common/quantization/w8a8_quant.py +++ b/lightllm/common/quantization/w8a8_quant.py @@ -34,6 +34,10 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): """ """ pass + @property + def method_name(self): + return "w8a8-base" + @QUANTMETHODS.register(["vllm-w8a8", "w8a8"]) class w8a8QuantizationMethod(BaseQuantizationMethod): @@ -71,6 +75,10 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) return out + @property + def method_name(self): + return "vllm-w8a8" + @QUANTMETHODS.register(["vllm-fp8w8a8", "fp8w8a8"]) class FP8w8a8QuantizationMethod(BaseQuantizationMethod): @@ -114,6 +122,10 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias) return out + @property + def method_name(self): + return "vllm-fp8w8a8" + @QUANTMETHODS.register(["vllm-fp8w8a8-b128", "fp8w8a8-b128"]) class FP8w8a8B128QuantizationMethod(BaseQuantizationMethod): @@ -152,3 +164,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ input_scale = input_scale.t().contiguous().t() cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias) return out + + @property + def method_name(self): + return "vllm-fp8w8a8-b128" From 36862fe3db0b7e503b90dc9fa791670d69814d61 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 6 Nov 2025 20:06:40 +0800 Subject: [PATCH 08/10] refactor mm_weight --- .../layer_weights/meta_weights/base_weight.py | 34 +- .../meta_weights/mm_weight/__init__.py | 11 +- .../meta_weights/mm_weight/colmm_weight.py | 132 +--- .../meta_weights/mm_weight/mm_factory.py | 83 +++ .../meta_weights/mm_weight/mm_slicer.py | 117 ++++ .../meta_weights/mm_weight/mm_weight.py | 613 ++++++++++++------ .../meta_weights/mm_weight/rowmm_weight.py | 313 +++------ .../layer_weights/transformer_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 2 +- .../layer_weights/transformer_layer_weight.py | 2 +- 10 files changed, 750 insertions(+), 561 deletions(-) create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py index 30a736b4f..544dcb2fa 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py @@ -1,3 +1,4 @@ +from multiprocessing import parent_process import torch from abc import ABC, abstractmethod from typing import Dict @@ -14,7 +15,7 @@ def load_hf_weights(self, weights): @abstractmethod def verify_load(self): - pass + parent_process class BaseWeightTpl(BaseWeight): @@ -24,35 +25,8 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, data_type: to self.device_id_ = get_current_device_id() self.data_type_ = data_type - def _slice_weight(self, weight: torch.Tensor): - # slice weight - return weight.to(self.data_type_) - - def _slice_bias(self, bias: torch.Tensor): - # slice bias - return bias.to(self.data_type_) - - def _slice_weight_scale(self, weight_scale: torch.Tensor): - # slice weight scale and zero point - return weight_scale - - def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None: - # load weight - pass - - def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: - # load quantization scale - pass - - def _load_zero_points(self, weights: Dict[str, torch.Tensor]) -> None: - # load quantization zero points - pass - def load_hf_weights(self, weights): - self._load_weights(weights) - self._load_scales(weights) - self._load_zero_points(weights) - return + raise NotImplementedError("load_hf_weights must implement this method") def verify_load(self): - pass + raise NotImplementedError("verify_load must implement this method") diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py index c273ee456..ed82aa559 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py @@ -3,15 +3,10 @@ MultiMMWeightTpl, AWQMultiMMWeightTpl, ) -from .rowmm_weight import ( +from .mm_factory import ( + MMWeight, ROWMMWeight, - ROWBMMWeight, MultiROWMMWeight, - W8A8B128ROWMMWeight, - W8A8B128ROWBMMWeight, - W8A8B128MultiROWMMWeight, -) -from .colmm_weight import ( + ROWBMMWeight, COLMMWeight, - W8A8B128COLMMWeight, ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py index 093b89285..a43933ce6 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py @@ -1,25 +1,17 @@ import torch from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( - MMWeight, - MMWeightTpl, - generate_scale_name, + SingleMMWeightTpl, + DeepGemmFP8W8A8B128MMWeight, AWQMMWeightTpl, ) from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod from typing import Dict, List, Optional +from .mm_slicer import ColSliceMixin, QuantizedRowSliceMixin, QuantizedColSliceMixin -class COLMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): - if quant_method is None or not quantized_weight: - return UnquantizedCOLMMWeight - return COLBMM_WEIGHT_CLS_MAP[quant_method.method_name] - - -class UnquantizedCOLMMWeight(MMWeightTpl): +class UnquantizedCOLMMWeight(SingleMMWeightTpl): def __init__( self, weight_name: str, @@ -29,24 +21,18 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: - super().__init__(data_type, quant_method, tp_rank, tp_world_size) - self.weight_name = weight_name - self.bias_name = bias_name - self.has_bias = bias_name is not None - - def _slice_weight(self, tensor): - assert tensor.shape[1] % self.tp_world_size_ == 0, f"tp slice error {tensor.shape[1]} % {self.tp_world_size_}" - tp_size = tensor.shape[1] // self.tp_world_size_ - return tensor[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)].to(self.data_type_) - - def _slice_bias(self, bias): - """ - 因为 Colmm 列 tp 切分的计算,最后会有一个 reduce 操作,直接将 bias / tp_world_size 可以节省一步计算。 - """ - return (bias / self.tp_world_size_).to(self.data_type_) + super().__init__( + weight_name=weight_name, + data_type=data_type, + bias_name=bias_name, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) + self.param_slicer = ColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) -class W8A8B128COLMMWeight(MMWeightTpl): +class DeepGemmFP8W8A8B128COLMMWeight(DeepGemmFP8W8A8B128MMWeight): def __init__( self, weight_name: str, @@ -56,47 +42,15 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: - super().__init__(data_type, quant_method, tp_rank, tp_world_size) - self.weight_name = weight_name - self.bias_name = bias_name - self.has_bias = bias_name is not None - - self.weight_scale_name, self.act_scale_name = generate_scale_name( - weight_name, quant_method.weight_scale_suffix, quant_method.act_scale_suffix + super().__init__( + weight_name=weight_name, + data_type=data_type, + bias_name=bias_name, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, ) - self.weight_scale: Optional[torch.Tensor] = None - self.block_size = self.quant_method.block_size - self.quantized_weight = True - - def _slice_weight(self, tensor): - assert tensor.shape[1] % self.tp_world_size_ == 0, f"tp slice error {tensor.shape[1]} % {self.tp_world_size_}" - tp_size = tensor.shape[1] // self.tp_world_size_ - return tensor[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] - - def _slice_weight_scale(self, weight_scale: torch.Tensor): - assert ( - weight_scale.shape[1] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[1]} % {self.tp_world_size_}" - tp_size = weight_scale.shape[1] // self.tp_world_size_ - scale_start = tp_size * self.tp_rank_ - scale_end = tp_size * (self.tp_rank_ + 1) - return weight_scale[:, scale_start:scale_end].to(torch.float) - - def _process_weight_scale(self, weight_scale) -> None: - self.weight_scale = weight_scale.cuda(get_current_device_id()).transpose(0, 1) - - def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: - if self.weight_scale_name in weights: - weight_scale = self._slice_weight_scale(weights[self.weight_scale_name]) - self._process_weight_scale(weight_scale) - if self.weight_scale is not None and isinstance(self.weight, torch.Tensor): - # weight 中保存的 None 是为 激活静态量化 scale 预留的扩展位置。 - self.weight = [ - self.weight, - self.weight_scale, - None, - ] - return + self.param_slicer = QuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) class AWQCOLMMWeight(AWQMMWeightTpl): @@ -110,35 +64,8 @@ def __init__( tp_world_size: int = None, ) -> None: super().__init__(data_type, quant_method, tp_rank, tp_world_size) - self.weight_name = weight_name.replace("weight", quant_method.weight_suffix) - self.weight_scale_name = weight_name.replace("weight", quant_method.weight_scale_suffix) - self.weight_zero_point_name = weight_name.replace("weight", quant_method.weight_zero_point_suffix) - self.bias_name = bias_name - self.weight_scale: Optional[torch.Tensor] = None - self.quantized_weight = True - self.weight = [None, None, None] - - def _slice_weight(self, weight: torch.Tensor): - assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}" - tp_size = weight.shape[0] // self.tp_world_size_ - return weight[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1), :] - - def _slice_bias(self, bias): - assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" - tp_size = bias.shape[0] // self.tp_world_size_ - return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1), :] - - def _slice_weight_scale(self, weight_scale: torch.Tensor): - tp_size = weight_scale.shape[0] // self.tp_world_size_ - scale_start = tp_size * self.tp_rank_ - scale_end = tp_size * (self.tp_rank_ + 1) - return weight_scale[scale_start:scale_end, :].to(torch.half) - - def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor): - tp_size = weight_zero_point.shape[0] // self.tp_world_size_ - zero_point_start = tp_size * self.tp_rank_ - zero_point_end = tp_size * (self.tp_rank_ + 1) - return weight_zero_point[zero_point_start:zero_point_end, :] + # 注意这里不是错误,因为awq的weight是按inxout存的 + self.param_slicer = QuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) class AWQMARLINCOLMMWeight(AWQCOLMMWeight): @@ -151,7 +78,14 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: - super().__init__(weight_name, data_type, bias_name, quant_method, tp_rank, tp_world_size) + super().__init__( + weight_name=weight_name, + data_type=data_type, + bias_name=bias_name, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) def _process_weight(self, weight: torch.Tensor) -> torch.Tensor: return self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id())) @@ -168,7 +102,7 @@ def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.T COLBMM_WEIGHT_CLS_MAP = { - "deepgemm-fp8w8a8-b128": W8A8B128COLMMWeight, + "deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128COLMMWeight, "awq": AWQCOLMMWeight, "awq_marlin": AWQMARLINCOLMMWeight, } diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py new file mode 100644 index 000000000..1b993c1fa --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py @@ -0,0 +1,83 @@ +from lightllm.common.quantization import Quantcfg +from lightllm.common.quantization.quantize_method import QuantizationMethod +from typing import Type, Union, Dict +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( + MMWeightTpl, + MultiMMWeightTpl, + BMMWeightTpl, +) +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import ( + UnquantizedROWMMWeight, + UnquantizedROWBMMWeight, + UnquantizedMultiROWMMWeight, + ROWMM_WEIGHT_CLS_MAP, + MULTI_ROWMM_WEIGHT_CLS_MAP, +) +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.colmm_weight import ( + UnquantizedCOLMMWeight, + COLBMM_WEIGHT_CLS_MAP, +) + + +class MMWeight: + def __new__(cls, **kwargs): + quant_cfg = kwargs.pop("quant_cfg", None) + layer_num_ = kwargs.pop("layer_num", None) + name = kwargs.pop("name", None) + quant_method, quantized_weight = cls._get_quant_method(quant_cfg, layer_num_, name) + kwargs["quant_method"] = quant_method + mmcls = cls._get_mmcls(quant_method, quantized_weight) + return mmcls(**kwargs) + + @classmethod + def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> QuantizationMethod: + if quant_cfg is None: + return None, False + quant_method = quant_cfg.get_quant_method(layer_num_, name) + if quant_method is None: + return None, False + quant_method.hf_quantization_config = quant_cfg.hf_quantization_config + quantized_weight = quant_cfg.quantized_weight + return quant_method, quantized_weight + + @classmethod + def _get_mmcls( + cls, quant_method: QuantizationMethod, quantized_weight: bool + ) -> Type[Union[MMWeightTpl, MultiMMWeightTpl, BMMWeightTpl]]: + raise NotImplementedError("Subclasses must implement _get_mmcls method") + + +class ROWMMWeight(MMWeight): + @classmethod + def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): + if quant_method is None or not quantized_weight: + return UnquantizedROWMMWeight + + return ROWMM_WEIGHT_CLS_MAP[quant_method.method_name] + + +class MultiROWMMWeight(MMWeight): + @classmethod + def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): + if quant_method is None or not quantized_weight: + return UnquantizedMultiROWMMWeight + + return MULTI_ROWMM_WEIGHT_CLS_MAP[quant_method.method_name] + + +class ROWBMMWeight(MMWeight): + @classmethod + def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): + if quant_method is None or not quantized_weight: + return UnquantizedROWBMMWeight + else: + # TODO: Implement more quantization weight + raise NotImplementedError("ROWBMMWeight is not implemented") + + +class COLMMWeight(MMWeight): + @classmethod + def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): + if quant_method is None or not quantized_weight: + return UnquantizedCOLMMWeight + return COLBMM_WEIGHT_CLS_MAP[quant_method.method_name] diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py new file mode 100644 index 000000000..6c90deaa7 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -0,0 +1,117 @@ +import torch +from typing import Optional +from abc import ABC, abstractmethod +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size + + +class SliceMixinBase(ABC): + """切片操作的Mixin基类""" + + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + + @abstractmethod + def _slice_weight(self, weight: torch.Tensor): + pass + + @abstractmethod + def _slice_bias(self, bias): + pass + + +class SliceMixinTpl(SliceMixinBase): + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) + + def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("slice_weight must implement this method") + + def _slice_bias(self, bias) -> torch.Tensor: + raise NotImplementedError("slice_bias must implement this method") + + def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("slice_weight_scale must implement this method") + + def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("slice_weight_zero_point must implement this method") + + +# 默认weight 的shape是 outxin,这也是目前最通用的约定。 +# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。 +class RowSliceMixin(SliceMixinTpl): + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) + + def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: + assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}" + tp_size = weight.shape[0] // self.tp_world_size_ + return weight[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + + def _slice_bias(self, bias) -> torch.Tensor: + assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" + tp_size = bias.shape[0] // self.tp_world_size_ + return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + + +# 量化切片默认实现方式是group-wise的量化,所以weight_scale 和weight_zero_point ndims跟weight一样。 +# 后续按需要,扩展per-tensor、per-channel的量化方式。 +class QuantizedRowSliceMixin(RowSliceMixin): + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) + + def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: + assert ( + weight_scale.shape[0] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_scale.shape[0]} % {self.tp_world_size_}" + tp_size = weight_scale.shape[0] // self.tp_world_size_ + scale_start = tp_size * self.tp_rank_ + scale_end = tp_size * (self.tp_rank_ + 1) + return weight_scale[scale_start:scale_end] + + def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + assert ( + weight_zero_point.shape[0] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[0]} % {self.tp_world_size_}" + tp_size = weight_zero_point.shape[0] // self.tp_world_size_ + zero_point_start = tp_size * self.tp_rank_ + zero_point_end = tp_size * (self.tp_rank_ + 1) + return weight_zero_point[zero_point_start:zero_point_end] + + +class ColSliceMixin(SliceMixinTpl): + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) + + def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: + assert weight.shape[1] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[1]} % {self.tp_world_size_}" + tp_size = weight.shape[1] // self.tp_world_size_ + return weight[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + + def _slice_bias(self, bias) -> torch.Tensor: + assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" + tp_size = bias.shape[0] // self.tp_world_size_ + return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + + +class QuantizedColSliceMixin(ColSliceMixin): + def __init__(self, tp_rank: int = None, tp_world_size: int = None): + super().__init__(tp_rank, tp_world_size) + + def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: + assert ( + weight_scale.shape[1] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_scale.shape[1]} % {self.tp_world_size_}" + tp_size = weight_scale.shape[1] // self.tp_world_size_ + scale_start = tp_size * self.tp_rank_ + scale_end = tp_size * (self.tp_rank_ + 1) + return weight_scale[:, scale_start:scale_end] + + def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: + assert ( + weight_zero_point.shape[1] % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[1]} % {self.tp_world_size_}" + tp_size = weight_zero_point.shape[1] // self.tp_world_size_ + zero_point_start = tp_size * self.tp_rank_ + zero_point_end = tp_size * (self.tp_rank_ + 1) + return weight_zero_point[:, zero_point_start:zero_point_end] diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 47f7b8949..a56bedc50 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -1,6 +1,7 @@ import os import torch from abc import abstractmethod +from dataclasses import dataclass from typing import Optional, Tuple, List, Dict, Union, Type from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.common.quantization.quantize_method import QuantizationMethod @@ -8,18 +9,29 @@ from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.log_utils import init_logger +from .mm_slicer import SliceMixinTpl logger = init_logger(__name__) -def generate_scale_name(name, weight_scale_suffix, act_scale_suffix): - weight_scale_name = None - act_scale_name = None - if weight_scale_suffix is not None: - weight_scale_name = ".".join(name.split(".")[:-1] + [weight_scale_suffix]) - if act_scale_suffix is not None: - act_scale_name = ".".join(name.split(".")[:-1] + [act_scale_suffix]) - return weight_scale_name, act_scale_name +@dataclass +class MMWeightPack: + weight: Optional[torch.Tensor] = None + bias: Optional[torch.Tensor] = None + weight_scale: Optional[torch.Tensor] = None + weight_zero_point: Optional[torch.Tensor] = None + + has_bias: bool = False + has_weight_scale: bool = False + has_weight_zero_point: bool = False + + def is_ready(self) -> bool: + return ( + self.weight is not None + and (not self.has_bias or (self.has_bias and self.bias is not None)) + and (not self.has_weight_scale or (self.has_weight_scale and self.weight_scale is not None)) + and (not self.has_weight_zero_point or (self.has_weight_zero_point and self.weight_zero_point is not None)) + ) class MMWeightTpl(BaseWeightTpl): @@ -29,59 +41,125 @@ def __init__( quant_method: QuantizationMethod = None, tp_rank: int = None, tp_world_size: int = None, + has_bias: bool = False, + has_weight_scale: bool = False, + has_weight_zero_point: bool = False, ) -> None: super().__init__(tp_rank, tp_world_size, data_type) self.quant_method = quant_method - self.weight: Optional[torch.Tensor] = None - self.bias: Optional[torch.Tensor] = None - # quantized_weight 用于标记加载的权重是已经量化的权重格式 - # 不需要做在线量化 - self.quantized_weight: bool = False - # 标记是否存在 bias, 由子类初始化 - self.has_bias: bool = None + self.mm_param: MMWeightPack = MMWeightPack( + has_bias=has_bias, + has_weight_scale=has_weight_scale, + has_weight_zero_point=has_weight_zero_point, + ) + self.param_slicer: SliceMixinTpl = None def mm( self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True ) -> torch.Tensor: if self.quant_method is not None: return self.quant_method.apply( - input_tensor, self.weight, self.bias, out, use_custom_tensor_mananger=use_custom_tensor_mananger + input_tensor, self.mm_param, out, use_custom_tensor_mananger=use_custom_tensor_mananger ) if out is None: - shape = (input_tensor.shape[0], self.weight.shape[1]) + shape = (input_tensor.shape[0], self.mm_param.weight.shape[1]) dtype = input_tensor.dtype device = input_tensor.device if use_custom_tensor_mananger: out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) else: out = torch.empty(shape, dtype=dtype, device=device) - if self.bias is None: - return torch.mm(input_tensor, self.weight, out=out) - return torch.addmm(self.bias, input_tensor, self.weight, out=out) + if self.mm_param.bias is None: + return torch.mm(input_tensor, self.mm_param.weight, out=out) + return torch.addmm(self.mm_param.bias, input_tensor, self.mm_param.weight, out=out) + + def load_hf_weights(self, weights): + raise NotImplementedError("load_hf_weights must implement this method") def verify_load(self) -> bool: - load_ok = True - # Verify weight. The weight must be not None. - load_ok = load_ok and self.weight is not None - # Verify bias. If bias_name is set, it must be not None. - if self.has_bias: - load_ok = load_ok and self.bias is not None - return load_ok + return self.mm_param.is_ready() - def _process_weight(self, weight) -> None: - if self.quant_method is not None and not self.quantized_weight: - self.weight = self.quant_method.quantize(weight.to(self.data_type_).cuda(get_current_device_id())) + def _process_weight(self, weight: torch.Tensor) -> None: + # 由于所有的量化算法,都会产生一个scale,所以只要没有scale,就说明需要在线对weight进行量化 + if self.quant_method is not None and not self.mm_param.has_weight_scale: + quantized_weight, weight_scale, weight_zero_point = self.quant_method.quantize( + weight.to(self.data_type_).cuda(get_current_device_id()) + ) + self.mm_param.weight = quantized_weight + self.mm_param.weight_scale = weight_scale + self.mm_param.weight_zero_point = weight_zero_point return # 让 k dim 更连续,大多数split k 算法的算子可能能更快 - self.weight = weight.cuda(get_current_device_id()).transpose(0, 1) + self.mm_param.weight = weight.to(self.data_type_).cuda(get_current_device_id()).transpose(0, 1) + return + + def _process_bias(self, bias: torch.Tensor) -> None: + self.mm_param.bias = bias.to(self.data_type_).cuda(get_current_device_id()) + return + + def _process_weight_scale(self, weight_scale: torch.Tensor) -> None: + raise NotImplementedError("process_weight_scale must implement this method") + + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> None: + raise NotImplementedError("process_weight_zero_point must implement this method") + + def _load_weight(self, weights: Dict[str, torch.Tensor]) -> None: + raise NotImplementedError("load_weight_scale must implement this method") + + def _load_bias(self, weights: Dict[str, torch.Tensor]) -> None: + raise NotImplementedError("load_bias must implement this method") + + def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: + raise NotImplementedError("load_weight_scale must implement this method") + + def _load_weight_zero_point(self, weights: Dict[str, torch.Tensor]) -> None: + raise NotImplementedError("load_weight_zero_point must implement this method") + + def _fuse_weights(self, dim: int = 0) -> None: + raise NotImplementedError("fuse_weights must implement this method") - def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None: + +class SingleMMWeightTpl(MMWeightTpl): + def __init__( + self, + weight_name: str, + bias_name: Optional[str] = None, + data_type: torch.dtype = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + has_weight_scale: bool = False, + has_weight_zero_point: bool = False, + ) -> None: + super().__init__( + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + has_bias=bias_name is not None, + has_weight_scale=has_weight_scale, + has_weight_zero_point=has_weight_zero_point, + ) + self.weight_name = weight_name + self.bias_name = bias_name + return + + def _load_weight(self, weights: Dict[str, torch.Tensor]) -> None: if self.weight_name in weights: - weight = self._slice_weight(weights[self.weight_name]) + weight = weights[self.weight_name] + weight = self.param_slicer._slice_weight(weight) self._process_weight(weight) + return + def _load_bias(self, weights: Dict[str, torch.Tensor]) -> None: if self.bias_name in weights: - self.bias = self._slice_bias(weights[self.bias_name]).cuda(get_current_device_id()) + bias = self.param_slicer._slice_bias(weights[self.bias_name]) + self._process_bias(bias) + return + + def load_hf_weights(self, weights): + self._load_weight(weights) + self._load_bias(weights) return @@ -89,54 +167,77 @@ class MultiMMWeightTpl(MMWeightTpl): def __init__( self, weight_names: List[str], - data_type: torch.dtype, bias_names: Optional[List[str]] = None, + data_type: torch.dtype = None, quant_method: QuantizationMethod = None, tp_rank: int = None, tp_world_size: int = None, + has_weight_scale: bool = False, + has_weight_zero_point: bool = False, ) -> None: - super().__init__(data_type, quant_method, tp_rank, tp_world_size) - + has_bias = bias_names is not None and any(b is not None for b in bias_names) + super().__init__( + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + has_bias=has_bias, + has_weight_scale=has_weight_scale, + has_weight_zero_point=has_weight_zero_point, + ) self.weight_names = weight_names self.bias_names = bias_names - self.weights = [None] * len(self.weight_names) - if self.bias_names is not None: - self.biases = [None] * len(self.bias_names) - self.has_bias = all(b is not None for b in self.bias_names) and len(bias_names) > 0 - else: - self.biases = None - self.has_bias = False - - def _pre_porcess_weights(self, weights: Dict[str, torch.Tensor]) -> None: + self.mm_params: List[MMWeightPack] = [ + MMWeightPack( + weight=None, + bias=None, + weight_scale=None, + weight_zero_point=None, + has_bias=has_bias, + has_weight_scale=has_weight_scale, + has_weight_zero_point=has_weight_zero_point, + ) + for _ in range(len(weight_names)) + ] + + def _load_weight(self, weights: Dict[str, torch.Tensor]) -> None: for i in range(len(self.weight_names)): if self.weight_names[i] in weights: - weight = weights[self.weight_names[i]] - self.weights[i] = self._slice_weight(weight) - if self.has_bias and self.bias_names[i] in weights: - bias = weights[self.bias_names[i]] - self.biases[i] = self._slice_bias(bias) - - def _fuse_weights(self) -> None: - if self.weight is None and (None not in self.weights): - weight = torch.cat(self.weights, dim=0) - self._process_weight(weight) - delattr(self, "weights") + weight_i = weights[self.weight_names[i]] + weight_i = self.param_slicer._slice_weight(weight_i) + self.mm_params[i].weight = weight_i + return - if self.has_bias and self.bias is None and (None not in self.biases): - self.bias = torch.cat(self.biases, dim=0).cuda(get_current_device_id()) - delattr(self, "biases") - return self + def _load_bias(self, weights: Dict[str, torch.Tensor]) -> None: + for i in range(len(self.bias_names)): + if self.bias_names[i] in weights: + bias_i = weights[self.bias_names[i]] + bias_i = self.param_slicer._slice_bias(bias_i) + self.mm_params[i].bias = bias_i.to(self.data_type_) + return - def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None: - self._pre_porcess_weights(weights) + def _fuse_weights(self, dim: int = 0) -> None: + if self.mm_param.weight is None and all(p.weight is not None for p in self.mm_params): + weight = torch.cat([p.weight for p in self.mm_params], dim=dim) + self._process_weight(weight) + for p in self.mm_params: + p.weight = None + + if self.mm_param.has_bias and self.mm_param.bias is None and all(p.bias is not None for p in self.mm_params): + bias = torch.cat([p.bias for p in self.mm_params], dim=dim) + self._process_bias(bias) + for p in self.mm_params: + p.bias = None + return def load_hf_weights(self, weights): - super().load_hf_weights(weights) - self._fuse_weights() + self._load_weight(weights) + self._load_bias(weights) + self._fuse_weights(dim=0) return -class BMMWeightTpl(MMWeightTpl): +class BMMWeightTpl(SingleMMWeightTpl): def mm( self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True ) -> torch.Tensor: @@ -146,7 +247,7 @@ def bmm( self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True ) -> torch.Tensor: # 目前 bmm 不支持量化运算操作 - fpweight = self.weight + fpweight = self.mm_param.weight if out is None: shape = (input_tensor.shape[0], input_tensor.shape[1], fpweight.shape[2]) dtype = input_tensor.dtype @@ -155,168 +256,304 @@ def bmm( out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) else: out = torch.empty(shape, dtype=dtype, device=device) - if self.bias is None: + if self.mm_param.bias is None: return torch.bmm(input_tensor, fpweight, out=out) - return torch.addbmm(self.bias, input_tensor, fpweight, out=out) + return torch.addbmm(self.mm_param.bias, input_tensor, fpweight, out=out) def _process_weight(self, weight) -> None: - self.weight = weight.cuda(get_current_device_id()) + self.mm_param.weight = weight.cuda(get_current_device_id()) -class AWQMMWeightTpl(MMWeightTpl): +class SingleQuantizedMMWeightTpl(SingleMMWeightTpl): def __init__( self, - data_type: torch.dtype, + weight_name: str, + bias_name: Optional[str] = None, + data_type: torch.dtype = None, quant_method: QuantizationMethod = None, tp_rank: int = None, tp_world_size: int = None, + has_weight_scale: bool = True, + has_weight_zero_point: bool = False, # 目前较多的是对称量化,所以默认没有zero_point ) -> None: - super().__init__(data_type, quant_method, tp_rank, tp_world_size) - self.weight = [None, None, None] + super().__init__( + weight_name=weight_name, + bias_name=bias_name, + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + has_weight_scale=has_weight_scale, + has_weight_zero_point=has_weight_zero_point, + ) + assert quant_method is not None, "quant_method is not set" + assert quant_method.weight_scale_suffix is not None, "weight_scale_suffix is not set" + self.weight_scale_name = weight_name.replace("weight", quant_method.weight_scale_suffix) + if has_weight_zero_point: + assert quant_method.weight_zero_point_suffix is not None, "weight_zero_point_suffix is not set" + self.weight_zero_point_name = weight_name.replace("weight", quant_method.weight_zero_point_suffix) + if quant_method.weight_suffix is not None: + self.weight_name = weight_name.replace("weight", quant_method.weight_suffix) + return - def verify_load(self) -> bool: - load_ok = True - # Verify weight. The weight must be not None. - weight_ok = all(w is not None for w in self.weight) - load_ok = load_ok and weight_ok - # Verify bias. If bias_name is set, it must be not None. - if self.has_bias: - load_ok = load_ok and self.bias is not None - return load_ok - - def _process_weight(self, weight: torch.Tensor) -> torch.Tensor: - return weight.cuda(get_current_device_id()) - - def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: - return weight_scale.cuda(get_current_device_id()).to(self.data_type_) - - def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: - return weight_zero_point.cuda(get_current_device_id()) - - def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None: - if self.weight_name is not None and self.weight_name in weights: - weight = weights[self.weight_name] - weight = self._slice_weight(weight) - self.weight[0] = self._process_weight(weight) - if self.bias_name is not None and self.bias_name in weights: - bias = weights[self.bias_name] - bias = self._slice_bias(bias) - self.bias = bias.cuda(get_current_device_id()) - - def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: + def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: if self.weight_scale_name is not None and self.weight_scale_name in weights: weight_scale = weights[self.weight_scale_name] - weight_scale = self._slice_weight_scale(weight_scale) - self.weight[1] = self._process_weight_scale(weight_scale) + weight_scale = self.param_slicer._slice_weight_scale(weight_scale) + self._process_weight_scale(weight_scale) - def _load_zero_points(self, weights: Dict[str, torch.Tensor]) -> None: + def _load_weight_zero_point(self, weights: Dict[str, torch.Tensor]) -> None: if self.weight_zero_point_name is not None and self.weight_zero_point_name in weights: weight_zero_point = weights[self.weight_zero_point_name] - weight_zero_point = self._slice_weight_zero_point(weight_zero_point) - self.weight[2] = self._process_weight_zero_point(weight_zero_point) + weight_zero_point = self.param_slicer._slice_weight_zero_point(weight_zero_point) + self._process_weight_zero_point(weight_zero_point) + + def load_hf_weights(self, weights): + self._load_weight(weights) + self._load_bias(weights) + self._load_weight_scale(weights) + self._load_weight_zero_point(weights) + return + + # 不同的量化算法,往往需要不同的处理方式,所以强制要求实现这些方法 + def _process_weight(self, weight: torch.Tensor) -> None: + raise NotImplementedError("Quantized weight process_weight must implement this method") + + def _process_weight_scale(self, weight_scale: torch.Tensor) -> None: + raise NotImplementedError("Quantized weight process_weight_scale must implement this method") + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> None: + raise NotImplementedError("Quantized weight process_weight_zero_point must implement this method") -class AWQMultiMMWeightTpl(AWQMMWeightTpl): + +class MultiQuantizedMMWeightTpl(MultiMMWeightTpl): def __init__( self, weight_names: List[str], - data_type: torch.dtype, bias_names: Optional[List[str]] = None, + data_type: torch.dtype = None, quant_method: QuantizationMethod = None, tp_rank: int = None, tp_world_size: int = None, + has_weight_scale: bool = True, + has_weight_zero_point: bool = False, ) -> None: - super().__init__(data_type, quant_method, tp_rank, tp_world_size) - self.weight_names = [weight.replace("weight", quant_method.weight_suffix) for weight in weight_names] + super().__init__( + weight_names=weight_names, + bias_names=bias_names, + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + has_weight_scale=has_weight_scale, + has_weight_zero_point=has_weight_zero_point, + ) + assert quant_method is not None, "quant_method is not set" + assert quant_method.weight_scale_suffix is not None, "weight_scale_suffix is not set" self.weight_scale_names = [ - weight.replace("weight", quant_method.weight_scale_suffix) for weight in weight_names + weight_name.replace("weight", quant_method.weight_scale_suffix) for weight_name in weight_names ] - self.weight_zero_point_names = [ - weight.replace("weight", quant_method.weight_zero_point_suffix) for weight in weight_names - ] - self.bias_names = bias_names - self.weights = [None] * len(self.weight_names) - self.weight_scales = [None] * len(self.weight_names) - self.weight_zero_points = [None] * len(self.weight_names) - if self.bias_names is not None: - self.biases = [None] * len(self.bias_names) - self.has_bias = all(b is not None for b in self.bias_names) and len(bias_names) > 0 - else: - self.biases = None - self.has_bias = False - - def _fuse(self) -> None: - if self.weight[0] is None and (None not in self.weights): - weight = torch.cat(self.weights, dim=1) - self.weight[0] = self._process_weight(weight) - delattr(self, "weights") - - if self.weight[1] is None and (None not in self.weight_scales): - # awq 保存的量化参数,weight shape 是 in x out。所以这里的cat dim 是 1 - weight_scale = torch.cat(self.weight_scales, dim=1).cuda(get_current_device_id()) - self.weight[1] = self._process_weight_scale(weight_scale) - delattr(self, "weight_scales") - - if self.weight[2] is None and (None not in self.weight_zero_points): - weight_zero_point = torch.cat(self.weight_zero_points, dim=1) - self.weight[2] = self._process_weight_zero_point(weight_zero_point) - delattr(self, "weight_zero_points") - - if self.has_bias and self.bias is None and (None not in self.biases): - self.bias = torch.cat(self.biases, dim=0).cuda(get_current_device_id()) - delattr(self, "biases") - return self + if has_weight_zero_point: + assert quant_method.weight_zero_point_suffix is not None, "weight_zero_point_suffix is not set" + self.weight_zero_point_names = [ + weight_name.replace("weight", quant_method.weight_zero_point_suffix) for weight_name in weight_names + ] + if quant_method.weight_suffix is not None: + self.weight_names = [ + weight_name.replace("weight", quant_method.weight_suffix) for weight_name in weight_names + ] + return - def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None: + def _load_weight(self, weights: Dict[str, torch.Tensor]) -> None: for i in range(len(self.weight_names)): - if self.weight_names[i] is not None and self.weight_names[i] in weights: + if self.weight_names[i] in weights: weight = weights[self.weight_names[i]] - weight = self._slice_weight(weight) - self.weights[i] = weight.cuda(get_current_device_id()) + weight = self.param_slicer._slice_weight(weight) + self.mm_params[i].weight = weight - def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: + def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: for i in range(len(self.weight_names)): if self.weight_scale_names[i] is not None and self.weight_scale_names[i] in weights: weight_scale = weights[self.weight_scale_names[i]] - weight_scale = self._slice_weight_scale(weight_scale) - self.weight_scales[i] = weight_scale.cuda(get_current_device_id()).to(self.data_type_) + weight_scale = self.param_slicer._slice_weight_scale(weight_scale) + self.mm_params[i].weight_scale = weight_scale.to(self.data_type_) - def _load_zero_points(self, weights: Dict[str, torch.Tensor]) -> None: + def _load_weight_zero_point(self, weights: Dict[str, torch.Tensor]) -> None: for i in range(len(self.weight_names)): if self.weight_zero_point_names[i] is not None and self.weight_zero_point_names[i] in weights: weight_zero_point = weights[self.weight_zero_point_names[i]] - weight_zero_point = self._slice_weight_zero_point(weight_zero_point) - self.weight_zero_points[i] = weight_zero_point.cuda(get_current_device_id()) + weight_zero_point = self.param_slicer._slice_weight_zero_point(weight_zero_point) + self.mm_params[i].weight_zero_point = weight_zero_point + return - def load_hf_weights(self, weights): - super().load_hf_weights(weights) - self._fuse() + def _fuse_weights(self, dim: int = 0) -> None: + super()._fuse_weights(dim=dim) + if self.mm_param.weight_scale is None and (None not in [p.weight_scale for p in self.mm_params]): + # awq 保存的量化参数,weight shape 是 in x out。所以这里的cat dim 是 1 + weight_scale = torch.cat([p.weight_scale for p in self.mm_params], dim=dim).cuda(get_current_device_id()) + self._process_weight_scale(weight_scale) + for p in self.mm_params: + p.weight_scale = None + + if self.mm_param.weight_zero_point is None and (None not in [p.weight_zero_point for p in self.mm_params]): + weight_zero_point = torch.cat([p.weight_zero_point for p in self.mm_params], dim=dim) + self._process_weight_zero_point(weight_zero_point) + for p in self.mm_params: + p.weight_zero_point = None + torch.cuda.empty_cache() + return + + # 不同的量化算法,往往需要不同的处理方式,所以强制要求实现这些方法 + def _process_weight(self, weight: torch.Tensor) -> None: + raise NotImplementedError("Quantized weight process_weight must implement this method") + + def _process_weight_scale(self, weight_scale: torch.Tensor) -> None: + raise NotImplementedError("Quantized weight process_weight_scale must implement this method") + + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> None: + raise NotImplementedError("Quantized weight process_weight_zero_point must implement this method") + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: + self._load_weight(weights) + self._load_bias(weights) + self._load_weight_scale(weights) + self._load_weight_zero_point(weights) + self._fuse_weights(dim=0) + return + + +class DeepGemmFP8W8A8B128MMWeight(SingleQuantizedMMWeightTpl): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + bias_name: Optional[str] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__( + weight_name=weight_name, + bias_name=bias_name, + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + has_weight_scale=True, + has_weight_zero_point=False, + ) + + def _process_weight_scale(self, weight_scale) -> None: + self.mm_param.weight_scale = weight_scale.to(torch.float).cuda(get_current_device_id()).transpose(0, 1) return + def _process_weight(self, weight) -> None: + self.mm_param.weight = weight.cuda(get_current_device_id()).transpose(0, 1) + return + + +class DeepGemmFP8W8A8B128MultiMMWeight(MultiQuantizedMMWeightTpl): + def __init__( + self, + weight_names: str, + data_type: torch.dtype, + bias_names: Optional[str] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__( + weight_names=weight_names, + bias_names=bias_names, + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + has_weight_scale=True, + has_weight_zero_point=False, + ) + + def _process_weight_scale(self, weight_scale) -> None: + self.mm_param.weight_scale = weight_scale.cuda(get_current_device_id()).transpose(0, 1) + return + + def _process_weight(self, weight) -> None: + self.mm_param.weight = weight.cuda(get_current_device_id()).transpose(0, 1) + return -class MMWeight: - def __new__(cls, **kwargs): - quant_cfg = kwargs.pop("quant_cfg", None) - layer_num_ = kwargs.pop("layer_num", None) - name = kwargs.pop("name", None) - quant_method, quantized_weight = cls._get_quant_method(quant_cfg, layer_num_, name) - kwargs["quant_method"] = quant_method - mmcls = cls._get_mmcls(quant_method, quantized_weight) - return mmcls(**kwargs) - - @classmethod - def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> QuantizationMethod: - if quant_cfg is None: - return None, False - quant_method = quant_cfg.get_quant_method(layer_num_, name) - if quant_method is None: - return None, False - quant_method.hf_quantization_config = quant_cfg.hf_quantization_config - quantized_weight = quant_cfg.quantized_weight - return quant_method, quantized_weight - - @classmethod - def _get_mmcls( - cls, quant_method: QuantizationMethod, quantized_weight: bool - ) -> Type[Union[MMWeightTpl, MultiMMWeightTpl, BMMWeightTpl]]: - raise NotImplementedError("Subclasses must implement _get_mmcls method") + +class AWQMMWeightTpl(SingleQuantizedMMWeightTpl): + def __init__( + self, + weight_name: str, + bias_name: Optional[str] = None, + data_type: torch.dtype = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__( + weight_name=weight_name, + bias_name=bias_name, + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + has_weight_scale=True, + has_weight_zero_point=True, + ) + + def _process_weight(self, weight: torch.Tensor) -> None: + self.mm_param.weight = weight.cuda(get_current_device_id()) + return + + def _process_weight_scale(self, weight_scale: torch.Tensor) -> None: + self.mm_param.weight_scale = weight_scale.to(self.data_type_).cuda(get_current_device_id()) + return + + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> None: + self.mm_param.weight_zero_point = weight_zero_point.cuda(get_current_device_id()) + return + + +class AWQMultiMMWeightTpl(MultiQuantizedMMWeightTpl): + def __init__( + self, + weight_names: List[str], + bias_names: Optional[List[str]] = None, + data_type: torch.dtype = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__( + weight_names=weight_names, + bias_names=bias_names, + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + has_weight_scale=True, + has_weight_zero_point=True, + ) + + def _process_weight(self, weight: torch.Tensor) -> None: + self.mm_param.weight = weight.cuda(get_current_device_id()) + return + + def _process_weight_scale(self, weight_scale: torch.Tensor) -> None: + self.mm_param.weight_scale = weight_scale.to(self.data_type_).cuda(get_current_device_id()) + return + + def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> None: + self.mm_param.weight_zero_point = weight_zero_point.cuda(get_current_device_id()) + return + + def load_hf_weights(self, weights): + self._load_weight(weights) + self._load_bias(weights) + self._load_weight_scale(weights) + self._load_weight_zero_point(weights) + # 由于awq的储存格式是inxout,所以拼接dim是 1 + self._fuse_weights(dim=1) + return diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index 5cfdcfd5b..9123c3e8e 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -1,49 +1,21 @@ import torch from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( - MMWeight, - MMWeightTpl, - BMMWeightTpl, + SingleMMWeightTpl, MultiMMWeightTpl, + DeepGemmFP8W8A8B128MMWeight, + DeepGemmFP8W8A8B128MultiMMWeight, AWQMMWeightTpl, AWQMultiMMWeightTpl, - generate_scale_name, + BMMWeightTpl, ) from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod from typing import Dict, List, Optional +from .mm_slicer import RowSliceMixin, QuantizedRowSliceMixin, QuantizedColSliceMixin -class ROWMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): - if quant_method is None or not quantized_weight: - return UnquantizedROWMMWeight - - return ROWBMM_WEIGHT_CLS_MAP[quant_method.method_name] - - -class MultiROWMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): - if quant_method is None or not quantized_weight: - return UnquantizedMultiROWMMWeight - - return MULTI_ROWBMM_WEIGHT_CLS_MAP[quant_method.method_name] - - -class ROWBMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): - if quant_method is None or not quantized_weight: - return UnquantizedROWBMMWeight - else: - return W8A8B128ROWBMMWeight - # TODO: Implement more quantization weight - return None - - -class UnquantizedROWMMWeight(MMWeightTpl): +class UnquantizedROWMMWeight(SingleMMWeightTpl): def __init__( self, weight_name: str, @@ -53,84 +25,18 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: - self.weight_name = weight_name - self.bias_name = bias_name - self.has_bias = bias_name is not None - super().__init__(data_type, quant_method, tp_rank, tp_world_size) - - def _slice_weight(self, weight: torch.Tensor): - assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}" - tp_size = weight.shape[0] // self.tp_world_size_ - return weight[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)].to(self.data_type_) - - def _slice_bias(self, bias): - assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" - tp_size = bias.shape[0] // self.tp_world_size_ - return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)].to(self.data_type_) - - -class W8A8B128ROWMMWeight(UnquantizedROWMMWeight): - def __init__( - self, - weight_name: str, - data_type: torch.dtype, - bias_name: Optional[str] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__(weight_name, data_type, bias_name, quant_method, tp_rank, tp_world_size) - - self.weight_scale_name, _ = generate_scale_name( - weight_name, quant_method.weight_scale_suffix, quant_method.act_scale_suffix + super().__init__( + weight_name=weight_name, + bias_name=bias_name, + data_type=data_type, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, ) - self.weight_scale: Optional[torch.Tensor] = None - self.quantized_weight = True - - def _slice_weight(self, weight: torch.Tensor): - assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}" - tp_size = weight.shape[0] // self.tp_world_size_ - return weight[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] - - def _slice_bias(self, bias): - assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" - tp_size = bias.shape[0] // self.tp_world_size_ - return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] - - def _slice_weight_scale(self, weight_scale: torch.Tensor): - assert ( - weight_scale.shape[0] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[0]} % {self.tp_world_size_}" - tp_size = weight_scale.shape[0] // self.tp_world_size_ - scale_start = tp_size * self.tp_rank_ - scale_end = tp_size * (self.tp_rank_ + 1) - return weight_scale.to(torch.float)[scale_start:scale_end] - - def _process_weight_scale(self, weight_scale) -> None: - self.weight_scale = weight_scale.cuda(get_current_device_id()).transpose(0, 1) - - def _process_weight(self, weight) -> None: - self.weight = weight.cuda(get_current_device_id()).transpose(0, 1) - - def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: - if self.weight_scale_name in weights: - weight_scale = weights[self.weight_scale_name] - weight_scale = self._slice_weight_scale(weight_scale) - self._process_weight_scale(weight_scale) - - if self.weight_scale is not None and isinstance(self.weight, torch.Tensor): - self.weight = [ - self.weight, - self.weight_scale, - None, # placeholder for input scale - ] - return + self.param_slicer = RowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) class UnquantizedMultiROWMMWeight(MultiMMWeightTpl): - _slice_weight = UnquantizedROWMMWeight._slice_weight - _slice_bias = UnquantizedROWMMWeight._slice_bias - def __init__( self, weight_names: str, @@ -140,85 +46,61 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: - super().__init__(weight_names, data_type, bias_names, quant_method, tp_rank, tp_world_size) - + super().__init__( + weight_names=weight_names, + data_type=data_type, + bias_names=bias_names, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) + self.param_slicer = RowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) -class W8A8B128MultiROWMMWeight(UnquantizedMultiROWMMWeight): - _slice_weight = W8A8B128ROWMMWeight._slice_weight - _slice_bias = W8A8B128ROWMMWeight._slice_bias - _slice_weight_scale = W8A8B128ROWMMWeight._slice_weight_scale +class DeepGemmFP8W8A8B128ROWMMWeight(DeepGemmFP8W8A8B128MMWeight): def __init__( self, - weight_names: str, + weight_name: str, data_type: torch.dtype, - bias_names: Optional[str] = None, + bias_name: Optional[str] = None, quant_method: QuantizationMethod = None, tp_rank: int = None, tp_world_size: int = None, ) -> None: - super().__init__(weight_names, data_type, bias_names, quant_method, tp_rank, tp_world_size) - self.weight_scale_names = [] - self.weight_scale: Optional[torch.Tensor] = None - self.weight_scales = [None] * len(self.weight_names) - for weight_name in weight_names: - weight_scale_name, act_scale_name = generate_scale_name( - weight_name, quant_method.weight_scale_suffix, quant_method.act_scale_suffix - ) - self.weight_scale_names.append(weight_scale_name) - self.quantized_weight = True - - def _load_scales(self, weights): - for i in range(len(self.weight_names)): - if self.weight_scale_names[i] in weights: - weight_scale = weights[self.weight_scale_names[i]] - weight_scale = self._slice_weight_scale(weight_scale) - self.weight_scales[i] = weight_scale - - def _process_weight_scale(self, weight_scale) -> None: - self.weight_scale = weight_scale.cuda(get_current_device_id()).transpose(0, 1) - - def _process_weight(self, weight) -> None: - self.weight = weight.cuda(get_current_device_id()).transpose(0, 1) - - def _fuse_weights(self) -> None: - super()._fuse_weights() - if self.weight_scale is None and (None not in self.weight_scales): - weight_scale = torch.cat(self.weight_scales, dim=0).cuda(get_current_device_id()) - self._process_weight_scale(weight_scale) - delattr(self, "weight_scales") - - if self.weight_scale is not None and isinstance(self.weight, torch.Tensor): - self.weight = [ - self.weight, - self.weight_scale, - None, - ] - + super().__init__( + weight_name=weight_name, + data_type=data_type, + bias_name=bias_name, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) + self.param_slicer = QuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) + return -class UnquantizedROWBMMWeight(BMMWeightTpl): - _slice_weight = UnquantizedROWMMWeight._slice_weight - _slice_bias = UnquantizedROWMMWeight._slice_bias +class DeepGemmFP8W8A8B128MultiROWMMWeight(DeepGemmFP8W8A8B128MultiMMWeight): def __init__( self, - weight_name: str, + weight_names: str, data_type: torch.dtype, - bias_name: Optional[str] = None, + bias_names: Optional[str] = None, quant_method: QuantizationMethod = None, tp_rank: int = None, tp_world_size: int = None, ) -> None: - self.weight_name = weight_name - self.bias_name = bias_name - self.has_bias = bias_name is not None - super().__init__(data_type, quant_method, tp_rank, tp_world_size) - + super().__init__( + weight_names=weight_names, + data_type=data_type, + bias_names=bias_names, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) + self.param_slicer = QuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) -class W8A8B128ROWBMMWeight(UnquantizedROWBMMWeight): - _slice_weight = W8A8B128ROWMMWeight._slice_weight - _slice_bias = W8A8B128ROWMMWeight._slice_bias +class UnquantizedROWBMMWeight(BMMWeightTpl): def __init__( self, weight_name: str, @@ -227,33 +109,16 @@ def __init__( quant_method: QuantizationMethod = None, tp_rank: int = None, tp_world_size: int = None, - weight_scale_suffix: Optional[str] = None, - act_scale_suffix: Optional[str] = None, ) -> None: - super().__init__(weight_name, data_type, bias_name, quant_method, tp_rank, tp_world_size) - self.weight_scale_name, self.act_scale_name = generate_scale_name( - weight_name, weight_scale_suffix, act_scale_suffix + super().__init__( + weight_name=weight_name, + data_type=data_type, + bias_name=bias_name, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, ) - self.weight_scale: Optional[torch.Tensor] = None - self.quantized_weight = True - - def _slice_weight_scale(self, weight_scale: torch.Tensor): - tp_size = weight_scale.shape[0] // self.tp_world_size_ - scale_start = tp_size * self.tp_rank_ - scale_end = tp_size * (self.tp_rank_ + 1) - return weight_scale[scale_start:scale_end].to(torch.float) - - def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None: - if self.weight_scale_name is not None and self.weight_scale_name in weights: - weight_scale = weights[self.weight_scale_name] - weight_scale = self._slice_weight_scale(weight_scale) - - if self.weight_scale is not None and isinstance(self.weight, torch.Tensor): - self.weight = [ - self.weight, - self.weight_scale, - None, - ] + self.param_slicer = RowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) class AWQROWMMWeight(AWQMMWeightTpl): @@ -266,44 +131,19 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: - super().__init__(data_type, quant_method, tp_rank, tp_world_size) - self.weight_name = weight_name.replace("weight", quant_method.weight_suffix) - self.weight_scale_name = weight_name.replace("weight", quant_method.weight_scale_suffix) - self.weight_zero_point_name = weight_name.replace("weight", quant_method.weight_zero_point_suffix) - self.bias_name = bias_name - self.weight_scale: Optional[torch.Tensor] = None - self.quantized_weight = True - self.weight = [None, None, None] - - def _slice_weight(self, weight: torch.Tensor): - assert weight.shape[1] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[1]} % {self.tp_world_size_}" - tp_size = weight.shape[1] // self.tp_world_size_ - return weight[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] - - def _slice_bias(self, bias): - assert bias.shape[1] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[1]} % {self.tp_world_size_}" - tp_size = bias.shape[1] // self.tp_world_size_ - return bias[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] - - def _slice_weight_scale(self, weight_scale: torch.Tensor): - tp_size = weight_scale.shape[1] // self.tp_world_size_ - scale_start = tp_size * self.tp_rank_ - scale_end = tp_size * (self.tp_rank_ + 1) - return weight_scale[:, scale_start:scale_end].to(torch.half) - - def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor): - tp_size = weight_zero_point.shape[1] // self.tp_world_size_ - zero_point_start = tp_size * self.tp_rank_ - zero_point_end = tp_size * (self.tp_rank_ + 1) - return weight_zero_point[:, zero_point_start:zero_point_end] + super().__init__( + weight_name=weight_name, + data_type=data_type, + bias_name=bias_name, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) + # 注意这里不是错误,因为awq的weight是按inxout存的 + self.param_slicer = QuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) class AWQMultiROWMMWeight(AWQMultiMMWeightTpl): - _slice_weight = AWQROWMMWeight._slice_weight - _slice_bias = AWQROWMMWeight._slice_bias - _slice_weight_scale = AWQROWMMWeight._slice_weight_scale - _slice_weight_zero_point = AWQROWMMWeight._slice_weight_zero_point - def __init__( self, weight_names: List[str], @@ -313,7 +153,16 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: - super().__init__(weight_names, data_type, bias_names, quant_method, tp_rank, tp_world_size) + super().__init__( + weight_names=weight_names, + data_type=data_type, + bias_names=bias_names, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) + # 注意这里不是错误,因为awq的weight是按inxout存的 + self.param_slicer = QuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) class AWQMARLINROWMMWeight(AWQROWMMWeight): @@ -368,14 +217,14 @@ def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.T ) -ROWBMM_WEIGHT_CLS_MAP = { - "deepgemm-fp8w8a8-b128": W8A8B128ROWMMWeight, +ROWMM_WEIGHT_CLS_MAP = { + "deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128ROWMMWeight, "awq": AWQROWMMWeight, "awq_marlin": AWQMARLINROWMMWeight, } -MULTI_ROWBMM_WEIGHT_CLS_MAP = { - "deepgemm-fp8w8a8-b128": W8A8B128MultiROWMMWeight, +MULTI_ROWMM_WEIGHT_CLS_MAP = { + "deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128MultiROWMMWeight, "awq": AWQMultiROWMMWeight, "awq_marlin": AWQMARLINMultiROWMMWeight, } diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 2e557e0e2..48167a067 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -2,7 +2,7 @@ # from lightllm.common.layers.mm import MM from .base_layer_weight import BaseLayerWeight -from .meta_weights import BaseWeight, MultiMMWeightTpl, AWQMultiMMWeightTpl +from .meta_weights import BaseWeight, MultiMMWeightTpl from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -36,7 +36,7 @@ def load_hf_weights(self, weights): """ for attr_name in dir(self): attr = getattr(self, attr_name, None) - if isinstance(attr, MultiMMWeightTpl) or isinstance(attr, AWQMultiMMWeightTpl): + if isinstance(attr, MultiMMWeightTpl): with self.lock: attr.load_hf_weights(weights) elif isinstance(attr, BaseWeight): diff --git a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py index 0d2a3084d..b6de41140 100644 --- a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py @@ -1,6 +1,6 @@ import torch import numpy as np -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import ROWMMWeight +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight from lightllm.common.basemodel.layer_weights.meta_weights.norm_weight import NormWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index a67cd5ac2..661c450f0 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -3,7 +3,7 @@ import numpy as np from lightllm.common.basemodel.layer_weights.meta_weights.gpt_oss_fused_moe_weight_tp import GPTOSSFusedMoeWeightTP -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import ROWMMWeight +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight from lightllm.common.basemodel.layer_weights.meta_weights.norm_weight import NormWeight, TpNormWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight from lightllm.utils.log_utils import init_logger From dfd2c022d01826e270d5028d2282be5ab1016f74 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 6 Nov 2025 22:41:33 +0800 Subject: [PATCH 09/10] fix awq --- .../layer_weights/meta_weights/__init__.py | 1 + .../meta_weights/mm_weight/__init__.py | 1 + .../meta_weights/mm_weight/colmm_weight.py | 23 ++++++++++++++---- .../meta_weights/mm_weight/mm_factory.py | 4 ++-- .../meta_weights/mm_weight/mm_weight.py | 8 +++---- .../meta_weights/mm_weight/rowmm_weight.py | 24 ++++++++++++++----- 6 files changed, 44 insertions(+), 17 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index ec9d707f8..396f1fc11 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -1,5 +1,6 @@ from .base_weight import BaseWeight from .mm_weight import ( + MMWeightPack, MMWeightTpl, MultiMMWeightTpl, ROWMMWeight, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py index ed82aa559..ea343b41d 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py @@ -1,4 +1,5 @@ from .mm_weight import ( + MMWeightPack, MMWeightTpl, MultiMMWeightTpl, AWQMultiMMWeightTpl, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py index a43933ce6..1b4e3e815 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py @@ -63,7 +63,14 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: - super().__init__(data_type, quant_method, tp_rank, tp_world_size) + super().__init__( + weight_name=weight_name, + data_type=data_type, + bias_name=bias_name, + quant_method=quant_method, + tp_rank=tp_rank, + tp_world_size=tp_world_size, + ) # 注意这里不是错误,因为awq的weight是按inxout存的 self.param_slicer = QuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) @@ -88,20 +95,26 @@ def __init__( ) def _process_weight(self, weight: torch.Tensor) -> torch.Tensor: - return self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id())) + new_weight = self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id())) + self.mm_param.weight = new_weight + return def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: - return self.quant_method._process_weight_scale_after_loading( + new_weight_scale = self.quant_method._process_weight_scale_after_loading( weight_scale.cuda(get_current_device_id()).to(self.data_type_) ) + self.mm_param.weight_scale = new_weight_scale + return def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: - return self.quant_method._process_weight_zero_point_after_loading( + new_weight_zero_point = self.quant_method._process_weight_zero_point_after_loading( weight_zero_point.cuda(get_current_device_id()) ) + self.mm_param.weight_zero_point = new_weight_zero_point + return -COLBMM_WEIGHT_CLS_MAP = { +COLMM_WEIGHT_CLS_MAP = { "deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128COLMMWeight, "awq": AWQCOLMMWeight, "awq_marlin": AWQMARLINCOLMMWeight, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py index 1b993c1fa..a6486bfa8 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py @@ -15,7 +15,7 @@ ) from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.colmm_weight import ( UnquantizedCOLMMWeight, - COLBMM_WEIGHT_CLS_MAP, + COLMM_WEIGHT_CLS_MAP, ) @@ -80,4 +80,4 @@ class COLMMWeight(MMWeight): def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool): if quant_method is None or not quantized_weight: return UnquantizedCOLMMWeight - return COLBMM_WEIGHT_CLS_MAP[quant_method.method_name] + return COLMM_WEIGHT_CLS_MAP[quant_method.method_name] diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index a56bedc50..a98c47f52 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -186,7 +186,7 @@ def __init__( has_weight_zero_point=has_weight_zero_point, ) self.weight_names = weight_names - self.bias_names = bias_names + self.bias_names = bias_names if bias_names is not None else [] self.mm_params: List[MMWeightPack] = [ MMWeightPack( weight=None, @@ -303,7 +303,7 @@ def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: self._process_weight_scale(weight_scale) def _load_weight_zero_point(self, weights: Dict[str, torch.Tensor]) -> None: - if self.weight_zero_point_name is not None and self.weight_zero_point_name in weights: + if self.mm_param.has_weight_zero_point and self.weight_zero_point_name in weights: weight_zero_point = weights[self.weight_zero_point_name] weight_zero_point = self.param_slicer._slice_weight_zero_point(weight_zero_point) self._process_weight_zero_point(weight_zero_point) @@ -380,7 +380,7 @@ def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: def _load_weight_zero_point(self, weights: Dict[str, torch.Tensor]) -> None: for i in range(len(self.weight_names)): - if self.weight_zero_point_names[i] is not None and self.weight_zero_point_names[i] in weights: + if self.mm_params[i].has_weight_zero_point and self.weight_zero_point_names[i] in weights: weight_zero_point = weights[self.weight_zero_point_names[i]] weight_zero_point = self.param_slicer._slice_weight_zero_point(weight_zero_point) self.mm_params[i].weight_zero_point = weight_zero_point @@ -474,7 +474,7 @@ def __init__( ) def _process_weight_scale(self, weight_scale) -> None: - self.mm_param.weight_scale = weight_scale.cuda(get_current_device_id()).transpose(0, 1) + self.mm_param.weight_scale = weight_scale.to(torch.float).cuda(get_current_device_id()).transpose(0, 1) return def _process_weight(self, weight) -> None: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index 9123c3e8e..599162c64 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -178,17 +178,23 @@ def __init__( super().__init__(weight_name, data_type, bias_name, quant_method, tp_rank, tp_world_size) def _process_weight(self, weight: torch.Tensor) -> torch.Tensor: - return self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id())) + new_weight = self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id())) + self.mm_param.weight = new_weight + return def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: - return self.quant_method._process_weight_scale_after_loading( + new_weight_scale = self.quant_method._process_weight_scale_after_loading( weight_scale.cuda(get_current_device_id()).to(self.data_type_) ) + self.mm_param.weight_scale = new_weight_scale + return def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: - return self.quant_method._process_weight_zero_point_after_loading( + new_weight_zero_point = self.quant_method._process_weight_zero_point_after_loading( weight_zero_point.cuda(get_current_device_id()) ) + self.mm_param.weight_zero_point = new_weight_zero_point + return class AWQMARLINMultiROWMMWeight(AWQMultiROWMMWeight): @@ -204,17 +210,23 @@ def __init__( super().__init__(weight_names, data_type, bias_names, quant_method, tp_rank, tp_world_size) def _process_weight(self, weight: torch.Tensor) -> torch.Tensor: - return self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id())) + new_weight = self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id())) + self.mm_param.weight = new_weight + return def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: - return self.quant_method._process_weight_scale_after_loading( + new_weight_scale = self.quant_method._process_weight_scale_after_loading( weight_scale.cuda(get_current_device_id()).to(self.data_type_) ) + self.mm_param.weight_scale = new_weight_scale + return def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: - return self.quant_method._process_weight_zero_point_after_loading( + new_weight_zero_point = self.quant_method._process_weight_zero_point_after_loading( weight_zero_point.cuda(get_current_device_id()) ) + self.mm_param.weight_zero_point = new_weight_zero_point + return ROWMM_WEIGHT_CLS_MAP = { From cf5d7fccb7f3ee61d4166d833111ceb1564a579e Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 6 Nov 2025 22:49:55 +0800 Subject: [PATCH 10/10] refactor quantization --- lightllm/common/quantization/awq_quant.py | 47 ++++++++--- .../common/quantization/deepgemm_quant.py | 37 ++++++--- .../common/quantization/quantize_method.py | 16 +++- lightllm/common/quantization/torchao_quant.py | 22 ++++-- .../quantization/triton_quant/triton_quant.py | 33 ++++++-- lightllm/common/quantization/w8a8_quant.py | 78 +++++++++++++------ 6 files changed, 177 insertions(+), 56 deletions(-) diff --git a/lightllm/common/quantization/awq_quant.py b/lightllm/common/quantization/awq_quant.py index c758c545b..6a00bcf80 100644 --- a/lightllm/common/quantization/awq_quant.py +++ b/lightllm/common/quantization/awq_quant.py @@ -6,6 +6,10 @@ from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops from typing import Any +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack if HAS_VLLM: awq_dequantize = vllm_ops.awq_dequantize @@ -35,12 +39,17 @@ def __init__(self): self.cache_manager = g_cache_manager def quantize(self, weight: torch.Tensor): - """ """ - pass + raise NotImplementedError("AWQ online quantization is not supported yet.") - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): - """ """ - pass + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + raise NotImplementedError("AWQ online quantization is not supported yet.") @property def method_name(self): @@ -63,8 +72,18 @@ def method_name(self): def quantize(self, weight: torch.Tensor): raise NotImplementedError("AWQ online quantization is not supported yet.") - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): - qweight, weight_scale, qzeros = weights + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + qzeros = weight_pack.weight_zero_point + bias = weight_pack.bias NEED_DEQUANT_WEIGHT = input_tensor.shape[:-1].numel() >= 256 if NEED_DEQUANT_WEIGHT: @@ -128,8 +147,18 @@ def _process_weight_zero_point_after_loading(self, weight_zero_point: torch.Tens num_bits=self.hf_quantization_config["bits"], ) - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): - qweight, weight_scale, qzeros = weights + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + qzeros = weight_pack.weight_zero_point + bias = weight_pack.bias reshaped_x = input_tensor.reshape(-1, input_tensor.shape[-1]) use_atomic_add = should_use_atomic_add_reduce( diff --git a/lightllm/common/quantization/deepgemm_quant.py b/lightllm/common/quantization/deepgemm_quant.py index 6816c8f51..7d3fc5358 100644 --- a/lightllm/common/quantization/deepgemm_quant.py +++ b/lightllm/common/quantization/deepgemm_quant.py @@ -7,7 +7,10 @@ per_token_group_quant_fp8, tma_align_input_scale, ) +from typing import TYPE_CHECKING, Optional +if TYPE_CHECKING: + from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack try: HAS_DEEPGEMM = True import deep_gemm @@ -27,9 +30,15 @@ def quantize(self, weight: torch.Tensor): """ """ pass - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): - """ """ - pass + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + raise NotImplementedError("Not implemented") @property def method_name(self): @@ -41,8 +50,9 @@ class DeepGEMMFP8w8a8B128QuantizationMethod(DeepGEMMBaseQuantizationMethod): def __init__(self): super().__init__() self.block_size = 128 + self.weight_suffix = None + self.weight_zero_point_suffix = None self.weight_scale_suffix = "weight_scale_inv" - self.act_scale_suffix = None # no support for static input tensor scale for ds model. @property def method_name(self): @@ -53,15 +63,20 @@ def quantize(self, weight: torch.Tensor): return weight_quant(weight, self.block_size) - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): - if len(weights) == 3: - qweight, weight_scale, input_scale = weights - else: - qweight, weight_scale = weights - input_scale = None + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + input_scale = None alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty m, k = input_tensor.shape - n = weights[0].shape[1] + n = qweight.shape[1] if input_scale is None: qinput_tensor, input_scale = per_token_group_quant_fp8( input_tensor, diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index 80dad1fe2..5a7db15fc 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -1,13 +1,19 @@ import torch from abc import ABC, abstractmethod from lightllm.utils.dist_utils import get_current_device_id +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack class QuantizationMethod(ABC): def __init__(self): super().__init__() self.device_id_ = get_current_device_id() + self.weight_suffix = None self.weight_scale_suffix = None + self.weight_zero_point_suffix = None self.act_scale_suffix = None @abstractmethod @@ -15,7 +21,15 @@ def quantize(self, weights: torch.Tensor): pass @abstractmethod - def apply(self, input_tensor, weight, bias=None, out=None, use_custom_tensor_mananger=True): + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + bias: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: pass @property diff --git a/lightllm/common/quantization/torchao_quant.py b/lightllm/common/quantization/torchao_quant.py index 1a75492b1..df8d1319d 100644 --- a/lightllm/common/quantization/torchao_quant.py +++ b/lightllm/common/quantization/torchao_quant.py @@ -3,12 +3,13 @@ from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS import torch.nn.functional as F +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack try: HAS_TORCH_AO = True - from torchao.dtypes import to_affine_quantized_intx, AffineQuantizedTensor - from torchao.dtypes import TensorCoreTiledLayoutType - from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.quantization import ( int4_weight_only, int8_weight_only, @@ -38,9 +39,18 @@ def quantize(self, weight: torch.Tensor): dummy_linear = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) dummy_linear.weight = torch.nn.Parameter(weight.cuda(self.device_id_)) quantize_(dummy_linear, self.quant_func) - return dummy_linear.weight - - def apply(self, input_tensor, weights, bias=None, out=None, use_custom_tensor_mananger=True): + return dummy_linear.weight, None, None + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + weights = weight_pack.weight + bias = weight_pack.bias return F.linear(input_tensor, weights, bias) @property diff --git a/lightllm/common/quantization/triton_quant/triton_quant.py b/lightllm/common/quantization/triton_quant/triton_quant.py index a8d6a0055..a79e3f65a 100644 --- a/lightllm/common/quantization/triton_quant/triton_quant.py +++ b/lightllm/common/quantization/triton_quant/triton_quant.py @@ -5,6 +5,10 @@ from lightllm.common.quantization.registry import QUANTMETHODS from .fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul from .fp8.fp8act_quant_kernel import per_token_group_quant_fp8 +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack class TritonBaseQuantizationMethod(QuantizationMethod): @@ -15,12 +19,17 @@ def __init__(self): self.cache_manager = g_cache_manager def quantize(self, weight: torch.Tensor): - """ """ pass - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): - """ """ - pass + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + raise NotImplementedError("Not implemented") @QUANTMETHODS.register(["triton-fp8w8a8-block128"]) @@ -29,13 +38,25 @@ def __init__(self): super().__init__() self.is_moe = False self.block_size = 128 + self.weight_suffix = None + self.weight_zero_point_suffix = None + self.weight_scale_suffix = "weight_scale_inv" def quantize(self, weight: torch.Tensor): # TODO block-wise quant kernel pass - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): - qweight, weight_scale, input_scale = weights + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + input_scale = None m, k = input_tensor.shape n = qweight.shape[1] alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty diff --git a/lightllm/common/quantization/w8a8_quant.py b/lightllm/common/quantization/w8a8_quant.py index c07f6b208..ea5b66bce 100644 --- a/lightllm/common/quantization/w8a8_quant.py +++ b/lightllm/common/quantization/w8a8_quant.py @@ -3,11 +3,15 @@ from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS import torch.nn.functional as F +from typing import Optional, TYPE_CHECKING from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8 from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops +if TYPE_CHECKING: + from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack + if HAS_LIGHTLLM_KERNEL: def scaled_fp8_quant(tensor, *args, **kwargs): @@ -27,12 +31,17 @@ def __init__(self): self.cache_manager = g_cache_manager def quantize(self, weight: torch.Tensor): - """ """ pass - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): - """ """ - pass + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + raise NotImplementedError("Not implemented") @property def method_name(self): @@ -51,17 +60,21 @@ def quantize(self, weight: torch.Tensor): scale = weight.abs().max(dim=-1)[0] / 127 weight = weight.transpose(0, 1) / scale.reshape(1, -1) weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8) - return weight.cuda(self.device_id_), scale.cuda(self.device_id_) - - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): + return weight.cuda(self.device_id_), scale.cuda(self.device_id_), None + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: input_scale = None - if len(weights) == 3: - qweight, weight_scale, input_scale = weights - elif len(weights) == 2: - qweight, weight_scale = weights - else: - raise ValueError("vllm-quant Weights must be a tuple of length 2 or 3.") - + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + bias = weight_pack.bias + input_scale = None # dynamic quantization for input tensor x_q, x_scale, x_zp = vllm_ops.scaled_int8_quant(input_tensor, scale=input_scale, azp=None, symmetric=True) m = input_tensor.shape[0] n = qweight.shape[1] @@ -92,9 +105,9 @@ def quantize(self, weight: torch.Tensor): qweight, weight_scale = scaled_fp8_quant( weight.contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True ) - return qweight.transpose(0, 1), weight_scale + return qweight.transpose(0, 1), weight_scale, None - def quantize_moe(self, weight): + def quantize_moe(self, weight: torch.Tensor): num_experts = weight.shape[0] qweights = [] weight_scales = [] @@ -108,10 +121,20 @@ def quantize_moe(self, weight): weight_scale = torch.stack(weight_scales, dim=0).contiguous() return qweights, weight_scale - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + bias = weight_pack.bias x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) m = input_tensor.shape[0] - n = weights[0].shape[1] + n = qweight.shape[1] if out is None: if use_custom_tensor_mananger: out = self.cache_manager.alloc_tensor( @@ -119,7 +142,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ ) else: out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias) + cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) return out @property @@ -133,16 +156,25 @@ def __init__(self): super().__init__() self.block_size = 128 self.weight_scale_suffix = "weight_scale_inv" - self.act_scale_suffix = None # no support for static input tensor scale for ds model. def quantize(self, weight: torch.Tensor): raise Exception("Not implemented") - def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): - qweight, weight_scale, input_scale = weights + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "MMWeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + ) -> torch.Tensor: + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + bias = weight_pack.bias + input_scale = None # dynamic quantization for input tensor m, k = input_tensor.shape - n = weights[0].shape[1] + n = qweight.shape[1] alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty if input_scale is None: qinput_tensor, input_scale = per_token_group_quant_fp8(