Skip to content

Commit 09e4576

Browse files
authored
[Kernel] Add non-gated support for NVFP4 CUTLASS MoE (vllm-project#37320)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent 3ed7b1e commit 09e4576

File tree

8 files changed

+53
-26
lines changed

8 files changed

+53
-26
lines changed

csrc/ops.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,8 @@ void get_cutlass_moe_mm_data(
262262
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
263263
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
264264
const int64_t num_experts, const int64_t n, const int64_t k,
265-
const std::optional<torch::Tensor>& blockscale_offsets);
265+
const std::optional<torch::Tensor>& blockscale_offsets,
266+
const bool is_gated);
266267

267268
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
268269
const torch::Tensor& expert_first_token_offset,

csrc/quantization/w8a8/cutlass/moe/moe_data.cu

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
1717
int32_t* problem_sizes2,
1818
int32_t* atomic_buffer,
1919
const int topk_length, const int n,
20-
const int k) {
20+
const int k, const bool is_gated) {
2121
int expert_id = blockIdx.x;
22+
// For gated activations (gate + up), first GEMM output is 2*n.
23+
// For non-gated activations (up only), first GEMM output is n.
24+
int const n1 = is_gated ? 2 * n : n;
2225

2326
int occurrences = 0;
2427
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
@@ -31,13 +34,13 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
3134
int final_occurrences = atomic_buffer[expert_id];
3235
if constexpr (!SWAP_AB) {
3336
problem_sizes1[expert_id * 3] = final_occurrences;
34-
problem_sizes1[expert_id * 3 + 1] = 2 * n;
37+
problem_sizes1[expert_id * 3 + 1] = n1;
3538
problem_sizes1[expert_id * 3 + 2] = k;
3639
problem_sizes2[expert_id * 3] = final_occurrences;
3740
problem_sizes2[expert_id * 3 + 1] = k;
3841
problem_sizes2[expert_id * 3 + 2] = n;
3942
} else {
40-
problem_sizes1[expert_id * 3] = 2 * n;
43+
problem_sizes1[expert_id * 3] = n1;
4144
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
4245
problem_sizes1[expert_id * 3 + 2] = k;
4346
problem_sizes2[expert_id * 3] = k;
@@ -107,13 +110,11 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
107110
}
108111

109112
namespace {
110-
inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
111-
torch::Tensor& problem_sizes1,
112-
torch::Tensor& problem_sizes2,
113-
torch::Tensor& atomic_buffer,
114-
int64_t num_experts, int64_t n,
115-
int64_t k, cudaStream_t stream,
116-
const bool swap_ab) {
113+
inline void launch_compute_problem_sizes(
114+
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
115+
torch::Tensor& problem_sizes2, torch::Tensor& atomic_buffer,
116+
int64_t num_experts, int64_t n, int64_t k, cudaStream_t stream,
117+
const bool swap_ab, const bool is_gated) {
117118
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
118119

119120
auto const* topk_ptr = topk_ids.data_ptr<int32_t>();
@@ -125,7 +126,7 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
125126
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
126127
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
127128
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
128-
static_cast<int>(k));
129+
static_cast<int>(k), is_gated);
129130
});
130131
}
131132
} // namespace
@@ -222,7 +223,8 @@ void get_cutlass_moe_mm_data_caller(
222223
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
223224
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
224225
const int64_t num_experts, const int64_t n, const int64_t k,
225-
const std::optional<torch::Tensor>& blockscale_offsets) {
226+
const std::optional<torch::Tensor>& blockscale_offsets,
227+
const bool is_gated) {
226228
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
227229
auto options_int32 =
228230
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
@@ -236,7 +238,7 @@ void get_cutlass_moe_mm_data_caller(
236238

237239
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
238240
atomic_buffer, num_experts, n, k, stream,
239-
may_swap_ab);
241+
may_swap_ab, is_gated);
240242

241243
if (blockscale_offsets.has_value()) {
242244
// fp4 path

csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ void get_cutlass_moe_mm_data_caller(
7575
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
7676
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
7777
const int64_t num_experts, const int64_t n, const int64_t k,
78-
const std::optional<torch::Tensor>& blockscale_offsets);
78+
const std::optional<torch::Tensor>& blockscale_offsets,
79+
const bool is_gated);
7980

8081
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
8182
const torch::Tensor& expert_first_token_offset,
@@ -278,7 +279,8 @@ void get_cutlass_moe_mm_data(
278279
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
279280
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
280281
const int64_t num_experts, const int64_t n, const int64_t k,
281-
const std::optional<torch::Tensor>& blockscale_offsets) {
282+
const std::optional<torch::Tensor>& blockscale_offsets,
283+
const bool is_gated) {
282284
// This function currently gets compiled only if we have a valid cutlass moe
283285
// mm to run it for.
284286
int32_t version_num = get_sm_version_num();
@@ -288,7 +290,7 @@ void get_cutlass_moe_mm_data(
288290
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
289291
problem_sizes2, input_permutation,
290292
output_permutation, num_experts, n, k,
291-
blockscale_offsets);
293+
blockscale_offsets, is_gated);
292294
return;
293295
#endif
294296
TORCH_CHECK_NOT_IMPLEMENTED(

csrc/torch_bindings.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
489489
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
490490
" Tensor! input_permutation, "
491491
" Tensor! output_permutation, int num_experts, "
492-
" int n, int k, Tensor? blockscale_offsets) -> "
493-
"()");
492+
" int n, int k, Tensor? blockscale_offsets, "
493+
" bool is_gated) -> ()");
494494
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
495495

496496
// compute per-expert problem sizes from expert_first_token_offset
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
model_name: "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4"
2+
accuracy_threshold: 0.29
3+
num_questions: 1319
4+
num_fewshot: 5
5+
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=cutlass"

tests/evals/gsm8k/configs/moe-refactor/config-b200.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ Mixtral-8x7B-BF16-fi-cutlass.yaml
1515
Mixtral-8x7B-BF16-triton.yaml
1616
Nemotron-Nano-30B-Fp8-ModelOpt-fi-trtllm.yaml
1717
Nemotron-Nano-30B-NvFp4-ModelOpt-fi-cutlass.yaml
18+
Nemotron-Nano-30B-NvFp4-ModelOpt-vllm-cutlass.yaml

vllm/_custom_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,7 @@ def get_cutlass_moe_mm_data(
989989
n: int,
990990
k: int,
991991
blockscale_offsets: torch.Tensor | None = None,
992+
is_gated: bool = True,
992993
):
993994
"""
994995
Prepare data necessary to perform CUTLASS grouped matrix multiplications
@@ -1012,6 +1013,8 @@ def get_cutlass_moe_mm_data(
10121013
its computation. The number of block scale rows
10131014
computed with expert E is blockscale_offsets[E + 1] -
10141015
blockscale_offsets[E]
1016+
- is_gated: Whether the activation is gated (gate + up). When True, the
1017+
first GEMM N dimension is 2*n; when False, it is n.
10151018
"""
10161019
return torch.ops._C.get_cutlass_moe_mm_data(
10171020
topk_ids,
@@ -1024,6 +1027,7 @@ def get_cutlass_moe_mm_data(
10241027
n,
10251028
k,
10261029
blockscale_offsets,
1030+
is_gated,
10271031
)
10281032

10291033

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -507,11 +507,12 @@ def run_cutlass_moe_fp4(
507507
# Gemm 1
508508
a: Input tensor: [m, k] (half/bfloat16)
509509
a1_gscale: Activation scale per expert: [e] (float32)
510-
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
511-
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
510+
w1 (not an argument to cutlass_moe_fp4): [e, w1_n, k]
511+
w1_fp4: [e, w1_n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
512+
where w1_n = 2*n for gated activations (gate+up), n for non-gated (up only).
512513
(Note: `n` is the up projection output dim, `k` is the input dim in
513514
full precision)
514-
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
515+
w1_blockscale: [e, w1_n, k // block_size] (float8_e4m3)
515516
(Block size = 16 for NVFP4)
516517
517518
# Gemm 2
@@ -528,6 +529,11 @@ def run_cutlass_moe_fp4(
528529
529530
assumes that topk < k < n to satisfy - up/down projection expectations.
530531
"""
532+
is_gated = activation.is_gated
533+
# For gated activations (e.g. SiLU), w1 output is 2*n (gate + up).
534+
# For non-gated activations (e.g. SiLU_NO_MUL), w1 output is n (up only).
535+
w1_n = n * 2 if is_gated else n
536+
531537
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
532538
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
533539
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
@@ -538,7 +544,7 @@ def run_cutlass_moe_fp4(
538544
and w2_blockscale.ndim == 3
539545
), "All Weights must be of rank 3 for cutlass_moe_fp4"
540546
m_a, k_a = a.shape
541-
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
547+
e_w1, w1_n_actual, half_k_w1 = w1_fp4.shape
542548
e_w2, k_w2, half_n_w2 = w2_fp4.shape
543549

544550
assert e_w1 == e_w2 and e_w1 == e, (
@@ -548,7 +554,7 @@ def run_cutlass_moe_fp4(
548554
assert k_a == half_k_w1 * 2 and k == k_w2, (
549555
"Hidden size mismatch between a, w1 and w2"
550556
)
551-
assert nx2_w1 == n * 2 and half_n_w2 * 2 == n, "mismatch in expected `n`"
557+
assert w1_n_actual == w1_n and half_n_w2 * 2 == n, "mismatch in expected `n`"
552558
assert m == m_a, "input shape mismatch"
553559
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
554560
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
@@ -589,6 +595,7 @@ def run_cutlass_moe_fp4(
589595
n,
590596
k,
591597
blockscale_offsets,
598+
is_gated=is_gated,
592599
)
593600

594601
a = ops.shuffle_rows(a, a_map)
@@ -599,7 +606,7 @@ def run_cutlass_moe_fp4(
599606
blockscale_offsets,
600607
num_topk,
601608
)
602-
c1 = _resize_cache(workspace13, (m * topk, n * 2))
609+
c1 = _resize_cache(workspace13, (m * topk, w1_n))
603610
c2 = _resize_cache(workspace2, (m * topk, n))
604611
c3 = _resize_cache(workspace13, (m * topk, k))
605612
ops.cutlass_fp4_moe_mm(
@@ -681,7 +688,7 @@ def _supports_current_device() -> bool:
681688

682689
@staticmethod
683690
def _supports_no_act_and_mul() -> bool:
684-
return False
691+
return True
685692

686693
@staticmethod
687694
def _supports_quant_scheme(
@@ -695,11 +702,16 @@ def _supports_activation(activation: MoEActivation) -> bool:
695702
# SILU uses a fused silu+mul+fp4_quant kernel path.
696703
# Other gated activations use the generic apply_moe_activation()
697704
# fallback + separate fp4 quantization in run_cutlass_moe_fp4().
705+
# Non-gated activations (_NO_MUL) are also supported for models
706+
# like Nemotron-Nano that don't use gated MLP.
698707
return activation in [
699708
MoEActivation.SILU,
700709
MoEActivation.GELU,
701710
MoEActivation.SWIGLUOAI,
702711
MoEActivation.SWIGLUSTEP,
712+
MoEActivation.SILU_NO_MUL,
713+
MoEActivation.GELU_NO_MUL,
714+
MoEActivation.RELU2_NO_MUL,
703715
]
704716

705717
@staticmethod

0 commit comments

Comments
 (0)