Skip to content

Commit f120bd4

Browse files
authored
[Kernel] Support Flashinfer trtllm fused MoE non gated FP8 & NVFP4 (vllm-project#33506)
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
1 parent fac4e96 commit f120bd4

File tree

5 files changed

+197
-45
lines changed

5 files changed

+197
-45
lines changed

tests/kernels/moe/test_flashinfer.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def quant_fp8_per_tensor_batches(a):
7171

7272
for i in range(num_batches):
7373
a_fp8, a_global_sf = input_to_float8(a[i])
74-
a_global_sf = 1.0 / a_global_sf
74+
if a_global_sf.numel() == 1:
75+
a_global_sf = a_global_sf.view(1, 1)
7576
a_quant.append(a_fp8)
7677
a_scales.append(a_global_sf)
7778

@@ -81,6 +82,20 @@ def quant_fp8_per_tensor_batches(a):
8182
return result_a_quant, result_a_scales
8283

8384

85+
def check_accuracy(ref_output, actual_output, atol=0.1, rtol=0.85, percent=0.925):
86+
close = torch.isclose(ref_output, actual_output, atol=atol, rtol=rtol)
87+
match_ratio = close.float().mean()
88+
assert match_ratio >= percent, (
89+
f"Match ratio {match_ratio:.4f} is below the threshold {percent:.4f}"
90+
)
91+
92+
mismatch_percent = 1.0 - match_ratio.item()
93+
assert mismatch_percent <= 1 - percent, (
94+
f"Mismatch percentage {mismatch_percent:.4f} is above the threshold "
95+
f"{1 - percent:.4f}"
96+
)
97+
98+
8499
@dataclass
85100
class TestData:
86101
hidden_states: torch.Tensor
@@ -104,14 +119,16 @@ def make_moe_tensors_8bit(
104119
is_gated = activation.is_gated
105120

106121
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
107-
w13 = torch.randn(
108-
(e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
122+
w13 = (
123+
torch.randn(
124+
(e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
125+
)
126+
/ 10
109127
)
110-
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
128+
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
111129

112130
# Scale to fp8
113131
_, a1_scale = input_to_float8(hidden_states)
114-
a1_scale = 1.0 / a1_scale
115132
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
116133
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
117134
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
@@ -124,14 +141,16 @@ def make_moe_tensors_8bit(
124141
layer.w2_input_scale = a2_scale
125142
layer.w13_weight_scale = w13_weight_scale
126143
layer.w2_weight_scale = w2_weight_scale
144+
layer.activation = activation
127145
# Setup dummy config.
128146
layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel()
129147

130148
# flashinfer expects swapped rows for w13
131-
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
149+
if is_gated:
150+
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
132151
if is_trtllm:
133152
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
134-
layer.w13_weight, layer.w2_weight
153+
layer.w13_weight, layer.w2_weight, is_gated
135154
)
136155
register_scales_for_trtllm_fp8_per_tensor_moe(
137156
layer,
@@ -162,20 +181,24 @@ def make_moe_tensors_8bit(
162181
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
163182
@pytest.mark.parametrize("e", NUM_EXPERTS)
164183
@pytest.mark.parametrize("topk", TOP_KS)
184+
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
165185
def test_flashinfer_per_tensor_moe_fp8_no_graph(
166186
m: int,
167187
n: int,
168188
k: int,
169189
e: int,
170190
topk: int,
191+
activation: MoEActivation,
171192
monkeypatch,
172193
):
173194
if not current_platform.has_device_capability(100):
174195
pytest.skip("Test is only supported for sm >= 100")
175196
set_random_seed(7)
176197
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
177198
with set_current_vllm_config(vllm_config):
178-
td = TestData.make_moe_tensors_8bit(m, k, n, e, is_trtllm=True)
199+
td = TestData.make_moe_tensors_8bit(
200+
m, k, n, e, is_trtllm=True, activation=activation
201+
)
179202

180203
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
181204
topk_weights, topk_ids = Llama4MoE.custom_routing_function(
@@ -200,7 +223,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
200223
topk_weights=topk_weights,
201224
topk_ids=topk_ids,
202225
inplace=False,
203-
activation=MoEActivation.SILU,
226+
activation=activation,
204227
global_num_experts=e,
205228
expert_map=None,
206229
apply_router_weight_on_input=True,
@@ -219,7 +242,13 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
219242
apply_router_weight_on_input=True,
220243
)
221244

222-
torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
245+
check_accuracy(
246+
ref_output=output,
247+
actual_output=flashinfer_output,
248+
atol=0.1,
249+
rtol=0.85,
250+
percent=0.925,
251+
)
223252

224253

225254
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@@ -320,8 +349,13 @@ def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
320349
expert_map=None,
321350
apply_router_weight_on_input=True,
322351
)
323-
torch.testing.assert_close(
324-
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
352+
353+
check_accuracy(
354+
ref_output=output,
355+
actual_output=flashinfer_cutlass_output,
356+
atol=0.1,
357+
rtol=0.85,
358+
percent=0.925,
325359
)
326360

327361

vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def _supports_current_device() -> bool:
3535

3636

3737
def _supports_no_act_and_mul() -> bool:
38-
"""Does not support non-gated MoE (i.e. Nanotron-Mini)."""
39-
return False
38+
"""Supports non-gated MoE."""
39+
return True
4040

4141

4242
def _supports_quant_scheme(
@@ -52,8 +52,7 @@ def _supports_quant_scheme(
5252

5353

5454
def _supports_activation(activation: MoEActivation) -> bool:
55-
"""Supports silu activation only."""
56-
return activation == MoEActivation.SILU
55+
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
5756

5857

5958
def _supports_routing_method(
@@ -74,6 +73,7 @@ def _supports_routing_method(
7473
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
7574
# NOTE(dbari): as above, potentially allow others here.
7675
return routing_method in [
76+
RoutingMethodType.DeepSeekV3,
7777
RoutingMethodType.Llama4,
7878
RoutingMethodType.Renormalize,
7979
RoutingMethodType.RenormalizeNaive,
@@ -291,6 +291,7 @@ def fi_trtllm_fp8_per_tensor_moe(
291291
local_num_experts: int,
292292
use_routing_scales_on_input: bool,
293293
routing_method_type: int,
294+
activation_type: int,
294295
routed_scaling_factor: float = 1.0,
295296
) -> torch.Tensor:
296297
num_expert_group = num_expert_group if num_expert_group is not None else 0
@@ -326,9 +327,9 @@ def fi_trtllm_fp8_per_tensor_moe(
326327
routed_scaling_factor=routed_scaling_factor,
327328
use_routing_scales_on_input=use_routing_scales_on_input,
328329
routing_method_type=routing_method_type,
329-
# TODO: Required for flashinfer==0.6.3, remove with update
330+
# TODO: enum type Required for flashinfer==0.6.3, remove with update
330331
# https://github.com/flashinfer-ai/flashinfer/pull/2508
331-
activation_type=ActivationType.Swiglu,
332+
activation_type=ActivationType(activation_type),
332333
)
333334

334335

@@ -351,6 +352,7 @@ def fi_trtllm_fp8_per_tensor_moe_fake(
351352
local_num_experts: int,
352353
use_routing_scales_on_input: bool,
353354
routing_method_type: int,
355+
activation_type: int,
354356
routed_scaling_factor: float = 1.0,
355357
) -> torch.Tensor:
356358
return torch.empty_like(hidden_states)

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -937,10 +937,11 @@ def apply_monolithic(
937937
)
938938
# TODO(rob): this validation should happen at kernel selection
939939
# time in the oracle rather than here.
940-
assert layer.activation == MoEActivation.SILU, (
941-
f"Expected 'silu' activation but got {layer.activation}"
940+
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
941+
assert layer.activation in SUPPORTED_ACTIVATIONS, (
942+
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
943+
f"TRTLLM FP4 MoE, {layer.activation} found instead."
942944
)
943-
assert not layer.renormalize
944945
return apply_fi_trtllm_fp8_per_tensor_moe(
945946
layer=layer,
946947
hidden_states=x,

vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
FusedMoEParallelConfig,
1616
RoutingMethodType,
1717
)
18+
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
19+
activation_to_flashinfer_int,
20+
align_fp4_moe_weights_for_fi,
21+
)
1822
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
1923
swizzle_blockscale,
2024
)
@@ -50,8 +54,8 @@ def _supports_current_device() -> bool:
5054

5155

5256
def _supports_no_act_and_mul() -> bool:
53-
"""Does not support non-gated MoE (i.e. Nemotron-Nano)."""
54-
return False
57+
"""Supports non-gated MoE."""
58+
return True
5559

5660

5761
def _supports_quant_scheme(
@@ -66,8 +70,7 @@ def _supports_quant_scheme(
6670

6771

6872
def _supports_activation(activation: MoEActivation) -> bool:
69-
"""Supports silu activation only."""
70-
return activation in [MoEActivation.SILU]
73+
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
7174

7275

7376
def _supports_routing_method(
@@ -150,6 +153,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
150153
hidden_size,
151154
intermediate_size,
152155
num_experts,
156+
is_gated_activation: bool,
153157
):
154158
from flashinfer import nvfp4_block_scale_interleave
155159
from flashinfer.fused_moe.core import (
@@ -160,15 +164,18 @@ def prepare_static_weights_for_trtllm_fp4_moe(
160164
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
161165
"""Prepare quantized weights for kernel (done offline with weights)."""
162166
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
167+
gemm1_intermediate_size = (
168+
2 * intermediate_size if is_gated_activation else intermediate_size
169+
)
163170

164171
# Convert quantized weights to proper formats
165172
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
166-
num_experts, 2 * intermediate_size, hidden_size // 2
173+
num_experts, gemm1_intermediate_size, hidden_size // 2
167174
) # packed fp4
168175
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
169176
torch.float8_e4m3fn
170177
).reshape(
171-
num_experts, 2 * intermediate_size, hidden_size // 16
178+
num_experts, gemm1_intermediate_size, hidden_size // 16
172179
) # fp8 scaling factors
173180

174181
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
@@ -191,6 +198,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
191198
_cache_permute_indices,
192199
gemm1_weights_fp4[i].view(torch.uint8),
193200
epilogue_tile_m,
201+
is_gated_act_gemm=is_gated_activation,
194202
)
195203
gemm1_weights_fp4_shuffled.append(
196204
gemm1_weights_fp4[i]
@@ -203,6 +211,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
203211
gemm1_scales_linear_fp4[i].view(torch.uint8),
204212
epilogue_tile_m,
205213
num_elts_per_sf=16,
214+
is_gated_act_gemm=is_gated_activation,
206215
)
207216
gemm1_scales_fp4_shuffled.append(
208217
nvfp4_block_scale_interleave(
@@ -246,7 +255,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
246255
gemm1_scales_fp4_shuffled = (
247256
torch.stack(gemm1_scales_fp4_shuffled)
248257
.view(torch.float8_e4m3fn)
249-
.reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
258+
.reshape(num_experts, gemm1_intermediate_size, hidden_size // 16)
250259
)
251260

252261
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
@@ -297,10 +306,10 @@ def flashinfer_trtllm_fp4_moe(
297306

298307
from vllm.model_executor.models.llama4 import Llama4MoE
299308

300-
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2404
301-
assert activation == MoEActivation.SILU, (
302-
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 MoE. "
303-
f"{activation} found instead."
309+
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
310+
assert activation in SUPPORTED_ACTIVATIONS, (
311+
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
312+
f"TRTLLM FP4 MoE, {activation} found instead."
304313
)
305314

306315
# Quantize input to FP4
@@ -325,6 +334,9 @@ def flashinfer_trtllm_fp4_moe(
325334
else router_logits
326335
)
327336

337+
# Determine activation type
338+
activation_type = activation_to_flashinfer_int(layer.activation)
339+
328340
# Call TRT-LLM FP4 block-scale MoE kernel
329341
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
330342
routing_logits=router_logits,
@@ -355,6 +367,7 @@ def flashinfer_trtllm_fp4_moe(
355367
routed_scaling_factor=None,
356368
routing_method_type=routing_method_type,
357369
do_finalize=True,
370+
activation_type=activation_type,
358371
)[0]
359372

360373
return out
@@ -479,10 +492,16 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
479492
]
480493

481494
# Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels.
482-
if is_act_and_mul and backend in [
483-
NvFp4MoeBackend.FLASHINFER_CUTLASS,
484-
NvFp4MoeBackend.FLASHINFER_TRTLLM,
485-
]:
495+
is_gated = layer.activation.is_gated
496+
if (
497+
is_gated
498+
and is_act_and_mul
499+
and backend
500+
in [
501+
NvFp4MoeBackend.FLASHINFER_CUTLASS,
502+
NvFp4MoeBackend.FLASHINFER_TRTLLM,
503+
]
504+
):
486505
w13, w13_scale = reorder_w1w3_to_w3w1(w13, w13_scale)
487506

488507
# For some FI kernels, the input scales are shared by all experts.
@@ -495,19 +514,32 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
495514

496515
# Shuffle weights and scales for FI TRTLLM NVFP4 MoE kernels.
497516
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
517+
# Align weights for FI NVFP4 MoE kernels.
518+
min_alignment = 16 if is_gated else 128
519+
w13, w13_scale, w2, w2_scale, padded_intermediate = (
520+
align_fp4_moe_weights_for_fi(
521+
w13, w13_scale, w2, w2_scale, is_act_and_mul, min_alignment
522+
)
523+
)
524+
layer.intermediate_size_per_partition = padded_intermediate
525+
498526
w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
499527
w13,
500528
w2,
501529
w13_scale,
502530
w2_scale,
503-
w2.size(-2), # hidden_size
504-
w13.size(-2) // 2, # intermediate_size
505-
w13.size(0), # num_experts
531+
hidden_size=w2.size(-2),
532+
intermediate_size=w13.size(-2) // 2 if is_gated else w13.size(-2),
533+
num_experts=w13.size(0),
534+
is_gated_activation=is_gated,
506535
)
507536

508537
# We do not need to make this a parameter, because
509538
# it is not used during the weight (re)-loading process.
510-
layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale
539+
if is_gated:
540+
layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale
541+
else:
542+
layer.g1_scale_c = torch.ones_like(a13_scale) / a2_scale
511543
layer.a1_gscale = 1.0 / a13_scale
512544
layer.g1_alphas = a13_scale * w13_scale_2
513545
layer.g2_alphas = a2_scale * w2_scale_2

0 commit comments

Comments
 (0)