Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase):

# Function `get_and_maybe_dequant_weights` will cast the weights to
# FRACTAL_AND. So we need to cast to FRACTAL_NZ again.
if is_enable_nz():
if is_enable_nz(self.kv_b_proj.weight.data.dtype):
self.kv_b_proj.weight.data = torch_npu.npu_format_cast(
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ)

Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def pad_qkv_weight(self, data):
dim=2)
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)

if is_enable_nz():
if is_enable_nz(qkv_weight_final.dtype):
qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_(
qkv_weight_final)
qkv_weight_final_copy = torch_npu.npu_format_cast(
Expand All @@ -300,7 +300,7 @@ def pad_proj_weight(self, data):
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
self.hidden_size, -1)

if is_enable_nz():
if is_enable_nz(out_weight.dtype):
out_weight_copy = torch.empty_like(out_weight).copy_(out_weight)
out_weight_copy = torch_npu.npu_format_cast(
out_weight_copy, ACL_FORMAT_FRACTAL_ND)
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def pad_qkv_weight(self, data):
dim=2)
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)

if is_enable_nz():
if is_enable_nz(qkv_weight_final.dtype):
qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_(
qkv_weight_final)
qkv_weight_final_copy = torch_npu.npu_format_cast(
Expand All @@ -284,7 +284,7 @@ def pad_proj_weight(self, data):
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
self.hidden_size, -1)

if is_enable_nz():
if is_enable_nz(out_weight.dtype):
out_weight_copy = torch.empty_like(out_weight).copy_(out_weight)
out_weight_copy = torch_npu.npu_format_cast(
out_weight_copy, ACL_FORMAT_FRACTAL_ND)
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def process_weights_after_loading(self, layer):
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)

if not is_310p() and is_enable_nz():
if not is_310p() and is_enable_nz(layer.w13_weight.data.dtype):
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
if (is_enable_nz() and layer.weight.data.dtype
in [torch.float16, torch.bfloat16]):
if (is_enable_nz(layer.weight.data.dtype)):
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)

Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/torchair/torchair_sfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv,
block_size=(16, 32)).unsqueeze(0).contiguous()
if is_enable_nz():
if is_enable_nz(wd_qkv.dtype):
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)

kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone()
Expand Down Expand Up @@ -876,7 +876,7 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
-1)
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
if is_enable_nz():
if is_enable_nz(wu_q.dtype):
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)

qb_deq_scl = self.q_proj.deq_scale.data.clone()
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/torchair/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def converting_weight_acl_format(model, format):
if torch_npu.get_npu_format(module.w13_weight.data) == format:
return
if format == ACL_FORMAT_FRACTAL_NZ \
and not is_enable_nz():
and not is_enable_nz(module.w13_weight.data.dtype):
return
module.w13_weight.data = torch_npu.npu_format_cast(
module.w13_weight.data, format)
Expand Down
5 changes: 4 additions & 1 deletion vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,16 @@ def is_310p():
return _IS_310P


def is_enable_nz(vllm_config: Optional[VllmConfig] = None) -> bool:
def is_enable_nz(dtype: Optional[torch.dtype] = torch.int8,
vllm_config: Optional[VllmConfig] = None) -> bool:
global _ENABLE_NZ
if _ENABLE_NZ is None:
if not vllm_config:
raise ValueError(
"vllm_config must be provided when _ENABLE_NZ is None")
_ENABLE_NZ = envs_ascend.VLLM_ASCEND_ENABLE_NZ and vllm_config.model_config.hf_config.model_type != "qwen3_next"
if dtype in [torch.float16, torch.bfloat16]:
return False
return _ENABLE_NZ


Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2676,7 +2676,7 @@ def load_model(self) -> None:

def _convert_torch_format(self, tensor):
if ACL_FORMAT == ACL_FORMAT_FRACTAL_NZ \
and not is_enable_nz():
and not is_enable_nz(tensor.dtype):
return tensor
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
return tensor
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
# register patch for vllm
from vllm_ascend.utils import adapt_patch
adapt_patch()
is_enable_nz(vllm_config)
is_enable_nz(vllm_config=vllm_config)
# Register ops when worker init.
from vllm_ascend import ops
ops.register_dummy_fusion_op()
Expand Down
Loading