Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,34 @@ __device__ __inline__ __nv_bf16x8_storage_t psx_cvt_lut_prmt_fp4x8_to_bf16x8(con
return bf16x8_raw;
}

__device__ __inline__ __nv_bf16x8_storage_t psx_cvt_lut_prmt_fp4x8_to_bf16x8_interleaved(
const __nv_fp4x8_storage_t fp4x8)
{
// interleaved version
// input fp4x8: 7564 3120
// output bf16x8: 7654 3210

__nv_bf16x8_storage_t bf16x8_raw;
__nv_bf16x2_storage_t* bf16x2_raw = reinterpret_cast<__nv_bf16x2_storage_t*>(&bf16x8_raw);

__nv_fp8x4_storage_t h_fp8x4_0to1_bits = (fp4x8 & 0xC0C0C0C0U) >> 6; // 7632
__nv_fp8x4_storage_t l_fp8x4_0to1_bits = (fp4x8 & 0x0C0C0C0CU) >> 2; // 5410

unsigned h4b_em_fp4x4 = (fp4x8 & 0x77770000U) >> 16U;
unsigned l4b_em_fp4x4 = (fp4x8 & 0x00007777U);

__nv_fp8x4_storage_t h4b_2to9_bits = cvt_lut_bf16(h4b_em_fp4x4); // 7564
__nv_fp8x4_storage_t l4b_2to9_bits = cvt_lut_bf16(l4b_em_fp4x4); // 3120

bf16x2_raw[0] = prmt(l_fp8x4_0to1_bits, l4b_2to9_bits, 0x5240U) << 6U; // 1 0
bf16x2_raw[1] = prmt(h_fp8x4_0to1_bits, l4b_2to9_bits, 0x5341U) << 6U; // 3 2

bf16x2_raw[2] = prmt(l_fp8x4_0to1_bits, h4b_2to9_bits, 0x7260U) << 6U; // 5 4
bf16x2_raw[3] = prmt(h_fp8x4_0to1_bits, h4b_2to9_bits, 0x7361U) << 6U; // 7 6

return bf16x8_raw;
}

template <class Collective>
struct MixedGroupedGemmInputUtils
{
Expand Down Expand Up @@ -330,7 +358,7 @@ struct MixedGroupedGemmInputUtils
auto&& src_ = cute::recast<__nv_fp4x8_storage_t>(src)(0);
auto&& dst_ = cute::recast<__nv_bf16x8_storage_t>(dst)(0);

dst_ = psx_cvt_lut_prmt_fp4x8_to_bf16x8(src_);
dst_ = psx_cvt_lut_prmt_fp4x8_to_bf16x8_interleaved(src_);
}

/// Utilities to dequantize A.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,28 @@ struct CollectiveMmaArrayMixedInput<
}

/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

// Override the FP8 conversion in CUTLASS to enforce the intended compiler behavior.
template <class T>
CUTLASS_DEVICE float scale_convertor(T scale)
{
if constexpr (cute::is_same_v<ElementA, cutlass::float_e2m1_t>)
{

cutlass::float_ue8m0_t scale_ue8m0 = scale;

uint32_t temp = 0;
temp = (temp | *reinterpret_cast<uint8_t*>(&scale_ue8m0)) << 23;
return *reinterpret_cast<float*>(&temp);
}
else
{
return static_cast<float>(scale);
}
}

/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <class FrgTensorC>
Expand Down Expand Up @@ -1084,12 +1106,12 @@ struct CollectiveMmaArrayMixedInput<
if (chunk_id_ == 0)
{
accum(accum_coord) = intermediate_array[chunk_id_](accum_coord)
* static_cast<float>(tCrS(scale_coord)[0]);
* scale_convertor(tCrS(scale_coord)[0]);
}
else
{
accum(accum_coord) = fma(intermediate_array[chunk_id_](accum_coord),
static_cast<float>(tCrS(scale_coord)[chunk_id_]), accum(accum_coord));
scale_convertor(tCrS(scale_coord)[chunk_id_]), accum(accum_coord));
}
}
}
Expand Down Expand Up @@ -1186,7 +1208,7 @@ struct CollectiveMmaArrayMixedInput<
auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0);

accum(accum_coord) = fma(intermediate_array[chunk_id_](accum_coord),
static_cast<float>(tCrS(scale_coord)[chunk_id_]), accum(accum_coord));
scale_convertor(tCrS(scale_coord)[chunk_id_]), accum(accum_coord));
}
}
}
Expand Down Expand Up @@ -1275,7 +1297,7 @@ struct CollectiveMmaArrayMixedInput<
int scale_idx = k_block / NumMMAsPerChunk;

accum(accum_coord) = fma(intermediate(accum_coord),
static_cast<float>(tCrS(scale_coord)[scale_idx]), accum(accum_coord));
scale_convertor(tCrS(scale_coord)[scale_idx]), accum(accum_coord));
}
}
}
Expand Down
43 changes: 23 additions & 20 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,27 +304,30 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm90(CutlassGemmConfig::Can
if (has_w4afp8)
{
bool const has_coop_supported = sm90_supports_coop(tile_config);
std::set<MainloopScheduleType> mainloop_schedules{MainloopScheduleType::PINGPONG};
if (has_coop_supported)
{
mainloop_schedules.insert(MainloopScheduleType::COOPERATIVE);
}

// It seems that ping-pong scheduler will never be selected.
// To shorten the tactic time, remove all alternative options involving ping-pong scheduler.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What models have you tested this with? I am hesitant to remove this without a comprehensive sweep of multiple model architectures like Mixtral, DeepSeek, Llama4 and GPT-OSS. Its hard to say what the next DeepSeek moment will look like.
I also dont think tactic selection time is actually a significant concern. There are lots of tactics sure, but weight loading is usually just as long. Maybe we should add a fast profile mode that users can opt into

if (!has_coop_supported)
continue;
// Due to the limitation on the number of registers on SM,
// cooperative scheduler does not support CtaShape128x128x128B.
if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
continue;
Comment on lines +310 to +315
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add braces around single-statement if bodies.

The coding guidelines require that if statements always be followed by brace-delimited statements. Both continue statements lack the required braces.

As per coding guidelines.

Apply this diff to add braces:

-            if (!has_coop_supported)
-                continue;
+            if (!has_coop_supported)
+            {
+                continue;
+            }
             // Due to the limitation on the number of registers on SM,
             // cooperative scheduler does not support CtaShape128x128x128B.
-            if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
-                continue;
+            if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
+            {
+                continue;
+            }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (!has_coop_supported)
continue;
// Due to the limitation on the number of registers on SM,
// cooperative scheduler does not support CtaShape128x128x128B.
if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
continue;
if (!has_coop_supported)
{
continue;
}
// Due to the limitation on the number of registers on SM,
// cooperative scheduler does not support CtaShape128x128x128B.
if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
{
continue;
}
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp around lines
310 to 315, two if statements use single-statement bodies with continue and must
be converted to brace-delimited blocks per the coding guidelines; change each
`if (condition) continue;` to `if (condition) { continue; }`, preserving
existing indentation and spacing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much performance are we leaving on the table here? Is there a way to reduce the number of stages or otherwise relieve register pressure

MainloopScheduleType mainloop_schedule = MainloopScheduleType::COOPERATIVE;
auto const epilogue_schedule = EpilogueScheduleType::AUTO;
for (auto const& mainloop_schedule : mainloop_schedules)
{
CutlassGemmConfig candidate(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x1x1);
candidate_configs.push_back(candidate);
candidate = CutlassGemmConfig(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x1x1);
candidate_configs.push_back(candidate);
candidate = CutlassGemmConfig(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x2x1);
candidate_configs.push_back(candidate);
candidate = CutlassGemmConfig(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x2x1);
candidate_configs.push_back(candidate);
}

CutlassGemmConfig candidate(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x1x1);
candidate_configs.push_back(candidate);
candidate = CutlassGemmConfig(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x1x1);
candidate_configs.push_back(candidate);
candidate = CutlassGemmConfig(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x2x1);
candidate_configs.push_back(candidate);
candidate = CutlassGemmConfig(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x2x1);
candidate_configs.push_back(candidate);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
Comment on lines +1 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Update copyright year to include 2025.

Header stops at 2023; repository guidelines require current year on source files. As per coding guidelines.

- * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2020-2025, NVIDIA CORPORATION.  All rights reserved.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu
around lines 1 to 15, the copyright header ends with "2023" but repository
guidelines require the current year; update the copyright range to include 2025
(e.g., "2020-2025" or "2020-2023, 2025" per project convention) and ensure the
license block formatting is preserved exactly as before.


#include "moe_gemm_mixed_utils.h"

namespace tensorrt_llm::kernels::cutlass_kernels
{

__global__ void interleave_fp4_for_Hopper_mixed_gemm_kernel(
uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols)
{
for (int block_id = blockIdx.x; block_id < rows / 2; block_id += gridDim.x)
{
for (int col_id = threadIdx.x; col_id < cols / 2; col_id += blockDim.x)
{
int row_id = block_id / 8 * 16 + block_id % 8;

int index_a = row_id * cols / 2 + col_id;
int index_b = (row_id + 8) * cols / 2 + col_id;

uint8_t fp4x2_a = weight[index_a];
uint8_t fp4x2_b = weight[index_b];

uint8_t fp4_temp_a = (fp4x2_a & 0xF0U) >> 4;
uint8_t fp4_temp_b = (fp4x2_b & 0x0FU) << 4;

fp4x2_a = (fp4x2_a & 0x0FU) | fp4_temp_b;
fp4x2_b = (fp4x2_b & 0xF0U) | fp4_temp_a;

weight_interleaved[index_a] = fp4x2_a;
weight_interleaved[index_b] = fp4x2_b;
}
}
}

void interleave_fp4_for_Hopper_mixed_gemm(uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols)
{
// column-major input
interleave_fp4_for_Hopper_mixed_gemm_kernel<<<1024, 1024>>>(weight, weight_interleaved, rows, cols);
}

} // namespace tensorrt_llm::kernels::cutlass_kernels
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cstdint>

namespace tensorrt_llm::kernels::cutlass_kernels
{

void interleave_fp4_for_Hopper_mixed_gemm(uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols);

}
29 changes: 29 additions & 0 deletions cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h"
#include "tensorrt_llm/thop/thUtils.h"

#if defined(TORCH_VERSION_MAJOR) \
Expand Down Expand Up @@ -398,6 +399,31 @@ Tensor mxfp4_dequantize_unswizzled(Tensor weight, Tensor scale, int64_t group_si
return dequant_weight;
}

Tensor fp4_interleave_for_Hopper_mixed_gemm(Tensor weight)
{
// weight (n, k / 2)
int const n = weight.size(0);
int const k = weight.size(1) * 2;

CHECK_TH_CUDA(weight);
CHECK_CONTIGUOUS(weight);

TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor");
TORCH_CHECK(weight.dtype() == torch::kUInt8, "Weight must be a packed uint8 tensor");
TORCH_CHECK(n % 16 == 0)
TORCH_CHECK(k % 16 == 0)

Tensor weight_interleaved
= torch::empty({n, k / 2}, torch::dtype(torch::kUInt8).device(torch::kCUDA).requires_grad(false));

uint8_t* weight_ptr = get_ptr<uint8_t>(weight);
uint8_t* weight_interleaved_ptr = get_ptr<uint8_t>(weight_interleaved);

interleave_fp4_for_Hopper_mixed_gemm(weight_ptr, weight_interleaved_ptr, n, k);

return weight_interleaved;
}

} // namespace torch_ext

// Utility methods that may be useful for preprocessing weights in torch.
Expand Down Expand Up @@ -432,3 +458,6 @@ static auto subbyte_transpose = torch::RegisterOperators("trtllm::_subbyte_trans

static auto mxfp4_dequantize_unswizzled
= torch::RegisterOperators("trtllm::mxfp4_dequantize_unswizzled", &torch_ext::mxfp4_dequantize_unswizzled);

static auto fp4_interleave_for_Hopper_mixed_gemm = torch::RegisterOperators(
"trtllm::fp4_interleave_for_Hopper_mixed_gemm", &torch_ext::fp4_interleave_for_Hopper_mixed_gemm);
11 changes: 11 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,9 +1390,11 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,
pad_size_inter = module.intermediate_size_per_partition - w3_weight_shard.shape[
0]
if w3_weight_shard.ndim == 2:
# [intermediate_size, hidden_size]
pad_size_hidden = module.hidden_size // 2 - w3_weight_shard.shape[1]
pad_shape = (0, pad_size_hidden, 0, pad_size_inter)
elif w3_weight_shard.ndim == 1:
# [intermediate_size]
pad_shape = (0, pad_size_inter)
else:
raise NotImplementedError(
Expand All @@ -1404,6 +1406,10 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,

w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], dim=0)

if w3_weight_shard.ndim == 2:
w31_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(
w31_weight_shard)

dst_w3_w1_weight.copy_(w31_weight_shard.view(dst_w3_w1_weight.dtype),
non_blocking=True)

Expand Down Expand Up @@ -1433,6 +1439,11 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
f"Invalid shape of w2_weight_shard {w2_weight_shard.shape}")

w2_weight_shard = torch.nn.functional.pad(w2_weight_shard, pad_shape)

if w2_weight_shard.ndim == 2:
w2_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(
w2_weight_shard)

dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype),
non_blocking=True)

Expand Down
Loading