Skip to content

Commit 16e85cf

Browse files
samremesCopilotCongMa13ThomasNing
authored
[CK_TILE] B matrix 2D block scale gemm (#3074)
* Refactor quant group size to be configurable for M/N/K, not just K * add some asserts for configurations not implemented * start setting of group size for N dimension * enable 2d for reference quant gemm * WIP: trying to figure out tile dstr and/or indexing for scale matrix * WIP * Fix handling of n dim blocks in tile windows etc * remove commented code and enable all tests again * fix formatting * Add more specialized tile distributions * Enable NWarps replication for bquant tile dstr * fix formatting * fix format * Fix some issues from the merge * fix formatting * one more fix to tile dstr, and revert debug initialization * Remove commented code Co-authored-by: Copilot <[email protected]> * simplify conditions that are needed for tile distributions * only enable the working group sizes in tests * fix formatting * Update tile distribution for 2D bquant * add some documentation and 2d block scale example * fix formatting * Add in Changlog and restructure the quant 2d example * fix CMake * support the change for blockscale 2d * fix the test file --------- Co-authored-by: Copilot <[email protected]> Co-authored-by: Cong Ma <[email protected]> Co-authored-by: ThomasNing <[email protected]>
1 parent 73f6378 commit 16e85cf

24 files changed

+473
-360
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
2424
* Added WMMA (gfx12) support for FMHA.
2525
* Added pooling kernel in CK_TILE
2626
* Added top-k sigmoid kernel in CK_TILE
27+
* Added the blockscale 2D support for CK_TILE GEMM.
2728

2829
### Changed
2930

example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp

Lines changed: 114 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
// SPDX-License-Identifier: MIT
22
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
33

4+
// This example demonstrates 2D block scale quantization (N×K) for BQuant
5+
// using non-preshuffled configuration.
6+
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
7+
// This is currently done separately to avoid too verbose dispatching.
8+
49
#include <cstring>
510
#include <iostream>
611
#include <ostream>
@@ -17,7 +22,7 @@ template <typename GemmConfig,
1722
typename ALayout,
1823
typename BLayout,
1924
typename CLayout,
20-
uint32_t QuantGroupSize,
25+
typename QuantGroupSize,
2126
ck_tile::QuantType QuantMode,
2227
typename CDEElementWise>
2328
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
@@ -57,11 +62,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
5762
GemmTraits,
5863
ComputeDataType>;
5964

65+
// This example only supports BQuant (no AQuant)
66+
// For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3
6067
using BaseGemmPipeline = std::conditional_t<
6168
GemmConfig::PreshuffleB == true,
6269
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
63-
ck_tile::BaseAQuantGemmPipelineAgBgCrMem<GemmPipelineProblem>>; // memory pipeline hardcoded
64-
// for aquant
70+
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
6571

6672
const ck_tile::index_t K_split =
6773
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
@@ -229,7 +235,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
229235

230236
template <typename GemmConfig,
231237
typename TypeConfig,
232-
uint32_t QuantGroupSize,
238+
typename QuantGroupSize,
233239
ck_tile::QuantType QuantMode>
234240
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
235241
{
@@ -266,146 +272,99 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
266272
return 0;
267273
}
268274

275+
// Forward declaration for dispatch function
276+
template <template <typename PreType> typename GemmConfig, typename QuantGroupSize>
277+
int dispatch_by_data_type(const std::string& data_type,
278+
const std::string& quant_mode,
279+
const std::string& a_layout,
280+
const std::string& b_layout,
281+
int argc,
282+
char* argv[]);
283+
284+
// Helper function to parse group size string "MxNxK"
285+
std::tuple<int, int, int> parse_group_size(const std::string& group_size_str)
286+
{
287+
int m = 1, n = 1, k = 128;
288+
289+
size_t first_x = group_size_str.find('x');
290+
if(first_x == std::string::npos)
291+
{
292+
// Single number provided, assume it's the K dimension
293+
k = std::stoi(group_size_str);
294+
return {1, 1, k};
295+
}
296+
297+
size_t second_x = group_size_str.find('x', first_x + 1);
298+
if(second_x == std::string::npos)
299+
{
300+
throw std::runtime_error("Invalid group_size format! Expected MxNxK (e.g., 1x32x128)");
301+
}
302+
303+
m = std::stoi(group_size_str.substr(0, first_x));
304+
n = std::stoi(group_size_str.substr(first_x + 1, second_x - first_x - 1));
305+
k = std::stoi(group_size_str.substr(second_x + 1));
306+
307+
return {m, n, k};
308+
}
309+
269310
template <template <typename PreType> typename GemmConfig>
270311
int run_gemm_example(int argc, char* argv[])
271312
{
272313
auto [result, arg_parser] = create_args(argc, argv);
273314
if(!result)
274315
return -1;
275316

276-
std::string data_type = arg_parser.get_str("prec");
277-
std::string a_layout = arg_parser.get_str("a_layout");
278-
std::string b_layout = arg_parser.get_str("b_layout");
317+
std::string data_type = arg_parser.get_str("prec");
318+
std::string a_layout = arg_parser.get_str("a_layout");
319+
std::string b_layout = arg_parser.get_str("b_layout");
320+
std::string quant_mode = arg_parser.get_str("quant_mode");
321+
std::string group_size_str = arg_parser.get_str("group_size");
322+
323+
auto [m_group, n_group, k_group] = parse_group_size(group_size_str);
324+
325+
// Dispatch based on group size (M, N, K)
326+
return dispatch_group_size_ct<GemmConfig>(m_group, n_group, k_group, [&](auto QGS_) {
327+
using QuantGroupSize = decltype(QGS_);
328+
return dispatch_by_data_type<GemmConfig, QuantGroupSize>(
329+
data_type, quant_mode, a_layout, b_layout, argc, argv);
330+
});
331+
}
279332

280-
std::string quant_mode = arg_parser.get_str("quant_mode");
333+
template <template <typename PreType> typename GemmConfig, typename QuantGroupSize>
334+
int dispatch_by_data_type(const std::string& data_type,
335+
const std::string& quant_mode,
336+
const std::string& a_layout,
337+
const std::string& b_layout,
338+
int argc,
339+
char* argv[])
340+
{
341+
// This example ONLY supports BQuant for 2D block scale quantization
342+
if(quant_mode != "bquant")
343+
{
344+
throw std::runtime_error("This example only supports BQuant! Use --quant_mode=bquant");
345+
}
281346

282347
if(data_type == "fp8")
283348
{
284349
using TypeConfig =
285350
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
286351

287-
if(quant_mode == "aquant")
288-
{
289-
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
290-
TypeConfig,
291-
128,
292-
ck_tile::QuantType::AQuantGrouped>(
293-
a_layout, b_layout, argc, argv);
294-
}
295-
else if(quant_mode == "bquant")
296-
{
297-
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
298-
TypeConfig,
299-
128,
300-
ck_tile::QuantType::BQuantGrouped>(
301-
a_layout, b_layout, argc, argv);
302-
}
303-
else if(quant_mode == "rowcol")
304-
{
305-
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
306-
TypeConfig,
307-
128,
308-
ck_tile::QuantType::RowColQuant>(
309-
a_layout, b_layout, argc, argv);
310-
}
311-
else if(quant_mode == "tensor")
312-
{
313-
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
314-
TypeConfig,
315-
128,
316-
ck_tile::QuantType::TensorQuant>(
317-
a_layout, b_layout, argc, argv);
318-
}
319-
else
320-
{
321-
throw std::runtime_error(
322-
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
323-
}
352+
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
353+
TypeConfig,
354+
QuantGroupSize,
355+
ck_tile::QuantType::BQuantGrouped>(
356+
a_layout, b_layout, argc, argv);
324357
}
325358
else if(data_type == "bf8")
326359
{
327360
using TypeConfig =
328361
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
329362

330-
if(quant_mode == "aquant")
331-
{
332-
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
333-
TypeConfig,
334-
128,
335-
ck_tile::QuantType::AQuantGrouped>(
336-
a_layout, b_layout, argc, argv);
337-
}
338-
else if(quant_mode == "bquant")
339-
{
340-
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
341-
TypeConfig,
342-
128,
343-
ck_tile::QuantType::BQuantGrouped>(
344-
a_layout, b_layout, argc, argv);
345-
}
346-
else if(quant_mode == "rowcol")
347-
{
348-
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
349-
TypeConfig,
350-
128,
351-
ck_tile::QuantType::RowColQuant>(
352-
a_layout, b_layout, argc, argv);
353-
}
354-
else if(quant_mode == "tensor")
355-
{
356-
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
357-
TypeConfig,
358-
128,
359-
ck_tile::QuantType::TensorQuant>(
360-
a_layout, b_layout, argc, argv);
361-
}
362-
else
363-
{
364-
throw std::runtime_error(
365-
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
366-
}
367-
}
368-
else if(data_type == "i4fp8")
369-
{
370-
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
371-
ck_tile::fp8_t,
372-
ck_tile::half_t,
373-
ck_tile::fp8_t>{});
374-
375-
if(quant_mode == "aquant")
376-
{
377-
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
378-
TypeConfig,
379-
128,
380-
ck_tile::QuantType::AQuantGrouped>(
381-
a_layout, b_layout, argc, argv);
382-
}
383-
else
384-
{
385-
throw std::runtime_error(
386-
"Unsupported quantization mode for this datatype! Use 'aquant'.");
387-
}
388-
}
389-
else if(data_type == "i4bf8")
390-
{
391-
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
392-
ck_tile::bf8_t,
393-
ck_tile::half_t,
394-
ck_tile::bf8_t>{});
395-
396-
if(quant_mode == "aquant")
397-
{
398-
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
399-
TypeConfig,
400-
128,
401-
ck_tile::QuantType::AQuantGrouped>(
402-
a_layout, b_layout, argc, argv);
403-
}
404-
else
405-
{
406-
throw std::runtime_error(
407-
"Unsupported quantization mode for this datatype! Use 'aquant'.");
408-
}
363+
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
364+
TypeConfig,
365+
QuantGroupSize,
366+
ck_tile::QuantType::BQuantGrouped>(
367+
a_layout, b_layout, argc, argv);
409368
}
410369
else if(data_type == "fp8i4")
411370
{
@@ -414,19 +373,11 @@ int run_gemm_example(int argc, char* argv[])
414373
ck_tile::half_t,
415374
ck_tile::fp8_t>{});
416375

417-
if(quant_mode == "bquant")
418-
{
419-
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
420-
TypeConfig,
421-
128,
422-
ck_tile::QuantType::BQuantGrouped>(
423-
a_layout, b_layout, argc, argv);
424-
}
425-
else
426-
{
427-
throw std::runtime_error(
428-
"Unsupported quantization mode for this datatype! Use 'bquant'.");
429-
}
376+
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
377+
TypeConfig,
378+
QuantGroupSize,
379+
ck_tile::QuantType::BQuantGrouped>(
380+
a_layout, b_layout, argc, argv);
430381
}
431382
else if(data_type == "bf8i4")
432383
{
@@ -435,27 +386,39 @@ int run_gemm_example(int argc, char* argv[])
435386
ck_tile::half_t,
436387
ck_tile::bf8_t>{});
437388

438-
if(quant_mode == "bquant")
439-
{
440-
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
441-
TypeConfig,
442-
128,
443-
ck_tile::QuantType::BQuantGrouped>(
444-
a_layout, b_layout, argc, argv);
445-
}
446-
else
447-
{
448-
throw std::runtime_error(
449-
"Unsupported quantization mode for this datatype! Use 'bquant'.");
450-
}
389+
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
390+
TypeConfig,
391+
QuantGroupSize,
392+
ck_tile::QuantType::BQuantGrouped>(
393+
a_layout, b_layout, argc, argv);
451394
}
452395
else
453396
{
454397
throw std::runtime_error("Unsupported data type for this operation !!!");
455398
}
456399
}
457400

401+
template <template <typename> typename GemmConfig, typename F>
402+
int dispatch_group_size_ct(int m, int n, int k, F&& f)
403+
{
404+
// This expands into a sequence of `if (m==M && n==N && k==K) { ... }`
405+
#define DISPATCH_ONE(M, N, K) \
406+
if(m == M && n == N && k == K) \
407+
{ \
408+
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<M, N, K>>; \
409+
return f(QuantGroupSize{}); \
410+
}
411+
412+
CK_TILE_SUPPORTED_QUANT_GROUPS(DISPATCH_ONE)
413+
414+
#undef DISPATCH_ONE
415+
416+
throw std::runtime_error(
417+
"Unsupported group size! Please add it to CK_TILE_SUPPORTED_QUANT_GROUPS(X).");
418+
}
419+
458420
int main(int argc, char* argv[])
459421
{
460-
return !run_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
422+
// Use non-preshuffled GemmConfig for 2D block scale support
423+
return !run_gemm_example<GemmConfigBQuantPrefill>(argc, argv);
461424
}

0 commit comments

Comments
 (0)