1515 FusedMoEParallelConfig ,
1616 RoutingMethodType ,
1717)
18+ from vllm .model_executor .layers .quantization .utils .flashinfer_utils import (
19+ activation_to_flashinfer_int ,
20+ align_fp4_moe_weights_for_fi ,
21+ )
1822from vllm .model_executor .layers .quantization .utils .nvfp4_utils import (
1923 swizzle_blockscale ,
2024)
@@ -50,8 +54,8 @@ def _supports_current_device() -> bool:
5054
5155
5256def _supports_no_act_and_mul () -> bool :
53- """Does not support non-gated MoE (i.e. Nemotron-Nano) ."""
54- return False
57+ """Supports non-gated MoE."""
58+ return True
5559
5660
5761def _supports_quant_scheme (
@@ -66,8 +70,7 @@ def _supports_quant_scheme(
6670
6771
6872def _supports_activation (activation : MoEActivation ) -> bool :
69- """Supports silu activation only."""
70- return activation in [MoEActivation .SILU ]
73+ return activation in [MoEActivation .SILU , MoEActivation .RELU2_NO_MUL ]
7174
7275
7376def _supports_routing_method (
@@ -150,6 +153,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
150153 hidden_size ,
151154 intermediate_size ,
152155 num_experts ,
156+ is_gated_activation : bool ,
153157):
154158 from flashinfer import nvfp4_block_scale_interleave
155159 from flashinfer .fused_moe .core import (
@@ -160,15 +164,18 @@ def prepare_static_weights_for_trtllm_fp4_moe(
160164 _cache_permute_indices : dict [torch .Size , torch .Tensor ] = {}
161165 """Prepare quantized weights for kernel (done offline with weights)."""
162166 epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
167+ gemm1_intermediate_size = (
168+ 2 * intermediate_size if is_gated_activation else intermediate_size
169+ )
163170
164171 # Convert quantized weights to proper formats
165172 gemm1_weights_fp4 = gemm1_weights .view (torch .float8_e4m3fn ).reshape (
166- num_experts , 2 * intermediate_size , hidden_size // 2
173+ num_experts , gemm1_intermediate_size , hidden_size // 2
167174 ) # packed fp4
168175 gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes .view (
169176 torch .float8_e4m3fn
170177 ).reshape (
171- num_experts , 2 * intermediate_size , hidden_size // 16
178+ num_experts , gemm1_intermediate_size , hidden_size // 16
172179 ) # fp8 scaling factors
173180
174181 gemm2_weights_fp4 = gemm2_weights .view (torch .float8_e4m3fn ).reshape (
@@ -191,6 +198,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
191198 _cache_permute_indices ,
192199 gemm1_weights_fp4 [i ].view (torch .uint8 ),
193200 epilogue_tile_m ,
201+ is_gated_act_gemm = is_gated_activation ,
194202 )
195203 gemm1_weights_fp4_shuffled .append (
196204 gemm1_weights_fp4 [i ]
@@ -203,6 +211,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
203211 gemm1_scales_linear_fp4 [i ].view (torch .uint8 ),
204212 epilogue_tile_m ,
205213 num_elts_per_sf = 16 ,
214+ is_gated_act_gemm = is_gated_activation ,
206215 )
207216 gemm1_scales_fp4_shuffled .append (
208217 nvfp4_block_scale_interleave (
@@ -246,7 +255,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
246255 gemm1_scales_fp4_shuffled = (
247256 torch .stack (gemm1_scales_fp4_shuffled )
248257 .view (torch .float8_e4m3fn )
249- .reshape (num_experts , 2 * intermediate_size , hidden_size // 16 )
258+ .reshape (num_experts , gemm1_intermediate_size , hidden_size // 16 )
250259 )
251260
252261 gemm2_weights_fp4_shuffled = torch .stack (gemm2_weights_fp4_shuffled )
@@ -297,10 +306,10 @@ def flashinfer_trtllm_fp4_moe(
297306
298307 from vllm .model_executor .models .llama4 import Llama4MoE
299308
300- # https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2404
301- assert activation == MoEActivation . SILU , (
302- "Only SiLU activation is supported for FlashInfer TRTLLM FP4 MoE. "
303- f"{ activation } found instead."
309+ SUPPORTED_ACTIVATIONS = [ MoEActivation . SILU , MoEActivation . RELU2_NO_MUL ]
310+ assert activation in SUPPORTED_ACTIVATIONS , (
311+ f "Only { SUPPORTED_ACTIVATIONS } activations are supported for FlashInfer "
312+ f"TRTLLM FP4 MoE, { activation } found instead."
304313 )
305314
306315 # Quantize input to FP4
@@ -325,6 +334,9 @@ def flashinfer_trtllm_fp4_moe(
325334 else router_logits
326335 )
327336
337+ # Determine activation type
338+ activation_type = activation_to_flashinfer_int (layer .activation )
339+
328340 # Call TRT-LLM FP4 block-scale MoE kernel
329341 out = flashinfer .fused_moe .trtllm_fp4_block_scale_moe (
330342 routing_logits = router_logits ,
@@ -355,6 +367,7 @@ def flashinfer_trtllm_fp4_moe(
355367 routed_scaling_factor = None ,
356368 routing_method_type = routing_method_type ,
357369 do_finalize = True ,
370+ activation_type = activation_type ,
358371 )[0 ]
359372
360373 return out
@@ -479,10 +492,16 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
479492 ]
480493
481494 # Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels.
482- if is_act_and_mul and backend in [
483- NvFp4MoeBackend .FLASHINFER_CUTLASS ,
484- NvFp4MoeBackend .FLASHINFER_TRTLLM ,
485- ]:
495+ is_gated = layer .activation .is_gated
496+ if (
497+ is_gated
498+ and is_act_and_mul
499+ and backend
500+ in [
501+ NvFp4MoeBackend .FLASHINFER_CUTLASS ,
502+ NvFp4MoeBackend .FLASHINFER_TRTLLM ,
503+ ]
504+ ):
486505 w13 , w13_scale = reorder_w1w3_to_w3w1 (w13 , w13_scale )
487506
488507 # For some FI kernels, the input scales are shared by all experts.
@@ -495,19 +514,32 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
495514
496515 # Shuffle weights and scales for FI TRTLLM NVFP4 MoE kernels.
497516 if backend == NvFp4MoeBackend .FLASHINFER_TRTLLM :
517+ # Align weights for FI NVFP4 MoE kernels.
518+ min_alignment = 16 if is_gated else 128
519+ w13 , w13_scale , w2 , w2_scale , padded_intermediate = (
520+ align_fp4_moe_weights_for_fi (
521+ w13 , w13_scale , w2 , w2_scale , is_act_and_mul , min_alignment
522+ )
523+ )
524+ layer .intermediate_size_per_partition = padded_intermediate
525+
498526 w13 , w13_scale , w2 , w2_scale = prepare_static_weights_for_trtllm_fp4_moe (
499527 w13 ,
500528 w2 ,
501529 w13_scale ,
502530 w2_scale ,
503- w2 .size (- 2 ), # hidden_size
504- w13 .size (- 2 ) // 2 , # intermediate_size
505- w13 .size (0 ), # num_experts
531+ hidden_size = w2 .size (- 2 ),
532+ intermediate_size = w13 .size (- 2 ) // 2 if is_gated else w13 .size (- 2 ),
533+ num_experts = w13 .size (0 ),
534+ is_gated_activation = is_gated ,
506535 )
507536
508537 # We do not need to make this a parameter, because
509538 # it is not used during the weight (re)-loading process.
510- layer .g1_scale_c = a13_scale * w13_scale_2 / a2_scale
539+ if is_gated :
540+ layer .g1_scale_c = a13_scale * w13_scale_2 / a2_scale
541+ else :
542+ layer .g1_scale_c = torch .ones_like (a13_scale ) / a2_scale
511543 layer .a1_gscale = 1.0 / a13_scale
512544 layer .g1_alphas = a13_scale * w13_scale_2
513545 layer .g2_alphas = a2_scale * w2_scale_2
0 commit comments