Skip to content

Conversation

@aleozlx
Copy link
Collaborator

@aleozlx aleozlx commented Oct 13, 2025

📌 Description

Upgraded cutlass moe kernel launcher to match trtllm. This facilitates further feature upgrades (perhaps separate PRs)

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

This currently compiles. still need testing but may be sufficient to unblock other work
once tested ok i'll remove draft status

Summary by CodeRabbit

  • New Features

    • Expanded mixed-precision MOE support (SM90/SM100/SM103/SM120), finalize-epilogue fusion, per-expert activation scaling swizzles, and optional A/B layout swap.
  • Improvements

    • More dynamic GEMM/tile configurations (dynamic cluster shapes), richer profiling and workspace orchestration, broader launcher dispatch paths, improved runtime tracing/logging, and safer default handling for new runtime flags.
  • Bug Fixes

    • Suppressed compiler warnings, tightened runtime validations, and updated test gating for device capabilities.

@aleozlx aleozlx self-assigned this Oct 13, 2025
@yzh119
Copy link
Collaborator

yzh119 commented Oct 14, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !81 has been created, and the CI pipeline #36551350 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #36551350: 1/17 passed

@yongwww yongwww self-assigned this Oct 15, 2025
@yongwww
Copy link
Collaborator

yongwww commented Oct 16, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !81 has been updated with latest changes, and the CI pipeline #36760904 is currently running. I'll report back once the pipeline job completes.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 17, 2025

Walkthrough

Threads per-expert swizzled activation scaling, AB-layout swap, and finalize-epilogue fusion throughout MOE Cutlass paths; adds SM90 epilogue scatter visitor, dynamic tile/cluster enums (SM100/SM103/SM120), many launcher/dispatcher signature/template changes, OSS namespace variants, and multiple explicit template instantiations.

Changes

Cohort / File(s) Summary
Fused MOE instantiations
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu
Added explicit CutlassMoeFCRunner template instantiation <__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_fp8_e4m3> and removed a header comment/semicolon formatting.
Fused MOE kernel core
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh
Expanded API to support swizzled_input_sf, AB swap_ab, finalize-epilogue fusion, padded/unpadded column handling; updated kernels, launchers, macros, runner runMoe, stride computations, and profiler plumbing.
Bindings & runner call sites
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
Added default stubs for new params and threaded swizzled_input_sf, unpadded_hidden_size, and use_lora through runMoe/min-latency variants and profiler init; consolidated tactics lists.
Public utilities
csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h
Added alias template <bool VALUE> using ConstBool = ConstExprWrapper<bool, VALUE>;.
SM90 epilogue scatter
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp
New SM90 epilogue scatter/fusion types: Sm90ScatterPtrArray, ScaledAcc per-row/col variants, composite scatter-fusion pointer arrays, FusionCallbacks specializations, workspace/Arguments/Params types.
Gemm config & shapes refactor
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h
Refactored tile/cluster enums to shape-based IDs; added shape_tuple_to_enum/enum_to_shape_tuple, SM100/SM120 enums, TileShape/ClusterShape, EpilogueFusionType, swap_ab, dynamic/fallback cluster shapes, and richer introspection/name functions.
Gather / cute qualification
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp
Removed top-level using; qualified cute:: symbols, adjusted local using directives and namespace blocks.
Heuristics & traits
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h, .../moe_tma_warp_specialized_traits.h
Added DYNAMIC_CGA template flag to shape-filter predicate and expanded FP4/FP8-aware MOE specialization checks with ENABLE_FP4 gating.
Namespace migration → _oss
multiple files under csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/... (e.g., fpA_intB_gemm/*, moe_gemm/*, launchers/*, moe_gemm_template_dispatch*.h, tma_ws_*)
Public namespace renamed to tensorrt_llm::kernels::cutlass_kernels_oss; dispatcher/launcher references updated and using-declarations added where necessary.
MOE kernel interface & types
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h, .../moe_kernels.h
Reworked grouped-gemm input types (e.g., ptr_act/ptr_weight), fixed layout/stride aliases, added swap_ab, use_reduction, increased workspace buffers (17→20), introduced MoeGemmId, and changed public signatures (getTactics, runMoe, gemm2) to accept new params like swizzled_input_sf and unpadded_hidden_size.
Workspace/input construction
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_input.cu
Increased workspaceBuffers to 20; remapped workspace pointer layout and buffer assignments; changed setFinalizeFusionParams signature and finalize-related fields (new ordering and use_reduction).
Launcher signatures & templates
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/*, moe_gemm_tma_ws_launcher.h
Expanded template parameters (EpilogueSchedule, DYNAMIC_CGA, SwapAB), added dynamic/fallback cluster shape arguments, additional bias/output/token metadata parameters, and added logging in several launchers.
Dispatch / TMA selection
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch*.h, moe_gemm_template_dispatch_tma_ws.h
Added getDispatchFunctionForSM100, integrated finalize-fusion and swap_ab into config generation and selection, handled SM100/SM103 special cases, dynamic-cluster selection, and routed launchers via OSS namespace variants.
Template instantiations & license/include cleanups
multiple moe_gemm_kernels_*.cu (fp16/fp32/bf16/fp8/uint4/uint8 variants)
Replaced license headers with Apache-2.0, simplified include paths to local moe_gemm_template_dispatch.h, and added explicit MoeGemmRunner template instantiations for several type combinations.
Kernel generation (JIT)
flashinfer/jit/gemm/cutlass/generate_kernels.py
Added dynamic_cga and swap_ab constructor params; threaded flags into generated instantiations; added generate_sm103_operations and extended SM100 generation to accept arch.
Tests
tests/moe/test_trtllm_cutlass_fused_moe.py
Extended MXFP8xMXFP4 test skip condition to include SM120.

Sequence Diagram(s)

sequenceDiagram
    participant App as Application
    participant Runner as CutlassMoeFCRunner
    participant Profiler as GemmProfilerBackend
    participant Dispatch as Dispatch Layer
    participant Kernel as CUDA Kernel

    App->>Runner: runMoe(..., swizzled_input_sf, unpadded_hidden_size, swap_ab, ...)
    activate Runner

    alt profiling
        Runner->>Profiler: prepare(..., epilogue_fusion, swap_ab, unpadded_hidden_size_profiler)
        Profiler-->>Runner: profiled inputs & workspaces
    end

    alt swizzled_input_sf == true
        Runner->>Runner: expand inputs using swizzled SF layout
    else
        Runner->>Runner: expand inputs using linear SF layout
    end

    alt epilogue_fusion == FINALIZE
        Runner->>Dispatch: computeStridesTmaWarpSpecialized(router_scales, permuted_row_map, ...)
        Dispatch-->>Runner: strides & finalize params
    end

    alt swap_ab == true
        Runner->>Runner: swap act/weight pointers and strides
    end

    Runner->>Dispatch: select & launch tile/cluster (dynamic_cga considered)
    Dispatch->>Kernel: launch GEMM kernel (possibly fused finalize epilogue)
    Kernel-->>App: results
    deactivate Runner
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Poem

🐰
I hopped through namespaces, tidy and neat,
Swizzled scales and swapped A/B feet.
Finalize fused, shapes learn to roam,
SM90 scatter finds a new home.
Apache winds hum — the kernels feel light!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 9.09% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The pull request title "chore: upgrade cutlass moe kernel launcher to match trtllm" directly and specifically describes the primary change in the changeset. It clearly identifies what is being upgraded (cutlass moe kernel launcher) and why (to match trtllm), which aligns with the extensive file modifications summarized in the raw summary across multiple cutlass-related files. The title is concise, single-sentence, and provides sufficient context for a developer scanning commit history to understand the main objective without being vague or generic.
Description Check ✅ Passed The pull request description follows the repository template structure and is mostly complete. The "📌 Description" section provides clear, specific information about upgrading the cutlass MOE kernel launcher to match trtllm and notes it facilitates further feature upgrades. The pre-commit checklist items are all marked as completed, and the "🧪 Tests" section is partially completed with an explanation in the "Reviewer Notes" section indicating testing is pending but the code currently compiles. The "🔍 Related Issues" section is left empty, which is a non-critical omission when there are no related issues to link. Overall, the description provides substantive information and context rather than being vague or generic.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #36760904: 1/17 passed

@yongwww
Copy link
Collaborator

yongwww commented Oct 17, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !81 has been updated with latest changes, and the CI pipeline #36801731 is currently running. I'll report back once the pipeline job completes.

@yongwww
Copy link
Collaborator

yongwww commented Oct 17, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !81 has been updated with latest changes, and the CI pipeline #36827105 is currently running. I'll report back once the pipeline job completes.

@yongwww yongwww marked this pull request as ready for review October 18, 2025 00:41
@yongwww
Copy link
Collaborator

yongwww commented Oct 18, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request upgrades the cutlass moe kernel launcher to match trtllm, facilitating further feature upgrades. The changes include modifications to copyright years, adding a new template class, and adjusting kernel configurations and function calls for improved performance and functionality. The review focuses on correctness and maintainability, ensuring the changes align with the project's coding standards and improve the overall code quality.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 13

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (9)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h (1)

17-26: Add include guards and necessary headers for self-contained compilation.

The header currently lacks #pragma once, include guards, and required includes for int64_t and cudaStream_t. It depends on transitive includes from consumers (e.g., moe_gemm_template_dispatch.h includes <cuda.h> before including this header), making it fragile. Add the following:

+#pragma once
+#include <cstdint>
+#if !defined(__CUDACC_RTC__)
+#include <cuda_runtime_api.h>
+#else
+using cudaStream_t = void*;
+#endif
+
+#ifdef __cplusplus
 namespace tensorrt_llm::kernels::cutlass_kernels_oss {
 template <typename ElementType_, typename CutlassWeightType_, int MaxTileM_, int TileN_, int TileK_,
           int Stages_, typename EpilogueTag>
 void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B,
                                                 ElementType_ const* biases, bool bias_is_broadcast,
                                                 ElementType_* C,
                                                 int64_t const* total_tokens_including_expert,
                                                 int64_t num_rows, int64_t gemm_n, int64_t gemm_k,
                                                 int num_experts, int multi_processor_count,
                                                 cudaStream_t stream, int* kernel_occupancy);
 }
+#endif

Optional: use int32_t instead of bare int for API stability (line with num_experts, num_experts, int multi_processor_count).

csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)

546-582: Same default parameter concerns apply here.

The min-latency path has the same hardcoded defaults and assumptions as the regular runMoe path. The same validation and documentation suggestions from the earlier comment apply here.

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (3)

4040-4065: Uninitialized byte value used for MXFP8 scaling memset when E8M0 helper is unavailable.

weight_block_scale_value_int is only assigned under FLASHINFER_ENABLE_FP8_E8M0 && CUDA>=12.8; otherwise it is uninitialized and passed to cudaMemsetAsync, resulting in undefined fill values.

Initialize a safe default encoding for 1.0f, or guard the path. Suggested minimal fix (1.0f encodes to 0x7F for e8m0; adjust if your helper provides a constant):

-      TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF weight_block_scale_value_int{};
-#if defined(FLASHINFER_ENABLE_FP8_E8M0) && CUDART_VERSION >= 12080
+      TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF weight_block_scale_value_int =
+          static_cast<TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF>(0x7F); // e8m0(1.0f)
+#if defined(FLASHINFER_ENABLE_FP8_E8M0) && CUDART_VERSION >= 12080
       __nv_fp8_e8m0 tmp;
       tmp.__x = __nv_cvt_float_to_e8m0(1.0f, __NV_SATFINITE, cudaRoundPosInf);
       std::memcpy(&weight_block_scale_value_int, &tmp, sizeof(tmp));
 #endif

If you prefer stricter behavior, conditionally TLLM_CHECK for the feature instead of defaulting.


1006-1054: Guard potential nullptr from input SF lookup and avoid const_cast.

When input_sf is provided, cvt_quant_get_sf_out_offset may still return nullptr; dereferencing without a check risks UB. Also, passing const_cast to an API expecting non-const is fragile.

Add a nullptr check and provide a const overload or wrap the pointer without const_cast:

-      auto const sf_in =
-          cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
-                                      NumThreadsPerSF>(
-              std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */,
-              num_cols / VecSize,
-              const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf),
-              QuantizationSFLayout::SWIZZLED_128x4);
-      *sf_out = *sf_in;
+      auto const sf_in =
+          cvt_quant_get_sf_out_offset<const TmaWarpSpecializedGroupedGemmInput::ElementSF,
+                                      NumThreadsPerSF>(
+              std::nullopt, source_token_id, elem_idx, std::nullopt, num_cols / VecSize,
+              input_sf, QuantizationSFLayout::SWIZZLED_128x4);
+      if (sf_in) { *sf_out = *sf_in; }

Apply similar handling in the LINEAR branch.


2918-2926: Confirmed: Missing parameter and incorrect argument order at line 2920.

The function signature requires 17 parameters with padded_cols, actual_cols, and experts_per_token in positions 10–12, but the line 2920 call site passes only 16 arguments with hidden_size, k, and num_experts_per_node at those positions. The call omits unpadded_hidden_size entirely. The corrected call sites at lines 3297 and 3304 demonstrate the proper argument order with all 17 parameters.

-  finalizeMoeRoutingKernelLauncher<OutputType, UnfusedGemmOutputType>(
-      static_cast<UnfusedGemmOutputType const*>(gemm_output), final_output, fc2_expert_biases,
-      unpermuted_final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row,
-      token_selected_experts, expert_first_token_offset, num_rows, hidden_size, k,
-      num_experts_per_node, parallelism_config, enable_alltoall, enable_pdl, stream);
+  finalizeMoeRoutingKernelLauncher<OutputType, UnfusedGemmOutputType>(
+      static_cast<UnfusedGemmOutputType const*>(gemm_output), final_output, fc2_expert_biases,
+      unpermuted_final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row,
+      token_selected_experts, expert_first_token_offset, num_rows, hidden_size,
+      unpadded_hidden_size, k, num_experts_per_node, parallelism_config, enable_alltoall,
+      enable_pdl, stream);
tests/moe/test_trtllm_cutlass_fused_moe.py (1)

1096-1100: Update skip reason to match devices list.

Guard allows SM100/110/120, but reason says “only supported on SM100 and SM110”. Add SM120 for consistency.

-    reason="MXFP8xMXFP4 is only supported on SM100 and SM110",
+    reason="MXFP8xMXFP4 is only supported on SM100, SM110 and SM120",
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h (1)

112-117: Missing <type_traits> include for std::is_same_v

std::is_same_v is used but <type_traits> isn’t included in this header, risking compile errors depending on include order.

Apply:

 #include "cutlass_extensions/epilogue_helpers.h"
+#include <type_traits>

Alternatively, use cutlass::platform::is_same consistently.

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h (1)

24-37: Remove stale namespace import in implementation file

The implementation in fpA_intB_launcher_sm90.inl (line 49) contains using namespace tensorrt_llm::kernels::cutlass_kernels; which imports the old namespace inside the new cutlass_kernels_oss namespace. This line should be removed since:

  • The function is correctly declared and implemented within cutlass_kernels_oss
  • All callers are also in cutlass_kernels_oss and resolve correctly via unqualified lookup
  • The stale using directive creates unnecessary coupling to the old namespace and violates the separation intent of the rename

Fix location: csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl line 49 — delete the using namespace tensorrt_llm::kernels::cutlass_kernels; line.

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)

868-889: Contradictory FP8 fallback: check forbids it, code dispatches it.

TLLM_CHECK_WITH_INFO(!use_fp8, …) guarantees a throw, yet the branch below still dispatches FP8 to Sm89. Pick one behavior.

Two options:

  • Disallow FP8 fallback (remove dead dispatch):
-      if constexpr (use_fp8) {
-        cutlass_kernels_oss::dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType,
-                                                      cutlass::arch::Sm89, EpilogueTag>(
-            inputs, multi_processor_count_);
-      } else {
+      {
         cutlass_kernels_oss::dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType,
                                                       cutlass::arch::Sm80, EpilogueTag>(
             inputs, multi_processor_count_);
       }
  • Or allow FP8 fallback (drop the check):
-      TLLM_CHECK_WITH_INFO(!use_fp8, "No fallback FP8 implementation available");
+      // FP8 fallback to SM89 is allowed here.

Please confirm intended behavior.

♻️ Duplicate comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h (1)

29-35: Namespace rename ripple effects are broad

This file now exposes sm90_* dispatchers under cutlass_kernels_oss. Ensure all indirect users (e.g., moe_gemm_template_dispatch_tma_ws.h) and tests reference the new namespace. Avoid duplicate comments; see launcher header note for verification script.

🧹 Nitpick comments (27)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp4.cu (1)

1-15: Consider updating copyright year to reflect 2024-2025.

The Apache License header shows copyright years as 2020-2023, which may be outdated for a 2025 PR. While this is a minor concern, it's worth updating to reflect the current active development period if this PR makes material changes to the file.

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h (1)

35-36: Consider adding documentation for DYNAMIC_CGA parameter.

The filtering logic correctly includes the DYNAMIC_CGA flag, which will filter shapes when dynamic cluster shapes are enabled during FAST_BUILD. However, there's no inline documentation explaining what DYNAMIC_CGA represents or when it should be set to true.

Consider adding a brief comment above the struct definition:

+// DYNAMIC_CGA: Set to true when using dynamic cluster group array shapes.
+// When enabled, shapes will be filtered during FAST_BUILD to reduce compile time.
 template <class ArchTag, class TileShape, class ClusterShape, bool DYNAMIC_CGA,
           class ActivationType>
 struct should_filter_tma_warp_specialized_gemm_problem_shape {
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h (4)

59-60: Remove extra blank line for consistency.

The blank line after the debug logging statement is inconsistent with other logging additions throughout the file (e.g., lines 235, 273, 427, 479, 502, 537). For consistency, debug logging should not be followed by blank lines.

Apply this diff:

 TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
-
 #ifdef ENABLE_BF16

316-317: Remove extra blank line for consistency.

Similar to line 60, this blank line after the debug logging statement is inconsistent with other logging additions throughout the file.

Apply this diff:

 TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
-
 // Don't instantiate configs that are not supported pre-hopper. Produce a sensible error instead.

409-411: Consider the performance impact of logging in destructor.

Logging in the destructor adds overhead to every object destruction. While useful for debugging, consider whether this logging should be conditional (e.g., only in debug builds) or if it might impact performance in production, especially if many instances are created and destroyed.


516-517: Remove extra blank line for consistency.

This is the third instance of an inconsistent blank line after debug logging (also at lines 60 and 317). For consistency, remove this blank line to match the style used elsewhere in the file.

Apply this diff:

 TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
-
 if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY) {
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp (2)

288-317: Add divisibility asserts for vectorization/layout compatibility.

VecSize is derived from alignment and per-thread values, but the chosen major dimension may not be divisible by VecSize, leading to malformed tilers. Add compile-time checks.

@@
     int constexpr NumThreads = CUTE_STATIC_V(size(args.tiled_copy));
     int constexpr NumValTile = product(take<0,2>(shape(cD_epi)));
     int constexpr MaxVecSize = cute::min(AlignmentOutput, NumValTile / NumThreads);
+    if constexpr (cutlass::gemm::detail::is_k_major<StrideOutput>()) {
+      static_assert((size<1>(args.epi_tile) % MaxVecSize) == 0, "EPI_N must be divisible by vector size.");
+    } else if constexpr (cutlass::gemm::detail::is_mn_major<StrideOutput>()) {
+      static_assert((size<0>(args.epi_tile) % MaxVecSize) == 0, "EPI_M must be divisible by vector size.");
+    }

Optionally, relax using a computed VecSize that divides the corresponding dimension (fallback to smaller factor) if compile-time constraints are too strict. Do you want a patch to auto-derive such VecSize?


51-52: Avoid using namespace detail; in a public header.

Namespace pollution in headers can cause ADL surprises and ODR issues. Prefer explicit qualification or targeted using-declarations.

-using namespace cute;
-using namespace detail;
+using namespace cute;
// Remove `using namespace detail;` and fully qualify uses (e.g., cutlass::gemm::detail::is_major)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (2)

75-76: Fix typo in comment.

"supress" should be "suppress".

-    return nvinfer1::DataType::kFLOAT;  // supress compiler warning
+    return nvinfer1::DataType::kFLOAT;  // suppress compiler warning

116-117: Fix typo in comment.

Same typo as Line 76.

-    return nullptr;  // supress compiler warning
+    return nullptr;  // suppress compiler warning
flashinfer/jit/gemm/cutlass/generate_kernels.py (1)

382-384: Clarify cluster shape usage comment.

The comment states "We use a runtime cluster shape for SM100, so we only use cluster shapes to distinguish between 1SM and 2SM variants." This is potentially confusing - the cluster shape is used for variant selection, not runtime configuration. Consider rewording to make this clearer, e.g., "The cga_shape parameter selects between 1SM and 2SM kernel variants; actual cluster shapes are determined at runtime."

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu (1)

20-41: Avoid duplicate explicit instantiations across TUs

If the same runner variants are instantiated in other .cu/.cc files, link-time ODR errors will occur. Consider using extern template declarations in headers and centralizing definitions in a single TU.

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl (4)

64-65: Make debug log portable.

PRETTY_FUNCTION is GCC/Clang-specific. Use func to avoid MSVC breaks.

-  TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
+  TLLM_LOG_DEBUG(__func__);

197-199: Fix error text clarity.

Minor but user-facing: “must a multiple” → “must be a multiple”.

-        std::string err_msg = "The group size must a multiple of " + std::to_string(cta_shape_k);
+        std::string err_msg = "The group size must be a multiple of " + std::to_string(cta_shape_k);

253-259: Unify error reporting; avoid std::cout in error path.

Use project logger or throw with message; printing to stdout is noisy in libraries.

-      std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " +
-                            std::string(cutlassGetStatusString(can_implement));
-      std::cout << err_msg << std::endl;
-      throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg);
+      std::string err_msg = std::string("[TensorRT LLM Error][fpA_intB Runner] ") +
+                            "fpA_intB cutlass kernel will fail for params. Error: " +
+                            std::string(cutlassGetStatusString(can_implement));
+      TLLM_LOG_ERROR(err_msg.c_str());
+      throw std::runtime_error(err_msg);

48-50: Limit using-namespace to needed symbols.

Pulling the entire cutlass_kernels namespace into cutlass_kernels_oss risks ADL/ODR collisions.

Prefer targeted using-declarations for the specific adapters/types needed in this TU.

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h (1)

74-75: Debug logs: acceptable, but consider noise.

TLLM_LOG_DEBUG at hot entry points can be noisy at scale. If needed, wrap behind a verbose flag.

Also applies to: 126-127, 160-161

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h (1)

19-22: Guard CUDA runtime include for non-CUDA toolchains.

Mirror the portability guard to avoid SA/host-only build breaks.

-#include <cuda_runtime_api.h>
+#if defined(__has_include)
+#  if __has_include(<cuda_runtime_api.h>)
+#    include <cuda_runtime_api.h>
+#  else
+struct CUstream_st;
+using cudaStream_t = CUstream_st*;
+#  endif
+#else
+#  include <cuda_runtime_api.h>
+#endif
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl (3)

73-74: Make debug log portable.

Replace PRETTY_FUNCTION with func.

-  TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
+  TLLM_LOG_DEBUG(__func__);

222-226: Unify error reporting; avoid std::cout.

Use logger then throw to keep stderr consistent.

-    std::string err_msg = "mixed dtype WS grouped cutlass kernel will fail for params. Error: " +
-                          std::string(cutlassGetStatusString(can_implement));
-    std::cout << err_msg << std::endl;
-    throw std::runtime_error("[Mixed dtype WS grouped GEMM] " + err_msg);
+    std::string err_msg = std::string("[Mixed dtype WS grouped GEMM] ") +
+                          "cutlass kernel will fail for params. Error: " +
+                          std::string(cutlassGetStatusString(can_implement));
+    TLLM_LOG_ERROR(err_msg.c_str());
+    throw std::runtime_error(err_msg);

187-189: Device id hardcoded to 0.

On multi-GPU systems this can misreport hardware info. Initialize from current device.

-  hw_info.device_id = 0;
+  int dev = 0;
+  (void)cudaGetDevice(&dev);
+  hw_info.device_id = dev;
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu (2)

51-73: Stabilize workspace index mapping (avoid future drift).

Hard-coded indices in configureWorkspace must stay in lockstep with workspaceBuffers’ array order. Introduce named indices to prevent brittle coupling.

Apply along these lines:

+// Keep these enums in sync with workspaceBuffers() return order
+enum class WSIdx : int {
+  ProblemShape = 0,
+  StrideAct,
+  StrideWeight,
+  StrideC,
+  StrideD,
+  PtrAct,
+  PtrWeight,
+  PtrC,
+  PtrD,
+  AlphaScales,
+  SfAct,
+  SfWeight,
+  StrideSfAct,
+  StrideSfWeight,
+  Int4ProblemShape,
+  Int4SFA,
+  Int4StrideSFA,
+  FinalizeBias,
+  FinalizeRouterScales,
+  FinalizeSourceTokenIndex
+};
...
-  shape_info.problem_shapes = reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(pointers[0]);
+  shape_info.problem_shapes = reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(
+      pointers[static_cast<int>(WSIdx::ProblemShape)]);
...
-  ptr_act = reinterpret_cast<void const**>(pointers[5]);
+  ptr_act = reinterpret_cast<void const**>(pointers[static_cast<int>(WSIdx::PtrAct)]);
...
-  fused_finalize_epilogue.ptr_bias = reinterpret_cast<void const**>(pointers[17]);
-  fused_finalize_epilogue.ptr_router_scales = reinterpret_cast<float const**>(pointers[18]);
-  fused_finalize_epilogue.ptr_source_token_index = reinterpret_cast<int const**>(pointers[19]);
+  fused_finalize_epilogue.ptr_bias = reinterpret_cast<void const**>(
+      pointers[static_cast<int>(WSIdx::FinalizeBias)]);
+  fused_finalize_epilogue.ptr_router_scales = reinterpret_cast<float const**>(
+      pointers[static_cast<int>(WSIdx::FinalizeRouterScales)]);
+  fused_finalize_epilogue.ptr_source_token_index = reinterpret_cast<int const**>(
+      pointers[static_cast<int>(WSIdx::FinalizeSourceTokenIndex)]);

Also applies to: 94-125


171-176: Correct log label: this is generic FpX, not FP4-only.

Rename “Fp4 Block Scaling Factors …” to “FpX Block Scaling Factors …” to avoid confusion when MXFPX is used.

-    ss << "Fp4 Block Scaling Factors Act: " 
+    ss << "FpX Block Scaling Factors Act: "
...
-    ss << "Fp4 Block Scaling Factors Weight: " 
+    ss << "FpX Block Scaling Factors Weight: "
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (1)

52-52: Remove unused include.

or isn’t used here; drop it to speed up compile and reduce surface.

-#include <mutex>
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)

372-372: Fix misleading namespace close comment.

The brace closes tensorrt_llm::kernels::cutlass_kernels_oss, not the root namespace.

-}  // namespace tensorrt_llm
+}  // namespace tensorrt_llm::kernels::cutlass_kernels_oss
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h (2)

394-418: Improve readability: print names, not raw ints.

Use the existing helpers to render human‑readable schedules and shapes.

-             << "\n\tmainloop sched: " << (int)mainloop_schedule
-             << "\n\tepi sched: " << (int)epilogue_schedule
+             << "\n\tmainloop sched: " << get_mainloop_schedule_name(mainloop_schedule)
+             << "\n\tepi sched: " << (epilogue_schedule == EpilogueScheduleType::AUTO ? "auto"
+                                  : epilogue_schedule == EpilogueScheduleType::NO_SMEM ? "no_smem"
+                                  : "tma")

And keep getTileConfigAsName()/get_cluster_shape_name you already added.


435-459: Use names in stream operator for configs too.

Replace integer dumps with the name helpers to simplify telemetry parsing.

- out << "tile_config_sm90_enum: " << config.getTileConfigAsInt()
+ out << "tile_config: " << config.getTileConfigAsName()
...
-    << ", cluster_shape_enum: " << int(config.cluster_shape)
-    << ", dynamic_cluster_shape_enum: " << int(config.dynamic_cluster_shape)
-    << ", fallback_cluster_shape_enum: " << int(config.fallback_cluster_shape)
+    << ", cluster_shape: " << get_cluster_shape_name(config.cluster_shape)
+    << ", dynamic_cluster_shape: " << get_cluster_shape_name(config.dynamic_cluster_shape)
+    << ", fallback_cluster_shape: " << get_cluster_shape_name(config.fallback_cluster_shape)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 80bdea5 and eddb10b.

📒 Files selected for processing (40)
  • csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu (2 hunks)
  • csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (54 hunks)
  • csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (8 hunks)
  • csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h (1 hunks)
  • csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp (1 hunks)
  • csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h (11 hunks)
  • csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp (5 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h (17 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h (9 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h (2 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl (5 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (7 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h (22 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl (2 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl (4 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp8.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp4_fp4.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp4.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_uint4.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (14 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (11 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h (7 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu (4 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h (5 hunks)
  • flashinfer/jit/gemm/cutlass/generate_kernels.py (21 hunks)
  • tests/moe/test_trtllm_cutlass_fused_moe.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (17)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h (1)
  • ActivationType (22-31)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h (1)
  • cutlass_kernels_oss (29-67)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h (1)
  • cutlass_kernels_oss (25-37)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (3)
  • hidden_size (251-251)
  • num_experts (249-249)
  • top_k (256-256)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • enable_pdl (220-220)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h (3)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h (1)
  • cutlass_kernels_oss (24-37)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h (1)
  • cutlass_kernels_oss (25-37)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (1)
  • std (81-95)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h (5)
  • cutlass (114-116)
  • cutlass (120-122)
  • cutlass (127-129)
  • cutlass (132-134)
  • cutlass (140-142)
csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h (2)
  • __nv_fp8_e4m3 (204-206)
  • __nv_fp8_e4m3 (220-220)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • tensorrt_llm (38-172)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu (1)
include/flashinfer/trtllm/fused_moe/runner.h (2)
  • num_experts (249-249)
  • hidden_size (251-251)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (3)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (3)
  • tensorrt_llm (38-172)
  • EpilogueFusion (176-333)
  • FpXBlockScalingType (190-235)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h (2)
  • tensorrt_llm (27-119)
  • isValidBlackwellMOESpecialisation (54-68)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h (1)
  • tensorrt_llm (23-36)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (2)
  • enable_pdl (220-220)
  • EpilogueFusion (176-333)
csrc/nv_internal/cpp/kernels/quantization.cu (2)
  • void (256-280)
  • void (282-300)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (4)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (4)
  • tensorrt_llm (38-172)
  • multi_processor_count_ (327-327)
  • EpilogueFusion (176-333)
  • FpXBlockScalingType (190-235)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h (9)
  • tensorrt_llm (33-150)
  • kernels (34-149)
  • cutlass (114-116)
  • cutlass (120-122)
  • cutlass (127-129)
  • cutlass (132-134)
  • cutlass (140-142)
  • cutlass_kernels (35-148)
  • __nv_fp8_e5m2 (91-93)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (2)
  • get_candidate_configs (525-576)
  • get_candidate_configs (525-527)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h (1)
  • EpilogueScheduleType (197-433)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu (6)
  • workspaceBuffers (26-73)
  • workspaceBuffers (26-27)
  • configureWorkspace (81-128)
  • configureWorkspace (81-84)
  • setFinalizeFusionParams (130-146)
  • setFinalizeFusionParams (130-133)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (11)
  • std (81-95)
  • if (135-139)
  • if (142-146)
  • if (149-153)
  • if (156-160)
  • if (321-321)
  • if (323-323)
  • if (325-325)
  • if (426-439)
  • if (441-463)
  • if (463-476)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h (4)
  • string (261-268)
  • if (277-279)
  • if (279-281)
  • if (281-283)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • tensorrt_llm (38-172)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h (2)
  • TileShape (205-392)
  • ClusterShape (246-373)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (11)
  • std (81-95)
  • if (135-139)
  • if (142-146)
  • if (149-153)
  • if (156-160)
  • if (321-321)
  • if (323-323)
  • if (325-325)
  • if (426-439)
  • if (441-463)
  • if (463-476)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu (2)
  • toString (148-181)
  • toString (148-148)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • tensorrt_llm (38-172)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (2)
  • tensorrt_llm (70-227)
  • tensorrt_llm (526-533)
flashinfer/jit/gemm/cutlass/generate_kernels.py (1)
flashinfer/jit/gemm/cutlass/cutlass_library.py (4)
  • EpilogueScheduleType (710-726)
  • DataType (73-120)
  • KernelScheduleType (519-587)
  • GemmKind (1008-1021)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • tensorrt_llm (38-172)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h (3)
  • tensorrt_llm (33-150)
  • kernels (34-149)
  • cutlass_kernels (35-148)
🪛 Clang (14.0.6)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h

[error] 17-17: 'cuda_runtime_api.h' file not found

(clang-diagnostic-error)

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h

[error] 18-18: 'cstdint' file not found

(clang-diagnostic-error)

csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp

[error] 38-38: 'cutlass/cutlass.h' file not found

(clang-diagnostic-error)

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h

[error] 17-17: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 17-17: expected ';' after top level declarator

(clang-diagnostic-error)


[error] 17-17: expected identifier or '('

(clang-diagnostic-error)

Comment on lines +3948 to 3950
TLLM_THROW("Min latency mode is no longer supported");
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Min-latency hard-disabled but still invoked.

computeStridesTmaWarpSpecializedLowLatency now throws, yet setupTmaWarpSpecializedInputs still calls it when min_latency_mode is true, breaking that path.

Either remove the min-latency path or route min-latency to the non-low-latency stride builder (fusion NONE, swap_ab on). Example fix below updates setupTmaWarpSpecializedInputs to use computeStridesTmaWarpSpecialized for min-latency; then this throw can remain (unused).

@@
-    return Self::computeStridesTmaWarpSpecializedLowLatency(
-        gemm1_tma_ws_input, gemm2_tma_ws_input, num_rows, fc1_out_size, hidden_size, hidden_size,
-        inter_size, num_experts_per_node, reinterpret_cast<T const*>(gemm1_input),
-        reinterpret_cast<T const*>(gemm2_input), fc1_expert_weights, fc2_expert_weights,
-        quant_params.fp8.dequant_fc1, quant_params.fp8.dequant_fc2, input_sf, fc2_fp4_act_scale_,
-        quant_params, nullptr, nullptr, reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output),
-        reinterpret_cast<UnfusedGemmOutputType*>(fc2_result_),
-        min_latency_params.num_active_experts_per_node, min_latency_params.active_expert_global_ids,
-        start_expert, enable_pdl, stream);
+    // Use regular stride builder with fusion NONE and swap_ab already set; finalize fusion not used.
+    return Self::computeStridesTmaWarpSpecialized(
+        expert_first_token_offset_, gemm1_tma_ws_input, gemm2_tma_ws_input,
+        /*num_tokens*/ num_rows, /*expanded_num_tokens*/ num_rows,  // min-latency packs per expert
+        /*gemm1_n*/ fc1_out_size, /*gemm1_k*/ hidden_size,
+        /*gemm2_n*/ hidden_size,  /*gemm2_k*/ inter_size,
+        num_experts_per_node,
+        reinterpret_cast<T const*>(gemm1_input), reinterpret_cast<T const*>(gemm2_input),
+        fc1_expert_weights, fc2_expert_weights,
+        quant_params.fp8.dequant_fc1, quant_params.fp8.dequant_fc2,
+        input_sf, fc2_fp4_act_scale_, quant_params,
+        /*bias1*/ nullptr, /*bias2*/ nullptr,
+        reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output),
+        reinterpret_cast<UnfusedGemmOutputType*>(fc2_result_),
+        /*router_scales*/ nullptr, /*permuted_row_to_unpermuted_row*/ nullptr,
+        enable_pdl, stream);

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +226 to +230
// Get tactics for both GEMM1 and GEMM2, combine them
auto gemm1_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1);
auto gemm2_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_2);
mAllProfiles = gemm1_tactics;
mAllProfiles.insert(mAllProfiles.end(), gemm2_tactics.begin(), gemm2_tactics.end());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify tactic selection logic for GEMM1 vs GEMM2.

The tactics from GEMM1 and GEMM2 are now concatenated into a single mAllProfiles vector. Ensure that when selecting tactics by profile_id (e.g., at Line 618), the correct tactics are applied to the corresponding GEMM. The setRunnerProfiles method expects 2 profile IDs, so verify the indexing logic correctly distinguishes between GEMM1 and GEMM2 tactics.


🏁 Script executed:

#!/bin/bash
# Verify how profile_id is used to select tactics and ensure GEMM1/GEMM2 separation

# Check setRunnerProfiles implementation and how it maps profile IDs
rg -n -A 10 -B 3 'setRunnerProfiles' --type cpp

# Look for getTactics usage and how tactics are distinguished
rg -n -A 5 'getTactics.*GEMM' --type cpp

Length of output: 0


🏁 Script executed:

#!/bin/bash
# First, verify the file exists and examine the context around the modified lines
echo "=== Examining the target file around lines 226-230 ==="
head -n 240 csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu | tail -n 30

echo ""
echo "=== Search for setRunnerProfiles (broader search) ==="
rg -n 'setRunnerProfiles' csrc/fused_moe/cutlass_backend/

echo ""
echo "=== Search for how mAllProfiles is used ==="
rg -n -C 5 'mAllProfiles' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Search for profile_id usage around line 618 ==="
sed -n '610,630p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Search for getTactics calls ==="
rg -n 'getTactics' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

Length of output: 806


🏁 Script executed:

#!/bin/bash
# Examine the target file and trace tactic selection logic

echo "=== Examining lines around 226-230 (mAllProfiles concatenation) ==="
sed -n '220,240p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Search for mAllProfiles usage throughout the file ==="
rg -n 'mAllProfiles' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Examining setRunnerProfiles method (lines 757-778 mentioned in scratchpad) ==="
sed -n '750,790p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Search for getTactics calls and GEMM references ==="
rg -n 'getTactics|GEMM_1|GEMM_2' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Examine profile_id usage pattern ==="
sed -n '610,630p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

Length of output: 5633


🏁 Script executed:

#!/bin/bash
# Examine the context around line 618 and how profile is used
echo "=== Full context around line 618 (profile selection) ==="
sed -n '600,650p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Search for where setRunnerProfiles is called ==="
rg -n 'setRunnerProfiles' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu -A 2 -B 2

echo ""
echo "=== Search for mKernelRunner->setTactic calls ==="
rg -n 'setTactic' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu -A 2 -B 2

echo ""
echo "=== Check runMoe signature and how it's called ==="
rg -n 'void runMoe' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu -A 8

Length of output: 5456


🏁 Script executed:

#!/bin/bash
# Search for where profile_ids is generated/populated and if there's any offset logic
echo "=== Search for profile_ids origin and computation ==="
rg -n 'profile_ids' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu | head -30

echo ""
echo "=== Search for getNumProfiles or similar methods ==="
rg -n 'getNumProfiles|getNumTactics|size()' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Check class member getters and how tactics are exposed ==="
sed -n '580,595p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Look for any comments or logic about profile indexing/offset ==="
rg -n -B 3 -A 3 'mAllProfiles\[' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

Length of output: 12101


Apply offset correction when indexing GEMM2 tactics from concatenated mAllProfiles.

The tactics from GEMM1 and GEMM2 are concatenated into a single mAllProfiles vector (lines 229–230), but when selecting tactics by profile_ids in setRunnerProfiles (lines 772–775), the GEMM2 profile index is not offset. This causes GEMM2 to incorrectly select from GEMM1's tactics range.

After concatenation, GEMM2 tactics start at index gemm1_tactics.size(), but this offset is not applied when accessing mAllProfiles.at(profile_ids.value()[1]).

Fix: Apply offset for GEMM2 tactic selection:

best_gemm2_profile = profile_ids.value()[1] == -1 ? best_gemm2_profile
                                                  : mAllProfiles.at(profile_ids.value()[1] + gemm1_size);

where gemm1_size = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1).size() (stored during initialization).

🤖 Prompt for AI Agents
In csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
around lines 226–230 (and referencing setRunnerProfiles around lines 772–775),
the concatenation of GEMM1 and GEMM2 tactics into mAllProfiles does not account
for the GEMM1 length when later indexing GEMM2 tactics, causing wrong
selections; fix by capturing and storing gemm1_size during initialization (e.g.,
gemm1_size = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1).size()) and
then when selecting GEMM2 use profile index + gemm1_size (apply this offset
where mAllProfiles.at(profile_ids.value()[1]) is used), keeping the -1 sentinel
handling unchanged so that if profile_ids.value()[1] == -1 you keep the default,
otherwise index mAllProfiles at profile_ids.value()[1] + gemm1_size.

Comment on lines +372 to +375
// HACK Define default values for parameters we don't have good values for
bool const swizzled_input_sf = false; // Assume input_sf is not swizzled by default
int64_t const unpadded_hidden_size = hidden_size; // Assume no padding by default
bool const use_lora = false; // No lora support yet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Document assumptions for default parameter values.

The hardcoded defaults for swizzled_input_sf, unpadded_hidden_size, and use_lora may not hold in all scenarios. Particularly, setting unpadded_hidden_size = hidden_size assumes no padding, which may be incorrect if padding occurs upstream. Consider:

  1. Validating these assumptions with assertions when these features are actually used
  2. Adding TODO comments linking to future work that will provide proper values
  3. Documenting these limitations in user-facing documentation
🤖 Prompt for AI Agents
In csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
around lines 372-375 the code hardcodes defaults for swizzled_input_sf,
unpadded_hidden_size and use_lora which may be incorrect in some configurations;
update the code to (1) add a TODO comment with a link to the tracking issue/PR
for providing proper values, (2) add runtime assertions or checks where these
flags/values are actually used (e.g., assert unpadded_hidden_size <= hidden_size
and verify swizzled_input_sf only when the input layout indicates swizzling),
and (3) if possible derive unpadded_hidden_size from upstream tensor metadata or
add it as an explicit parameter to the caller API and fall back to the current
default only with a clear warning log; also add a brief note in the module’s
user-facing docs describing this limitation and the expectation until full
support is implemented.

Comment on lines 642 to 657
int64_t const unpadded_hidden_size_profiler = hidden_size; // HACK no padding by default
#ifdef USING_OSS_CUTLASS_MOE_GEMM
mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile,
DtypeUtils::dataType(activation_dtype), DtypeUtils::dataType(mWeightDtype),
DtypeUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k),
hidden_size, inter_size, group_size, ActivationType::Swiglu, USE_BIAS,
USE_LORA, min_latency_mode,
hidden_size, unpadded_hidden_size_profiler, inter_size, group_size,
ActivationType::Swiglu, USE_BIAS, USE_LORA, min_latency_mode,
/*need_weights*/ false, parallelism_config, enable_alltoall);
#else
mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile,
DtypeUtils::dataType(activation_dtype), DtypeUtils::dataType(mWeightDtype),
DtypeUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k),
hidden_size, inter_size, group_size, ActivationType::Swiglu, USE_BIAS,
USE_LORA, min_latency_mode,
hidden_size, unpadded_hidden_size_profiler, inter_size, group_size,
ActivationType::Swiglu, USE_BIAS, USE_LORA, min_latency_mode,
/*need_weights*/ false, parallelism_config);
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Consistent default parameter pattern in profiler.

The profiler also uses the hardcoded unpadded_hidden_size default. This is consistent with the runtime paths but shares the same concerns about correctness. If padding is actually present, profiling may select suboptimal tactics.

🤖 Prompt for AI Agents
In csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
around lines 642 to 657, the code passes a hardcoded
unpadded_hidden_size_profiler equal to hidden_size into mProfiler->init, which
forces the profiler to assume no padding; replace this with the actual unpadded
hidden size used at runtime (compute or reuse the existing unpadded_hidden_size
variable that accounts for padding) and pass that variable into both the OSS and
non-OSS init calls so the profiler sees the true unpadded dimension and can pick
appropriate tactics.

#include "tensorrt_llm/common/cudaUtils.h"

namespace tensorrt_llm::kernels::cutlass_kernels {
namespace tensorrt_llm::kernels::cutlass_kernels_oss {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify that all callers have been updated to the new cutlass_kernels_oss namespace.

The namespace change from tensorrt_llm::kernels::cutlass_kernels to tensorrt_llm::kernels::cutlass_kernels_oss is a breaking API change. Ensure all usage sites across the codebase have been updated to reference the new namespace.

Run the following script to verify namespace consistency:

Also applies to: 98-98


🏁 Script executed:

#!/bin/bash
# Description: Verify all references to the MOE launcher use the new cutlass_kernels_oss namespace

echo "Checking for lingering references to old cutlass_kernels namespace..."

# Search for old namespace references (excluding this file and commented code)
rg -t cpp -t cu -t h "cutlass_kernels::" \
  --glob "!*/fused_moe_gemm_launcher_sm80.inl" \
  -C2

echo -e "\nSearching for new cutlass_kernels_oss namespace usage..."

# Find all references to the new namespace
rg -t cpp -t cu -t h "cutlass_kernels_oss::" -C2

echo -e "\nSearching for sm80_generic_fused_moe_gemm_kernelLauncher calls..."

# Find all call sites of this launcher function
rg -t cpp -t cu "sm80_generic_fused_moe_gemm_kernelLauncher" -C3

Length of output: 275


🏁 Script executed:

#!/bin/bash
# Description: Verify namespace migration with improved file discovery

echo "=== Finding all C++/CUDA files in cutlass_kernels directory ==="
fd -e cpp -e cu -e h -e inl "cutlass_kernels" | head -20

echo -e "\n=== Searching for old namespace (cutlass_kernels::) in all C++/CUDA files ==="
rg "cutlass_kernels::" --type-add 'cuda:*.cu' --type-add 'cuda:*.cuh' -C2 | head -40

echo -e "\n=== Searching for new namespace (cutlass_kernels_oss::) ==="
rg "cutlass_kernels_oss::" -C2 | head -40

echo -e "\n=== Searching for sm80_generic_fused_moe_gemm_kernelLauncher function calls ==="
rg "sm80_generic_fused_moe_gemm_kernelLauncher" -C3 | head -40

echo -e "\n=== Checking the target file for namespace declaration ==="
fd -e inl "fused_moe_gemm_launcher_sm80.inl" -x cat {}

Length of output: 21869


Namespace migration is incomplete; related files still reference the old namespace.

The namespace change to cutlass_kernels_oss is not fully propagated. Files in the new namespace still contain references to the old tensorrt_llm::kernels::cutlass_kernels:: namespace:

  • moe_gemm_template_dispatch_tma_ws_mixed_dtype.h: uses tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput and TmaWarpSpecializedGroupedGemmInput
  • moe_gemm_template_dispatch_tma_ws.h: uses tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput, isValidHopperMOESpecialisation, isValidBlackwellMOESpecialisation, and TllmToCutlassTypeAdapter

Update all type and function references in these files to use the new namespace path.

🤖 Prompt for AI Agents
In
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl
(around line 27) and the related files
moe_gemm_template_dispatch_tma_ws_mixed_dtype.h and
moe_gemm_template_dispatch_tma_ws.h, references to the old namespace
tensorrt_llm::kernels::cutlass_kernels:: remain and must be updated to the
migrated namespace tensorrt_llm::kernels::cutlass_kernels_oss::; search for all
uses of GroupedGemmInput, TmaWarpSpecializedGroupedGemmInput,
isValidHopperMOESpecialisation, isValidBlackwellMOESpecialisation,
TllmToCutlassTypeAdapter and any other symbols qualified with the old namespace
and replace their qualification with
tensorrt_llm::kernels::cutlass_kernels_oss:: (or add appropriate using
declarations in the header scope), ensuring includes and forward declarations
still match the new namespace so compilation resolves correctly.

Comment on lines 17 to 21
#include <cuda_runtime_api.h>

#include "../../include/moe_gemm_kernels.h"
#include "cutlass_extensions/gemm_configs.h"
#include "cutlass_extensions/weight_only_quant_op.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard CUDA runtime include for non-CUDA toolchains.

Static analysis reports: 'cuda_runtime_api.h' not found. Provide a fallback forward-declaration.

-#include <cuda_runtime_api.h>
+#if defined(__has_include)
+#  if __has_include(<cuda_runtime_api.h>)
+#    include <cuda_runtime_api.h>
+#  else
+struct CUstream_st;
+using cudaStream_t = CUstream_st*;
+#  endif
+#else
+#  include <cuda_runtime_api.h>
+#endif
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#include <cuda_runtime_api.h>
#include "../../include/moe_gemm_kernels.h"
#include "cutlass_extensions/gemm_configs.h"
#include "cutlass_extensions/weight_only_quant_op.h"
#if defined(__has_include)
# if __has_include(<cuda_runtime_api.h>)
# include <cuda_runtime_api.h>
# else
struct CUstream_st;
using cudaStream_t = CUstream_st*;
# endif
#else
# include <cuda_runtime_api.h>
#endif
#include "../../include/moe_gemm_kernels.h"
#include "cutlass_extensions/gemm_configs.h"
#include "cutlass_extensions/weight_only_quant_op.h"
🧰 Tools
🪛 Clang (14.0.6)

[error] 17-17: 'cuda_runtime_api.h' file not found

(clang-diagnostic-error)

🤖 Prompt for AI Agents
In
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h
lines 17-21, the raw #include <cuda_runtime_api.h> causes static-analysis/build
failures on non-CUDA toolchains; wrap the include in a guard that only pulls in
the CUDA header when available (e.g. use a preprocessor check like
__has_include(<cuda_runtime_api.h>) or an appropriate CUDA macro) and provide a
minimal fallback forward-declaration/typedef for cudaStream_t (or equivalent)
when the header is not present so non-CUDA builds can compile.

Comment on lines 211 to 214
if (gemm.get_workspace_size(arguments) > hopper_inputs.gemm_workspace_size) {
TLLM_LOG_ERROR("[Mixed dtype WS grouped GEMM] given workspace size insufficient, %d < %d.",
gemm.get_workspace_size(arguments), hopper_inputs.gemm_workspace_size);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Do not proceed on insufficient workspace; also avoid printf specifiers in logger.

Use a clear exception; mixing %d with size_t risks UB depending on logger.

-  if (gemm.get_workspace_size(arguments) > hopper_inputs.gemm_workspace_size) {
-    TLLM_LOG_ERROR("[Mixed dtype WS grouped GEMM] given workspace size insufficient, %d < %d.",
-                   gemm.get_workspace_size(arguments), hopper_inputs.gemm_workspace_size);
-  }
+  {
+    size_t required = gemm.get_workspace_size(arguments);
+    if (required > hopper_inputs.gemm_workspace_size) {
+      std::string msg = "[Mixed dtype WS grouped GEMM] given workspace size insufficient, required=" +
+                        std::to_string(required) + " > provided=" +
+                        std::to_string(hopper_inputs.gemm_workspace_size);
+      throw std::runtime_error(msg);
+    }
+  }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (gemm.get_workspace_size(arguments) > hopper_inputs.gemm_workspace_size) {
TLLM_LOG_ERROR("[Mixed dtype WS grouped GEMM] given workspace size insufficient, %d < %d.",
gemm.get_workspace_size(arguments), hopper_inputs.gemm_workspace_size);
}
{
size_t required = gemm.get_workspace_size(arguments);
if (required > hopper_inputs.gemm_workspace_size) {
std::string msg = "[Mixed dtype WS grouped GEMM] given workspace size insufficient, required=" +
std::to_string(required) + " > provided=" +
std::to_string(hopper_inputs.gemm_workspace_size);
throw std::runtime_error(msg);
}
}
🤖 Prompt for AI Agents
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl
around lines 211-214: currently the code logs an error with printf-style "%d"
and continues on when workspace is insufficient; change this to stop execution
and report sizes safely. Replace the TLLM_LOG_ERROR call with throwing a clear
exception (e.g. throw std::runtime_error) that constructs a message using
type-safe conversions (std::to_string or casting to unsigned long long when
formatting) so no printf specifiers are mixed with size_t, and ensure the
function returns/aborts by throwing when gemm.get_workspace_size(arguments) >
hopper_inputs.gemm_workspace_size.

Comment on lines +265 to +271
# We need to remove the '::' because this will break the instantiation macro
epi_fusion = epi_fusion.split(":")[-1]
epi_tag = epi_tag.split(":")[-1]
epi_sched = epi_sched.split(":")[-1]
epi_sched = epi_sched.replace(
"1Sm", ""
) # Hack to WAR missing `PtrArrayTmaWarpSpecialized` type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fragile string manipulation for type names.

The string manipulation to remove :: and 1Sm from type names is a workaround for missing types. This is brittle - if upstream type names change, this will break silently. Consider:

  1. Adding a comment explaining which types are missing and why
  2. Adding validation to detect if the expected patterns are not found
  3. Filing a bug to add the missing types upstream
🤖 Prompt for AI Agents
In flashinfer/jit/gemm/cutlass/generate_kernels.py around lines 265-271, the
current ad-hoc string chops (split(":")[-1] and replace("1Sm","")) silently
mutate type names and are brittle; update the code by (1) adding a concise
comment that documents exactly which upstream types are missing and why the
workaround is required, (2) adding validation checks after the string operations
that assert the resulting names match expected patterns (e.g., raise a clear
exception or log an error if the original string did not contain '::' or if
'1Sm' was not present so the hack didn’t apply), and (3) add a TODO comment
referencing an issue in the issue tracker and, if possible, create that bug
upstream to remove the workaround. Ensure the validation fails loudly rather
than allowing silent incorrect names.

@yongwww
Copy link
Collaborator

yongwww commented Oct 18, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !81 has been updated with latest changes, and the CI pipeline #36832441 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #36832441: 1/17 passed

@aleozlx
Copy link
Collaborator Author

aleozlx commented Oct 21, 2025

still looking into one more error

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (3)

342-378: Non-GROUPED path throws; prefer empty set to allow fallback (or keep dispatcher fix).

Throwing here aborts tactic discovery. If the dispatcher fix is adopted, this is fine. If not, return {} instead of throwing to let upstream handle fallback without crashing.

-  } else {
-    TLLM_THROW("Not Implemented: SM100 GEMM candidates have not been defined.");
-  }
+  } else {
+    // No SM100 configs defined for non-grouped GEMM yet; return empty to allow fallback.
+    return {};
+  }

267-340: Shadowed identifier ‘config’ — rename local for clarity and to avoid -Wshadow.

The local CutlassGemmConfig config{...} shadows the function parameter config (bitmask). Rename the local to cfg.

-  for (auto [tile, cluster] : tile_configs) {
-    CutlassGemmConfig config{tile,    MainloopScheduleType::AUTO, schedule,
-                             cluster, dynamic_cluster_shape,      fallback_cluster_shape,
-                             sm};
-    candidate_configs.push_back(config);
-  }
+  for (auto [tile, cluster] : tile_configs) {
+    CutlassGemmConfig cfg{tile,    MainloopScheduleType::AUTO, schedule,
+                          cluster, dynamic_cluster_shape,      fallback_cluster_shape,
+                          sm};
+    candidate_configs.push_back(cfg);
+  }

Also consider a defensive check to ensure sm is in [100,119] since this helper is SM100-specific.


446-529: Remove legacy SM100 implementation instead of commenting it out.

Keeping a large commented block hurts readability and risks divergence. Prefer deletion or #if 0 ... #endif with a brief rationale + removal timeline.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ddb1345 and 2563556.

📒 Files selected for processing (1)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (4 hunks)
🔇 Additional comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (1)

353-372: Review suggestions are valid but incomplete; FP4 filtering already exists, FP8 lacks schedule gating.

Your review correctly identifies that FP4's NO_SMEM exclusion is already implicit (line 278 returns {} for non-TMA schedules). However, you missed that FP8_ONLY (line 326) does not filter schedules—it generates configs with both TMA and NO_SMEM applied to the same tile set, potentially wastefully.

Additionally, no build flag exists to gate the cluster shapes (4x2x1, Undefined) as suggested; only FAST_BUILD exists and drastically reduces to a single config. The shell command provided (rg 'CtaShape.*SM100') searches for definitions but doesn't measure the actual candidate count generated.

The loop at lines 353–372 produces up to 8 calls to get_candidate_configs_sm100_dynamic_cluster_shape per GROUPED_GEMM invocation. With FP8_ONLY, this generates ~72 configs, but the profiling impact claim lacks supporting metrics.

Suggested corrections:

  • Add schedule filtering for FP8 if NO_SMEM yields no benefit (parallel to FP4's TMA-only restriction at line 278).
  • Consider a build flag for selective cluster shapes, or gather profiling data to justify current shape count.

Comment on lines 653 to 655
if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) {
return get_candidate_configs_sm100(config_type_param);
return get_candidate_configs_sm100(config_type_param, sm);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

SM100 dispatcher will throw on non-GROUPED GEMM — route only when GROUPED is set.

get_candidate_configs_sm100(...) throws for non-GROUPED, but the dispatcher now unconditionally routes BW SM in [100,119). This likely explains CI red (1/17). Guard the route to SM100 by GROUPED_GEMM and otherwise fall through to the generic path.

Apply:

-  if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) {
-    return get_candidate_configs_sm100(config_type_param, sm);
-  }
+  if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) {
+    if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) {
+      return get_candidate_configs_sm100(config_type_param, sm);
+    }
+    // Non-grouped on BW SM100–SM119: fall through to generic tiling below.
+  }

Run to confirm no other hard routes to SM100 exist and to surface non-GROUPED BW usage:


🏁 Script executed:

#!/bin/bash
rg -nP -C3 '\bget_candidate_configs\s*\('
rg -nP 'GROUPED_GEMM|BLACKWELL|SM100|SM110|SM120'

Length of output: 71381


Let me examine the full implementation of get_candidate_configs_sm100 to verify whether it throws for non-GROUPED cases.Let me use a shell script to directly extract and examine the function:


🏁 Script executed:

#!/bin/bash
# Extract get_candidate_configs_sm100 function and surrounding context
cat csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | sed -n '342,380p'

Length of output: 2096


🏁 Script executed:

#!/bin/bash
# Check the complete dispatcher function and fallthrough behavior
cat csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | sed -n '638,670p'

Length of output: 1495


🏁 Script executed:

#!/bin/bash
# Look for TLLM_THROW in the function and check conditions
rg -n -A2 -B2 'TLLM_THROW.*SM100.*GEMM' csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Length of output: 308


Guard SM100 dispatcher to non-GROUPED GEMM to prevent runtime throws.

The function get_candidate_configs_sm100 throws "Not Implemented: SM100 GEMM candidates have not been defined" for non-GROUPED paths, but the dispatcher at lines 653–655 unconditionally routes Blackwell SM [100,119) to get_candidate_configs_sm100 when the BLACKWELL flag is set, regardless of GROUPED_GEMM status. This causes a runtime throw whenever non-grouped GEMM is requested on this SM range with the Blackwell flag enabled—likely explaining the CI failure.

Apply the fix:

-  if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) {
-    return get_candidate_configs_sm100(config_type_param, sm);
-  }
+  if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) {
+    if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) {
+      return get_candidate_configs_sm100(config_type_param, sm);
+    }
+    // Non-grouped on BW SM100–SM119: fall through to generic tiling below.
+  }

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)

632-637: Profiler still reads GEMM2 tactic indices without offset. When gemm_idx == 2, profile_id continues to address mAllProfiles as if GEMM2 entries started at zero. With GEMM1+GEMM2 concatenated, this pulls the wrong tactic during profiling and produces misleading autotune results. Please adjust runGemmProfile to add the GEMM1 count before indexing the concatenated vector, with appropriate bounds checks, mirroring the runtime fix above.

Apply this diff to keep profiler selection consistent:

-    auto profile = profile_id == -1 ? mAllProfiles.front() : mAllProfiles[profile_id];
+    auto base_index = (gemm_idx == 1) ? 0 : mGemm1ProfileCount;
+    auto selected_index = profile_id == -1 ? base_index : base_index + profile_id;
+    TVM_FFI_ICHECK_LT(selected_index, mAllProfiles.size())
+        << "Invalid profile index for GEMM " << gemm_idx;
+    auto profile = mAllProfiles.at(selected_index);
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)

936-958: Silent config failures may hide real errors.

The CALC_SIZE_FUSION macro catches TllmException and logs at TRACE level (lines 945-947), then continues. While line 958 checks that at least one config succeeded (has_config), this approach may silently skip configs that should be valid.

Issues:

  • Users running without TRACE logging won't see why certain configs failed
  • If multiple configs are expected but only one succeeds, the failure is hidden
  • Legitimate configuration errors (typos, invalid parameters) vs. intentionally unsupported configs are treated the same

Consider one of these approaches:

  1. Log at WARNING level for unexpected failures:
     } catch (tensorrt_llm::common::TllmException const& e) {
-      TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size %s",
-                     e.what());
+      TLLM_LOG_WARNING("Config skipped when calculating MOE workspace size: %s", e.what());
     }
  1. Track and report which specific configs failed:
std::vector<std::string> failed_configs;
// ... in catch block:
failed_configs.push_back(conf.toString());
// ... after loop:
if (!failed_configs.empty()) {
  TLLM_LOG_DEBUG("Skipped %zu configs: ...", failed_configs.size());
}
♻️ Duplicate comments (1)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)

787-797: Still missing GEMM2 tactic offset in concatenated profile list. The concatenated mAllProfiles array still indexes GEMM2 profiles as if they started at 0. As a result, any explicit profile_ids routed to GEMM2 (or even the implicit default) reuse GEMM1 tactics, so the wrong kernel launches under tuning and at runtime. This has already been flagged in a prior review and is a release blocker; please apply the offset fix before landing.

Apply this diff to store the GEMM1 count and offset GEMM2 lookups:

@@
-    auto gemm1_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1);
-    auto gemm2_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_2);
-    mAllProfiles = gemm1_tactics;
+    auto gemm1_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1);
+    auto gemm2_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_2);
+    mGemm1ProfileCount = gemm1_tactics.size();
+    mAllProfiles = gemm1_tactics;
@@
-    auto best_gemm1_profile = mAllProfiles.front();
-    auto best_gemm2_profile = mAllProfiles.front();
+    auto best_gemm1_profile = mAllProfiles.front();
+    auto best_gemm2_profile =
+        mAllProfiles.at(mGemm1ProfileCount);  // first GEMM2 tactic
@@
-      best_gemm1_profile = profile_ids.value()[0] == -1 ? best_gemm1_profile
-                                                        : mAllProfiles.at(profile_ids.value()[0]);
-      best_gemm2_profile = profile_ids.value()[1] == -1 ? best_gemm2_profile
-                                                        : mAllProfiles.at(profile_ids.value()[1]);
+      if (profile_ids.value()[0] != -1)
+        best_gemm1_profile = mAllProfiles.at(profile_ids.value()[0]);
+      if (profile_ids.value()[1] != -1) {
+        auto gemm2_index = profile_ids.value()[1] + mGemm1ProfileCount;
+        TVM_FFI_ICHECK_LT(gemm2_index, mAllProfiles.size());
+        best_gemm2_profile = mAllProfiles.at(gemm2_index);
+      }
@@
   std::vector<Profile> mAllProfiles;
+  size_t mGemm1ProfileCount{0};
🧹 Nitpick comments (3)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (3)

797-799: SM version major check is fragile.

The check inputs.gemm_config.sm_version / 10 == sm_ / 10 relies on NVIDIA's SM numbering scheme:

  • SM90 / 10 = 9
  • SM100-109 / 10 = 10
  • SM120-129 / 10 = 12

While this works for current architectures (Hopper, Blackwell, future variants), it's fragile if SM versioning changes. Consider adding a comment explaining this assumption or using a more explicit check:

-        // Check the major version of the SM matches
-        TLLM_CHECK_WITH_INFO(inputs.gemm_config.sm_version / 10 == sm_ / 10,
+        // Check the SM architecture family matches (90->9, 100-109->10, 120-129->12)
+        TLLM_CHECK_WITH_INFO(inputs.gemm_config.sm_version / 10 == sm_ / 10,
                              "Using SM %d configuration for SM %d device",
                              inputs.gemm_config.sm_version, sm_);

781-788: Document the SM120+ FP8 fallback to Ada (SM89).

The comment states "For SM120+ pure FP8 MoE (not FP8 x FP4), redirect to SM89 (Ada) FP8 kernel implementations," but this creates a non-obvious special case where SM120+ doesn't use its native TMA warp specialized path for pure FP8.

Consider adding a more detailed comment explaining why this fallback is necessary (e.g., kernel not yet implemented, performance considerations, etc.).


926-927: Workspace sizing always assumes finalize fusion support.

Line 926 always passes true for supports_finalize_fusion, meaning the workspace calculation accounts for all possible fusion configs even if finalize fusion won't be used at runtime. This is conservative and safe, but may over-allocate workspace memory.

If finalize fusion is optional and not always enabled, consider:

  1. Passing the actual fusion support status if known at workspace allocation time
  2. Adding a comment explaining this conservative sizing choice
  3. Documenting the potential memory overhead
-    auto configs = getTmaWarpSpecializedConfigs(sm_, true);
+    // Conservative: size for all fusion variants including FINALIZE to ensure sufficient workspace
+    // This may over-allocate if finalize fusion is not used at runtime
+    auto configs = getTmaWarpSpecializedConfigs(sm_, /* supports_finalize_fusion */ true);
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2563556 and bfe2852.

📒 Files selected for processing (3)
  • csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (54 hunks)
  • csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (8 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (14 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (3)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (3)
  • tensorrt_llm (38-172)
  • multi_processor_count_ (327-327)
  • EpilogueFusion (176-333)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (2)
  • get_candidate_configs (638-689)
  • get_candidate_configs (638-640)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h (1)
  • EpilogueScheduleType (197-433)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (3)
  • hidden_size (265-265)
  • num_experts (263-263)
  • top_k (270-270)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • enable_pdl (220-220)
🔇 Additional comments (4)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (4)

70-228: Namespace split looks good.

The introduction of cutlass_kernels_oss namespace and the cross-namespace type adapter references are correctly structured for the OSS/internal separation architecture.


258-524: Dispatch routing to OSS namespace is correct.

The dispatcher correctly routes to cutlass_kernels_oss launchers while preserving the original namespace structure.


599-629: Review comment is incorrect; the two SM103 code paths handle different activation data types.

The conditions at line 599 and lines 624-629 are not confusing separate paths—they represent mutually exclusive configurations:

  • Line 599: use_wfp4afp8 (FP8 activations + FP4 weights) has no dedicated SM103 support and remaps to SM100
  • Line 624: use_fp4 (FP4 activations) has dedicated SM103 kernels but also includes SM100 configs as fallback

These are different data types, not arbitrary branching logic. No clarification needed.

Likely an incorrect or invalid review comment.


631-666: The primary concern in this review is based on an incorrect understanding of the code.

The check at line 656 TLLM_CHECK_WITH_INFO(!config.swap_ab, ...) is safe and will not fail. Here's why:

At line 653, swap_ab_configs is a fresh copy of tma_ws_configs. The std::transform at lines 654-659 reads from the source range (swap_ab_configs) and writes to the destination (via back_inserter into tma_ws_configs). The lambda parameter auto& config receives references from the source range only—swap_ab_configs is never modified during iteration. This is the standard and safe usage of std::transform.

Regarding the empty config concern: The function can legitimately return empty for unsupported SM/architecture combinations. However, getConfigs() (line 537-544) provides a fallback mechanism by combining results from both getTmaWarpSpecializedConfigs() and getAmpereConfigs(), mitigating complete config loss for most scenarios.

The suggested validation is reasonable defensive programming but not necessary given the fallback in the calling function.

Likely an incorrect or invalid review comment.

@aleozlx
Copy link
Collaborator Author

aleozlx commented Oct 22, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !81 has been updated with latest changes, and the CI pipeline #37018845 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (1)

48-56: Fix includes: add missing std headers; drop unused <mutex>.

This file uses std::array, std::is_same_v, std::get, and std::cout but doesn't include the required headers; <mutex> appears unused.

Apply:

 #include <cuda_fp16.h>
 #include <math.h>
-
-#include <mutex>
+#include <array>
+#include <tuple>
+#include <type_traits>
+#include <iostream>
🧹 Nitpick comments (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (2)

67-112: SM100 dispatch helper: header hygiene + C++ version assumption.

Logic looks correct. Ensure we have the proper headers for std::array and std::is_same_v (see include fix). Also confirm build uses C++17+ since this relies on CTAD for std::array.

Please confirm the project enforces -std=c++17 or newer.


380-380: Misleading closing comment.

This brace closes the function, not a namespace. Update the comment to avoid confusion.

-}  // namespace tensorrt_llm
+}  // end dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bfe2852 and da54367.

📒 Files selected for processing (1)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (11 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (4)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (2)
  • tensorrt_llm (70-227)
  • tensorrt_llm (526-533)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (3)
  • tensorrt_llm (38-172)
  • EpilogueFusion (176-333)
  • FpXBlockScalingType (190-235)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h (2)
  • tensorrt_llm (27-119)
  • isValidBlackwellMOESpecialisation (54-68)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h (1)
  • tensorrt_llm (23-36)
🔇 Additional comments (7)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (7)

63-66: Namespace aliasing reads well.

Re-exporting TmaWarpSpecializedGroupedGemmInput into the OSS namespace keeps call sites clean. LGTM.


141-147: Guarded SM103 error path is consistent.

The compile-time gate and error message are clear and match the new arch split. LGTM.


224-231: SM100 cluster-shape constraint check is clear.

Restricting to 1x1x1 and 2x1x1 for dynamic CGA is reasonable. LGTM.


239-246: SM103 FP4-only tile gate seems correct.

The SM103 path constrained to FP4/FP4 with specific tiles matches traits. LGTM.


404-414: Default-case messages are helpful.

Good diagnostics for undefined/heuristic tile configs. LGTM.


488-501: Workspace size reuse via dispatcher is neat.

Reusing the dispatcher for WS calc with a dummy input is a good way to keep logic in sync. LGTM.


432-448: Code is correct—CutlassTileConfigSM103 is a type alias to SM100.

The review comment assumes SM103 is a distinct enum class, but verification shows CutlassTileConfigSM103 is explicitly aliased to CutlassTileConfigSM100, with a comment stating "An alias to make the SHAPE_CASE macro work". When the macro expands SHAPE_CASE(103, ...), the case labels resolve to CutlassTileConfigSM100 enum values, matching the switch expression type. No type incompatibility exists, and no refactoring is needed.

Likely an incorrect or invalid review comment.

Comment on lines 162 to +206
else {
#ifdef ENABLE_FP4
auto getFunc = [&]() {
if constexpr (std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, __nv_fp4_e2m1>) {
TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type ==
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX,
"MXFPX is the only supported scaling type for WFP4AFP8");
return &kernels::cutlass_kernels::tma_warp_specialized_generic_moe_gemm_kernelLauncher<
Arch, T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape, true,
false>;
} else {
TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type !=
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX,
"MXFPX is not supported for the selected weight combination");
return &kernels::cutlass_kernels::tma_warp_specialized_generic_moe_gemm_kernelLauncher<
Arch, T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape, false,
false>;
}
};
getFunc()(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size);
#if defined(ENABLE_FP4)
constexpr static bool is_wfp4afp8 =
std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, __nv_fp4_e2m1>;
#else
TLLM_THROW("FP4 data type is not supported on this architecture and CUDA version");
constexpr static bool is_wfp4afp8 = false;
#endif
if constexpr (is_wfp4afp8) {
TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type ==
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX,
"MXFPX is the only supported scaling type for WFP4AFP8");
} else {
TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type !=
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX,
"MXFPX is not supported for the selected weight combination");
}

if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) {
bool const dynamic_cga =
gemm_config.dynamic_cluster_shape != cutlass_extensions::ClusterShape::Undefined;
bool const swap_ab = hopper_input.swap_ab;
auto cluster_shape =
cutlass_extensions::enum_to_shape_tuple(gemm_config.dynamic_cluster_shape);
auto cluster_shape_cute = cute::Shape<int32_t, int32_t, cute::_1>{
std::get<0>(cluster_shape), std::get<1>(cluster_shape), cute::_1{}};
auto cluster_shape_fallback =
cutlass_extensions::enum_to_shape_tuple(gemm_config.fallback_cluster_shape);
auto cluster_shape_cute_fallback = cute::Shape<int32_t, int32_t, cute::_1>{
std::get<0>(cluster_shape_fallback), std::get<1>(cluster_shape_fallback), cute::_1{}};

// HACK debug the gemm_config used to produce selected_func
std::cout << "[SM100 gemm_config] sm_version=" << gemm_config.sm_version
<< ", tile_config_sm100=" << static_cast<int>(gemm_config.tile_config_sm100)
<< ", epilogue_schedule=" << static_cast<int>(gemm_config.epilogue_schedule)
<< ", dynamic_cluster_shape=" << static_cast<int>(gemm_config.dynamic_cluster_shape)
<< ", fallback_cluster_shape="
<< static_cast<int>(gemm_config.fallback_cluster_shape) << std::endl;

auto selected_func =
getDispatchFunctionForSM100<Arch, T, WeightType, OutputType, EpilogueTag, FUSION,
TileShape, ClusterShape, is_wfp4afp8>(
gemm_config.epilogue_schedule, dynamic_cga, swap_ab);
selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy,
workspace_size, cluster_shape_cute, cluster_shape_cute_fallback);
} else if constexpr (Arch::kMinComputeCapability >= 120 || Arch::kMinComputeCapability == 90) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Remove/guard std::cout debug in header; use project logging.

Printing from a header via std::cout is noisy, pulls in iostream globally, and hurts performance. Replace with TLLM_LOG_TRACE (gemm_config.toString()) or guard with a debug macro.

Apply:

-      // HACK debug the gemm_config used to produce selected_func
-      std::cout << "[SM100 gemm_config] sm_version=" << gemm_config.sm_version
-                << ", tile_config_sm100=" << static_cast<int>(gemm_config.tile_config_sm100)
-                << ", epilogue_schedule=" << static_cast<int>(gemm_config.epilogue_schedule)
-                << ", dynamic_cluster_shape=" << static_cast<int>(gemm_config.dynamic_cluster_shape)
-                << ", fallback_cluster_shape="
-                << static_cast<int>(gemm_config.fallback_cluster_shape) << std::endl;
+      TLLM_LOG_TRACE("[SM100 gemm_config] %s", gemm_config.toString().c_str());
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
else {
#ifdef ENABLE_FP4
auto getFunc = [&]() {
if constexpr (std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, __nv_fp4_e2m1>) {
TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type ==
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX,
"MXFPX is the only supported scaling type for WFP4AFP8");
return &kernels::cutlass_kernels::tma_warp_specialized_generic_moe_gemm_kernelLauncher<
Arch, T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape, true,
false>;
} else {
TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type !=
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX,
"MXFPX is not supported for the selected weight combination");
return &kernels::cutlass_kernels::tma_warp_specialized_generic_moe_gemm_kernelLauncher<
Arch, T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape, false,
false>;
}
};
getFunc()(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size);
#if defined(ENABLE_FP4)
constexpr static bool is_wfp4afp8 =
std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, __nv_fp4_e2m1>;
#else
TLLM_THROW("FP4 data type is not supported on this architecture and CUDA version");
constexpr static bool is_wfp4afp8 = false;
#endif
if constexpr (is_wfp4afp8) {
TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type ==
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX,
"MXFPX is the only supported scaling type for WFP4AFP8");
} else {
TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type !=
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX,
"MXFPX is not supported for the selected weight combination");
}
if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) {
bool const dynamic_cga =
gemm_config.dynamic_cluster_shape != cutlass_extensions::ClusterShape::Undefined;
bool const swap_ab = hopper_input.swap_ab;
auto cluster_shape =
cutlass_extensions::enum_to_shape_tuple(gemm_config.dynamic_cluster_shape);
auto cluster_shape_cute = cute::Shape<int32_t, int32_t, cute::_1>{
std::get<0>(cluster_shape), std::get<1>(cluster_shape), cute::_1{}};
auto cluster_shape_fallback =
cutlass_extensions::enum_to_shape_tuple(gemm_config.fallback_cluster_shape);
auto cluster_shape_cute_fallback = cute::Shape<int32_t, int32_t, cute::_1>{
std::get<0>(cluster_shape_fallback), std::get<1>(cluster_shape_fallback), cute::_1{}};
// HACK debug the gemm_config used to produce selected_func
std::cout << "[SM100 gemm_config] sm_version=" << gemm_config.sm_version
<< ", tile_config_sm100=" << static_cast<int>(gemm_config.tile_config_sm100)
<< ", epilogue_schedule=" << static_cast<int>(gemm_config.epilogue_schedule)
<< ", dynamic_cluster_shape=" << static_cast<int>(gemm_config.dynamic_cluster_shape)
<< ", fallback_cluster_shape="
<< static_cast<int>(gemm_config.fallback_cluster_shape) << std::endl;
auto selected_func =
getDispatchFunctionForSM100<Arch, T, WeightType, OutputType, EpilogueTag, FUSION,
TileShape, ClusterShape, is_wfp4afp8>(
gemm_config.epilogue_schedule, dynamic_cga, swap_ab);
selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy,
workspace_size, cluster_shape_cute, cluster_shape_cute_fallback);
} else if constexpr (Arch::kMinComputeCapability >= 120 || Arch::kMinComputeCapability == 90) {
else {
#if defined(ENABLE_FP4)
constexpr static bool is_wfp4afp8 =
std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, __nv_fp4_e2m1>;
#else
constexpr static bool is_wfp4afp8 = false;
#endif
if constexpr (is_wfp4afp8) {
TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type ==
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX,
"MXFPX is the only supported scaling type for WFP4AFP8");
} else {
TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type !=
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX,
"MXFPX is not supported for the selected weight combination");
}
if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) {
bool const dynamic_cga =
gemm_config.dynamic_cluster_shape != cutlass_extensions::ClusterShape::Undefined;
bool const swap_ab = hopper_input.swap_ab;
auto cluster_shape =
cutlass_extensions::enum_to_shape_tuple(gemm_config.dynamic_cluster_shape);
auto cluster_shape_cute = cute::Shape<int32_t, int32_t, cute::_1>{
std::get<0>(cluster_shape), std::get<1>(cluster_shape), cute::_1{}};
auto cluster_shape_fallback =
cutlass_extensions::enum_to_shape_tuple(gemm_config.fallback_cluster_shape);
auto cluster_shape_cute_fallback = cute::Shape<int32_t, int32_t, cute::_1>{
std::get<0>(cluster_shape_fallback), std::get<1>(cluster_shape_fallback), cute::_1{}};
TLLM_LOG_TRACE("[SM100 gemm_config] %s", gemm_config.toString().c_str());
auto selected_func =
getDispatchFunctionForSM100<Arch, T, WeightType, OutputType, EpilogueTag, FUSION,
TileShape, ClusterShape, is_wfp4afp8>(
gemm_config.epilogue_schedule, dynamic_cga, swap_ab);
selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy,
workspace_size, cluster_shape_cute, cluster_shape_cute_fallback);
} else if constexpr (Arch::kMinComputeCapability >= 120 || Arch::kMinComputeCapability == 90) {

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #37018845: 1/17 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants