-
Notifications
You must be signed in to change notification settings - Fork 1.8k
MXFP4 x BF16 CUTLASS MoE backend perf and profiling improvement on Hopper #8721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -304,27 +304,30 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm90(CutlassGemmConfig::Can | |||||||||||||||||||||||||||||||||
| if (has_w4afp8) | ||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||
| bool const has_coop_supported = sm90_supports_coop(tile_config); | ||||||||||||||||||||||||||||||||||
| std::set<MainloopScheduleType> mainloop_schedules{MainloopScheduleType::PINGPONG}; | ||||||||||||||||||||||||||||||||||
| if (has_coop_supported) | ||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||
| mainloop_schedules.insert(MainloopScheduleType::COOPERATIVE); | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // It seems that ping-pong scheduler will never be selected. | ||||||||||||||||||||||||||||||||||
| // To shorten the tactic time, remove all alternative options involving ping-pong scheduler. | ||||||||||||||||||||||||||||||||||
| if (!has_coop_supported) | ||||||||||||||||||||||||||||||||||
| continue; | ||||||||||||||||||||||||||||||||||
| // Due to the limitation on the number of registers on SM, | ||||||||||||||||||||||||||||||||||
| // cooperative scheduler does not support CtaShape128x128x128B. | ||||||||||||||||||||||||||||||||||
| if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B) | ||||||||||||||||||||||||||||||||||
| continue; | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+310
to
+315
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Add braces around single-statement if bodies. The coding guidelines require that if statements always be followed by brace-delimited statements. Both continue statements lack the required braces. As per coding guidelines. Apply this diff to add braces: - if (!has_coop_supported)
- continue;
+ if (!has_coop_supported)
+ {
+ continue;
+ }
// Due to the limitation on the number of registers on SM,
// cooperative scheduler does not support CtaShape128x128x128B.
- if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
- continue;
+ if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
+ {
+ continue;
+ }📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How much performance are we leaving on the table here? Is there a way to reduce the number of stages or otherwise relieve register pressure |
||||||||||||||||||||||||||||||||||
| MainloopScheduleType mainloop_schedule = MainloopScheduleType::COOPERATIVE; | ||||||||||||||||||||||||||||||||||
| auto const epilogue_schedule = EpilogueScheduleType::AUTO; | ||||||||||||||||||||||||||||||||||
| for (auto const& mainloop_schedule : mainloop_schedules) | ||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||
| CutlassGemmConfig candidate( | ||||||||||||||||||||||||||||||||||
| tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x1x1); | ||||||||||||||||||||||||||||||||||
| candidate_configs.push_back(candidate); | ||||||||||||||||||||||||||||||||||
| candidate = CutlassGemmConfig( | ||||||||||||||||||||||||||||||||||
| tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x1x1); | ||||||||||||||||||||||||||||||||||
| candidate_configs.push_back(candidate); | ||||||||||||||||||||||||||||||||||
| candidate = CutlassGemmConfig( | ||||||||||||||||||||||||||||||||||
| tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x2x1); | ||||||||||||||||||||||||||||||||||
| candidate_configs.push_back(candidate); | ||||||||||||||||||||||||||||||||||
| candidate = CutlassGemmConfig( | ||||||||||||||||||||||||||||||||||
| tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x2x1); | ||||||||||||||||||||||||||||||||||
| candidate_configs.push_back(candidate); | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| CutlassGemmConfig candidate( | ||||||||||||||||||||||||||||||||||
| tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x1x1); | ||||||||||||||||||||||||||||||||||
| candidate_configs.push_back(candidate); | ||||||||||||||||||||||||||||||||||
| candidate = CutlassGemmConfig( | ||||||||||||||||||||||||||||||||||
| tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x1x1); | ||||||||||||||||||||||||||||||||||
| candidate_configs.push_back(candidate); | ||||||||||||||||||||||||||||||||||
| candidate = CutlassGemmConfig( | ||||||||||||||||||||||||||||||||||
| tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x2x1); | ||||||||||||||||||||||||||||||||||
| candidate_configs.push_back(candidate); | ||||||||||||||||||||||||||||||||||
| candidate = CutlassGemmConfig( | ||||||||||||||||||||||||||||||||||
| tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x2x1); | ||||||||||||||||||||||||||||||||||
| candidate_configs.push_back(candidate); | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,55 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /* | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * You may obtain a copy of the License at | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1
to
+15
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update copyright year to include 2025. Header stops at 2023; repository guidelines require current year on source files. As per coding guidelines. - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "moe_gemm_mixed_utils.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace tensorrt_llm::kernels::cutlass_kernels | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| __global__ void interleave_fp4_for_Hopper_mixed_gemm_kernel( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int block_id = blockIdx.x; block_id < rows / 2; block_id += gridDim.x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int col_id = threadIdx.x; col_id < cols / 2; col_id += blockDim.x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int row_id = block_id / 8 * 16 + block_id % 8; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int index_a = row_id * cols / 2 + col_id; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int index_b = (row_id + 8) * cols / 2 + col_id; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
StudyingShao marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint8_t fp4x2_a = weight[index_a]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint8_t fp4x2_b = weight[index_b]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint8_t fp4_temp_a = (fp4x2_a & 0xF0U) >> 4; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint8_t fp4_temp_b = (fp4x2_b & 0x0FU) << 4; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fp4x2_a = (fp4x2_a & 0x0FU) | fp4_temp_b; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fp4x2_b = (fp4x2_b & 0xF0U) | fp4_temp_a; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_interleaved[index_a] = fp4x2_a; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_interleaved[index_b] = fp4x2_b; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void interleave_fp4_for_Hopper_mixed_gemm(uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // column-major input | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| interleave_fp4_for_Hopper_mixed_gemm_kernel<<<1024, 1024>>>(weight, weight_interleaved, rows, cols); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } // namespace tensorrt_llm::kernels::cutlass_kernels | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| /* | ||
| * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <cstdint> | ||
|
|
||
| namespace tensorrt_llm::kernels::cutlass_kernels | ||
| { | ||
|
|
||
| void interleave_fp4_for_Hopper_mixed_gemm(uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols); | ||
|
|
||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What models have you tested this with? I am hesitant to remove this without a comprehensive sweep of multiple model architectures like Mixtral, DeepSeek, Llama4 and GPT-OSS. Its hard to say what the next DeepSeek moment will look like.
I also dont think tactic selection time is actually a significant concern. There are lots of tactics sure, but weight loading is usually just as long. Maybe we should add a fast profile mode that users can opt into