- 
                Notifications
    You must be signed in to change notification settings 
- Fork 542
Feature: Add support for L40 FusedMoE in cutlass path #1973
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Amir Klein <[email protected]>
| Summary of ChangesHello @amirkl94, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive support for L40 GPUs (SM89 architecture) within the CUTLASS FusedMoE kernels. It resolves critical compilation issues and a runtime crash related to shared memory allocation for a specific GEMM tactic on SM89, ensuring stable and efficient operation. The changes also include updates to the Python JIT compilation infrastructure to enable proper module generation and loading for L40. Highlights
 Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either  
 Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a  Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
 | 
| Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdd explicit SM89 handling across CUTLASS kernel heuristics, MOE GEMM dispatch, and Flashinfer JIT: new SM89 NVCC flags, SM89-specific JIT module generator, tightened FP4 guards, and adjusted FP8/FP4 dispatch behavior for SM80–89 paths. Changes
 Sequence Diagram(s)sequenceDiagram
    actor User
    participant PythonCore as flashinfer/fused_moe/core.py
    participant JITFactory as fused_moe.gen_cutlass_fused_moe_sm89_module
    participant JITCore as flashinfer/jit/core.py
    participant CppDispatch as moe_gemm_template_dispatch.h
    participant CUTLASS as CUTLASS Kernels
    User->>PythonCore: get_cutlass_fused_moe_module(backend="89")
    PythonCore->>JITFactory: build_and_load()
    JITFactory->>JITCore: use sm89_nvcc_flags + BF16/FP8 (+FP8_BLOCK_SCALE?)
    JITFactory->>PythonCore: return JitSpec (SM89)
    PythonCore-->>User: compiled module
    Note over CppDispatch,CUTLASS: SM 80–89 dispatch path
    alt FP8 or FP4 detected
        CppDispatch->>CppDispatch: if sm == 89 (new branch)
        CppDispatch->>CUTLASS: dispatch to SM89 implementation
    else No FP8/FP4
        CppDispatch->>CUTLASS: use SM80/S800 fallback
    end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested reviewers
 Poem
 Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
 ✅ Passed checks (1 passed)
 ✨ Finishing touches
 🧪 Generate unit tests (beta)
 Comment  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for L40 GPUs (sm_89) in the CUTLASS FusedMoE path. The changes include fixing compilation issues, removing a problematic GEMM tactic for sm_89 that was causing crashes, and adding the necessary build configurations for this architecture. The changes are logical and well-implemented. I have one suggestion to improve code clarity when constructing compiler flags.
| nvcc_flags = sm89_nvcc_flags + [ | ||
| "-DENABLE_BF16", | ||
| "-DENABLE_FP8", | ||
| "-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "", | ||
| "-DUSING_OSS_CUTLASS_MOE_GEMM", | ||
| ] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conditionally adding an empty string to the list of compiler flags can be slightly confusing and less clean. It's better to conditionally append the flag to the list to avoid adding empty elements. This improves code readability and maintainability.
| nvcc_flags = sm89_nvcc_flags + [ | |
| "-DENABLE_BF16", | |
| "-DENABLE_FP8", | |
| "-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "", | |
| "-DUSING_OSS_CUTLASS_MOE_GEMM", | |
| ] | |
| nvcc_flags = sm89_nvcc_flags + [ | |
| "-DENABLE_BF16", | |
| "-DENABLE_FP8", | |
| "-DUSING_OSS_CUTLASS_MOE_GEMM", | |
| ] | |
| if is_cuda_version_at_least("12.8"): | |
| nvcc_flags.append("-DENABLE_FP8_BLOCK_SCALE") | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️  Outside diff range comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h (1)
42-42: Fix misleading comment about SM120 vs SM100.The comment mentions
CUTLASS_ARCH_MMA_SM100_SUPPORTEDbut this function checks forCUTLASS_ARCH_MMA_SM120_SUPPORTED(line 35). Update the comment to accurately reflect the macro being checked.Apply this diff:
- return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled + return false; // CUTLASS_ARCH_MMA_SM120_SUPPORTED is set when Blackwell kernels are enabled
🧹 Nitpick comments (2)
flashinfer/jit/fused_moe.py (1)
80-88: SM89 module generation is correctly implemented.The function appropriately:
- Uses
sm89_nvcc_flagswhich excludes FP4 support for L40- Omits Hopper-specific TMA GEMM flags (correct for Ada architecture)
- Includes conditional FP8 block scale support for CUDA ≥12.8
Optional: Consider iterable unpacking for cleaner syntax.
As suggested by Ruff, you could use iterable unpacking instead of concatenation:
- nvcc_flags = sm89_nvcc_flags + [ + nvcc_flags = [ + *sm89_nvcc_flags, "-DENABLE_BF16", "-DENABLE_FP8", "-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "", "-DUSING_OSS_CUTLASS_MOE_GEMM", ]This is a minor style improvement and can be deferred.
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)
696-709: LGTM: Correctly reorganizes SM89 dispatch to avoid shared memory issues.The reorganized control flow properly addresses the L40 (SM89) issue by:
- Routing FP8 workloads to Sm89 kernels with runtime validation (line 703)
- Routing non-FP8 workloads to Sm80 kernels (lines 707-708)
This aligns with the kernel implementation in
moe_cutlass_kernel.hwhich shows SM89 architecture reusing Sm80 kernels for non-FP8 types, and prevents the "GPU lacks the shared memory resources" assertion mentioned in the PR objectives.Optional suggestion: Consider adding a brief comment explaining why non-FP8 on SM89 uses the Sm80 path, to help future maintainers understand the shared memory constraint that motivated this design.
Apply this diff to add a clarifying comment:
} else { + // Non-FP8 workloads on SM89 (L40) reuse Sm80 kernels to avoid + // Sm89-specific tactics that exceed shared memory limits dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>( inputs, multi_processor_count_); }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
- csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp(1 hunks)
- csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h(1 hunks)
- csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h(2 hunks)
- flashinfer/fused_moe/core.py(2 hunks)
- flashinfer/jit/core.py(1 hunks)
- flashinfer/jit/fused_moe.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
flashinfer/fused_moe/core.py (2)
flashinfer/jit/fused_moe.py (1)
gen_cutlass_fused_moe_sm89_module(80-87)flashinfer/jit/core.py (1)
build_and_load(272-284)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h (6)
__nv_fp8_e5m2(91-93)
cutlass(114-116)
cutlass(120-122)
cutlass(127-129)
cutlass(132-134)
cutlass(140-142)csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h (1)
cutlass(40-677)
flashinfer/jit/fused_moe.py (2)
flashinfer/jit/core.py (2)
JitSpec(185-284)
gen_jit_spec(287-353)flashinfer/jit/cpp_ext.py (1)
is_cuda_version_at_least(86-87)
🪛 Ruff (0.14.1)
flashinfer/jit/fused_moe.py
81-86: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
🔇 Additional comments (6)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (1)
161-168: LGTM! Targeted fix for L40 shared memory constraints.The separation of SM89 from SM >= 120 handling correctly removes
CtaShape16x256x128_WarpShape16x64x128for L40 GPUs. This config's larger K dimension (128 vs 64) would exceed L40's shared memory capacity in GROUPED_GEMM mode, causing the assertion failure mentioned in the PR. The fix is minimal, well-scoped, and maintains full functionality for newer architectures.flashinfer/jit/core.py (1)
93-96: SM89 flags correctly omit FP4 support.The implementation appropriately excludes
-DFLASHINFER_ENABLE_FP4_E2M1for SM89 (L40), unlike other architectures that usecommon_nvcc_flags. This aligns with the PR objective of fixing compilation issues for L40 by removing problematic GEMM tactics.flashinfer/fused_moe/core.py (2)
40-40: LGTM!Import added correctly to support SM89 module generation.
289-290: LGTM!Backend dispatch for SM89 follows the established pattern and correctly invokes the new module generation function.
flashinfer/jit/fused_moe.py (1)
21-27: LGTM!Import correctly updated to include
sm89_nvcc_flagsneeded for the new SM89 module generation function.csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)
692-694: LGTM: Proper FP4 validation for SM89.The explicit compile-time check correctly prevents FP4 usage on SM < 90, consistent with the earlier validation for SM75-80 at lines 685-688.
        
          
                .../nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
                .../nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h
              
                Outdated
          
            Show resolved
            Hide resolved
        
      Signed-off-by: Amir Klein <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
- csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h (1)
35-36: LGTM: Guard is appropriate for FP4-only SM120 specialization.The
ENABLE_FP4guard is correct here. SinceisValidSM120MOESpecialisationonly supports__nv_fp4_e2m1types (lines 37-40), requiringENABLE_FP4at the preprocessor level prevents compilation errors when FP4 support is not available.
| #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \ | ||
| defined(ENABLE_FP4) // TODO Is there a better choice | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overly restrictive guard blocks non-FP4 Blackwell specializations.
Requiring ENABLE_FP4 at the top level prevents ALL Blackwell MOE specializations from being valid when FP4 support is not compiled, including non-FP4 cases like FP8/FP8 or BF16/BF16 that should work regardless. Line 53 explicitly supports cutlass::platform::is_same<T, WeightType>::value, which matches any same-type pair.
The Hopper implementation (lines 68-89) uses a better pattern: conditional compilation guards only the FP4-specific parts inside the return statement, allowing non-FP4 specializations to work without ENABLE_FP4.
Apply this diff to follow the Hopper pattern:
-#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \
-    defined(ENABLE_FP4)  // TODO Is there a better choice
+#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
   return (cutlass::platform::is_same<T, WeightType>::value ||
+#ifdef ENABLE_FP4
           (cutlass::platform::is_same<T, __nv_fp8_e4m3>::value &&
-           cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value)) &&
+           cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value) ||
+#else
+          false ||
+#endif
+          false) &&
          cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value &&
          Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h
around lines 51-52, the current top-level guard requires ENABLE_FP4 and blocks
all Blackwell (SM100) MOE specializations when FP4 is not enabled; change this
to only guard FP4-specific logic: keep the CUTLASS_ARCH_MMA_SM100_SUPPORTED
check on the specialization but remove ENABLE_FP4 from the #if, then wrap only
the FP4-specific condition(s) inside the specialization (the return expression)
with #if defined(ENABLE_FP4) / #endif as the Hopper implementation does (lines
~68-89), ensuring the same-type WeightType check remains unguarded so non-FP4
types (FP8/BF16/etc.) compile normally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, mainly just want to understand why we need to disable the tile shape.
Coderabbit's comment about more granular Fp4 guard might make sense, but I assume that if we are compiling with blackwell support we should also have FP4
| CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, | ||
| CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, | ||
| CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, | ||
| CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any SM89 GPUs that can support the CtaShape16x256x128_WarpShape16x64x128, or is this an SM120 addition?
Having a small M value like this helps with low latency cases, so I'd want to understand why its not supported before disabling it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the very least can you leave a comment saying what the difference between the two lists are, so people dont have to manually compare the items
| There are some compilation errors on CI, @amirkl94 would you mind taking a look? | 
📌 Description
Fixed a few compilation issues for L40, and removed 1 gemm tactic for
sm == 89that crashes due to:🧪 Tests
Ran
pytest tests/moe/test_trtllm_cutlass_fused_moe.pymanually on an L40 GPU and verified all tests passed.Summary by CodeRabbit
New Features
Bug Fixes / Compatibility
Performance