Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions vllm_ascend/eplb/adaptor/vllm_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,20 +194,34 @@ def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str):
json.dump(record, f, indent=4)

def do_update_expert_map(self, layer_id, updated_expert_map):
self.expert_map_per_layer[layer_id] = updated_expert_map.clone()
self.expert_map_per_layer_cpu[layer_id] = updated_expert_map.clone()
pad_len = self.expert_map_per_layer[layer_id].shape[
0] - updated_expert_map.shape[0]
updated_expert_map_padded = torch.nn.functional.pad(updated_expert_map,
pad=(0, pad_len),
mode='constant',
value=-1)
self.expert_map_per_layer[layer_id].copy_(updated_expert_map_padded)
self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The update to self.expert_map_per_layer_cpu using copy_ might lead to a runtime error if updated_expert_map has a different shape than the existing tensor in self.expert_map_per_layer_cpu[layer_id]. The logic to calculate pad_len just before this suggests that the shape of updated_expert_map can indeed be smaller. While the device tensor self.expert_map_per_layer is correctly updated with a padded version, this CPU copy is not, which is inconsistent and unsafe.

To prevent potential crashes, I recommend reverting this line to use clone(), which was the original implementation and is safer as it replaces the tensor entirely, handling any shape changes gracefully.

Suggested change
self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map)
self.expert_map_per_layer_cpu[layer_id] = updated_expert_map.clone()


def do_update_expert_weight(self, layer_id, local_expert_to_replace,
buffer_tensor_id):
for expert_tensor, buffer_tensor in zip(
self.expert_param_per_layer[layer_id][local_expert_to_replace],
self.buffer_tensor_list[buffer_tensor_id]):
expert_tensor = buffer_tensor.clone()
expert_tensor.copy_(buffer_tensor)
logger.debug(f"Expert tensor shape is :{expert_tensor.shape}")

def do_update_log2phy_map(self, layer_id, updated_log2phy_map):
if self.log2phy_map_per_layer[layer_id] is not None:
self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map)
pad_len = self.log2phy_map_per_layer[layer_id].shape[
0] - updated_log2phy_map.shape[0]
updated_log2phy_map_padded = torch.nn.functional.pad(
updated_log2phy_map,
pad=(0, pad_len),
mode='constant',
value=-1)
self.log2phy_map_per_layer[layer_id].copy_(
updated_log2phy_map_padded)

def global2local(self, placement: torch.Tensor,
E_local: int) -> torch.Tensor:
Expand Down
4 changes: 0 additions & 4 deletions vllm_ascend/eplb/core/eplb_device_transfer_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ def generate_expert_d2d_transfer_task(self, expert_send_info,
)
return

# If neither send nor receive task is needed for this layer on this rank, return
if not (expert_send_info or expert_recv_info):
return

self.updated_expert_map = updated_expert_map

self.layer_id = layer_id
Expand Down
5 changes: 4 additions & 1 deletion vllm_ascend/ops/moe/moe_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,17 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list=group_list,
output_dtype=torch.int32)[0]
# act_fn: swiglu
group_diff = torch.diff(group_list)
new_group = torch.cat([group_list[0].unsqueeze(0), group_diff],
dim=0)
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale,
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
group_index=new_group,
activate_left=True,
quant_mode=1,
)
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def apply(
return moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale_fp32,
w1_scale=layer.w13_weight_scale.to(torch.float32),
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
Expand Down
Loading