From a05a02f6df2c95997e7246bc82a8007fd533cfb8 Mon Sep 17 00:00:00 2001 From: Che Ruan Date: Thu, 27 Nov 2025 15:12:26 +0800 Subject: [PATCH 1/2] quick fix eplb Signed-off-by: Che Ruan --- vllm_ascend/eplb/adaptor/vllm_adaptor.py | 22 +++++++++++++++---- .../eplb/core/eplb_device_transfer_loader.py | 4 ---- vllm_ascend/ops/moe/moe_mlp.py | 4 +++- vllm_ascend/quantization/w8a8_dynamic.py | 2 +- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index 726763013f4..1fb17c42fc8 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -194,20 +194,34 @@ def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str): json.dump(record, f, indent=4) def do_update_expert_map(self, layer_id, updated_expert_map): - self.expert_map_per_layer[layer_id] = updated_expert_map.clone() - self.expert_map_per_layer_cpu[layer_id] = updated_expert_map.clone() + pad_len = self.expert_map_per_layer[layer_id].shape[0] - updated_expert_map.shape[0] + updated_expert_map_padded = torch.nn.functional.pad( + updated_expert_map, + pad=(0,pad_len), + mode='constant', + value=-1 + ) + self.expert_map_per_layer[layer_id].copy_(updated_expert_map_padded) + self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map) def do_update_expert_weight(self, layer_id, local_expert_to_replace, buffer_tensor_id): for expert_tensor, buffer_tensor in zip( self.expert_param_per_layer[layer_id][local_expert_to_replace], self.buffer_tensor_list[buffer_tensor_id]): - expert_tensor = buffer_tensor.clone() + expert_tensor.copy_(buffer_tensor) logger.debug(f"Expert tensor shape is :{expert_tensor.shape}") def do_update_log2phy_map(self, layer_id, updated_log2phy_map): if self.log2phy_map_per_layer[layer_id] is not None: - self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map) + pad_len = self.log2phy_map_per_layer[layer_id].shape[0] - updated_log2phy_map.shape[0] + updated_log2phy_map_padded = torch.nn.functional.pad( + updated_log2phy_map, + pad=(0,pad_len), + mode='constant', + value=-1 + ) + self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map_padded) def global2local(self, placement: torch.Tensor, E_local: int) -> torch.Tensor: diff --git a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py index 9a8a323f718..67e4d562f90 100644 --- a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py +++ b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py @@ -50,10 +50,6 @@ def generate_expert_d2d_transfer_task(self, expert_send_info, ) return - # If neither send nor receive task is needed for this layer on this rank, return - if not (expert_send_info or expert_recv_info): - return - self.updated_expert_map = updated_expert_map self.layer_id = layer_id diff --git a/vllm_ascend/ops/moe/moe_mlp.py b/vllm_ascend/ops/moe/moe_mlp.py index 5ee7d70d549..32655cd7da9 100644 --- a/vllm_ascend/ops/moe/moe_mlp.py +++ b/vllm_ascend/ops/moe/moe_mlp.py @@ -105,6 +105,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor, group_list=group_list, output_dtype=torch.int32)[0] # act_fn: swiglu + group_diff = torch.diff( group_list) + new_group = torch.cat( [ group_list[0].unsqueeze(0), group_diff ],dim=0) hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( x=hidden_states, weight_scale=w1_scale, @@ -112,7 +114,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor, bias=None, quant_scale=None, quant_offset=None, - group_index=group_list, + group_index=new_group, activate_left=True, quant_mode=1, ) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 0f96e8cf541..6e3f42d00eb 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -249,7 +249,7 @@ def apply( return moe_comm_method.fused_experts( hidden_states=x, w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale_fp32, + w1_scale=layer.w13_weight_scale.to(torch.float32), w2=layer.w2_weight, w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, From 59b55368ce964dfa7eff5528e6cbf35db70f2dab Mon Sep 17 00:00:00 2001 From: Che Ruan Date: Thu, 27 Nov 2025 15:40:22 +0800 Subject: [PATCH 2/2] format fix Signed-off-by: Che Ruan --- vllm_ascend/eplb/adaptor/vllm_adaptor.py | 28 ++++++++++++------------ vllm_ascend/ops/moe/moe_mlp.py | 5 +++-- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index 1fb17c42fc8..4e1dab3edb5 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -194,13 +194,12 @@ def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str): json.dump(record, f, indent=4) def do_update_expert_map(self, layer_id, updated_expert_map): - pad_len = self.expert_map_per_layer[layer_id].shape[0] - updated_expert_map.shape[0] - updated_expert_map_padded = torch.nn.functional.pad( - updated_expert_map, - pad=(0,pad_len), - mode='constant', - value=-1 - ) + pad_len = self.expert_map_per_layer[layer_id].shape[ + 0] - updated_expert_map.shape[0] + updated_expert_map_padded = torch.nn.functional.pad(updated_expert_map, + pad=(0, pad_len), + mode='constant', + value=-1) self.expert_map_per_layer[layer_id].copy_(updated_expert_map_padded) self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map) @@ -214,14 +213,15 @@ def do_update_expert_weight(self, layer_id, local_expert_to_replace, def do_update_log2phy_map(self, layer_id, updated_log2phy_map): if self.log2phy_map_per_layer[layer_id] is not None: - pad_len = self.log2phy_map_per_layer[layer_id].shape[0] - updated_log2phy_map.shape[0] + pad_len = self.log2phy_map_per_layer[layer_id].shape[ + 0] - updated_log2phy_map.shape[0] updated_log2phy_map_padded = torch.nn.functional.pad( - updated_log2phy_map, - pad=(0,pad_len), - mode='constant', - value=-1 - ) - self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map_padded) + updated_log2phy_map, + pad=(0, pad_len), + mode='constant', + value=-1) + self.log2phy_map_per_layer[layer_id].copy_( + updated_log2phy_map_padded) def global2local(self, placement: torch.Tensor, E_local: int) -> torch.Tensor: diff --git a/vllm_ascend/ops/moe/moe_mlp.py b/vllm_ascend/ops/moe/moe_mlp.py index 32655cd7da9..2c734211f84 100644 --- a/vllm_ascend/ops/moe/moe_mlp.py +++ b/vllm_ascend/ops/moe/moe_mlp.py @@ -105,8 +105,9 @@ def quant_apply_mlp(hidden_states: torch.Tensor, group_list=group_list, output_dtype=torch.int32)[0] # act_fn: swiglu - group_diff = torch.diff( group_list) - new_group = torch.cat( [ group_list[0].unsqueeze(0), group_diff ],dim=0) + group_diff = torch.diff(group_list) + new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], + dim=0) hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( x=hidden_states, weight_scale=w1_scale,