-
Couldn't load subscription status.
- Fork 544
chore: upgrade cutlass moe kernel launcher to match trtllm #1925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
/bot run |
|
[FAILED] Pipeline #36551350: 1/17 passed |
|
/bot run |
WalkthroughThreads per-expert swizzled activation scaling, AB-layout swap, and finalize-epilogue fusion throughout MOE Cutlass paths; adds SM90 epilogue scatter visitor, dynamic tile/cluster enums (SM100/SM103/SM120), many launcher/dispatcher signature/template changes, OSS namespace variants, and multiple explicit template instantiations. Changes
Sequence Diagram(s)sequenceDiagram
participant App as Application
participant Runner as CutlassMoeFCRunner
participant Profiler as GemmProfilerBackend
participant Dispatch as Dispatch Layer
participant Kernel as CUDA Kernel
App->>Runner: runMoe(..., swizzled_input_sf, unpadded_hidden_size, swap_ab, ...)
activate Runner
alt profiling
Runner->>Profiler: prepare(..., epilogue_fusion, swap_ab, unpadded_hidden_size_profiler)
Profiler-->>Runner: profiled inputs & workspaces
end
alt swizzled_input_sf == true
Runner->>Runner: expand inputs using swizzled SF layout
else
Runner->>Runner: expand inputs using linear SF layout
end
alt epilogue_fusion == FINALIZE
Runner->>Dispatch: computeStridesTmaWarpSpecialized(router_scales, permuted_row_map, ...)
Dispatch-->>Runner: strides & finalize params
end
alt swap_ab == true
Runner->>Runner: swap act/weight pointers and strides
end
Runner->>Dispatch: select & launch tile/cluster (dynamic_cga considered)
Dispatch->>Kernel: launch GEMM kernel (possibly fused finalize epilogue)
Kernel-->>App: results
deactivate Runner
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
|
[FAILED] Pipeline #36760904: 1/17 passed |
|
/bot run |
|
/bot run |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request upgrades the cutlass moe kernel launcher to match trtllm, facilitating further feature upgrades. The changes include modifications to copyright years, adding a new template class, and adjusting kernel configurations and function calls for improved performance and functionality. The review focuses on correctness and maintainability, ensuring the changes align with the project's coding standards and improve the overall code quality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 13
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (9)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h (1)
17-26: Add include guards and necessary headers for self-contained compilation.The header currently lacks
#pragma once, include guards, and required includes forint64_tandcudaStream_t. It depends on transitive includes from consumers (e.g.,moe_gemm_template_dispatch.hincludes<cuda.h>before including this header), making it fragile. Add the following:+#pragma once +#include <cstdint> +#if !defined(__CUDACC_RTC__) +#include <cuda_runtime_api.h> +#else +using cudaStream_t = void*; +#endif + +#ifdef __cplusplus namespace tensorrt_llm::kernels::cutlass_kernels_oss { template <typename ElementType_, typename CutlassWeightType_, int MaxTileM_, int TileN_, int TileK_, int Stages_, typename EpilogueTag> void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy); } +#endifOptional: use
int32_tinstead of bareintfor API stability (line withnum_experts, num_experts, int multi_processor_count).csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)
546-582: Same default parameter concerns apply here.The min-latency path has the same hardcoded defaults and assumptions as the regular
runMoepath. The same validation and documentation suggestions from the earlier comment apply here.csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (3)
4040-4065: Uninitialized byte value used for MXFP8 scaling memset when E8M0 helper is unavailable.weight_block_scale_value_int is only assigned under FLASHINFER_ENABLE_FP8_E8M0 && CUDA>=12.8; otherwise it is uninitialized and passed to cudaMemsetAsync, resulting in undefined fill values.
Initialize a safe default encoding for 1.0f, or guard the path. Suggested minimal fix (1.0f encodes to 0x7F for e8m0; adjust if your helper provides a constant):
- TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF weight_block_scale_value_int{}; -#if defined(FLASHINFER_ENABLE_FP8_E8M0) && CUDART_VERSION >= 12080 + TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF weight_block_scale_value_int = + static_cast<TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF>(0x7F); // e8m0(1.0f) +#if defined(FLASHINFER_ENABLE_FP8_E8M0) && CUDART_VERSION >= 12080 __nv_fp8_e8m0 tmp; tmp.__x = __nv_cvt_float_to_e8m0(1.0f, __NV_SATFINITE, cudaRoundPosInf); std::memcpy(&weight_block_scale_value_int, &tmp, sizeof(tmp)); #endifIf you prefer stricter behavior, conditionally TLLM_CHECK for the feature instead of defaulting.
1006-1054: Guard potential nullptr from input SF lookup and avoid const_cast.When input_sf is provided, cvt_quant_get_sf_out_offset may still return nullptr; dereferencing without a check risks UB. Also, passing const_cast to an API expecting non-const is fragile.
Add a nullptr check and provide a const overload or wrap the pointer without const_cast:
- auto const sf_in = - cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF, - NumThreadsPerSF>( - std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, - num_cols / VecSize, - const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf), - QuantizationSFLayout::SWIZZLED_128x4); - *sf_out = *sf_in; + auto const sf_in = + cvt_quant_get_sf_out_offset<const TmaWarpSpecializedGroupedGemmInput::ElementSF, + NumThreadsPerSF>( + std::nullopt, source_token_id, elem_idx, std::nullopt, num_cols / VecSize, + input_sf, QuantizationSFLayout::SWIZZLED_128x4); + if (sf_in) { *sf_out = *sf_in; }Apply similar handling in the LINEAR branch.
2918-2926: Confirmed: Missing parameter and incorrect argument order at line 2920.The function signature requires 17 parameters with
padded_cols,actual_cols, andexperts_per_tokenin positions 10–12, but the line 2920 call site passes only 16 arguments withhidden_size,k, andnum_experts_per_nodeat those positions. The call omitsunpadded_hidden_sizeentirely. The corrected call sites at lines 3297 and 3304 demonstrate the proper argument order with all 17 parameters.- finalizeMoeRoutingKernelLauncher<OutputType, UnfusedGemmOutputType>( - static_cast<UnfusedGemmOutputType const*>(gemm_output), final_output, fc2_expert_biases, - unpermuted_final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, - token_selected_experts, expert_first_token_offset, num_rows, hidden_size, k, - num_experts_per_node, parallelism_config, enable_alltoall, enable_pdl, stream); + finalizeMoeRoutingKernelLauncher<OutputType, UnfusedGemmOutputType>( + static_cast<UnfusedGemmOutputType const*>(gemm_output), final_output, fc2_expert_biases, + unpermuted_final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, + token_selected_experts, expert_first_token_offset, num_rows, hidden_size, + unpadded_hidden_size, k, num_experts_per_node, parallelism_config, enable_alltoall, + enable_pdl, stream);tests/moe/test_trtllm_cutlass_fused_moe.py (1)
1096-1100: Update skip reason to match devices list.Guard allows SM100/110/120, but reason says “only supported on SM100 and SM110”. Add SM120 for consistency.
- reason="MXFP8xMXFP4 is only supported on SM100 and SM110", + reason="MXFP8xMXFP4 is only supported on SM100, SM110 and SM120",csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h (1)
112-117: Missing <type_traits> include for std::is_same_vstd::is_same_v is used but <type_traits> isn’t included in this header, risking compile errors depending on include order.
Apply:
#include "cutlass_extensions/epilogue_helpers.h" +#include <type_traits>Alternatively, use cutlass::platform::is_same consistently.
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h (1)
24-37: Remove stale namespace import in implementation fileThe implementation in
fpA_intB_launcher_sm90.inl(line 49) containsusing namespace tensorrt_llm::kernels::cutlass_kernels;which imports the old namespace inside the newcutlass_kernels_ossnamespace. This line should be removed since:
- The function is correctly declared and implemented within
cutlass_kernels_oss- All callers are also in
cutlass_kernels_ossand resolve correctly via unqualified lookup- The stale using directive creates unnecessary coupling to the old namespace and violates the separation intent of the rename
Fix location:
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inlline 49 — delete theusing namespace tensorrt_llm::kernels::cutlass_kernels;line.csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)
868-889: Contradictory FP8 fallback: check forbids it, code dispatches it.TLLM_CHECK_WITH_INFO(!use_fp8, …) guarantees a throw, yet the branch below still dispatches FP8 to Sm89. Pick one behavior.
Two options:
- Disallow FP8 fallback (remove dead dispatch):
- if constexpr (use_fp8) { - cutlass_kernels_oss::dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, - cutlass::arch::Sm89, EpilogueTag>( - inputs, multi_processor_count_); - } else { + { cutlass_kernels_oss::dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>( inputs, multi_processor_count_); }
- Or allow FP8 fallback (drop the check):
- TLLM_CHECK_WITH_INFO(!use_fp8, "No fallback FP8 implementation available"); + // FP8 fallback to SM89 is allowed here.Please confirm intended behavior.
♻️ Duplicate comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h (1)
29-35: Namespace rename ripple effects are broadThis file now exposes sm90_* dispatchers under cutlass_kernels_oss. Ensure all indirect users (e.g., moe_gemm_template_dispatch_tma_ws.h) and tests reference the new namespace. Avoid duplicate comments; see launcher header note for verification script.
🧹 Nitpick comments (27)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp4.cu (1)
1-15: Consider updating copyright year to reflect 2024-2025.The Apache License header shows copyright years as 2020-2023, which may be outdated for a 2025 PR. While this is a minor concern, it's worth updating to reflect the current active development period if this PR makes material changes to the file.
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h (1)
35-36: Consider adding documentation for DYNAMIC_CGA parameter.The filtering logic correctly includes the
DYNAMIC_CGAflag, which will filter shapes when dynamic cluster shapes are enabled during FAST_BUILD. However, there's no inline documentation explaining whatDYNAMIC_CGArepresents or when it should be set totrue.Consider adding a brief comment above the struct definition:
+// DYNAMIC_CGA: Set to true when using dynamic cluster group array shapes. +// When enabled, shapes will be filtered during FAST_BUILD to reduce compile time. template <class ArchTag, class TileShape, class ClusterShape, bool DYNAMIC_CGA, class ActivationType> struct should_filter_tma_warp_specialized_gemm_problem_shape {csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h (4)
59-60: Remove extra blank line for consistency.The blank line after the debug logging statement is inconsistent with other logging additions throughout the file (e.g., lines 235, 273, 427, 479, 502, 537). For consistency, debug logging should not be followed by blank lines.
Apply this diff:
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - #ifdef ENABLE_BF16
316-317: Remove extra blank line for consistency.Similar to line 60, this blank line after the debug logging statement is inconsistent with other logging additions throughout the file.
Apply this diff:
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - // Don't instantiate configs that are not supported pre-hopper. Produce a sensible error instead.
409-411: Consider the performance impact of logging in destructor.Logging in the destructor adds overhead to every object destruction. While useful for debugging, consider whether this logging should be conditional (e.g., only in debug builds) or if it might impact performance in production, especially if many instances are created and destroyed.
516-517: Remove extra blank line for consistency.This is the third instance of an inconsistent blank line after debug logging (also at lines 60 and 317). For consistency, remove this blank line to match the style used elsewhere in the file.
Apply this diff:
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); - if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY) {csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp (2)
288-317: Add divisibility asserts for vectorization/layout compatibility.
VecSizeis derived from alignment and per-thread values, but the chosen major dimension may not be divisible byVecSize, leading to malformed tilers. Add compile-time checks.@@ int constexpr NumThreads = CUTE_STATIC_V(size(args.tiled_copy)); int constexpr NumValTile = product(take<0,2>(shape(cD_epi))); int constexpr MaxVecSize = cute::min(AlignmentOutput, NumValTile / NumThreads); + if constexpr (cutlass::gemm::detail::is_k_major<StrideOutput>()) { + static_assert((size<1>(args.epi_tile) % MaxVecSize) == 0, "EPI_N must be divisible by vector size."); + } else if constexpr (cutlass::gemm::detail::is_mn_major<StrideOutput>()) { + static_assert((size<0>(args.epi_tile) % MaxVecSize) == 0, "EPI_M must be divisible by vector size."); + }Optionally, relax using a computed
VecSizethat divides the corresponding dimension (fallback to smaller factor) if compile-time constraints are too strict. Do you want a patch to auto-derive suchVecSize?
51-52: Avoidusing namespace detail;in a public header.Namespace pollution in headers can cause ADL surprises and ODR issues. Prefer explicit qualification or targeted using-declarations.
-using namespace cute; -using namespace detail; +using namespace cute; // Remove `using namespace detail;` and fully qualify uses (e.g., cutlass::gemm::detail::is_major)csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (2)
75-76: Fix typo in comment."supress" should be "suppress".
- return nvinfer1::DataType::kFLOAT; // supress compiler warning + return nvinfer1::DataType::kFLOAT; // suppress compiler warning
116-117: Fix typo in comment.Same typo as Line 76.
- return nullptr; // supress compiler warning + return nullptr; // suppress compiler warningflashinfer/jit/gemm/cutlass/generate_kernels.py (1)
382-384: Clarify cluster shape usage comment.The comment states "We use a runtime cluster shape for SM100, so we only use cluster shapes to distinguish between 1SM and 2SM variants." This is potentially confusing - the cluster shape is used for variant selection, not runtime configuration. Consider rewording to make this clearer, e.g., "The cga_shape parameter selects between 1SM and 2SM kernel variants; actual cluster shapes are determined at runtime."
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu (1)
20-41: Avoid duplicate explicit instantiations across TUsIf the same runner variants are instantiated in other .cu/.cc files, link-time ODR errors will occur. Consider using extern template declarations in headers and centralizing definitions in a single TU.
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl (4)
64-65: Make debug log portable.PRETTY_FUNCTION is GCC/Clang-specific. Use func to avoid MSVC breaks.
- TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + TLLM_LOG_DEBUG(__func__);
197-199: Fix error text clarity.Minor but user-facing: “must a multiple” → “must be a multiple”.
- std::string err_msg = "The group size must a multiple of " + std::to_string(cta_shape_k); + std::string err_msg = "The group size must be a multiple of " + std::to_string(cta_shape_k);
253-259: Unify error reporting; avoid std::cout in error path.Use project logger or throw with message; printing to stdout is noisy in libraries.
- std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + - std::string(cutlassGetStatusString(can_implement)); - std::cout << err_msg << std::endl; - throw std::runtime_error("[TensorRT LLM Error][fpA_intB Runner] " + err_msg); + std::string err_msg = std::string("[TensorRT LLM Error][fpA_intB Runner] ") + + "fpA_intB cutlass kernel will fail for params. Error: " + + std::string(cutlassGetStatusString(can_implement)); + TLLM_LOG_ERROR(err_msg.c_str()); + throw std::runtime_error(err_msg);
48-50: Limit using-namespace to needed symbols.Pulling the entire cutlass_kernels namespace into cutlass_kernels_oss risks ADL/ODR collisions.
Prefer targeted using-declarations for the specific adapters/types needed in this TU.
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h (1)
74-75: Debug logs: acceptable, but consider noise.TLLM_LOG_DEBUG at hot entry points can be noisy at scale. If needed, wrap behind a verbose flag.
Also applies to: 126-127, 160-161
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h (1)
19-22: Guard CUDA runtime include for non-CUDA toolchains.Mirror the portability guard to avoid SA/host-only build breaks.
-#include <cuda_runtime_api.h> +#if defined(__has_include) +# if __has_include(<cuda_runtime_api.h>) +# include <cuda_runtime_api.h> +# else +struct CUstream_st; +using cudaStream_t = CUstream_st*; +# endif +#else +# include <cuda_runtime_api.h> +#endifcsrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl (3)
73-74: Make debug log portable.Replace PRETTY_FUNCTION with func.
- TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); + TLLM_LOG_DEBUG(__func__);
222-226: Unify error reporting; avoid std::cout.Use logger then throw to keep stderr consistent.
- std::string err_msg = "mixed dtype WS grouped cutlass kernel will fail for params. Error: " + - std::string(cutlassGetStatusString(can_implement)); - std::cout << err_msg << std::endl; - throw std::runtime_error("[Mixed dtype WS grouped GEMM] " + err_msg); + std::string err_msg = std::string("[Mixed dtype WS grouped GEMM] ") + + "cutlass kernel will fail for params. Error: " + + std::string(cutlassGetStatusString(can_implement)); + TLLM_LOG_ERROR(err_msg.c_str()); + throw std::runtime_error(err_msg);
187-189: Device id hardcoded to 0.On multi-GPU systems this can misreport hardware info. Initialize from current device.
- hw_info.device_id = 0; + int dev = 0; + (void)cudaGetDevice(&dev); + hw_info.device_id = dev;csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu (2)
51-73: Stabilize workspace index mapping (avoid future drift).Hard-coded indices in configureWorkspace must stay in lockstep with workspaceBuffers’ array order. Introduce named indices to prevent brittle coupling.
Apply along these lines:
+// Keep these enums in sync with workspaceBuffers() return order +enum class WSIdx : int { + ProblemShape = 0, + StrideAct, + StrideWeight, + StrideC, + StrideD, + PtrAct, + PtrWeight, + PtrC, + PtrD, + AlphaScales, + SfAct, + SfWeight, + StrideSfAct, + StrideSfWeight, + Int4ProblemShape, + Int4SFA, + Int4StrideSFA, + FinalizeBias, + FinalizeRouterScales, + FinalizeSourceTokenIndex +}; ... - shape_info.problem_shapes = reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(pointers[0]); + shape_info.problem_shapes = reinterpret_cast<ProblemShape::UnderlyingProblemShape*>( + pointers[static_cast<int>(WSIdx::ProblemShape)]); ... - ptr_act = reinterpret_cast<void const**>(pointers[5]); + ptr_act = reinterpret_cast<void const**>(pointers[static_cast<int>(WSIdx::PtrAct)]); ... - fused_finalize_epilogue.ptr_bias = reinterpret_cast<void const**>(pointers[17]); - fused_finalize_epilogue.ptr_router_scales = reinterpret_cast<float const**>(pointers[18]); - fused_finalize_epilogue.ptr_source_token_index = reinterpret_cast<int const**>(pointers[19]); + fused_finalize_epilogue.ptr_bias = reinterpret_cast<void const**>( + pointers[static_cast<int>(WSIdx::FinalizeBias)]); + fused_finalize_epilogue.ptr_router_scales = reinterpret_cast<float const**>( + pointers[static_cast<int>(WSIdx::FinalizeRouterScales)]); + fused_finalize_epilogue.ptr_source_token_index = reinterpret_cast<int const**>( + pointers[static_cast<int>(WSIdx::FinalizeSourceTokenIndex)]);Also applies to: 94-125
171-176: Correct log label: this is generic FpX, not FP4-only.Rename “Fp4 Block Scaling Factors …” to “FpX Block Scaling Factors …” to avoid confusion when MXFPX is used.
- ss << "Fp4 Block Scaling Factors Act: " + ss << "FpX Block Scaling Factors Act: " ... - ss << "Fp4 Block Scaling Factors Weight: " + ss << "FpX Block Scaling Factors Weight: "csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (1)
52-52: Remove unused include.or isn’t used here; drop it to speed up compile and reduce surface.
-#include <mutex>csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)
372-372: Fix misleading namespace close comment.The brace closes tensorrt_llm::kernels::cutlass_kernels_oss, not the root namespace.
-} // namespace tensorrt_llm +} // namespace tensorrt_llm::kernels::cutlass_kernels_osscsrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h (2)
394-418: Improve readability: print names, not raw ints.Use the existing helpers to render human‑readable schedules and shapes.
- << "\n\tmainloop sched: " << (int)mainloop_schedule - << "\n\tepi sched: " << (int)epilogue_schedule + << "\n\tmainloop sched: " << get_mainloop_schedule_name(mainloop_schedule) + << "\n\tepi sched: " << (epilogue_schedule == EpilogueScheduleType::AUTO ? "auto" + : epilogue_schedule == EpilogueScheduleType::NO_SMEM ? "no_smem" + : "tma")And keep getTileConfigAsName()/get_cluster_shape_name you already added.
435-459: Use names in stream operator for configs too.Replace integer dumps with the name helpers to simplify telemetry parsing.
- out << "tile_config_sm90_enum: " << config.getTileConfigAsInt() + out << "tile_config: " << config.getTileConfigAsName() ... - << ", cluster_shape_enum: " << int(config.cluster_shape) - << ", dynamic_cluster_shape_enum: " << int(config.dynamic_cluster_shape) - << ", fallback_cluster_shape_enum: " << int(config.fallback_cluster_shape) + << ", cluster_shape: " << get_cluster_shape_name(config.cluster_shape) + << ", dynamic_cluster_shape: " << get_cluster_shape_name(config.dynamic_cluster_shape) + << ", fallback_cluster_shape: " << get_cluster_shape_name(config.fallback_cluster_shape)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (40)
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu(2 hunks)csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh(54 hunks)csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu(8 hunks)csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h(1 hunks)csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp(1 hunks)csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h(11 hunks)csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp(5 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h(17 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h(9 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h(2 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl(5 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h(7 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h(22 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl(2 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl(4 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp8.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp4_fp4.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp4.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_uint4.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h(14 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h(11 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h(7 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu(4 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h(5 hunks)flashinfer/jit/gemm/cutlass/generate_kernels.py(21 hunks)tests/moe/test_trtllm_cutlass_fused_moe.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (17)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h (1)
ActivationType(22-31)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h (1)
cutlass_kernels_oss(29-67)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h (1)
cutlass_kernels_oss(25-37)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (3)
hidden_size(251-251)num_experts(249-249)top_k(256-256)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
enable_pdl(220-220)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h (3)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h (1)
cutlass_kernels_oss(24-37)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h (1)
cutlass_kernels_oss(25-37)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (1)
std(81-95)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h (5)
cutlass(114-116)cutlass(120-122)cutlass(127-129)cutlass(132-134)cutlass(140-142)csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h (2)
__nv_fp8_e4m3(204-206)__nv_fp8_e4m3(220-220)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
tensorrt_llm(38-172)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu (1)
include/flashinfer/trtllm/fused_moe/runner.h (2)
num_experts(249-249)hidden_size(251-251)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (3)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (3)
tensorrt_llm(38-172)EpilogueFusion(176-333)FpXBlockScalingType(190-235)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h (2)
tensorrt_llm(27-119)isValidBlackwellMOESpecialisation(54-68)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h (1)
tensorrt_llm(23-36)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (2)
enable_pdl(220-220)EpilogueFusion(176-333)csrc/nv_internal/cpp/kernels/quantization.cu (2)
void(256-280)void(282-300)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (4)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (4)
tensorrt_llm(38-172)multi_processor_count_(327-327)EpilogueFusion(176-333)FpXBlockScalingType(190-235)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h (9)
tensorrt_llm(33-150)kernels(34-149)cutlass(114-116)cutlass(120-122)cutlass(127-129)cutlass(132-134)cutlass(140-142)cutlass_kernels(35-148)__nv_fp8_e5m2(91-93)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (2)
get_candidate_configs(525-576)get_candidate_configs(525-527)csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h (1)
EpilogueScheduleType(197-433)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu (6)
workspaceBuffers(26-73)workspaceBuffers(26-27)configureWorkspace(81-128)configureWorkspace(81-84)setFinalizeFusionParams(130-146)setFinalizeFusionParams(130-133)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (11)
std(81-95)if(135-139)if(142-146)if(149-153)if(156-160)if(321-321)if(323-323)if(325-325)if(426-439)if(441-463)if(463-476)csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h (4)
string(261-268)if(277-279)if(279-281)if(281-283)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
tensorrt_llm(38-172)csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h (2)
TileShape(205-392)ClusterShape(246-373)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (11)
std(81-95)if(135-139)if(142-146)if(149-153)if(156-160)if(321-321)if(323-323)if(325-325)if(426-439)if(441-463)if(463-476)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu (2)
toString(148-181)toString(148-148)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
tensorrt_llm(38-172)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (2)
tensorrt_llm(70-227)tensorrt_llm(526-533)
flashinfer/jit/gemm/cutlass/generate_kernels.py (1)
flashinfer/jit/gemm/cutlass/cutlass_library.py (4)
EpilogueScheduleType(710-726)DataType(73-120)KernelScheduleType(519-587)GemmKind(1008-1021)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
tensorrt_llm(38-172)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h (3)
tensorrt_llm(33-150)kernels(34-149)cutlass_kernels(35-148)
🪛 Clang (14.0.6)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h
[error] 17-17: 'cuda_runtime_api.h' file not found
(clang-diagnostic-error)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h
[error] 18-18: 'cstdint' file not found
(clang-diagnostic-error)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp
[error] 38-38: 'cutlass/cutlass.h' file not found
(clang-diagnostic-error)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h
[error] 17-17: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 17-17: expected ';' after top level declarator
(clang-diagnostic-error)
[error] 17-17: expected identifier or '('
(clang-diagnostic-error)
| TLLM_THROW("Min latency mode is no longer supported"); | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Min-latency hard-disabled but still invoked.
computeStridesTmaWarpSpecializedLowLatency now throws, yet setupTmaWarpSpecializedInputs still calls it when min_latency_mode is true, breaking that path.
Either remove the min-latency path or route min-latency to the non-low-latency stride builder (fusion NONE, swap_ab on). Example fix below updates setupTmaWarpSpecializedInputs to use computeStridesTmaWarpSpecialized for min-latency; then this throw can remain (unused).
@@
- return Self::computeStridesTmaWarpSpecializedLowLatency(
- gemm1_tma_ws_input, gemm2_tma_ws_input, num_rows, fc1_out_size, hidden_size, hidden_size,
- inter_size, num_experts_per_node, reinterpret_cast<T const*>(gemm1_input),
- reinterpret_cast<T const*>(gemm2_input), fc1_expert_weights, fc2_expert_weights,
- quant_params.fp8.dequant_fc1, quant_params.fp8.dequant_fc2, input_sf, fc2_fp4_act_scale_,
- quant_params, nullptr, nullptr, reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output),
- reinterpret_cast<UnfusedGemmOutputType*>(fc2_result_),
- min_latency_params.num_active_experts_per_node, min_latency_params.active_expert_global_ids,
- start_expert, enable_pdl, stream);
+ // Use regular stride builder with fusion NONE and swap_ab already set; finalize fusion not used.
+ return Self::computeStridesTmaWarpSpecialized(
+ expert_first_token_offset_, gemm1_tma_ws_input, gemm2_tma_ws_input,
+ /*num_tokens*/ num_rows, /*expanded_num_tokens*/ num_rows, // min-latency packs per expert
+ /*gemm1_n*/ fc1_out_size, /*gemm1_k*/ hidden_size,
+ /*gemm2_n*/ hidden_size, /*gemm2_k*/ inter_size,
+ num_experts_per_node,
+ reinterpret_cast<T const*>(gemm1_input), reinterpret_cast<T const*>(gemm2_input),
+ fc1_expert_weights, fc2_expert_weights,
+ quant_params.fp8.dequant_fc1, quant_params.fp8.dequant_fc2,
+ input_sf, fc2_fp4_act_scale_, quant_params,
+ /*bias1*/ nullptr, /*bias2*/ nullptr,
+ reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output),
+ reinterpret_cast<UnfusedGemmOutputType*>(fc2_result_),
+ /*router_scales*/ nullptr, /*permuted_row_to_unpermuted_row*/ nullptr,
+ enable_pdl, stream);Committable suggestion skipped: line range outside the PR's diff.
| // Get tactics for both GEMM1 and GEMM2, combine them | ||
| auto gemm1_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1); | ||
| auto gemm2_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_2); | ||
| mAllProfiles = gemm1_tactics; | ||
| mAllProfiles.insert(mAllProfiles.end(), gemm2_tactics.begin(), gemm2_tactics.end()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify tactic selection logic for GEMM1 vs GEMM2.
The tactics from GEMM1 and GEMM2 are now concatenated into a single mAllProfiles vector. Ensure that when selecting tactics by profile_id (e.g., at Line 618), the correct tactics are applied to the corresponding GEMM. The setRunnerProfiles method expects 2 profile IDs, so verify the indexing logic correctly distinguishes between GEMM1 and GEMM2 tactics.
🏁 Script executed:
#!/bin/bash
# Verify how profile_id is used to select tactics and ensure GEMM1/GEMM2 separation
# Check setRunnerProfiles implementation and how it maps profile IDs
rg -n -A 10 -B 3 'setRunnerProfiles' --type cpp
# Look for getTactics usage and how tactics are distinguished
rg -n -A 5 'getTactics.*GEMM' --type cppLength of output: 0
🏁 Script executed:
#!/bin/bash
# First, verify the file exists and examine the context around the modified lines
echo "=== Examining the target file around lines 226-230 ==="
head -n 240 csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu | tail -n 30
echo ""
echo "=== Search for setRunnerProfiles (broader search) ==="
rg -n 'setRunnerProfiles' csrc/fused_moe/cutlass_backend/
echo ""
echo "=== Search for how mAllProfiles is used ==="
rg -n -C 5 'mAllProfiles' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
echo ""
echo "=== Search for profile_id usage around line 618 ==="
sed -n '610,630p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
echo ""
echo "=== Search for getTactics calls ==="
rg -n 'getTactics' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cuLength of output: 806
🏁 Script executed:
#!/bin/bash
# Examine the target file and trace tactic selection logic
echo "=== Examining lines around 226-230 (mAllProfiles concatenation) ==="
sed -n '220,240p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
echo ""
echo "=== Search for mAllProfiles usage throughout the file ==="
rg -n 'mAllProfiles' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
echo ""
echo "=== Examining setRunnerProfiles method (lines 757-778 mentioned in scratchpad) ==="
sed -n '750,790p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
echo ""
echo "=== Search for getTactics calls and GEMM references ==="
rg -n 'getTactics|GEMM_1|GEMM_2' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
echo ""
echo "=== Examine profile_id usage pattern ==="
sed -n '610,630p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cuLength of output: 5633
🏁 Script executed:
#!/bin/bash
# Examine the context around line 618 and how profile is used
echo "=== Full context around line 618 (profile selection) ==="
sed -n '600,650p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
echo ""
echo "=== Search for where setRunnerProfiles is called ==="
rg -n 'setRunnerProfiles' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu -A 2 -B 2
echo ""
echo "=== Search for mKernelRunner->setTactic calls ==="
rg -n 'setTactic' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu -A 2 -B 2
echo ""
echo "=== Check runMoe signature and how it's called ==="
rg -n 'void runMoe' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu -A 8Length of output: 5456
🏁 Script executed:
#!/bin/bash
# Search for where profile_ids is generated/populated and if there's any offset logic
echo "=== Search for profile_ids origin and computation ==="
rg -n 'profile_ids' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu | head -30
echo ""
echo "=== Search for getNumProfiles or similar methods ==="
rg -n 'getNumProfiles|getNumTactics|size()' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
echo ""
echo "=== Check class member getters and how tactics are exposed ==="
sed -n '580,595p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
echo ""
echo "=== Look for any comments or logic about profile indexing/offset ==="
rg -n -B 3 -A 3 'mAllProfiles\[' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cuLength of output: 12101
Apply offset correction when indexing GEMM2 tactics from concatenated mAllProfiles.
The tactics from GEMM1 and GEMM2 are concatenated into a single mAllProfiles vector (lines 229–230), but when selecting tactics by profile_ids in setRunnerProfiles (lines 772–775), the GEMM2 profile index is not offset. This causes GEMM2 to incorrectly select from GEMM1's tactics range.
After concatenation, GEMM2 tactics start at index gemm1_tactics.size(), but this offset is not applied when accessing mAllProfiles.at(profile_ids.value()[1]).
Fix: Apply offset for GEMM2 tactic selection:
best_gemm2_profile = profile_ids.value()[1] == -1 ? best_gemm2_profile
: mAllProfiles.at(profile_ids.value()[1] + gemm1_size);where gemm1_size = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1).size() (stored during initialization).
🤖 Prompt for AI Agents
In csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
around lines 226–230 (and referencing setRunnerProfiles around lines 772–775),
the concatenation of GEMM1 and GEMM2 tactics into mAllProfiles does not account
for the GEMM1 length when later indexing GEMM2 tactics, causing wrong
selections; fix by capturing and storing gemm1_size during initialization (e.g.,
gemm1_size = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1).size()) and
then when selecting GEMM2 use profile index + gemm1_size (apply this offset
where mAllProfiles.at(profile_ids.value()[1]) is used), keeping the -1 sentinel
handling unchanged so that if profile_ids.value()[1] == -1 you keep the default,
otherwise index mAllProfiles at profile_ids.value()[1] + gemm1_size.
| // HACK Define default values for parameters we don't have good values for | ||
| bool const swizzled_input_sf = false; // Assume input_sf is not swizzled by default | ||
| int64_t const unpadded_hidden_size = hidden_size; // Assume no padding by default | ||
| bool const use_lora = false; // No lora support yet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document assumptions for default parameter values.
The hardcoded defaults for swizzled_input_sf, unpadded_hidden_size, and use_lora may not hold in all scenarios. Particularly, setting unpadded_hidden_size = hidden_size assumes no padding, which may be incorrect if padding occurs upstream. Consider:
- Validating these assumptions with assertions when these features are actually used
- Adding TODO comments linking to future work that will provide proper values
- Documenting these limitations in user-facing documentation
🤖 Prompt for AI Agents
In csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
around lines 372-375 the code hardcodes defaults for swizzled_input_sf,
unpadded_hidden_size and use_lora which may be incorrect in some configurations;
update the code to (1) add a TODO comment with a link to the tracking issue/PR
for providing proper values, (2) add runtime assertions or checks where these
flags/values are actually used (e.g., assert unpadded_hidden_size <= hidden_size
and verify swizzled_input_sf only when the input layout indicates swizzling),
and (3) if possible derive unpadded_hidden_size from upstream tensor metadata or
add it as an explicit parameter to the caller API and fall back to the current
default only with a clear warning log; also add a brief note in the module’s
user-facing docs describing this limitation and the expectation until full
support is implemented.
| int64_t const unpadded_hidden_size_profiler = hidden_size; // HACK no padding by default | ||
| #ifdef USING_OSS_CUTLASS_MOE_GEMM | ||
| mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile, | ||
| DtypeUtils::dataType(activation_dtype), DtypeUtils::dataType(mWeightDtype), | ||
| DtypeUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k), | ||
| hidden_size, inter_size, group_size, ActivationType::Swiglu, USE_BIAS, | ||
| USE_LORA, min_latency_mode, | ||
| hidden_size, unpadded_hidden_size_profiler, inter_size, group_size, | ||
| ActivationType::Swiglu, USE_BIAS, USE_LORA, min_latency_mode, | ||
| /*need_weights*/ false, parallelism_config, enable_alltoall); | ||
| #else | ||
| mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile, | ||
| DtypeUtils::dataType(activation_dtype), DtypeUtils::dataType(mWeightDtype), | ||
| DtypeUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k), | ||
| hidden_size, inter_size, group_size, ActivationType::Swiglu, USE_BIAS, | ||
| USE_LORA, min_latency_mode, | ||
| hidden_size, unpadded_hidden_size_profiler, inter_size, group_size, | ||
| ActivationType::Swiglu, USE_BIAS, USE_LORA, min_latency_mode, | ||
| /*need_weights*/ false, parallelism_config); | ||
| #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consistent default parameter pattern in profiler.
The profiler also uses the hardcoded unpadded_hidden_size default. This is consistent with the runtime paths but shares the same concerns about correctness. If padding is actually present, profiling may select suboptimal tactics.
🤖 Prompt for AI Agents
In csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
around lines 642 to 657, the code passes a hardcoded
unpadded_hidden_size_profiler equal to hidden_size into mProfiler->init, which
forces the profiler to assume no padding; replace this with the actual unpadded
hidden size used at runtime (compute or reuse the existing unpadded_hidden_size
variable that accounts for padding) and pass that variable into both the OSS and
non-OSS init calls so the profiler sees the true unpadded dimension and can pick
appropriate tactics.
...t_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp
Show resolved
Hide resolved
| #include "tensorrt_llm/common/cudaUtils.h" | ||
|
|
||
| namespace tensorrt_llm::kernels::cutlass_kernels { | ||
| namespace tensorrt_llm::kernels::cutlass_kernels_oss { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify that all callers have been updated to the new cutlass_kernels_oss namespace.
The namespace change from tensorrt_llm::kernels::cutlass_kernels to tensorrt_llm::kernels::cutlass_kernels_oss is a breaking API change. Ensure all usage sites across the codebase have been updated to reference the new namespace.
Run the following script to verify namespace consistency:
Also applies to: 98-98
🏁 Script executed:
#!/bin/bash
# Description: Verify all references to the MOE launcher use the new cutlass_kernels_oss namespace
echo "Checking for lingering references to old cutlass_kernels namespace..."
# Search for old namespace references (excluding this file and commented code)
rg -t cpp -t cu -t h "cutlass_kernels::" \
--glob "!*/fused_moe_gemm_launcher_sm80.inl" \
-C2
echo -e "\nSearching for new cutlass_kernels_oss namespace usage..."
# Find all references to the new namespace
rg -t cpp -t cu -t h "cutlass_kernels_oss::" -C2
echo -e "\nSearching for sm80_generic_fused_moe_gemm_kernelLauncher calls..."
# Find all call sites of this launcher function
rg -t cpp -t cu "sm80_generic_fused_moe_gemm_kernelLauncher" -C3Length of output: 275
🏁 Script executed:
#!/bin/bash
# Description: Verify namespace migration with improved file discovery
echo "=== Finding all C++/CUDA files in cutlass_kernels directory ==="
fd -e cpp -e cu -e h -e inl "cutlass_kernels" | head -20
echo -e "\n=== Searching for old namespace (cutlass_kernels::) in all C++/CUDA files ==="
rg "cutlass_kernels::" --type-add 'cuda:*.cu' --type-add 'cuda:*.cuh' -C2 | head -40
echo -e "\n=== Searching for new namespace (cutlass_kernels_oss::) ==="
rg "cutlass_kernels_oss::" -C2 | head -40
echo -e "\n=== Searching for sm80_generic_fused_moe_gemm_kernelLauncher function calls ==="
rg "sm80_generic_fused_moe_gemm_kernelLauncher" -C3 | head -40
echo -e "\n=== Checking the target file for namespace declaration ==="
fd -e inl "fused_moe_gemm_launcher_sm80.inl" -x cat {}Length of output: 21869
Namespace migration is incomplete; related files still reference the old namespace.
The namespace change to cutlass_kernels_oss is not fully propagated. Files in the new namespace still contain references to the old tensorrt_llm::kernels::cutlass_kernels:: namespace:
moe_gemm_template_dispatch_tma_ws_mixed_dtype.h: usestensorrt_llm::kernels::cutlass_kernels::GroupedGemmInputandTmaWarpSpecializedGroupedGemmInputmoe_gemm_template_dispatch_tma_ws.h: usestensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput,isValidHopperMOESpecialisation,isValidBlackwellMOESpecialisation, andTllmToCutlassTypeAdapter
Update all type and function references in these files to use the new namespace path.
🤖 Prompt for AI Agents
In
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl
(around line 27) and the related files
moe_gemm_template_dispatch_tma_ws_mixed_dtype.h and
moe_gemm_template_dispatch_tma_ws.h, references to the old namespace
tensorrt_llm::kernels::cutlass_kernels:: remain and must be updated to the
migrated namespace tensorrt_llm::kernels::cutlass_kernels_oss::; search for all
uses of GroupedGemmInput, TmaWarpSpecializedGroupedGemmInput,
isValidHopperMOESpecialisation, isValidBlackwellMOESpecialisation,
TllmToCutlassTypeAdapter and any other symbols qualified with the old namespace
and replace their qualification with
tensorrt_llm::kernels::cutlass_kernels_oss:: (or add appropriate using
declarations in the header scope), ensuring includes and forward declarations
still match the new namespace so compilation resolves correctly.
| #include <cuda_runtime_api.h> | ||
|
|
||
| #include "../../include/moe_gemm_kernels.h" | ||
| #include "cutlass_extensions/gemm_configs.h" | ||
| #include "cutlass_extensions/weight_only_quant_op.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard CUDA runtime include for non-CUDA toolchains.
Static analysis reports: 'cuda_runtime_api.h' not found. Provide a fallback forward-declaration.
-#include <cuda_runtime_api.h>
+#if defined(__has_include)
+# if __has_include(<cuda_runtime_api.h>)
+# include <cuda_runtime_api.h>
+# else
+struct CUstream_st;
+using cudaStream_t = CUstream_st*;
+# endif
+#else
+# include <cuda_runtime_api.h>
+#endif📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #include <cuda_runtime_api.h> | |
| #include "../../include/moe_gemm_kernels.h" | |
| #include "cutlass_extensions/gemm_configs.h" | |
| #include "cutlass_extensions/weight_only_quant_op.h" | |
| #if defined(__has_include) | |
| # if __has_include(<cuda_runtime_api.h>) | |
| # include <cuda_runtime_api.h> | |
| # else | |
| struct CUstream_st; | |
| using cudaStream_t = CUstream_st*; | |
| # endif | |
| #else | |
| # include <cuda_runtime_api.h> | |
| #endif | |
| #include "../../include/moe_gemm_kernels.h" | |
| #include "cutlass_extensions/gemm_configs.h" | |
| #include "cutlass_extensions/weight_only_quant_op.h" |
🧰 Tools
🪛 Clang (14.0.6)
[error] 17-17: 'cuda_runtime_api.h' file not found
(clang-diagnostic-error)
🤖 Prompt for AI Agents
In
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.h
lines 17-21, the raw #include <cuda_runtime_api.h> causes static-analysis/build
failures on non-CUDA toolchains; wrap the include in a guard that only pulls in
the CUDA header when available (e.g. use a preprocessor check like
__has_include(<cuda_runtime_api.h>) or an appropriate CUDA macro) and provide a
minimal fallback forward-declaration/typedef for cudaStream_t (or equivalent)
when the header is not present so non-CUDA builds can compile.
| if (gemm.get_workspace_size(arguments) > hopper_inputs.gemm_workspace_size) { | ||
| TLLM_LOG_ERROR("[Mixed dtype WS grouped GEMM] given workspace size insufficient, %d < %d.", | ||
| gemm.get_workspace_size(arguments), hopper_inputs.gemm_workspace_size); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not proceed on insufficient workspace; also avoid printf specifiers in logger.
Use a clear exception; mixing %d with size_t risks UB depending on logger.
- if (gemm.get_workspace_size(arguments) > hopper_inputs.gemm_workspace_size) {
- TLLM_LOG_ERROR("[Mixed dtype WS grouped GEMM] given workspace size insufficient, %d < %d.",
- gemm.get_workspace_size(arguments), hopper_inputs.gemm_workspace_size);
- }
+ {
+ size_t required = gemm.get_workspace_size(arguments);
+ if (required > hopper_inputs.gemm_workspace_size) {
+ std::string msg = "[Mixed dtype WS grouped GEMM] given workspace size insufficient, required=" +
+ std::to_string(required) + " > provided=" +
+ std::to_string(hopper_inputs.gemm_workspace_size);
+ throw std::runtime_error(msg);
+ }
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if (gemm.get_workspace_size(arguments) > hopper_inputs.gemm_workspace_size) { | |
| TLLM_LOG_ERROR("[Mixed dtype WS grouped GEMM] given workspace size insufficient, %d < %d.", | |
| gemm.get_workspace_size(arguments), hopper_inputs.gemm_workspace_size); | |
| } | |
| { | |
| size_t required = gemm.get_workspace_size(arguments); | |
| if (required > hopper_inputs.gemm_workspace_size) { | |
| std::string msg = "[Mixed dtype WS grouped GEMM] given workspace size insufficient, required=" + | |
| std::to_string(required) + " > provided=" + | |
| std::to_string(hopper_inputs.gemm_workspace_size); | |
| throw std::runtime_error(msg); | |
| } | |
| } |
🤖 Prompt for AI Agents
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl
around lines 211-214: currently the code logs an error with printf-style "%d"
and continues on when workspace is insufficient; change this to stop execution
and report sizes safely. Replace the TLLM_LOG_ERROR call with throwing a clear
exception (e.g. throw std::runtime_error) that constructs a message using
type-safe conversions (std::to_string or casting to unsigned long long when
formatting) so no printf specifiers are mixed with size_t, and ensure the
function returns/aborts by throwing when gemm.get_workspace_size(arguments) >
hopper_inputs.gemm_workspace_size.
.../nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h
Show resolved
Hide resolved
| # We need to remove the '::' because this will break the instantiation macro | ||
| epi_fusion = epi_fusion.split(":")[-1] | ||
| epi_tag = epi_tag.split(":")[-1] | ||
| epi_sched = epi_sched.split(":")[-1] | ||
| epi_sched = epi_sched.replace( | ||
| "1Sm", "" | ||
| ) # Hack to WAR missing `PtrArrayTmaWarpSpecialized` type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fragile string manipulation for type names.
The string manipulation to remove :: and 1Sm from type names is a workaround for missing types. This is brittle - if upstream type names change, this will break silently. Consider:
- Adding a comment explaining which types are missing and why
- Adding validation to detect if the expected patterns are not found
- Filing a bug to add the missing types upstream
🤖 Prompt for AI Agents
In flashinfer/jit/gemm/cutlass/generate_kernels.py around lines 265-271, the
current ad-hoc string chops (split(":")[-1] and replace("1Sm","")) silently
mutate type names and are brittle; update the code by (1) adding a concise
comment that documents exactly which upstream types are missing and why the
workaround is required, (2) adding validation checks after the string operations
that assert the resulting names match expected patterns (e.g., raise a clear
exception or log an error if the original string did not contain '::' or if
'1Sm' was not present so the hack didn’t apply), and (3) add a TODO comment
referencing an issue in the issue tracker and, if possible, create that bug
upstream to remove the workaround. Ensure the validation fails loudly rather
than allowing silent incorrect names.
|
/bot run |
|
[FAILED] Pipeline #36832441: 1/17 passed |
|
still looking into one more error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (3)
342-378: Non-GROUPED path throws; prefer empty set to allow fallback (or keep dispatcher fix).Throwing here aborts tactic discovery. If the dispatcher fix is adopted, this is fine. If not, return {} instead of throwing to let upstream handle fallback without crashing.
- } else { - TLLM_THROW("Not Implemented: SM100 GEMM candidates have not been defined."); - } + } else { + // No SM100 configs defined for non-grouped GEMM yet; return empty to allow fallback. + return {}; + }
267-340: Shadowed identifier ‘config’ — rename local for clarity and to avoid -Wshadow.The local
CutlassGemmConfig config{...}shadows the function parameterconfig(bitmask). Rename the local tocfg.- for (auto [tile, cluster] : tile_configs) { - CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, schedule, - cluster, dynamic_cluster_shape, fallback_cluster_shape, - sm}; - candidate_configs.push_back(config); - } + for (auto [tile, cluster] : tile_configs) { + CutlassGemmConfig cfg{tile, MainloopScheduleType::AUTO, schedule, + cluster, dynamic_cluster_shape, fallback_cluster_shape, + sm}; + candidate_configs.push_back(cfg); + }Also consider a defensive check to ensure
smis in [100,119] since this helper is SM100-specific.
446-529: Remove legacy SM100 implementation instead of commenting it out.Keeping a large commented block hurts readability and risks divergence. Prefer deletion or
#if 0 ... #endifwith a brief rationale + removal timeline.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp(4 hunks)
🔇 Additional comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (1)
353-372: Review suggestions are valid but incomplete; FP4 filtering already exists, FP8 lacks schedule gating.Your review correctly identifies that FP4's NO_SMEM exclusion is already implicit (line 278 returns
{}for non-TMA schedules). However, you missed that FP8_ONLY (line 326) does not filter schedules—it generates configs with both TMA and NO_SMEM applied to the same tile set, potentially wastefully.Additionally, no build flag exists to gate the cluster shapes (4x2x1, Undefined) as suggested; only
FAST_BUILDexists and drastically reduces to a single config. The shell command provided (rg 'CtaShape.*SM100') searches for definitions but doesn't measure the actual candidate count generated.The loop at lines 353–372 produces up to 8 calls to
get_candidate_configs_sm100_dynamic_cluster_shapeper GROUPED_GEMM invocation. With FP8_ONLY, this generates ~72 configs, but the profiling impact claim lacks supporting metrics.Suggested corrections:
- Add schedule filtering for FP8 if NO_SMEM yields no benefit (parallel to FP4's TMA-only restriction at line 278).
- Consider a build flag for selective cluster shapes, or gather profiling data to justify current shape count.
| if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) { | ||
| return get_candidate_configs_sm100(config_type_param); | ||
| return get_candidate_configs_sm100(config_type_param, sm); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
SM100 dispatcher will throw on non-GROUPED GEMM — route only when GROUPED is set.
get_candidate_configs_sm100(...) throws for non-GROUPED, but the dispatcher now unconditionally routes BW SM in [100,119). This likely explains CI red (1/17). Guard the route to SM100 by GROUPED_GEMM and otherwise fall through to the generic path.
Apply:
- if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) {
- return get_candidate_configs_sm100(config_type_param, sm);
- }
+ if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) {
+ if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) {
+ return get_candidate_configs_sm100(config_type_param, sm);
+ }
+ // Non-grouped on BW SM100–SM119: fall through to generic tiling below.
+ }Run to confirm no other hard routes to SM100 exist and to surface non-GROUPED BW usage:
🏁 Script executed:
#!/bin/bash
rg -nP -C3 '\bget_candidate_configs\s*\('
rg -nP 'GROUPED_GEMM|BLACKWELL|SM100|SM110|SM120'Length of output: 71381
Let me examine the full implementation of get_candidate_configs_sm100 to verify whether it throws for non-GROUPED cases.Let me use a shell script to directly extract and examine the function:
🏁 Script executed:
#!/bin/bash
# Extract get_candidate_configs_sm100 function and surrounding context
cat csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | sed -n '342,380p'Length of output: 2096
🏁 Script executed:
#!/bin/bash
# Check the complete dispatcher function and fallthrough behavior
cat csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp | sed -n '638,670p'Length of output: 1495
🏁 Script executed:
#!/bin/bash
# Look for TLLM_THROW in the function and check conditions
rg -n -A2 -B2 'TLLM_THROW.*SM100.*GEMM' csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppLength of output: 308
Guard SM100 dispatcher to non-GROUPED GEMM to prevent runtime throws.
The function get_candidate_configs_sm100 throws "Not Implemented: SM100 GEMM candidates have not been defined" for non-GROUPED paths, but the dispatcher at lines 653–655 unconditionally routes Blackwell SM [100,119) to get_candidate_configs_sm100 when the BLACKWELL flag is set, regardless of GROUPED_GEMM status. This causes a runtime throw whenever non-grouped GEMM is requested on this SM range with the Blackwell flag enabled—likely explaining the CI failure.
Apply the fix:
- if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) {
- return get_candidate_configs_sm100(config_type_param, sm);
- }
+ if (sm >= 100 && sm < 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) {
+ if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) {
+ return get_candidate_configs_sm100(config_type_param, sm);
+ }
+ // Non-grouped on BW SM100–SM119: fall through to generic tiling below.
+ }…to feature/cutlass_moe_3xfp4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)
632-637: Profiler still reads GEMM2 tactic indices without offset. Whengemm_idx == 2,profile_idcontinues to addressmAllProfilesas if GEMM2 entries started at zero. With GEMM1+GEMM2 concatenated, this pulls the wrong tactic during profiling and produces misleading autotune results. Please adjustrunGemmProfileto add the GEMM1 count before indexing the concatenated vector, with appropriate bounds checks, mirroring the runtime fix above.Apply this diff to keep profiler selection consistent:
- auto profile = profile_id == -1 ? mAllProfiles.front() : mAllProfiles[profile_id]; + auto base_index = (gemm_idx == 1) ? 0 : mGemm1ProfileCount; + auto selected_index = profile_id == -1 ? base_index : base_index + profile_id; + TVM_FFI_ICHECK_LT(selected_index, mAllProfiles.size()) + << "Invalid profile index for GEMM " << gemm_idx; + auto profile = mAllProfiles.at(selected_index);csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)
936-958: Silent config failures may hide real errors.The
CALC_SIZE_FUSIONmacro catchesTllmExceptionand logs at TRACE level (lines 945-947), then continues. While line 958 checks that at least one config succeeded (has_config), this approach may silently skip configs that should be valid.Issues:
- Users running without TRACE logging won't see why certain configs failed
- If multiple configs are expected but only one succeeds, the failure is hidden
- Legitimate configuration errors (typos, invalid parameters) vs. intentionally unsupported configs are treated the same
Consider one of these approaches:
- Log at WARNING level for unexpected failures:
} catch (tensorrt_llm::common::TllmException const& e) { - TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size %s", - e.what()); + TLLM_LOG_WARNING("Config skipped when calculating MOE workspace size: %s", e.what()); }
- Track and report which specific configs failed:
std::vector<std::string> failed_configs; // ... in catch block: failed_configs.push_back(conf.toString()); // ... after loop: if (!failed_configs.empty()) { TLLM_LOG_DEBUG("Skipped %zu configs: ...", failed_configs.size()); }
♻️ Duplicate comments (1)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)
787-797: Still missing GEMM2 tactic offset in concatenated profile list. The concatenatedmAllProfilesarray still indexes GEMM2 profiles as if they started at 0. As a result, any explicitprofile_idsrouted to GEMM2 (or even the implicit default) reuse GEMM1 tactics, so the wrong kernel launches under tuning and at runtime. This has already been flagged in a prior review and is a release blocker; please apply the offset fix before landing.Apply this diff to store the GEMM1 count and offset GEMM2 lookups:
@@ - auto gemm1_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1); - auto gemm2_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_2); - mAllProfiles = gemm1_tactics; + auto gemm1_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1); + auto gemm2_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_2); + mGemm1ProfileCount = gemm1_tactics.size(); + mAllProfiles = gemm1_tactics; @@ - auto best_gemm1_profile = mAllProfiles.front(); - auto best_gemm2_profile = mAllProfiles.front(); + auto best_gemm1_profile = mAllProfiles.front(); + auto best_gemm2_profile = + mAllProfiles.at(mGemm1ProfileCount); // first GEMM2 tactic @@ - best_gemm1_profile = profile_ids.value()[0] == -1 ? best_gemm1_profile - : mAllProfiles.at(profile_ids.value()[0]); - best_gemm2_profile = profile_ids.value()[1] == -1 ? best_gemm2_profile - : mAllProfiles.at(profile_ids.value()[1]); + if (profile_ids.value()[0] != -1) + best_gemm1_profile = mAllProfiles.at(profile_ids.value()[0]); + if (profile_ids.value()[1] != -1) { + auto gemm2_index = profile_ids.value()[1] + mGemm1ProfileCount; + TVM_FFI_ICHECK_LT(gemm2_index, mAllProfiles.size()); + best_gemm2_profile = mAllProfiles.at(gemm2_index); + } @@ std::vector<Profile> mAllProfiles; + size_t mGemm1ProfileCount{0};
🧹 Nitpick comments (3)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (3)
797-799: SM version major check is fragile.The check
inputs.gemm_config.sm_version / 10 == sm_ / 10relies on NVIDIA's SM numbering scheme:
- SM90 / 10 = 9
- SM100-109 / 10 = 10
- SM120-129 / 10 = 12
While this works for current architectures (Hopper, Blackwell, future variants), it's fragile if SM versioning changes. Consider adding a comment explaining this assumption or using a more explicit check:
- // Check the major version of the SM matches - TLLM_CHECK_WITH_INFO(inputs.gemm_config.sm_version / 10 == sm_ / 10, + // Check the SM architecture family matches (90->9, 100-109->10, 120-129->12) + TLLM_CHECK_WITH_INFO(inputs.gemm_config.sm_version / 10 == sm_ / 10, "Using SM %d configuration for SM %d device", inputs.gemm_config.sm_version, sm_);
781-788: Document the SM120+ FP8 fallback to Ada (SM89).The comment states "For SM120+ pure FP8 MoE (not FP8 x FP4), redirect to SM89 (Ada) FP8 kernel implementations," but this creates a non-obvious special case where SM120+ doesn't use its native TMA warp specialized path for pure FP8.
Consider adding a more detailed comment explaining why this fallback is necessary (e.g., kernel not yet implemented, performance considerations, etc.).
926-927: Workspace sizing always assumes finalize fusion support.Line 926 always passes
trueforsupports_finalize_fusion, meaning the workspace calculation accounts for all possible fusion configs even if finalize fusion won't be used at runtime. This is conservative and safe, but may over-allocate workspace memory.If finalize fusion is optional and not always enabled, consider:
- Passing the actual fusion support status if known at workspace allocation time
- Adding a comment explaining this conservative sizing choice
- Documenting the potential memory overhead
- auto configs = getTmaWarpSpecializedConfigs(sm_, true); + // Conservative: size for all fusion variants including FINALIZE to ensure sufficient workspace + // This may over-allocate if finalize fusion is not used at runtime + auto configs = getTmaWarpSpecializedConfigs(sm_, /* supports_finalize_fusion */ true);
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh(54 hunks)csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu(8 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h(14 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (3)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (3)
tensorrt_llm(38-172)multi_processor_count_(327-327)EpilogueFusion(176-333)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (2)
get_candidate_configs(638-689)get_candidate_configs(638-640)csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h (1)
EpilogueScheduleType(197-433)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (3)
hidden_size(265-265)num_experts(263-263)top_k(270-270)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
enable_pdl(220-220)
🔇 Additional comments (4)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (4)
70-228: Namespace split looks good.The introduction of
cutlass_kernels_ossnamespace and the cross-namespace type adapter references are correctly structured for the OSS/internal separation architecture.
258-524: Dispatch routing to OSS namespace is correct.The dispatcher correctly routes to
cutlass_kernels_osslaunchers while preserving the original namespace structure.
599-629: Review comment is incorrect; the two SM103 code paths handle different activation data types.The conditions at line 599 and lines 624-629 are not confusing separate paths—they represent mutually exclusive configurations:
- Line 599:
use_wfp4afp8(FP8 activations + FP4 weights) has no dedicated SM103 support and remaps to SM100- Line 624:
use_fp4(FP4 activations) has dedicated SM103 kernels but also includes SM100 configs as fallbackThese are different data types, not arbitrary branching logic. No clarification needed.
Likely an incorrect or invalid review comment.
631-666: The primary concern in this review is based on an incorrect understanding of the code.The check at line 656
TLLM_CHECK_WITH_INFO(!config.swap_ab, ...)is safe and will not fail. Here's why:At line 653,
swap_ab_configsis a fresh copy oftma_ws_configs. Thestd::transformat lines 654-659 reads from the source range (swap_ab_configs) and writes to the destination (viaback_inserterintotma_ws_configs). The lambda parameterauto& configreceives references from the source range only—swap_ab_configsis never modified during iteration. This is the standard and safe usage ofstd::transform.Regarding the empty config concern: The function can legitimately return empty for unsupported SM/architecture combinations. However,
getConfigs()(line 537-544) provides a fallback mechanism by combining results from bothgetTmaWarpSpecializedConfigs()andgetAmpereConfigs(), mitigating complete config loss for most scenarios.The suggested validation is reasonable defensive programming but not necessary given the fallback in the calling function.
Likely an incorrect or invalid review comment.
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (1)
48-56: Fix includes: add missing std headers; drop unused<mutex>.This file uses std::array, std::is_same_v, std::get, and std::cout but doesn't include the required headers;
<mutex>appears unused.Apply:
#include <cuda_fp16.h> #include <math.h> - -#include <mutex> +#include <array> +#include <tuple> +#include <type_traits> +#include <iostream>
🧹 Nitpick comments (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (2)
67-112: SM100 dispatch helper: header hygiene + C++ version assumption.Logic looks correct. Ensure we have the proper headers for std::array and std::is_same_v (see include fix). Also confirm build uses C++17+ since this relies on CTAD for std::array.
Please confirm the project enforces -std=c++17 or newer.
380-380: Misleading closing comment.This brace closes the function, not a namespace. Update the comment to avoid confusion.
-} // namespace tensorrt_llm +} // end dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h(11 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (4)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (2)
tensorrt_llm(70-227)tensorrt_llm(526-533)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (3)
tensorrt_llm(38-172)EpilogueFusion(176-333)FpXBlockScalingType(190-235)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h (2)
tensorrt_llm(27-119)isValidBlackwellMOESpecialisation(54-68)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.h (1)
tensorrt_llm(23-36)
🔇 Additional comments (7)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (7)
63-66: Namespace aliasing reads well.Re-exporting TmaWarpSpecializedGroupedGemmInput into the OSS namespace keeps call sites clean. LGTM.
141-147: Guarded SM103 error path is consistent.The compile-time gate and error message are clear and match the new arch split. LGTM.
224-231: SM100 cluster-shape constraint check is clear.Restricting to 1x1x1 and 2x1x1 for dynamic CGA is reasonable. LGTM.
239-246: SM103 FP4-only tile gate seems correct.The SM103 path constrained to FP4/FP4 with specific tiles matches traits. LGTM.
404-414: Default-case messages are helpful.Good diagnostics for undefined/heuristic tile configs. LGTM.
488-501: Workspace size reuse via dispatcher is neat.Reusing the dispatcher for WS calc with a dummy input is a good way to keep logic in sync. LGTM.
432-448: Code is correct—CutlassTileConfigSM103 is a type alias to SM100.The review comment assumes SM103 is a distinct enum class, but verification shows CutlassTileConfigSM103 is explicitly aliased to CutlassTileConfigSM100, with a comment stating "An alias to make the SHAPE_CASE macro work". When the macro expands
SHAPE_CASE(103, ...), the case labels resolve toCutlassTileConfigSM100enum values, matching the switch expression type. No type incompatibility exists, and no refactoring is needed.Likely an incorrect or invalid review comment.
| else { | ||
| #ifdef ENABLE_FP4 | ||
| auto getFunc = [&]() { | ||
| if constexpr (std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, __nv_fp4_e2m1>) { | ||
| TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type == | ||
| TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, | ||
| "MXFPX is the only supported scaling type for WFP4AFP8"); | ||
| return &kernels::cutlass_kernels::tma_warp_specialized_generic_moe_gemm_kernelLauncher< | ||
| Arch, T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape, true, | ||
| false>; | ||
| } else { | ||
| TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type != | ||
| TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, | ||
| "MXFPX is not supported for the selected weight combination"); | ||
| return &kernels::cutlass_kernels::tma_warp_specialized_generic_moe_gemm_kernelLauncher< | ||
| Arch, T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape, false, | ||
| false>; | ||
| } | ||
| }; | ||
| getFunc()(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); | ||
| #if defined(ENABLE_FP4) | ||
| constexpr static bool is_wfp4afp8 = | ||
| std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, __nv_fp4_e2m1>; | ||
| #else | ||
| TLLM_THROW("FP4 data type is not supported on this architecture and CUDA version"); | ||
| constexpr static bool is_wfp4afp8 = false; | ||
| #endif | ||
| if constexpr (is_wfp4afp8) { | ||
| TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type == | ||
| TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, | ||
| "MXFPX is the only supported scaling type for WFP4AFP8"); | ||
| } else { | ||
| TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type != | ||
| TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, | ||
| "MXFPX is not supported for the selected weight combination"); | ||
| } | ||
|
|
||
| if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { | ||
| bool const dynamic_cga = | ||
| gemm_config.dynamic_cluster_shape != cutlass_extensions::ClusterShape::Undefined; | ||
| bool const swap_ab = hopper_input.swap_ab; | ||
| auto cluster_shape = | ||
| cutlass_extensions::enum_to_shape_tuple(gemm_config.dynamic_cluster_shape); | ||
| auto cluster_shape_cute = cute::Shape<int32_t, int32_t, cute::_1>{ | ||
| std::get<0>(cluster_shape), std::get<1>(cluster_shape), cute::_1{}}; | ||
| auto cluster_shape_fallback = | ||
| cutlass_extensions::enum_to_shape_tuple(gemm_config.fallback_cluster_shape); | ||
| auto cluster_shape_cute_fallback = cute::Shape<int32_t, int32_t, cute::_1>{ | ||
| std::get<0>(cluster_shape_fallback), std::get<1>(cluster_shape_fallback), cute::_1{}}; | ||
|
|
||
| // HACK debug the gemm_config used to produce selected_func | ||
| std::cout << "[SM100 gemm_config] sm_version=" << gemm_config.sm_version | ||
| << ", tile_config_sm100=" << static_cast<int>(gemm_config.tile_config_sm100) | ||
| << ", epilogue_schedule=" << static_cast<int>(gemm_config.epilogue_schedule) | ||
| << ", dynamic_cluster_shape=" << static_cast<int>(gemm_config.dynamic_cluster_shape) | ||
| << ", fallback_cluster_shape=" | ||
| << static_cast<int>(gemm_config.fallback_cluster_shape) << std::endl; | ||
|
|
||
| auto selected_func = | ||
| getDispatchFunctionForSM100<Arch, T, WeightType, OutputType, EpilogueTag, FUSION, | ||
| TileShape, ClusterShape, is_wfp4afp8>( | ||
| gemm_config.epilogue_schedule, dynamic_cga, swap_ab); | ||
| selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, | ||
| workspace_size, cluster_shape_cute, cluster_shape_cute_fallback); | ||
| } else if constexpr (Arch::kMinComputeCapability >= 120 || Arch::kMinComputeCapability == 90) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove/guard std::cout debug in header; use project logging.
Printing from a header via std::cout is noisy, pulls in iostream globally, and hurts performance. Replace with TLLM_LOG_TRACE (gemm_config.toString()) or guard with a debug macro.
Apply:
- // HACK debug the gemm_config used to produce selected_func
- std::cout << "[SM100 gemm_config] sm_version=" << gemm_config.sm_version
- << ", tile_config_sm100=" << static_cast<int>(gemm_config.tile_config_sm100)
- << ", epilogue_schedule=" << static_cast<int>(gemm_config.epilogue_schedule)
- << ", dynamic_cluster_shape=" << static_cast<int>(gemm_config.dynamic_cluster_shape)
- << ", fallback_cluster_shape="
- << static_cast<int>(gemm_config.fallback_cluster_shape) << std::endl;
+ TLLM_LOG_TRACE("[SM100 gemm_config] %s", gemm_config.toString().c_str());📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| else { | |
| #ifdef ENABLE_FP4 | |
| auto getFunc = [&]() { | |
| if constexpr (std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, __nv_fp4_e2m1>) { | |
| TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type == | |
| TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, | |
| "MXFPX is the only supported scaling type for WFP4AFP8"); | |
| return &kernels::cutlass_kernels::tma_warp_specialized_generic_moe_gemm_kernelLauncher< | |
| Arch, T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape, true, | |
| false>; | |
| } else { | |
| TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type != | |
| TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, | |
| "MXFPX is not supported for the selected weight combination"); | |
| return &kernels::cutlass_kernels::tma_warp_specialized_generic_moe_gemm_kernelLauncher< | |
| Arch, T, WeightType, OutputType, EpilogueTag, FUSION, TileShape, ClusterShape, false, | |
| false>; | |
| } | |
| }; | |
| getFunc()(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); | |
| #if defined(ENABLE_FP4) | |
| constexpr static bool is_wfp4afp8 = | |
| std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, __nv_fp4_e2m1>; | |
| #else | |
| TLLM_THROW("FP4 data type is not supported on this architecture and CUDA version"); | |
| constexpr static bool is_wfp4afp8 = false; | |
| #endif | |
| if constexpr (is_wfp4afp8) { | |
| TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type == | |
| TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, | |
| "MXFPX is the only supported scaling type for WFP4AFP8"); | |
| } else { | |
| TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type != | |
| TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, | |
| "MXFPX is not supported for the selected weight combination"); | |
| } | |
| if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { | |
| bool const dynamic_cga = | |
| gemm_config.dynamic_cluster_shape != cutlass_extensions::ClusterShape::Undefined; | |
| bool const swap_ab = hopper_input.swap_ab; | |
| auto cluster_shape = | |
| cutlass_extensions::enum_to_shape_tuple(gemm_config.dynamic_cluster_shape); | |
| auto cluster_shape_cute = cute::Shape<int32_t, int32_t, cute::_1>{ | |
| std::get<0>(cluster_shape), std::get<1>(cluster_shape), cute::_1{}}; | |
| auto cluster_shape_fallback = | |
| cutlass_extensions::enum_to_shape_tuple(gemm_config.fallback_cluster_shape); | |
| auto cluster_shape_cute_fallback = cute::Shape<int32_t, int32_t, cute::_1>{ | |
| std::get<0>(cluster_shape_fallback), std::get<1>(cluster_shape_fallback), cute::_1{}}; | |
| // HACK debug the gemm_config used to produce selected_func | |
| std::cout << "[SM100 gemm_config] sm_version=" << gemm_config.sm_version | |
| << ", tile_config_sm100=" << static_cast<int>(gemm_config.tile_config_sm100) | |
| << ", epilogue_schedule=" << static_cast<int>(gemm_config.epilogue_schedule) | |
| << ", dynamic_cluster_shape=" << static_cast<int>(gemm_config.dynamic_cluster_shape) | |
| << ", fallback_cluster_shape=" | |
| << static_cast<int>(gemm_config.fallback_cluster_shape) << std::endl; | |
| auto selected_func = | |
| getDispatchFunctionForSM100<Arch, T, WeightType, OutputType, EpilogueTag, FUSION, | |
| TileShape, ClusterShape, is_wfp4afp8>( | |
| gemm_config.epilogue_schedule, dynamic_cga, swap_ab); | |
| selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, | |
| workspace_size, cluster_shape_cute, cluster_shape_cute_fallback); | |
| } else if constexpr (Arch::kMinComputeCapability >= 120 || Arch::kMinComputeCapability == 90) { | |
| else { | |
| #if defined(ENABLE_FP4) | |
| constexpr static bool is_wfp4afp8 = | |
| std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, __nv_fp4_e2m1>; | |
| #else | |
| constexpr static bool is_wfp4afp8 = false; | |
| #endif | |
| if constexpr (is_wfp4afp8) { | |
| TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type == | |
| TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, | |
| "MXFPX is the only supported scaling type for WFP4AFP8"); | |
| } else { | |
| TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type != | |
| TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, | |
| "MXFPX is not supported for the selected weight combination"); | |
| } | |
| if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { | |
| bool const dynamic_cga = | |
| gemm_config.dynamic_cluster_shape != cutlass_extensions::ClusterShape::Undefined; | |
| bool const swap_ab = hopper_input.swap_ab; | |
| auto cluster_shape = | |
| cutlass_extensions::enum_to_shape_tuple(gemm_config.dynamic_cluster_shape); | |
| auto cluster_shape_cute = cute::Shape<int32_t, int32_t, cute::_1>{ | |
| std::get<0>(cluster_shape), std::get<1>(cluster_shape), cute::_1{}}; | |
| auto cluster_shape_fallback = | |
| cutlass_extensions::enum_to_shape_tuple(gemm_config.fallback_cluster_shape); | |
| auto cluster_shape_cute_fallback = cute::Shape<int32_t, int32_t, cute::_1>{ | |
| std::get<0>(cluster_shape_fallback), std::get<1>(cluster_shape_fallback), cute::_1{}}; | |
| TLLM_LOG_TRACE("[SM100 gemm_config] %s", gemm_config.toString().c_str()); | |
| auto selected_func = | |
| getDispatchFunctionForSM100<Arch, T, WeightType, OutputType, EpilogueTag, FUSION, | |
| TileShape, ClusterShape, is_wfp4afp8>( | |
| gemm_config.epilogue_schedule, dynamic_cga, swap_ab); | |
| selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, | |
| workspace_size, cluster_shape_cute, cluster_shape_cute_fallback); | |
| } else if constexpr (Arch::kMinComputeCapability >= 120 || Arch::kMinComputeCapability == 90) { |
|
[FAILED] Pipeline #37018845: 1/17 passed |
📌 Description
Upgraded cutlass moe kernel launcher to match trtllm. This facilitates further feature upgrades (perhaps separate PRs)
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
This currently compiles. still need testing but may be sufficient to unblock other work
once tested ok i'll remove draft status
Summary by CodeRabbit
New Features
Improvements
Bug Fixes