From 77712781b5b5fb88e29340abfbfcda443a97ee37 Mon Sep 17 00:00:00 2001 From: Pavel Belevich Date: Sat, 12 Apr 2025 13:05:26 -0400 Subject: [PATCH 1/2] Llama4: remove redundant transpose of router_logits --- src/transformers/models/llama4/modeling_llama4.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 37614db7c230..6eea4454dc30 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -159,12 +159,12 @@ def __init__(self, config): def forward(self, hidden_states): batch, seq_len, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_dim) - router_logits = self.router(hidden_states).transpose(0, 1) + router_logits = self.router(hidden_states) tokens_per_expert = batch * seq_len - router_top_value, router_indices = torch.topk(router_logits.transpose(0, 1), self.top_k, dim=1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) router_scores = ( - torch.full_like(router_logits.transpose(0, 1), float("-inf")) + torch.full_like(router_logits, float("-inf")) .scatter_(1, router_indices, router_top_value) .transpose(0, 1) ) From b935ef02a5f1875d007964344e5ebb1f8c7f5041 Mon Sep 17 00:00:00 2001 From: Pavel Belevich Date: Sat, 12 Apr 2025 15:10:02 -0400 Subject: [PATCH 2/2] Fix formatting --- src/transformers/models/llama4/modeling_llama4.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 6eea4454dc30..021bae7f62c6 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -164,9 +164,7 @@ def forward(self, hidden_states): router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) router_scores = ( - torch.full_like(router_logits, float("-inf")) - .scatter_(1, router_indices, router_top_value) - .transpose(0, 1) + torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1) ) # We do this to make sure we have -inf for non topK tokens before going through the ! # Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this!