-
Couldn't load subscription status.
- Fork 542
Feature: Add support for L40 FusedMoE in cutlass path #1973
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,7 +32,8 @@ template <typename T, typename WeightType, | |
| TmaWarpSpecializedGroupedGemmInput::EpilogueFusion Fusion = | ||
| TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE> | ||
| constexpr bool isValidSM120MOESpecialisation() { | ||
| #if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) // TODO Is there a better choice | ||
| #if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) && \ | ||
| defined(ENABLE_FP4) // TODO Is there a better choice | ||
| return cutlass::platform::is_same<T, __nv_fp4_e2m1>::value && | ||
| cutlass::platform::is_same<T, WeightType>::value && | ||
| cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value && | ||
|
|
@@ -47,7 +48,8 @@ template <typename T, typename WeightType, | |
| TmaWarpSpecializedGroupedGemmInput::EpilogueFusion Fusion = | ||
| TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE> | ||
| constexpr bool isValidBlackwellMOESpecialisation() { | ||
| #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) // TODO Is there a better choice | ||
| #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \ | ||
| defined(ENABLE_FP4) // TODO Is there a better choice | ||
|
Comment on lines
+51
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Overly restrictive guard blocks non-FP4 Blackwell specializations. Requiring The Hopper implementation (lines 68-89) uses a better pattern: conditional compilation guards only the FP4-specific parts inside the return statement, allowing non-FP4 specializations to work without Apply this diff to follow the Hopper pattern: -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \
- defined(ENABLE_FP4) // TODO Is there a better choice
+#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
return (cutlass::platform::is_same<T, WeightType>::value ||
+#ifdef ENABLE_FP4
(cutlass::platform::is_same<T, __nv_fp8_e4m3>::value &&
- cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value)) &&
+ cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value) ||
+#else
+ false ||
+#endif
+ false) &&
cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value &&
Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
🤖 Prompt for AI AgentsThere was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue here is: Which I think isn't defined unless ENABLE_FP4 is defined. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @amirkl94, you're absolutely right that #ifdef ENABLE_FP4
(cutlass::platform::is_same<T, __nv_fp8_e4m3>::value &&
cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value) ||
#endifThe key insight is that the first condition in the return statement: cutlass::platform::is_same<T, WeightType>::valuedoesn't reference By moving the top-level Let me know if you'd like me to refine the suggestion further! |
||
| return (cutlass::platform::is_same<T, WeightType>::value || | ||
| (cutlass::platform::is_same<T, __nv_fp8_e4m3>::value && | ||
| cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value)) && | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,7 +18,13 @@ | |||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| from . import env as jit_env | ||||||||||||||||||||||||||||
| from ..artifacts import ArtifactPath, CheckSumHash | ||||||||||||||||||||||||||||
| from .core import JitSpec, gen_jit_spec, current_compilation_context, sm90a_nvcc_flags | ||||||||||||||||||||||||||||
| from .core import ( | ||||||||||||||||||||||||||||
| JitSpec, | ||||||||||||||||||||||||||||
| gen_jit_spec, | ||||||||||||||||||||||||||||
| current_compilation_context, | ||||||||||||||||||||||||||||
| sm90a_nvcc_flags, | ||||||||||||||||||||||||||||
| sm89_nvcc_flags, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
| from .cpp_ext import is_cuda_version_at_least | ||||||||||||||||||||||||||||
| from .cubin_loader import get_cubin, get_meta_hash | ||||||||||||||||||||||||||||
| from .gemm.cutlass.generate_kernels import generate_gemm_operations | ||||||||||||||||||||||||||||
|
|
@@ -71,6 +77,16 @@ def gen_cutlass_fused_moe_sm90_module(use_fast_build: bool = False) -> JitSpec: | |||||||||||||||||||||||||||
| return gen_cutlass_fused_moe_module(nvcc_flags, "90", use_fast_build) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def gen_cutlass_fused_moe_sm89_module(use_fast_build: bool = False) -> JitSpec: | ||||||||||||||||||||||||||||
| nvcc_flags = sm89_nvcc_flags + [ | ||||||||||||||||||||||||||||
| "-DENABLE_BF16", | ||||||||||||||||||||||||||||
| "-DENABLE_FP8", | ||||||||||||||||||||||||||||
| "-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "", | ||||||||||||||||||||||||||||
| "-DUSING_OSS_CUTLASS_MOE_GEMM", | ||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||
|
Comment on lines
+81
to
+86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Conditionally adding an empty string to the list of compiler flags can be slightly confusing and less clean. It's better to conditionally append the flag to the list to avoid adding empty elements. This improves code readability and maintainability.
Suggested change
|
||||||||||||||||||||||||||||
| return gen_cutlass_fused_moe_module(nvcc_flags, "89", use_fast_build) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def gen_cutlass_fused_moe_module( | ||||||||||||||||||||||||||||
| nvcc_flags: List[str], device_arch: str, use_fast_build: bool = False | ||||||||||||||||||||||||||||
| ) -> JitSpec: | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any SM89 GPUs that can support the
CtaShape16x256x128_WarpShape16x64x128, or is this an SM120 addition?Having a small M value like this helps with low latency cases, so I'd want to understand why its not supported before disabling it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the very least can you leave a comment saying what the difference between the two lists are, so people dont have to manually compare the items
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a removal from the sm89 path. When I tested it on an
L40GPU I gotAssertion failed: GPU lacks the shared memory resources to run GroupedGEMM kernel.It might be that on other sm89 GPUs it will pass, the main issue is that this was the default tactic that was chosen when trying to use
FusedMoE, I believe that moving it to be the last will also fix my issue.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried moving this tile config to be the last one and now the default tactic won't fail on l40. The issue is that if
autotuneris on then the tactics that use this tile config will report an error with a stacktrace which looks bad.@yzh119 do you think it'll be ok to change the errors that happen in the autotuner to be debug logs? Otherwise it means users will get spammed with error messages when they run autotuning on L40 FusedMoE.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the autotuner should still output warnings, but just make them say "Skipping tactic x due to error. This tactic may not be supported by the current GPU architecture".
That said I know there is a difference of opinion on whether we should proactively filter them as you have done here, the argument being that we should be able to do the due diligence to determine what tactics are supported so that we can raise an error when a tactic fails when it shouldn't. So I can see either side.