diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 37614db7c230..021bae7f62c6 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -159,14 +159,12 @@ def __init__(self, config): def forward(self, hidden_states): batch, seq_len, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_dim) - router_logits = self.router(hidden_states).transpose(0, 1) + router_logits = self.router(hidden_states) tokens_per_expert = batch * seq_len - router_top_value, router_indices = torch.topk(router_logits.transpose(0, 1), self.top_k, dim=1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) router_scores = ( - torch.full_like(router_logits.transpose(0, 1), float("-inf")) - .scatter_(1, router_indices, router_top_value) - .transpose(0, 1) + torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1) ) # We do this to make sure we have -inf for non topK tokens before going through the ! # Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this!