@@ -298,14 +298,21 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
298
298
using WeightType = typename TypeTuple_::WeightType;
299
299
using OutputType = typename TypeTuple_::OutputType;
300
300
constexpr static bool INT4 = std::is_same_v<WeightType, cutlass::uint4b_t >;
301
- constexpr static bool FP4 = std::is_same_v<DataType, SafeFP4>;
302
- constexpr static bool FP8 = std::is_same_v<DataType, SafeFP8>;
303
- constexpr static bool INT_QUANT = !std::is_same_v<DataType, WeightType>;
304
- using InputType = std::conditional_t <FP4, OutputType, DataType>;
305
- using WeightStorage = std::conditional_t <INT_QUANT || FP4, uint8_t , WeightType>;
306
- constexpr static int WEIGHT_ELEM_PER_BYTE = (INT4 || FP4) ? 2 : 1 ;
301
+ constexpr static bool NVFP4 = std::is_same_v<DataType, SafeFP4> && std::is_same_v<WeightType, SafeFP4>;
302
+ constexpr static bool FP8 = std::is_same_v<DataType, SafeFP8> && std::is_same_v<WeightType, SafeFP8>;
303
+ constexpr static bool WFP4AFP8 = std::is_same_v<WeightType, SafeFP4> && std::is_same_v<DataType, SafeFP8>;
304
+ constexpr static bool INT_QUANT = !std::is_same_v<DataType, WeightType>
305
+ && (std::is_same_v<WeightType, cutlass::uint4b_t > || std::is_same_v<WeightType, uint8_t >);
306
+ constexpr static bool ANY_FP4 = NVFP4 || WFP4AFP8;
307
+ using InputType = std::conditional_t <NVFP4, OutputType, DataType>;
308
+ using WeightStorage = std::conditional_t <INT_QUANT || ANY_FP4, uint8_t , WeightType>;
309
+ constexpr static int WEIGHT_ELEM_PER_BYTE = (INT4 || ANY_FP4) ? 2 : 1 ;
307
310
int const BASE_HIDDEN_SIZE = 64 / sizeof (WeightType) * WEIGHT_ELEM_PER_BYTE;
308
311
312
+ constexpr static int64_t FP4_VECTOR_SIZE = NVFP4
313
+ ? tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize
314
+ : tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize;
315
+
309
316
std::vector<BufferManager::IBufferPtr> managed_buffers;
310
317
int * mSelectedExperts {};
311
318
DataType* mInputTensor {};
@@ -316,12 +323,12 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
316
323
317
324
constexpr static nvinfer1::DataType toDTypeID ()
318
325
{
319
- if (FP8)
326
+ if (FP8 || WFP4AFP8 )
320
327
return nvinfer1::DataType::kFP8 ;
321
- if (FP4 )
328
+ if (NVFP4 )
322
329
return nvinfer1::DataType::kFP4 ;
323
330
if (INT_QUANT && INT4)
324
- return nvinfer1::DataType::kINT4 ; // Hack to distinguish int4, use unsigned
331
+ return nvinfer1::DataType::kINT4 ;
325
332
if (INT_QUANT)
326
333
return nvinfer1::DataType::kINT8 ;
327
334
if (std::is_same_v<DataType, float >)
@@ -331,9 +338,29 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
331
338
#ifdef ENABLE_BF16
332
339
if (std::is_same_v<DataType, nv_bfloat16>)
333
340
return nvinfer1::DataType::kBF16 ;
334
- #else
341
+ #endif
335
342
TLLM_THROW (" Unrecognised format" );
343
+ };
344
+
345
+ constexpr static nvinfer1::DataType toWTypeID ()
346
+ {
347
+ if (FP8)
348
+ return nvinfer1::DataType::kFP8 ;
349
+ if (NVFP4 || WFP4AFP8)
350
+ return nvinfer1::DataType::kFP4 ;
351
+ if (INT_QUANT && INT4)
352
+ return nvinfer1::DataType::kINT4 ;
353
+ if (INT_QUANT)
354
+ return nvinfer1::DataType::kINT8 ;
355
+ if (std::is_same_v<DataType, float >)
356
+ return nvinfer1::DataType::kFLOAT ;
357
+ if (std::is_same_v<DataType, half>)
358
+ return nvinfer1::DataType::kHALF ;
359
+ #ifdef ENABLE_BF16
360
+ if (std::is_same_v<DataType, nv_bfloat16>)
361
+ return nvinfer1::DataType::kBF16 ;
336
362
#endif
363
+ TLLM_THROW (" Unrecognised format" );
337
364
};
338
365
339
366
template <class T >
@@ -345,7 +372,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
345
372
}
346
373
else if constexpr (std::is_same_v<T, SafeFP4>)
347
374
{
348
- return nvinfer1::DataType::kINT64 ;
375
+ return nvinfer1::DataType::kFP4 ;
349
376
}
350
377
else if constexpr (std::is_same_v<T, uint8_t >)
351
378
{
@@ -380,10 +407,10 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
380
407
static_assert (!FP8, " FP8 Tests enabled on unsupported CUDA version" );
381
408
#endif
382
409
#ifndef ENABLE_FP4
383
- static_assert (!FP4 , " FP4 Tests enabled on unsupported CUDA version" );
410
+ static_assert (!ANY_FP4 , " FP4 Tests enabled on unsupported CUDA version" );
384
411
#endif
385
412
bool should_skip_unsupported_fp8 = getSMVersion () < 89 && FP8;
386
- bool should_skip_unsupported_fp4 = (getSMVersion () < 100 || getSMVersion () >= 120 ) && FP4 ;
413
+ bool should_skip_unsupported_fp4 = (getSMVersion () < 100 || getSMVersion () >= 120 ) && ANY_FP4 ;
387
414
return should_skip_unsupported_fp8 || should_skip_unsupported_fp4;
388
415
}
389
416
@@ -496,8 +523,9 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
496
523
mGatedMultiplier = mIsGated ? 2 : 1 ;
497
524
auto const gated_inter = mInterSize * mGatedMultiplier ;
498
525
499
- size_t workspace_size = mMoERunner .getWorkspaceSize (mTotalTokens , mHiddenSize , mInterSize , mNumExperts , mK ,
500
- mActType , {}, mUseLora , /* use_fp8_block_scaling=*/ false , /* min_latency_mode=*/ false , mUsePrequantScale );
526
+ size_t workspace_size
527
+ = mMoERunner .getWorkspaceSize (mTotalTokens , mHiddenSize , mInterSize , mNumExperts , mK , mActType , {},
528
+ mUseLora , /* use_deepseek_fp8_block_scale=*/ false , /* min_latency_mode=*/ false , mUsePrequantScale );
501
529
502
530
mWorkspace = allocBuffer<char >(workspace_size);
503
531
size_t const expert_matrix_size = mNumExperts * mHiddenSize * mInterSize ;
@@ -528,20 +556,19 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
528
556
529
557
mQuantParams = QuantParams::FP8 (mExpertFP8Scale1 , mExpertFP8Scale2 , mExpertFP8Scale3 );
530
558
}
531
- else if constexpr (FP4 )
559
+ else if constexpr (ANY_FP4 )
532
560
{
533
561
mExpertFP4ActScale1 = allocBuffer<float >(1 );
534
- mExpertFP4WeightSf1 = allocBuffer<ElementSF>(num_experts * gated_inter * mHiddenSize
535
- / tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::BlockScaleVectorSize);
562
+ mExpertFP4WeightSf1 = allocBuffer<ElementSF>(num_experts * gated_inter * mHiddenSize / FP4_VECTOR_SIZE);
536
563
mExpertFP4GlobalScale1 = allocBuffer<float >(num_experts);
537
564
538
565
mExpertFP4ActScale2 = allocBuffer<float >(1 );
539
- mExpertFP4WeightSf2 = allocBuffer<ElementSF>(num_experts * mInterSize * mHiddenSize
540
- / tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::BlockScaleVectorSize);
566
+ mExpertFP4WeightSf2 = allocBuffer<ElementSF>(num_experts * mInterSize * mHiddenSize / FP4_VECTOR_SIZE);
541
567
mExpertFP4GlobalScale2 = allocBuffer<float >(num_experts);
542
568
543
- mQuantParams = QuantParams::FP4 (mExpertFP4ActScale1 , mExpertFP4WeightSf1 , mExpertFP4GlobalScale1 ,
544
- mExpertFP4ActScale2 , mExpertFP4WeightSf2 , mExpertFP4GlobalScale2 );
569
+ auto func = NVFP4 ? QuantParams::FP4 : QuantParams::FP8MXFP4;
570
+ mQuantParams = func (mExpertFP4ActScale1 , mExpertFP4WeightSf1 , mExpertFP4GlobalScale1 , mExpertFP4ActScale2 ,
571
+ mExpertFP4WeightSf2 , mExpertFP4GlobalScale2 );
545
572
}
546
573
547
574
mSelectedExperts = allocBuffer<int >(mTotalTokens * mK );
@@ -734,7 +761,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
734
761
mExpertWeight1 , mExpertBias1 , mActType , mExpertWeight2 , mExpertBias2 , mQuantParams , mTotalTokens ,
735
762
mHiddenSize , mInterSize , mNumExperts , mK , mWorkspace , mFinalOutput , mSourceToExpandedMap ,
736
763
parallelism_config, mUseLora , mLoraParams ,
737
- /* use_fp8_block_scaling =*/ false , /* min_latency_mode=*/ false , min_latency_params, stream);
764
+ /* use_deepseek_fp8_block_scale =*/ false , /* min_latency_mode=*/ false , min_latency_params, stream);
738
765
}
739
766
740
767
void runBenchmark (benchmark::State& state);
@@ -772,6 +799,7 @@ void MixtureOfExpertsBenchmark<TypeTuple_>::runBenchmark(benchmark::State& state
772
799
state.counters [" act_fn" ] = (int ) mActType ;
773
800
state.counters [" routing_config" ] = (int ) routing_config;
774
801
state.counters [" dtype" ] = (int ) toDTypeID ();
802
+ state.counters [" wtype" ] = (int ) toWTypeID ();
775
803
776
804
std::stringstream ss;
777
805
ss << " Experts,K,Hidden,Inter,TP,EP,Rank,Tokens,Bias,Scale,Actfn,Tactic,Routing=" ;
0 commit comments