11// SPDX-License-Identifier: MIT
22// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
33
4+ // This example demonstrates 2D block scale quantization (N×K) for BQuant
5+ // using non-preshuffled configuration.
6+ // NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
7+ // This is currently done separately to avoid too verbose dispatching.
8+
49#include < cstring>
510#include < iostream>
611#include < ostream>
@@ -17,7 +22,7 @@ template <typename GemmConfig,
1722 typename ALayout,
1823 typename BLayout,
1924 typename CLayout,
20- uint32_t QuantGroupSize,
25+ typename QuantGroupSize,
2126 ck_tile::QuantType QuantMode,
2227 typename CDEElementWise>
2328float gemm_calc_quant (const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
@@ -57,11 +62,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
5762 GemmTraits,
5863 ComputeDataType>;
5964
65+ // This example only supports BQuant (no AQuant)
66+ // For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3
6067 using BaseGemmPipeline = std::conditional_t <
6168 GemmConfig::PreshuffleB == true ,
6269 ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
63- ck_tile::BaseAQuantGemmPipelineAgBgCrMem<GemmPipelineProblem>>; // memory pipeline hardcoded
64- // for aquant
70+ ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
6571
6672 const ck_tile::index_t K_split =
6773 (args.K + GemmConfig::K_Tile - 1 ) / GemmConfig::K_Tile * GemmConfig::K_Tile;
@@ -229,7 +235,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
229235
230236template <typename GemmConfig,
231237 typename TypeConfig,
232- uint32_t QuantGroupSize,
238+ typename QuantGroupSize,
233239 ck_tile::QuantType QuantMode>
234240int run_gemm_example_prec_type (std::string a_layout, std::string b_layout, int argc, char * argv[])
235241{
@@ -266,146 +272,99 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
266272 return 0 ;
267273}
268274
275+ // Forward declaration for dispatch function
276+ template <template <typename PreType> typename GemmConfig, typename QuantGroupSize>
277+ int dispatch_by_data_type (const std::string& data_type,
278+ const std::string& quant_mode,
279+ const std::string& a_layout,
280+ const std::string& b_layout,
281+ int argc,
282+ char * argv[]);
283+
284+ // Helper function to parse group size string "MxNxK"
285+ std::tuple<int , int , int > parse_group_size (const std::string& group_size_str)
286+ {
287+ int m = 1 , n = 1 , k = 128 ;
288+
289+ size_t first_x = group_size_str.find (' x' );
290+ if (first_x == std::string::npos)
291+ {
292+ // Single number provided, assume it's the K dimension
293+ k = std::stoi (group_size_str);
294+ return {1 , 1 , k};
295+ }
296+
297+ size_t second_x = group_size_str.find (' x' , first_x + 1 );
298+ if (second_x == std::string::npos)
299+ {
300+ throw std::runtime_error (" Invalid group_size format! Expected MxNxK (e.g., 1x32x128)" );
301+ }
302+
303+ m = std::stoi (group_size_str.substr (0 , first_x));
304+ n = std::stoi (group_size_str.substr (first_x + 1 , second_x - first_x - 1 ));
305+ k = std::stoi (group_size_str.substr (second_x + 1 ));
306+
307+ return {m, n, k};
308+ }
309+
269310template <template <typename PreType> typename GemmConfig>
270311int run_gemm_example (int argc, char * argv[])
271312{
272313 auto [result, arg_parser] = create_args (argc, argv);
273314 if (!result)
274315 return -1 ;
275316
276- std::string data_type = arg_parser.get_str (" prec" );
277- std::string a_layout = arg_parser.get_str (" a_layout" );
278- std::string b_layout = arg_parser.get_str (" b_layout" );
317+ std::string data_type = arg_parser.get_str (" prec" );
318+ std::string a_layout = arg_parser.get_str (" a_layout" );
319+ std::string b_layout = arg_parser.get_str (" b_layout" );
320+ std::string quant_mode = arg_parser.get_str (" quant_mode" );
321+ std::string group_size_str = arg_parser.get_str (" group_size" );
322+
323+ auto [m_group, n_group, k_group] = parse_group_size (group_size_str);
324+
325+ // Dispatch based on group size (M, N, K)
326+ return dispatch_group_size_ct<GemmConfig>(m_group, n_group, k_group, [&](auto QGS_) {
327+ using QuantGroupSize = decltype (QGS_);
328+ return dispatch_by_data_type<GemmConfig, QuantGroupSize>(
329+ data_type, quant_mode, a_layout, b_layout, argc, argv);
330+ });
331+ }
279332
280- std::string quant_mode = arg_parser.get_str (" quant_mode" );
333+ template <template <typename PreType> typename GemmConfig, typename QuantGroupSize>
334+ int dispatch_by_data_type (const std::string& data_type,
335+ const std::string& quant_mode,
336+ const std::string& a_layout,
337+ const std::string& b_layout,
338+ int argc,
339+ char * argv[])
340+ {
341+ // This example ONLY supports BQuant for 2D block scale quantization
342+ if (quant_mode != " bquant" )
343+ {
344+ throw std::runtime_error (" This example only supports BQuant! Use --quant_mode=bquant" );
345+ }
281346
282347 if (data_type == " fp8" )
283348 {
284349 using TypeConfig =
285350 decltype (GemmQuantTypeConfig<ck_tile::fp8_t , ck_tile::fp8_t , ck_tile::half_t , float >{});
286351
287- if (quant_mode == " aquant" )
288- {
289- return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t >,
290- TypeConfig,
291- 128 ,
292- ck_tile::QuantType::AQuantGrouped>(
293- a_layout, b_layout, argc, argv);
294- }
295- else if (quant_mode == " bquant" )
296- {
297- return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t >,
298- TypeConfig,
299- 128 ,
300- ck_tile::QuantType::BQuantGrouped>(
301- a_layout, b_layout, argc, argv);
302- }
303- else if (quant_mode == " rowcol" )
304- {
305- return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t >,
306- TypeConfig,
307- 128 ,
308- ck_tile::QuantType::RowColQuant>(
309- a_layout, b_layout, argc, argv);
310- }
311- else if (quant_mode == " tensor" )
312- {
313- return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t >,
314- TypeConfig,
315- 128 ,
316- ck_tile::QuantType::TensorQuant>(
317- a_layout, b_layout, argc, argv);
318- }
319- else
320- {
321- throw std::runtime_error (
322- " Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'" );
323- }
352+ return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t >,
353+ TypeConfig,
354+ QuantGroupSize,
355+ ck_tile::QuantType::BQuantGrouped>(
356+ a_layout, b_layout, argc, argv);
324357 }
325358 else if (data_type == " bf8" )
326359 {
327360 using TypeConfig =
328361 decltype (GemmQuantTypeConfig<ck_tile::bf8_t , ck_tile::bf8_t , ck_tile::half_t , float >{});
329362
330- if (quant_mode == " aquant" )
331- {
332- return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t >,
333- TypeConfig,
334- 128 ,
335- ck_tile::QuantType::AQuantGrouped>(
336- a_layout, b_layout, argc, argv);
337- }
338- else if (quant_mode == " bquant" )
339- {
340- return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t >,
341- TypeConfig,
342- 128 ,
343- ck_tile::QuantType::BQuantGrouped>(
344- a_layout, b_layout, argc, argv);
345- }
346- else if (quant_mode == " rowcol" )
347- {
348- return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t >,
349- TypeConfig,
350- 128 ,
351- ck_tile::QuantType::RowColQuant>(
352- a_layout, b_layout, argc, argv);
353- }
354- else if (quant_mode == " tensor" )
355- {
356- return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t >,
357- TypeConfig,
358- 128 ,
359- ck_tile::QuantType::TensorQuant>(
360- a_layout, b_layout, argc, argv);
361- }
362- else
363- {
364- throw std::runtime_error (
365- " Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'" );
366- }
367- }
368- else if (data_type == " i4fp8" )
369- {
370- using TypeConfig = decltype (GemmQuantTypeConfig<ck_tile::pk_int4_t ,
371- ck_tile::fp8_t ,
372- ck_tile::half_t ,
373- ck_tile::fp8_t >{});
374-
375- if (quant_mode == " aquant" )
376- {
377- return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t >,
378- TypeConfig,
379- 128 ,
380- ck_tile::QuantType::AQuantGrouped>(
381- a_layout, b_layout, argc, argv);
382- }
383- else
384- {
385- throw std::runtime_error (
386- " Unsupported quantization mode for this datatype! Use 'aquant'." );
387- }
388- }
389- else if (data_type == " i4bf8" )
390- {
391- using TypeConfig = decltype (GemmQuantTypeConfig<ck_tile::pk_int4_t ,
392- ck_tile::bf8_t ,
393- ck_tile::half_t ,
394- ck_tile::bf8_t >{});
395-
396- if (quant_mode == " aquant" )
397- {
398- return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t >,
399- TypeConfig,
400- 128 ,
401- ck_tile::QuantType::AQuantGrouped>(
402- a_layout, b_layout, argc, argv);
403- }
404- else
405- {
406- throw std::runtime_error (
407- " Unsupported quantization mode for this datatype! Use 'aquant'." );
408- }
363+ return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t >,
364+ TypeConfig,
365+ QuantGroupSize,
366+ ck_tile::QuantType::BQuantGrouped>(
367+ a_layout, b_layout, argc, argv);
409368 }
410369 else if (data_type == " fp8i4" )
411370 {
@@ -414,19 +373,11 @@ int run_gemm_example(int argc, char* argv[])
414373 ck_tile::half_t ,
415374 ck_tile::fp8_t >{});
416375
417- if (quant_mode == " bquant" )
418- {
419- return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t >,
420- TypeConfig,
421- 128 ,
422- ck_tile::QuantType::BQuantGrouped>(
423- a_layout, b_layout, argc, argv);
424- }
425- else
426- {
427- throw std::runtime_error (
428- " Unsupported quantization mode for this datatype! Use 'bquant'." );
429- }
376+ return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t >,
377+ TypeConfig,
378+ QuantGroupSize,
379+ ck_tile::QuantType::BQuantGrouped>(
380+ a_layout, b_layout, argc, argv);
430381 }
431382 else if (data_type == " bf8i4" )
432383 {
@@ -435,27 +386,39 @@ int run_gemm_example(int argc, char* argv[])
435386 ck_tile::half_t ,
436387 ck_tile::bf8_t >{});
437388
438- if (quant_mode == " bquant" )
439- {
440- return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t >,
441- TypeConfig,
442- 128 ,
443- ck_tile::QuantType::BQuantGrouped>(
444- a_layout, b_layout, argc, argv);
445- }
446- else
447- {
448- throw std::runtime_error (
449- " Unsupported quantization mode for this datatype! Use 'bquant'." );
450- }
389+ return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t >,
390+ TypeConfig,
391+ QuantGroupSize,
392+ ck_tile::QuantType::BQuantGrouped>(
393+ a_layout, b_layout, argc, argv);
451394 }
452395 else
453396 {
454397 throw std::runtime_error (" Unsupported data type for this operation !!!" );
455398 }
456399}
457400
401+ template <template <typename > typename GemmConfig, typename F>
402+ int dispatch_group_size_ct (int m, int n, int k, F&& f)
403+ {
404+ // This expands into a sequence of `if (m==M && n==N && k==K) { ... }`
405+ #define DISPATCH_ONE (M, N, K ) \
406+ if (m == M && n == N && k == K) \
407+ { \
408+ using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<M, N, K>>; \
409+ return f (QuantGroupSize{}); \
410+ }
411+
412+ CK_TILE_SUPPORTED_QUANT_GROUPS (DISPATCH_ONE)
413+
414+ #undef DISPATCH_ONE
415+
416+ throw std::runtime_error (
417+ " Unsupported group size! Please add it to CK_TILE_SUPPORTED_QUANT_GROUPS(X)." );
418+ }
419+
458420int main (int argc, char * argv[])
459421{
460- return !run_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
422+ // Use non-preshuffled GemmConfig for 2D block scale support
423+ return !run_gemm_example<GemmConfigBQuantPrefill>(argc, argv);
461424}
0 commit comments