Skip to content

Commit 3945647

Browse files
committed
feat: Add MXFP8xMXFP4 gemm support
Signed-off-by: Daniel Stokes <[email protected]>
1 parent 36b87b8 commit 3945647

27 files changed

+735
-459
lines changed

cpp/micro_benchmarks/gen-moe-benchmark-file.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,15 @@ def populate_benchmark_config(**kwargs):
5454

5555

5656
# Default Mixtral configurations
57-
num_experts = 8
58-
k = 2
57+
num_experts = 256
58+
k = 8
5959
hidden_size = 4096
60-
inter_size = 14336
61-
tp_size = 4
60+
inter_size = 2048
61+
tp_size = 8
6262
ep_size = 1
6363
world_rank = 0
6464
act_fn = 3
65-
dtype_string = make_dtype_string() # All dtypes
65+
dtype_string = make_dtype_string(["fp4", "wfp4afp8"]) # All dtypes
6666
routing_string = make_routing_string(
6767
name="uniform",
6868
is_distribution=True) # Use the default uniform random distribution

cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -298,14 +298,21 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
298298
using WeightType = typename TypeTuple_::WeightType;
299299
using OutputType = typename TypeTuple_::OutputType;
300300
constexpr static bool INT4 = std::is_same_v<WeightType, cutlass::uint4b_t>;
301-
constexpr static bool FP4 = std::is_same_v<DataType, SafeFP4>;
302-
constexpr static bool FP8 = std::is_same_v<DataType, SafeFP8>;
303-
constexpr static bool INT_QUANT = !std::is_same_v<DataType, WeightType>;
304-
using InputType = std::conditional_t<FP4, OutputType, DataType>;
305-
using WeightStorage = std::conditional_t<INT_QUANT || FP4, uint8_t, WeightType>;
306-
constexpr static int WEIGHT_ELEM_PER_BYTE = (INT4 || FP4) ? 2 : 1;
301+
constexpr static bool NVFP4 = std::is_same_v<DataType, SafeFP4> && std::is_same_v<WeightType, SafeFP4>;
302+
constexpr static bool FP8 = std::is_same_v<DataType, SafeFP8> && std::is_same_v<WeightType, SafeFP8>;
303+
constexpr static bool WFP4AFP8 = std::is_same_v<WeightType, SafeFP4> && std::is_same_v<DataType, SafeFP8>;
304+
constexpr static bool INT_QUANT = !std::is_same_v<DataType, WeightType>
305+
&& (std::is_same_v<WeightType, cutlass::uint4b_t> || std::is_same_v<WeightType, uint8_t>);
306+
constexpr static bool ANY_FP4 = NVFP4 || WFP4AFP8;
307+
using InputType = std::conditional_t<NVFP4, OutputType, DataType>;
308+
using WeightStorage = std::conditional_t<INT_QUANT || ANY_FP4, uint8_t, WeightType>;
309+
constexpr static int WEIGHT_ELEM_PER_BYTE = (INT4 || ANY_FP4) ? 2 : 1;
307310
int const BASE_HIDDEN_SIZE = 64 / sizeof(WeightType) * WEIGHT_ELEM_PER_BYTE;
308311

312+
constexpr static int64_t FP4_VECTOR_SIZE = NVFP4
313+
? tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize
314+
: tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize;
315+
309316
std::vector<BufferManager::IBufferPtr> managed_buffers;
310317
int* mSelectedExperts{};
311318
DataType* mInputTensor{};
@@ -316,12 +323,12 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
316323

317324
constexpr static nvinfer1::DataType toDTypeID()
318325
{
319-
if (FP8)
326+
if (FP8 || WFP4AFP8)
320327
return nvinfer1::DataType::kFP8;
321-
if (FP4)
328+
if (NVFP4)
322329
return nvinfer1::DataType::kFP4;
323330
if (INT_QUANT && INT4)
324-
return nvinfer1::DataType::kINT4; // Hack to distinguish int4, use unsigned
331+
return nvinfer1::DataType::kINT4;
325332
if (INT_QUANT)
326333
return nvinfer1::DataType::kINT8;
327334
if (std::is_same_v<DataType, float>)
@@ -331,9 +338,29 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
331338
#ifdef ENABLE_BF16
332339
if (std::is_same_v<DataType, nv_bfloat16>)
333340
return nvinfer1::DataType::kBF16;
334-
#else
341+
#endif
335342
TLLM_THROW("Unrecognised format");
343+
};
344+
345+
constexpr static nvinfer1::DataType toWTypeID()
346+
{
347+
if (FP8)
348+
return nvinfer1::DataType::kFP8;
349+
if (NVFP4 || WFP4AFP8)
350+
return nvinfer1::DataType::kFP4;
351+
if (INT_QUANT && INT4)
352+
return nvinfer1::DataType::kINT4;
353+
if (INT_QUANT)
354+
return nvinfer1::DataType::kINT8;
355+
if (std::is_same_v<DataType, float>)
356+
return nvinfer1::DataType::kFLOAT;
357+
if (std::is_same_v<DataType, half>)
358+
return nvinfer1::DataType::kHALF;
359+
#ifdef ENABLE_BF16
360+
if (std::is_same_v<DataType, nv_bfloat16>)
361+
return nvinfer1::DataType::kBF16;
336362
#endif
363+
TLLM_THROW("Unrecognised format");
337364
};
338365

339366
template <class T>
@@ -345,7 +372,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
345372
}
346373
else if constexpr (std::is_same_v<T, SafeFP4>)
347374
{
348-
return nvinfer1::DataType::kINT64;
375+
return nvinfer1::DataType::kFP4;
349376
}
350377
else if constexpr (std::is_same_v<T, uint8_t>)
351378
{
@@ -380,10 +407,10 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
380407
static_assert(!FP8, "FP8 Tests enabled on unsupported CUDA version");
381408
#endif
382409
#ifndef ENABLE_FP4
383-
static_assert(!FP4, "FP4 Tests enabled on unsupported CUDA version");
410+
static_assert(!ANY_FP4, "FP4 Tests enabled on unsupported CUDA version");
384411
#endif
385412
bool should_skip_unsupported_fp8 = getSMVersion() < 89 && FP8;
386-
bool should_skip_unsupported_fp4 = (getSMVersion() < 100 || getSMVersion() >= 120) && FP4;
413+
bool should_skip_unsupported_fp4 = (getSMVersion() < 100 || getSMVersion() >= 120) && ANY_FP4;
387414
return should_skip_unsupported_fp8 || should_skip_unsupported_fp4;
388415
}
389416

@@ -496,8 +523,9 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
496523
mGatedMultiplier = mIsGated ? 2 : 1;
497524
auto const gated_inter = mInterSize * mGatedMultiplier;
498525

499-
size_t workspace_size = mMoERunner.getWorkspaceSize(mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK,
500-
mActType, {}, mUseLora, /*use_fp8_block_scaling=*/false, /*min_latency_mode=*/false, mUsePrequantScale);
526+
size_t workspace_size
527+
= mMoERunner.getWorkspaceSize(mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mActType, {},
528+
mUseLora, /*use_deepseek_fp8_block_scale=*/false, /*min_latency_mode=*/false, mUsePrequantScale);
501529

502530
mWorkspace = allocBuffer<char>(workspace_size);
503531
size_t const expert_matrix_size = mNumExperts * mHiddenSize * mInterSize;
@@ -528,20 +556,19 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
528556

529557
mQuantParams = QuantParams::FP8(mExpertFP8Scale1, mExpertFP8Scale2, mExpertFP8Scale3);
530558
}
531-
else if constexpr (FP4)
559+
else if constexpr (ANY_FP4)
532560
{
533561
mExpertFP4ActScale1 = allocBuffer<float>(1);
534-
mExpertFP4WeightSf1 = allocBuffer<ElementSF>(num_experts * gated_inter * mHiddenSize
535-
/ tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::BlockScaleVectorSize);
562+
mExpertFP4WeightSf1 = allocBuffer<ElementSF>(num_experts * gated_inter * mHiddenSize / FP4_VECTOR_SIZE);
536563
mExpertFP4GlobalScale1 = allocBuffer<float>(num_experts);
537564

538565
mExpertFP4ActScale2 = allocBuffer<float>(1);
539-
mExpertFP4WeightSf2 = allocBuffer<ElementSF>(num_experts * mInterSize * mHiddenSize
540-
/ tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::BlockScaleVectorSize);
566+
mExpertFP4WeightSf2 = allocBuffer<ElementSF>(num_experts * mInterSize * mHiddenSize / FP4_VECTOR_SIZE);
541567
mExpertFP4GlobalScale2 = allocBuffer<float>(num_experts);
542568

543-
mQuantParams = QuantParams::FP4(mExpertFP4ActScale1, mExpertFP4WeightSf1, mExpertFP4GlobalScale1,
544-
mExpertFP4ActScale2, mExpertFP4WeightSf2, mExpertFP4GlobalScale2);
569+
auto func = NVFP4 ? QuantParams::FP4 : QuantParams::FP8MXFP4;
570+
mQuantParams = func(mExpertFP4ActScale1, mExpertFP4WeightSf1, mExpertFP4GlobalScale1, mExpertFP4ActScale2,
571+
mExpertFP4WeightSf2, mExpertFP4GlobalScale2);
545572
}
546573

547574
mSelectedExperts = allocBuffer<int>(mTotalTokens * mK);
@@ -737,7 +764,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
737764
mExpertWeight1, mExpertBias1, mActType, mExpertWeight2, mExpertBias2, mQuantParams, mTotalTokens,
738765
mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap,
739766
parallelism_config, mUseLora, mLoraParams,
740-
/*use_fp8_block_scaling=*/false, /*min_latency_mode=*/false, min_latency_params, stream);
767+
/*use_deepseek_fp8_block_scale=*/false, /*min_latency_mode=*/false, min_latency_params, stream);
741768
}
742769

743770
void runBenchmark(benchmark::State& state);
@@ -775,6 +802,7 @@ void MixtureOfExpertsBenchmark<TypeTuple_>::runBenchmark(benchmark::State& state
775802
state.counters["act_fn"] = (int) mActType;
776803
state.counters["routing_config"] = (int) routing_config;
777804
state.counters["dtype"] = (int) toDTypeID();
805+
state.counters["wtype"] = (int) toWTypeID();
778806

779807
std::stringstream ss;
780808
ss << "Experts,K,Hidden,Inter,TP,EP,Rank,Tokens,Bias,Scale,Actfn,Tactic,Routing=";

cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
377377
{
378378
continue;
379379
}
380-
else if (BenchClass::FP4 && !hasDtype("fp4"))
380+
else if (BenchClass::NVFP4 && !hasDtype("fp4"))
381381
{
382382
continue;
383383
}
@@ -403,6 +403,10 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
403403
{
404404
continue;
405405
}
406+
else if (BenchClass::WFP4AFP8 && !hasDtype("wfp4afp8"))
407+
{
408+
continue;
409+
}
406410
}
407411

408412
// Do this after filtering datatypes as tactics only make sense if we know the data type
@@ -559,6 +563,7 @@ BENCHMARK_BASIC(SafeFP8, SafeFP8, half)
559563
#endif
560564
#ifdef ENABLE_FP4
561565
BENCHMARK_BASIC(SafeFP4, SafeFP4, half)
566+
BENCHMARK_BASIC(SafeFP8, SafeFP4, half)
562567
#endif
563568

564569
void delayedRegisterBenchmark()
@@ -578,6 +583,7 @@ void delayedRegisterBenchmark()
578583
#endif
579584
#ifdef ENABLE_FP4
580585
BENCHMARK_BASIC_DO_REGISTER(SafeFP4, SafeFP4, half);
586+
BENCHMARK_BASIC_DO_REGISTER(SafeFP8, SafeFP4, half);
581587
#endif
582588
}
583589
}
@@ -657,7 +663,7 @@ void help()
657663
"Useful for quick perf tests, prefer a full sweep and manually setting the tactic for more accurate "
658664
"results"
659665
"- dtypes - A list of dtypes to run this config through.\n"
660-
"Allowed values are: fp8, int4, int8, float, half, bfloat16\n"
666+
"Allowed values are: fp8, fp4, wfp4afp8, int4, int8, float, half, bfloat16\n"
661667
"If this argument is omitted all dtypes will be run. Note, not all tactics are supported for all "
662668
"dtypes,\n"
663669
"unsupported tactics will be skipped with a warning.\n"

cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,14 @@ public:
253253
}
254254
if constexpr (GetQuantType<Pattern> == QuantType::kFP4)
255255
{
256-
PackedVec<DType> pack_val = *reinterpret_cast<PackedVec<DType> const*>(&val);
257-
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2>(std::nullopt, token_id, m_access_id_in_token,
258-
std::nullopt, m_params.hidden_dim, reinterpret_cast<uint32_t*>(m_params.scale_out), m_params.layout);
256+
constexpr int SF_VEC_SIZE = 16;
257+
using PackedVec = PackedVec<DType>;
258+
PackedVec pack_val = *reinterpret_cast<PackedVec const*>(&val);
259+
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2, SF_VEC_SIZE>(std::nullopt, token_id,
260+
m_access_id_in_token, std::nullopt, m_params.hidden_dim,
261+
reinterpret_cast<uint32_t*>(m_params.scale_out), m_params.layout);
259262
reinterpret_cast<uint32_t*>(m_params.quant_out)[m_access_id]
260-
= cvt_warp_fp16_to_fp4(pack_val, m_scale_factor, sf_out);
263+
= cvt_warp_fp16_to_fp4<DType, SF_VEC_SIZE, false>(pack_val, m_scale_factor, sf_out);
261264
}
262265
else if constexpr (GetQuantType<Pattern> == QuantType::kFP8)
263266
{

cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,14 @@ __device__ __forceinline__ void fused_op(
147147
}
148148
if constexpr (QuantOut)
149149
{
150-
PackedVec<DType> pack_val = *reinterpret_cast<PackedVec<DType> const*>(&norm_val);
151-
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2>(std::nullopt /* batchIdx */, token_id,
152-
access_id_in_token, std::nullopt /* numRows */, params.hidden_dim,
150+
constexpr int SF_VEC_SIZE = 16;
151+
using PackedVec = PackedVec<DType>;
152+
PackedVec pack_val = *reinterpret_cast<PackedVec const*>(&norm_val);
153+
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2, SF_VEC_SIZE>(std::nullopt /* batchIdx */,
154+
token_id, access_id_in_token, std::nullopt /* numRows */, params.hidden_dim,
153155
reinterpret_cast<uint32_t*>(params.scale_out), params.layout);
154156
reinterpret_cast<uint32_t*>(params.quant_out)[access_id]
155-
= cvt_warp_fp16_to_fp4(pack_val, *params.scale_factor, sf_out);
157+
= cvt_warp_fp16_to_fp4<DType, SF_VEC_SIZE, false>(pack_val, *params.scale_factor, sf_out);
156158
}
157159
}
158160

cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,14 @@ set_directory_properties(
4242

4343
set(INSTANTIATION_GENERATION_DIR
4444
${CMAKE_CURRENT_BINARY_DIR}/cutlass_instantiations)
45+
4546
execute_process(
4647
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/python/
4748
COMMAND
4849
${Python3_EXECUTABLE} generate_kernels.py -a
4950
"${CMAKE_CUDA_ARCHITECTURES_ORIG};${CMAKE_CUDA_ARCHITECTURES_NATIVE}" -o
5051
${INSTANTIATION_GENERATION_DIR}
52+
OUTPUT_VARIABLE _KERNEL_GEN_OUTPUT
5153
RESULT_VARIABLE _KERNEL_GEN_SUCCESS)
5254

5355
if(NOT _KERNEL_GEN_SUCCESS MATCHES 0)
@@ -57,6 +59,23 @@ if(NOT _KERNEL_GEN_SUCCESS MATCHES 0)
5759
)
5860
endif()
5961

62+
file(GLOB_RECURSE INSTANTIATIONS_GENERATED ${INSTANTIATION_GENERATION_DIR}/*.cu)
63+
string(STRIP "${_KERNEL_GEN_OUTPUT}" _KERNEL_GEN_OUTPUT)
64+
65+
# Sort both lists to ensure order doesn't matter
66+
list(SORT _KERNEL_GEN_OUTPUT)
67+
list(SORT INSTANTIATIONS_GENERATED)
68+
69+
# Compare the lists
70+
if(NOT _KERNEL_GEN_OUTPUT STREQUAL INSTANTIATIONS_GENERATED)
71+
list(REMOVE_ITEM INSTANTIATIONS_GENERATED ${_KERNEL_GEN_OUTPUT})
72+
message(
73+
WARNING
74+
"There exist stale generated kernels in ${INSTANTIATION_GENERATION_DIR}. Removing these files:\n${INSTANTIATIONS_GENERATED}"
75+
)
76+
file(REMOVE ${INSTANTIATIONS_GENERATED})
77+
endif()
78+
6079
# Get the sources for Mixed Input GEMM launchers
6180
file(GLOB_RECURSE MIXED_CU_INSTANTIATIONS
6281
${INSTANTIATION_GENERATION_DIR}/gemm/*.cu)

cpp/tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_impl_sm90.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "cutlass/gemm/kernel/gemm_universal.hpp"
2929
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
3030
#include "cutlass_extensions/communication/collective/sm90_allreduce_nvls_warpspecialized.hpp"
31+
#include "cutlass_extensions/epilogue/fusion/sm90_visitor_allreduce_tma_warpspecialized.hpp"
3132
#include "cutlass_extensions/gemm/kernel/sm90_gemm_allreduce_tma_warpspecialized_pingpong.hpp"
3233

3334
#include "tensorrt_llm/common/cudaUtils.h"

cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,6 @@
2121
namespace tensorrt_llm::kernels::fp8_blockscale_gemm
2222
{
2323

24-
template <typename ElementA, typename ElementB, typename ElementD>
25-
CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::CutlassFp8BlockScaleGemmRunner()
26-
{
27-
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
28-
}
29-
30-
template <typename ElementA, typename ElementB, typename ElementD>
31-
CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::~CutlassFp8BlockScaleGemmRunner()
32-
{
33-
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
34-
}
35-
3624
template <typename ElementA, typename ElementB, typename ElementD>
3725
void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::gemm(void* mat_d, void const* mat_a,
3826
void const* mat_b, int shape_m, int shape_n, int shape_k, cudaStream_t stream, float const* scales_a,

cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ template <typename ElementA, typename ElementB, typename ElementD>
8585
class CutlassFp8BlockScaleGemmRunner : public CutlassFp8BlockScaleGemmRunnerInterface
8686
{
8787
public:
88-
CutlassFp8BlockScaleGemmRunner();
89-
~CutlassFp8BlockScaleGemmRunner();
88+
CutlassFp8BlockScaleGemmRunner() = default;
89+
~CutlassFp8BlockScaleGemmRunner() override = default;
9090

9191
void gemm(void* mat_d, void const* mat_a, void const* mat_b, int shape_m, int shape_n, int shape_k,
9292
cudaStream_t stream, float const* scales_a = nullptr, float const* scales_b = nullptr) override;

0 commit comments

Comments
 (0)