@@ -507,11 +507,12 @@ def run_cutlass_moe_fp4(
507507 # Gemm 1
508508 a: Input tensor: [m, k] (half/bfloat16)
509509 a1_gscale: Activation scale per expert: [e] (float32)
510- w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
511- w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
510+ w1 (not an argument to cutlass_moe_fp4): [e, w1_n, k]
511+ w1_fp4: [e, w1_n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
512+ where w1_n = 2*n for gated activations (gate+up), n for non-gated (up only).
512513 (Note: `n` is the up projection output dim, `k` is the input dim in
513514 full precision)
514- w1_blockscale: [e, 2 * n , k // block_size] (float8_e4m3)
515+ w1_blockscale: [e, w1_n , k // block_size] (float8_e4m3)
515516 (Block size = 16 for NVFP4)
516517
517518 # Gemm 2
@@ -528,6 +529,11 @@ def run_cutlass_moe_fp4(
528529
529530 assumes that topk < k < n to satisfy - up/down projection expectations.
530531 """
532+ is_gated = activation .is_gated
533+ # For gated activations (e.g. SiLU), w1 output is 2*n (gate + up).
534+ # For non-gated activations (e.g. SiLU_NO_MUL), w1 output is n (up only).
535+ w1_n = n * 2 if is_gated else n
536+
531537 assert topk_weights .shape == topk_ids .shape , "topk shape mismatch"
532538 assert w1_fp4 .dtype == torch .uint8 , "weight 1 must be uint8"
533539 assert w2_fp4 .dtype == torch .uint8 , "weight 2 must be uint8"
@@ -538,7 +544,7 @@ def run_cutlass_moe_fp4(
538544 and w2_blockscale .ndim == 3
539545 ), "All Weights must be of rank 3 for cutlass_moe_fp4"
540546 m_a , k_a = a .shape
541- e_w1 , nx2_w1 , half_k_w1 = w1_fp4 .shape
547+ e_w1 , w1_n_actual , half_k_w1 = w1_fp4 .shape
542548 e_w2 , k_w2 , half_n_w2 = w2_fp4 .shape
543549
544550 assert e_w1 == e_w2 and e_w1 == e , (
@@ -548,7 +554,7 @@ def run_cutlass_moe_fp4(
548554 assert k_a == half_k_w1 * 2 and k == k_w2 , (
549555 "Hidden size mismatch between a, w1 and w2"
550556 )
551- assert nx2_w1 == n * 2 and half_n_w2 * 2 == n , "mismatch in expected `n`"
557+ assert w1_n_actual == w1_n and half_n_w2 * 2 == n , "mismatch in expected `n`"
552558 assert m == m_a , "input shape mismatch"
553559 assert 2 * half_k_w1 == k_w2 , "Hidden size mismatch w2 and w1"
554560 assert a .dtype in [torch .half , torch .bfloat16 ], "Invalid input dtype"
@@ -589,6 +595,7 @@ def run_cutlass_moe_fp4(
589595 n ,
590596 k ,
591597 blockscale_offsets ,
598+ is_gated = is_gated ,
592599 )
593600
594601 a = ops .shuffle_rows (a , a_map )
@@ -599,7 +606,7 @@ def run_cutlass_moe_fp4(
599606 blockscale_offsets ,
600607 num_topk ,
601608 )
602- c1 = _resize_cache (workspace13 , (m * topk , n * 2 ))
609+ c1 = _resize_cache (workspace13 , (m * topk , w1_n ))
603610 c2 = _resize_cache (workspace2 , (m * topk , n ))
604611 c3 = _resize_cache (workspace13 , (m * topk , k ))
605612 ops .cutlass_fp4_moe_mm (
@@ -681,7 +688,7 @@ def _supports_current_device() -> bool:
681688
682689 @staticmethod
683690 def _supports_no_act_and_mul () -> bool :
684- return False
691+ return True
685692
686693 @staticmethod
687694 def _supports_quant_scheme (
@@ -695,11 +702,16 @@ def _supports_activation(activation: MoEActivation) -> bool:
695702 # SILU uses a fused silu+mul+fp4_quant kernel path.
696703 # Other gated activations use the generic apply_moe_activation()
697704 # fallback + separate fp4 quantization in run_cutlass_moe_fp4().
705+ # Non-gated activations (_NO_MUL) are also supported for models
706+ # like Nemotron-Nano that don't use gated MLP.
698707 return activation in [
699708 MoEActivation .SILU ,
700709 MoEActivation .GELU ,
701710 MoEActivation .SWIGLUOAI ,
702711 MoEActivation .SWIGLUSTEP ,
712+ MoEActivation .SILU_NO_MUL ,
713+ MoEActivation .GELU_NO_MUL ,
714+ MoEActivation .RELU2_NO_MUL ,
703715 ]
704716
705717 @staticmethod
0 commit comments