Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
COLMMWeight,
MultiROWMMWeight,
ROWBMMWeight,
AWQMultiMMWeightTpl,
)
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
from .fused_moe_weight_tp import FusedMoeWeightTP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,14 @@ def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
# load quantization scale
pass

def _load_zero_points(self, weights: Dict[str, torch.Tensor]) -> None:
# load quantization zero points
pass

def load_hf_weights(self, weights):
self._load_weights(weights)
self._load_scales(weights)
self._load_zero_points(weights)
return

def verify_load(self):
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .mm_weight import (
MMWeightTpl,
MultiMMWeightTpl,
AWQMultiMMWeightTpl,
)
from .rowmm_weight import (
ROWMMWeight,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
MMWeight,
MMWeightTpl,
generate_scale_name,
AWQMMWeightTpl,
)
from lightllm.common.quantization import Quantcfg
from lightllm.utils.dist_utils import get_current_device_id
Expand All @@ -15,8 +16,7 @@ class COLMMWeight(MMWeight):
def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool):
if quant_method is None or not quantized_weight:
return UnquantizedCOLMMWeight
else:
return W8A8B128COLMMWeight
return COLBMM_WEIGHT_CLS_MAP[quant_method.get_name()]
Comment on lines 16 to +19

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Consider adding a default case or an error message if the quant_method.get_name() doesn't match any key in COLBMM_WEIGHT_CLS_MAP. This can prevent unexpected behavior if a new quantization method is added but not included in the map.

Suggested change
if quant_method is None or not quantized_weight:
return UnquantizedCOLMMWeight
else:
return W8A8B128COLMMWeight
return COLBMM_WEIGHT_CLS_MAP[quant_method.get_name()]
if quant_method is None or not quantized_weight:
return UnquantizedCOLMMWeight
return COLBMM_WEIGHT_CLS_MAP.get(quant_method.get_name(), None) # None or raise error



class UnquantizedCOLMMWeight(MMWeightTpl):
Expand Down Expand Up @@ -97,3 +97,78 @@ def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
None,
]
return


class AWQCOLMMWeight(AWQMMWeightTpl):
def __init__(
self,
weight_name: str,
data_type: torch.dtype,
bias_name: Optional[str] = None,
quant_method: QuantizationMethod = None,
tp_rank: int = None,
tp_world_size: int = None,
) -> None:
super().__init__(data_type, quant_method, tp_rank, tp_world_size)
self.weight_name = weight_name.replace("weight", quant_method.weight_suffix)
self.weight_scale_name = weight_name.replace("weight", quant_method.weight_scale_suffix)
self.weight_zero_point_name = weight_name.replace("weight", quant_method.weight_zero_point_suffix)
self.bias_name = bias_name
self.weight_scale: Optional[torch.Tensor] = None
self.quantized_weight = True
self.weight = [None, None, None]

def _slice_weight(self, weight: torch.Tensor):
assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}"
tp_size = weight.shape[0] // self.tp_world_size_
return weight[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1), :]

def _slice_bias(self, bias):
assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}"
tp_size = bias.shape[0] // self.tp_world_size_
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1), :]

def _slice_weight_scale(self, weight_scale: torch.Tensor):
tp_size = weight_scale.shape[0] // self.tp_world_size_
scale_start = tp_size * self.tp_rank_
scale_end = tp_size * (self.tp_rank_ + 1)
return weight_scale[scale_start:scale_end, :].to(torch.half)

def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor):
tp_size = weight_zero_point.shape[0] // self.tp_world_size_
zero_point_start = tp_size * self.tp_rank_
zero_point_end = tp_size * (self.tp_rank_ + 1)
return weight_zero_point[zero_point_start:zero_point_end, :]


class AWQMARLINCOLMMWeight(AWQCOLMMWeight):
def __init__(
self,
weight_name: str,
data_type: torch.dtype,
bias_name: Optional[str] = None,
quant_method: QuantizationMethod = None,
tp_rank: int = None,
tp_world_size: int = None,
) -> None:
super().__init__(weight_name, data_type, bias_name, quant_method, tp_rank, tp_world_size)

def _process_weight(self, weight: torch.Tensor) -> torch.Tensor:
return self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id()))

def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
return self.quant_method._process_weight_scale_after_loading(
weight_scale.cuda(get_current_device_id()).to(self.data_type_)
)

def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
return self.quant_method._process_weight_zero_point_after_loading(
weight_zero_point.cuda(get_current_device_id())
)


COLBMM_WEIGHT_CLS_MAP = {
"fp8w8a8b128": W8A8B128COLMMWeight,
"awq": AWQCOLMMWeight,
"awq_marlin": AWQMARLINCOLMMWeight,
}
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,137 @@ def _process_weight(self, weight) -> None:
self.weight = weight.cuda(get_current_device_id())


class AWQMMWeightTpl(MMWeightTpl):
def __init__(
self,
data_type: torch.dtype,
quant_method: QuantizationMethod = None,
tp_rank: int = None,
tp_world_size: int = None,
) -> None:
super().__init__(data_type, quant_method, tp_rank, tp_world_size)
self.weight = [None, None, None]

def verify_load(self) -> bool:
load_ok = True
# Verify weight. The weight must be not None.
weight_ok = all(w is not None for w in self.weight)
load_ok = load_ok and weight_ok
# Verify bias. If bias_name is set, it must be not None.
if self.has_bias:
load_ok = load_ok and self.bias is not None
return load_ok

def _process_weight(self, weight: torch.Tensor) -> torch.Tensor:
return weight.cuda(get_current_device_id())

def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
return weight_scale.cuda(get_current_device_id()).to(self.data_type_)

def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
return weight_zero_point.cuda(get_current_device_id())

def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None:
if self.weight_name is not None and self.weight_name in weights:
weight = weights[self.weight_name]
weight = self._slice_weight(weight)
self.weight[0] = self._process_weight(weight)
if self.bias_name is not None and self.bias_name in weights:
bias = weights[self.bias_name]
bias = self._slice_bias(bias)
self.bias = bias.cuda(get_current_device_id())

def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
if self.weight_scale_name is not None and self.weight_scale_name in weights:
weight_scale = weights[self.weight_scale_name]
weight_scale = self._slice_weight_scale(weight_scale)
self.weight[1] = self._process_weight_scale(weight_scale)

def _load_zero_points(self, weights: Dict[str, torch.Tensor]) -> None:
if self.weight_zero_point_name is not None and self.weight_zero_point_name in weights:
weight_zero_point = weights[self.weight_zero_point_name]
weight_zero_point = self._slice_weight_zero_point(weight_zero_point)
self.weight[2] = self._process_weight_zero_point(weight_zero_point)


class AWQMultiMMWeightTpl(AWQMMWeightTpl):
def __init__(
self,
weight_names: List[str],
data_type: torch.dtype,
bias_names: Optional[List[str]] = None,
quant_method: QuantizationMethod = None,
tp_rank: int = None,
tp_world_size: int = None,
) -> None:
super().__init__(data_type, quant_method, tp_rank, tp_world_size)
self.weight_names = [weight.replace("weight", quant_method.weight_suffix) for weight in weight_names]
self.weight_scale_names = [
weight.replace("weight", quant_method.weight_scale_suffix) for weight in weight_names
]
self.weight_zero_point_names = [
weight.replace("weight", quant_method.weight_zero_point_suffix) for weight in weight_names
]
self.bias_names = bias_names
self.weights = [None] * len(self.weight_names)
self.weight_scales = [None] * len(self.weight_names)
self.weight_zero_points = [None] * len(self.weight_names)
if self.bias_names is not None:
self.biases = [None] * len(self.bias_names)
self.has_bias = all(b is not None for b in self.bias_names) and len(bias_names) > 0
else:
self.biases = None
self.has_bias = False

def _fuse(self) -> None:
if self.weight[0] is None and (None not in self.weights):
weight = torch.cat(self.weights, dim=1)
self.weight[0] = self._process_weight(weight)
delattr(self, "weights")

if self.weight[1] is None and (None not in self.weight_scales):
# awq 保存的量化参数,weight shape 是 in x out。所以这里的cat dim 是 1
weight_scale = torch.cat(self.weight_scales, dim=1).cuda(get_current_device_id())
self.weight[1] = self._process_weight_scale(weight_scale)
delattr(self, "weight_scales")

if self.weight[2] is None and (None not in self.weight_zero_points):
weight_zero_point = torch.cat(self.weight_zero_points, dim=1)
self.weight[2] = self._process_weight_zero_point(weight_zero_point)
delattr(self, "weight_zero_points")

if self.has_bias and self.bias is None and (None not in self.biases):
self.bias = torch.cat(self.biases, dim=0).cuda(get_current_device_id())
delattr(self, "biases")
return self

def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None:
for i in range(len(self.weight_names)):
if self.weight_names[i] is not None and self.weight_names[i] in weights:
weight = weights[self.weight_names[i]]
weight = self._slice_weight(weight)
self.weights[i] = weight.cuda(get_current_device_id())

def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
for i in range(len(self.weight_names)):
if self.weight_scale_names[i] is not None and self.weight_scale_names[i] in weights:
weight_scale = weights[self.weight_scale_names[i]]
weight_scale = self._slice_weight_scale(weight_scale)
self.weight_scales[i] = weight_scale.cuda(get_current_device_id()).to(self.data_type_)

def _load_zero_points(self, weights: Dict[str, torch.Tensor]) -> None:
for i in range(len(self.weight_names)):
if self.weight_zero_point_names[i] is not None and self.weight_zero_point_names[i] in weights:
weight_zero_point = weights[self.weight_zero_point_names[i]]
weight_zero_point = self._slice_weight_zero_point(weight_zero_point)
self.weight_zero_points[i] = weight_zero_point.cuda(get_current_device_id())

def load_hf_weights(self, weights):
super().load_hf_weights(weights)
self._fuse()
return


class MMWeight:
def __new__(cls, **kwargs):
quant_cfg = kwargs.pop("quant_cfg", None)
Expand All @@ -178,6 +309,9 @@ def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> Q
if quant_cfg is None:
return None, False
quant_method = quant_cfg.get_quant_method(layer_num_, name)
if quant_method is None:
return None, False
quant_method.hf_quantization_config = quant_cfg.hf_quantization_config
quantized_weight = quant_cfg.quantized_weight
return quant_method, quantized_weight

Expand Down
Loading