@@ -365,9 +365,9 @@ def __init__(self, prefix, config: MixtralConfig, weights):
365365 self .gate = FastLinear .load (config , f"{ prefix } .gate" , weights , bias = False )
366366
367367 # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
368- self .w1 = _load_experts (config , f"{ prefix } .experts" , "w1" , weights ). t ()
368+ self .w1 = _load_experts (config , f"{ prefix } .experts" , "w1" , weights )
369369 self .w2 = _load_experts (config , f"{ prefix } .experts" , "w2" , weights )
370- self .w3 = _load_experts (config , f"{ prefix } .experts" , "w3" , weights ). t ()
370+ self .w3 = _load_experts (config , f"{ prefix } .experts" , "w3" , weights )
371371
372372 self .offsets = None
373373 self .offsets_block_rows = 0
@@ -467,8 +467,7 @@ def indices_and_padded_bins(self, selected_experts: torch.Tensor):
467467
468468 return indices , bin_ids , bins , padded_bins , tokens_per_expert
469469
470- @torch .inference_mode ()
471- def forward (self , x : torch .Tensor ) -> torch .Tensor :
470+ def sparse_forward (self , x : torch .Tensor ) -> torch .Tensor :
472471 """
473472 x: (sequence_length, model_dim)
474473 gate_logits: (sequence_length, n_experts)
@@ -502,8 +501,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
502501 # (top_k * sequence_length + padding, ffn_dim * n_experts)
503502 x = stk .Matrix (
504503 topo .size (),
505- self .act (stk .ops .sdd (x , self .w1 , topo ).data )
506- * stk .ops .sdd (x , self .w3 , topo ).data ,
504+ self .act (stk .ops .sdd (x , self .w1 . t () , topo ).data )
505+ * stk .ops .sdd (x , self .w3 . t () , topo ).data ,
507506 topo .row_indices ,
508507 topo .column_indices ,
509508 topo .offsets ,
@@ -534,6 +533,156 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
534533
535534 return x .view (* input_shape )
536535
536+ def dense_forward (self , x : torch .Tensor ) -> torch .Tensor :
537+ """
538+ x: (sequence_length, model_dim)
539+ gate_logits: (sequence_length, n_experts)
540+ """
541+ # optional reshape
542+ input_shape = x .shape
543+ x = x .view (- 1 , input_shape [- 1 ])
544+
545+ # gate_logits: (sequence_length, n_experts)
546+ gate_logits = self .gate (x )
547+ # all_probs: (sequence_length, n_experts) and upcast for softmax
548+ all_probs = torch .nn .functional .softmax (gate_logits , dim = 1 , dtype = torch .float )
549+
550+ if self .top_k < self .num_experts :
551+ _ , not_selected_experts = torch .topk (
552+ all_probs ,
553+ self .num_experts - self .top_k ,
554+ largest = False ,
555+ sorted = False ,
556+ dim = 1 ,
557+ )
558+ # Mask not selected experts
559+ all_probs .scatter_ (1 , not_selected_experts , 0 )
560+
561+ # Re-normalize
562+ weights = all_probs / all_probs .sum (dim = 1 , keepdim = True )
563+
564+ # Expand to [num_experts, sequence_length, model_dim]
565+ x = x .view (1 , - 1 , input_shape [- 1 ]).expand (self .num_experts , - 1 , input_shape [- 1 ])
566+
567+ # Permute to [num_experts, model_dim, ffn_dim]
568+ w1 = self .w1 .view (self .num_experts , self .ffn_dim , self .hidden_dim ).permute (
569+ 0 , 2 , 1
570+ )
571+ w3 = self .w3 .view (self .num_experts , self .ffn_dim , self .hidden_dim ).permute (
572+ 0 , 2 , 1
573+ )
574+
575+ inter = self .act (torch .bmm (x , w1 )) * torch .bmm (x , w3 )
576+
577+ out = torch .bmm (
578+ inter , self .w2 .view (self .num_experts , self .ffn_dim , self .hidden_dim )
579+ )
580+ # Mask not selected experts
581+ out *= weights .t ().view (self .num_experts , - 1 , 1 )
582+
583+ # Sum experts
584+ out = out .sum (0 )
585+
586+ # Reduce sum
587+ if self .process_group .size () > 1 :
588+ torch .distributed .all_reduce (out , group = self .process_group )
589+
590+ return out
591+
592+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
593+ if len (x ) > 256 :
594+ return self .sparse_forward (x )
595+ # This is faster when there is not a lot of tokens
596+ return self .dense_forward (x )
597+
598+
599+ class DenseMoE (nn .Module ):
600+ def __init__ (self , prefix , config : MixtralConfig , weights ):
601+ super ().__init__ ()
602+ self .hidden_dim = config .hidden_size
603+ self .ffn_dim = config .intermediate_size // weights .process_group .size ()
604+ self .num_experts = config .num_local_experts
605+ self .top_k = config .num_experts_per_tok
606+
607+ act = config .hidden_act
608+ if "gelu" in act :
609+ self .act = lambda x : torch .nn .functional .gelu (
610+ x ,
611+ approximate = "tanh"
612+ if act in ["gelu_fast" , "gelu_pytorch_tanh" ]
613+ else "none" ,
614+ )
615+ elif "silu" in act :
616+ self .act = torch .nn .functional .silu
617+ else :
618+ self .act = ACT2FN [act ]
619+
620+ # gating
621+ self .gate = FastLinear .load (config , f"{ prefix } .gate" , weights , bias = False )
622+
623+ self .w1 = [
624+ TensorParallelColumnLinear .load (
625+ config , prefix = f"{ prefix } .experts.{ i } .w1" , weights = weights , bias = False
626+ )
627+ for i in range (self .num_experts )
628+ ]
629+ self .w3 = [
630+ TensorParallelColumnLinear .load (
631+ config , prefix = f"{ prefix } .experts.{ i } .w3" , weights = weights , bias = False
632+ )
633+ for i in range (self .num_experts )
634+ ]
635+ self .w2 = [
636+ TensorParallelRowLinear .load (
637+ config , prefix = f"{ prefix } .experts.{ i } .w2" , weights = weights , bias = False
638+ )
639+ for i in range (self .num_experts )
640+ ]
641+
642+ self .process_group = weights .process_group
643+
644+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
645+ """
646+ x: (sequence_length, model_dim)
647+ gate_logits: (sequence_length, n_experts)
648+ """
649+ # optional reshape
650+ input_shape = x .shape
651+ x = x .view (- 1 , input_shape [- 1 ])
652+
653+ # gate_logits: (sequence_length, n_experts)
654+ gate_logits = self .gate (x )
655+ # all_probs: (sequence_length, n_experts) and upcast for softmax
656+ all_probs = torch .nn .functional .softmax (gate_logits , dim = 1 , dtype = torch .float )
657+
658+ if self .top_k < self .num_experts :
659+ _ , not_selected_experts = torch .topk (
660+ all_probs ,
661+ self .num_experts - self .top_k ,
662+ largest = False ,
663+ sorted = False ,
664+ dim = 1 ,
665+ )
666+ # Mask not selected experts
667+ all_probs .scatter_ (1 , not_selected_experts , 0 )
668+
669+ # Re-normalize
670+ weights = all_probs / all_probs .sum (dim = 1 , keepdim = True )
671+
672+ # Final output tensor
673+ out = x .new_zeros (x .shape [0 ], self .hidden_dim )
674+ for i in range (self .num_experts ):
675+ h = self .act (self .w1 [i ](x )) * self .w3 [i ](x )
676+ h = self .w2 [i ](h , reduce = False )
677+ # Add expert output to out with masking
678+ out += h * weights [:, i ].view (- 1 , 1 )
679+
680+ # Reduce sum
681+ if self .process_group .size () > 1 :
682+ torch .distributed .all_reduce (out , group = self .process_group )
683+
684+ return out
685+
537686
538687class MixtralLayer (nn .Module ):
539688 def __init__ (self , layer_id , config , weights ):
@@ -543,9 +692,9 @@ def __init__(self, layer_id, config, weights):
543692 self .self_attn = MixtralAttention (
544693 prefix = f"{ prefix } .self_attn" , config = config , weights = weights
545694 )
546- self . block_sparse_moe = BlockSparseMoE (
547- f" { prefix } .block_sparse_moe" , config , weights
548- )
695+
696+ moe_cls = BlockSparseMoE if config . quantize is None else DenseMoE
697+ self . moe = moe_cls ( f" { prefix } .block_sparse_moe" , config , weights )
549698
550699 self .input_layernorm = FastRMSNorm .load (
551700 prefix = f"{ prefix } .input_layernorm" , weights = weights , eps = config .rms_norm_eps
@@ -591,9 +740,9 @@ def forward(
591740 attn_output , res
592741 )
593742
594- block_sparse_moe_output = self .block_sparse_moe (normed_attn_res_output )
743+ moe_output = self .moe (normed_attn_res_output )
595744
596- return block_sparse_moe_output , attn_res
745+ return moe_output , attn_res
597746
598747
599748class MixtralModel (torch .nn .Module ):
@@ -675,8 +824,6 @@ def __init__(self, config, weights):
675824 weights = weights ,
676825 )
677826 self .max_past = config .sliding_window
678- if self .max_past is None :
679- raise ValueError ("max_past cannot be None" )
680827
681828 def forward (
682829 self ,
@@ -695,7 +842,7 @@ def forward(
695842 if prefill_cache_indices is not None :
696843 # Slots also need to be sliced as it has the same size as the whole kv tensor
697844 slots = slots [prefill_cache_indices ]
698- else :
845+ elif self . max_past is not None :
699846 # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
700847 # kernel requires the true values
701848 max_s = min (self .max_past , max_s )
0 commit comments