Skip to content

Commit 7c03bcb

Browse files
[https://nvbugs/5613089][fix] Fix the rank to access all_rank_chunk_size_list when chunked MoE is used
Signed-off-by: Jinyang Yuan <[email protected]>
1 parent ae57738 commit 7c03bcb

File tree

6 files changed

+10
-9
lines changed

6 files changed

+10
-9
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,10 +603,10 @@ def choose_one(
603603
new_tuning_failure_occured = False
604604

605605
for p in profiles:
606-
tensors = self._prepare_input_tensors(p, inputs)
607606
is_cache_hit, *_ = self.profiling_cache.search_cache(
608607
custom_op, runners, p.get_opt_shapes(), tuning_config)
609608
if not is_cache_hit:
609+
tensors = self._prepare_input_tensors(p, inputs)
610610
# Initialize runner and tactic as None in case of no valid tactic or runners are found
611611
best_runner_id, best_tactic, min_time, has_tuning_failure_occured = self._profile_runners(
612612
custom_op, runners, tensors, p, tuning_config, **kwargs)

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def forward_impl(
626626
all_rank_num_tokens_list = [[
627627
val[idx_chunk] for val in all_rank_chunk_size_list
628628
] for idx_chunk in range(num_chunks)]
629-
chunk_size_list = all_rank_chunk_size_list[self.rank]
629+
chunk_size_list = all_rank_chunk_size_list[self.parallel_rank]
630630
else:
631631
all_rank_num_tokens_list = [None] * num_chunks
632632
chunk_size_list = self.split_chunk(x.shape[0], num_chunks)
@@ -685,7 +685,7 @@ def _reducescatter_or_allreduce(x_, idx):
685685
outputs = torch.cat(outputs_list)
686686

687687
if self.use_dp and self.parallel_size > 1:
688-
rank = self.mapping.tp_rank
688+
rank = self.parallel_rank
689689
outputs = outputs[:all_rank_num_tokens[rank]]
690690
return outputs
691691

@@ -714,7 +714,7 @@ def forward_fake(
714714
is_nvfp4_input = isinstance(x, Fp4QuantizedTensor)
715715
data_type = output_dtype if is_nvfp4_input else x.dtype
716716
num_tokens = all_rank_num_tokens[
717-
self.mapping.tp_rank] if all_rank_num_tokens else x.shape[0]
717+
self.parallel_rank] if all_rank_num_tokens else x.shape[0]
718718
hidden_size = x.shape[1] * (2 if is_nvfp4_input else 1)
719719
top_k = self.routing_method.experts_per_token
720720
return x.new_empty((num_tokens, top_k, hidden_size),

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ def forward_impl(
706706
all_rank_num_tokens_list = [[
707707
val[idx_chunk] for val in all_rank_chunk_size_list
708708
] for idx_chunk in range(num_chunks)]
709-
chunk_size_list = all_rank_chunk_size_list[self.rank]
709+
chunk_size_list = all_rank_chunk_size_list[self.parallel_rank]
710710
else:
711711
all_rank_num_tokens_list = [None] * num_chunks
712712
chunk_size_list = self.split_chunk(x.shape[0], num_chunks)
@@ -778,6 +778,6 @@ def _reducescatter_or_allreduce(x_, idx):
778778
outputs = torch.cat(outputs_list)
779779

780780
if self.use_dp and self.parallel_size > 1:
781-
rank = self.mapping.tp_rank
781+
rank = self.parallel_rank
782782
outputs = outputs[:all_rank_num_tokens[rank]]
783783
return outputs

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ def forward_impl(
661661
)
662662

663663
if use_dp_padding:
664-
rank = self.mapping.tp_rank
664+
rank = self.parallel_rank
665665
final_hidden_states = final_hidden_states[:
666666
all_rank_num_tokens[rank]]
667667
return final_hidden_states

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
830830
] for idx_chunk in range(num_chunks)]
831831
all_rank_max_num_tokens_list = split_chunk(all_rank_max_num_tokens,
832832
num_chunks)
833-
chunk_size_list = all_rank_chunk_size_list[self.rank]
833+
chunk_size_list = all_rank_chunk_size_list[self.parallel_rank]
834834
if use_all_to_all:
835835
all_rank_num_tokens_list = [[
836836
1 if val == 0 else val for val in val_list
@@ -918,7 +918,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
918918
self.event_dict[EventType.MoeChunkingOverlap].record()
919919
self.event_dict[EventType.MoeChunkingOverlap].wait()
920920
outputs = torch.cat(outputs_list)
921-
rank = self.mapping.tp_rank
921+
rank = self.parallel_rank
922922
outputs = outputs[:all_rank_num_tokens[rank]]
923923
self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1
924924
return outputs

tensorrt_llm/_torch/modules/fused_moe/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def __init__(
181181

182182
# All ranks participate in allreduce regardless of EP/TP combination
183183
self.mapping = model_config.mapping
184+
self.parallel_rank = self.mapping.tp_rank
184185
self.parallel_size = self.mapping.tp_size
185186
self.intermediate_size_per_partition = intermediate_size // self.tp_size
186187

0 commit comments

Comments
 (0)