Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,14 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64};
case CutlassGemmType::Fp8:
if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) {
if (sm == 89 || sm >= 120) {
if (sm == 89) {
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64,
CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64};
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any SM89 GPUs that can support the CtaShape16x256x128_WarpShape16x64x128, or is this an SM120 addition?
Having a small M value like this helps with low latency cases, so I'd want to understand why its not supported before disabling it

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the very least can you leave a comment saying what the difference between the two lists are, so people dont have to manually compare the items

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any SM89 GPUs that can support the CtaShape16x256x128_WarpShape16x64x128, or is this an SM120 addition? Having a small M value like this helps with low latency cases, so I'd want to understand why its not supported before disabling it

This is a removal from the sm89 path. When I tested it on an L40 GPU I got Assertion failed: GPU lacks the shared memory resources to run GroupedGEMM kernel .
It might be that on other sm89 GPUs it will pass, the main issue is that this was the default tactic that was chosen when trying to use FusedMoE, I believe that moving it to be the last will also fix my issue.

Copy link
Contributor Author

@amirkl94 amirkl94 Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried moving this tile config to be the last one and now the default tactic won't fail on l40. The issue is that if autotuner is on then the tactics that use this tile config will report an error with a stacktrace which looks bad.
@yzh119 do you think it'll be ok to change the errors that happen in the autotuner to be debug logs? Otherwise it means users will get spammed with error messages when they run autotuning on L40 FusedMoE.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the autotuner should still output warnings, but just make them say "Skipping tactic x due to error. This tactic may not be supported by the current GPU architecture".
That said I know there is a difference of opinion on whether we should proactively filter them as you have done here, the argument being that we should be able to do the due diligence to determine what tactics are supported so that we can raise an error when a tactic fails when it shouldn't. So I can see either side.

} else if (sm >= 120) {
return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128,
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -688,28 +688,28 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
TLLM_THROW("FP4 data type is not supported on SM < 90");
#endif
} else if (sm_ >= 80 && sm_ < 90) {
#ifdef ENABLE_FP4
if constexpr (!std::is_same_v<WeightType, __nv_fp4_e2m1>) {
if constexpr (use_fp8 || use_w4afp8) {
if constexpr (use_fp8 || use_w4afp8) {
#if defined(ENABLE_FP8)
static_assert(!std::is_same_v<OutputType, __nv_fp8_e4m3> &&
!std::is_same_v<OutputType, __nv_fp8_e5m2>,
"FP8 GEMM Output not supported");
static_assert(
!std::is_same_v<OutputType, __nv_fp8_e4m3> && !std::is_same_v<OutputType, __nv_fp8_e5m2>,
"FP8 GEMM Output not supported");
#endif
TLLM_CHECK_WITH_INFO(sm_ == 89,
"For sm >= 80 and < 90, fp8 is only supported with sm == 89");
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm89, EpilogueTag>(
inputs, multi_processor_count_);
TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89");
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm89, EpilogueTag>(
inputs, multi_processor_count_);
} else {
#ifdef ENABLE_FP4
if constexpr (std::is_same_v<WeightType, __nv_fp4_e2m1>) {
TLLM_THROW("FP4 data type is not supported on SM < 90");
} else {
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(
inputs, multi_processor_count_);
}
} else {
TLLM_THROW("FP4 data type is not supported on SM < 90");
}
#else
TLLM_THROW("FP4 data type is not supported on SM < 90");
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(
inputs, multi_processor_count_);
#endif
}
} else if (sm_ >= 90) {
// For SM120+ FP8 MoE, redirect to SM89 (Ada) FP8 kernel implementations.
if constexpr (use_fp8) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ template <typename T, typename WeightType,
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion Fusion =
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE>
constexpr bool isValidSM120MOESpecialisation() {
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) // TODO Is there a better choice
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) && \
defined(ENABLE_FP4) // TODO Is there a better choice
return cutlass::platform::is_same<T, __nv_fp4_e2m1>::value &&
cutlass::platform::is_same<T, WeightType>::value &&
cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value &&
Expand All @@ -47,7 +48,8 @@ template <typename T, typename WeightType,
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion Fusion =
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE>
constexpr bool isValidBlackwellMOESpecialisation() {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) // TODO Is there a better choice
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \
defined(ENABLE_FP4) // TODO Is there a better choice
Comment on lines +51 to +52
Copy link
Contributor

@coderabbitai coderabbitai bot Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Overly restrictive guard blocks non-FP4 Blackwell specializations.

Requiring ENABLE_FP4 at the top level prevents ALL Blackwell MOE specializations from being valid when FP4 support is not compiled, including non-FP4 cases like FP8/FP8 or BF16/BF16 that should work regardless. Line 53 explicitly supports cutlass::platform::is_same<T, WeightType>::value, which matches any same-type pair.

The Hopper implementation (lines 68-89) uses a better pattern: conditional compilation guards only the FP4-specific parts inside the return statement, allowing non-FP4 specializations to work without ENABLE_FP4.

Apply this diff to follow the Hopper pattern:

-#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \
-    defined(ENABLE_FP4)  // TODO Is there a better choice
+#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
   return (cutlass::platform::is_same<T, WeightType>::value ||
+#ifdef ENABLE_FP4
           (cutlass::platform::is_same<T, __nv_fp8_e4m3>::value &&
-           cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value)) &&
+           cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value) ||
+#else
+          false ||
+#endif
+          false) &&
          cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value &&
          Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h
around lines 51-52, the current top-level guard requires ENABLE_FP4 and blocks
all Blackwell (SM100) MOE specializations when FP4 is not enabled; change this
to only guard FP4-specific logic: keep the CUTLASS_ARCH_MMA_SM100_SUPPORTED
check on the specialization but remove ENABLE_FP4 from the #if, then wrap only
the FP4-specific condition(s) inside the specialization (the return expression)
with #if defined(ENABLE_FP4) / #endif as the Hopper implementation does (lines
~68-89), ensuring the same-type WeightType check remains unguarded so non-FP4
types (FP8/BF16/etc.) compile normally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue here is:

cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value

Which I think isn't defined unless ENABLE_FP4 is defined.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amirkl94, you're absolutely right that __nv_fp4_e2m1 is only defined when ENABLE_FP4 is set! That's why my suggestion included the #ifdef ENABLE_FP4 guard specifically around the line that references that type:

#ifdef ENABLE_FP4
          (cutlass::platform::is_same<T, __nv_fp8_e4m3>::value &&
           cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value) ||
#endif

The key insight is that the first condition in the return statement:

cutlass::platform::is_same<T, WeightType>::value

doesn't reference __nv_fp4_e2m1 at all—it just checks if T and WeightType are the same type (e.g., FP8/FP8, BF16/BF16). This condition should work regardless of whether ENABLE_FP4 is defined.

By moving the top-level ENABLE_FP4 requirement into the return expression (guarding only the FP4-specific logic), non-FP4 Blackwell specializations can compile and work correctly even when FP4 support is disabled. This matches the pattern already used in isValidHopperMOESpecialisation at lines 68-89.

Let me know if you'd like me to refine the suggestion further!

return (cutlass::platform::is_same<T, WeightType>::value ||
(cutlass::platform::is_same<T, __nv_fp8_e4m3>::value &&
cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value)) &&
Expand Down
3 changes: 3 additions & 0 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
gen_cutlass_fused_moe_sm120_module,
gen_cutlass_fused_moe_sm100_module,
gen_cutlass_fused_moe_sm90_module,
gen_cutlass_fused_moe_sm89_module,
gen_trtllm_gen_fused_moe_sm100_module,
)
from ..utils import (
Expand Down Expand Up @@ -285,6 +286,8 @@ def get_cutlass_fused_moe_module(backend: str = "100", use_fast_build: bool = Fa
module = gen_cutlass_fused_moe_sm100_module(use_fast_build).build_and_load()
elif backend == "90":
module = gen_cutlass_fused_moe_sm90_module(use_fast_build).build_and_load()
elif backend == "89":
module = gen_cutlass_fused_moe_sm89_module(use_fast_build).build_and_load()
else:
raise ValueError(f"Invalid backend: {backend}")

Expand Down
4 changes: 4 additions & 0 deletions flashinfer/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def clear_cache_dir():
"-DFLASHINFER_ENABLE_FP8_E8M0",
"-DFLASHINFER_ENABLE_FP4_E2M1",
]
sm89_nvcc_flags = [
"-gencode=arch=compute_89,code=sm_89",
"-DFLASHINFER_ENABLE_FP8_E8M0",
]
sm90a_nvcc_flags = ["-gencode=arch=compute_90a,code=sm_90a"] + common_nvcc_flags
sm100a_nvcc_flags = ["-gencode=arch=compute_100a,code=sm_100a"] + common_nvcc_flags
sm103a_nvcc_flags = ["-gencode=arch=compute_103a,code=sm_103a"] + common_nvcc_flags
Expand Down
18 changes: 17 additions & 1 deletion flashinfer/jit/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@

from . import env as jit_env
from ..artifacts import ArtifactPath, CheckSumHash
from .core import JitSpec, gen_jit_spec, current_compilation_context, sm90a_nvcc_flags
from .core import (
JitSpec,
gen_jit_spec,
current_compilation_context,
sm90a_nvcc_flags,
sm89_nvcc_flags,
)
from .cpp_ext import is_cuda_version_at_least
from .cubin_loader import get_cubin, get_meta_hash
from .gemm.cutlass.generate_kernels import generate_gemm_operations
Expand Down Expand Up @@ -71,6 +77,16 @@ def gen_cutlass_fused_moe_sm90_module(use_fast_build: bool = False) -> JitSpec:
return gen_cutlass_fused_moe_module(nvcc_flags, "90", use_fast_build)


def gen_cutlass_fused_moe_sm89_module(use_fast_build: bool = False) -> JitSpec:
nvcc_flags = sm89_nvcc_flags + [
"-DENABLE_BF16",
"-DENABLE_FP8",
"-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "",
"-DUSING_OSS_CUTLASS_MOE_GEMM",
]
Comment on lines +81 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Conditionally adding an empty string to the list of compiler flags can be slightly confusing and less clean. It's better to conditionally append the flag to the list to avoid adding empty elements. This improves code readability and maintainability.

Suggested change
nvcc_flags = sm89_nvcc_flags + [
"-DENABLE_BF16",
"-DENABLE_FP8",
"-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "",
"-DUSING_OSS_CUTLASS_MOE_GEMM",
]
nvcc_flags = sm89_nvcc_flags + [
"-DENABLE_BF16",
"-DENABLE_FP8",
"-DUSING_OSS_CUTLASS_MOE_GEMM",
]
if is_cuda_version_at_least("12.8"):
nvcc_flags.append("-DENABLE_FP8_BLOCK_SCALE")

return gen_cutlass_fused_moe_module(nvcc_flags, "89", use_fast_build)


def gen_cutlass_fused_moe_module(
nvcc_flags: List[str], device_arch: str, use_fast_build: bool = False
) -> JitSpec:
Expand Down