Skip to content

Commit 1700ddf

Browse files
committed
Ck benchmark
1 parent 8fec805 commit 1700ddf

File tree

89 files changed

+8090
-1568
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+8090
-1568
lines changed
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
#pragma once
5+
6+
#include <string>
7+
#include <variant>
8+
9+
#include "ck_tile/core.hpp"
10+
#include "ck_tile/host/kernel_launch.hpp"
11+
#include "ck_tile/ops/epilogue.hpp"
12+
#include "ck_tile/ops/gemm.hpp"
13+
#include "ck_tile/utility/json_dump.hpp"
14+
15+
#define CK_TILE_PIPELINE_COMPUTE_V3 1
16+
#define CK_TILE_PIPELINE_MEMORY 2
17+
#define CK_TILE_PIPELINE_COMPUTE_V4 3
18+
#define CK_TILE_PIPELINE_COMPUTE_V5 4
19+
20+
struct GemmConfigBase
21+
{
22+
static constexpr bool kPadM = true;
23+
static constexpr bool kPadN = true;
24+
static constexpr bool kPadK = true;
25+
26+
static constexpr bool PermuteA = false;
27+
static constexpr bool PermuteB = false;
28+
29+
static constexpr bool TransposeC = false;
30+
static constexpr bool UseStructuredSparsity = false;
31+
32+
static constexpr int kBlockPerCu = 1;
33+
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
34+
static constexpr ck_tile::index_t TileParitionerM01 = 4;
35+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
36+
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
37+
static constexpr ck_tile::index_t NumWaveGroups = 1;
38+
static constexpr bool Preshuffle = false;
39+
static constexpr bool TiledMMAPermuteN = false;
40+
};
41+
42+
template <typename PrecType>
43+
struct GemmConfigMemoryInterwave : public GemmConfigBase
44+
{
45+
// Memory friendly for Interwave scheduler
46+
static constexpr ck_tile::index_t M_Tile = 128;
47+
static constexpr ck_tile::index_t N_Tile = 32;
48+
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
49+
50+
static constexpr ck_tile::index_t M_Warp = 4;
51+
static constexpr ck_tile::index_t N_Warp = 1;
52+
static constexpr ck_tile::index_t K_Warp = 1;
53+
54+
static constexpr ck_tile::index_t M_Warp_Tile = 32;
55+
static constexpr ck_tile::index_t N_Warp_Tile = 32;
56+
static constexpr ck_tile::index_t K_Warp_Tile = 16;
57+
58+
static constexpr bool DoubleSmemBuffer = false;
59+
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
60+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
61+
};
62+
63+
template <typename PrecType>
64+
struct GemmConfigMemoryIntrawave : public GemmConfigBase
65+
{
66+
static constexpr ck_tile::index_t M_Tile = 128;
67+
static constexpr ck_tile::index_t N_Tile = 32;
68+
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
69+
70+
static constexpr ck_tile::index_t M_Warp = 4;
71+
static constexpr ck_tile::index_t N_Warp = 1;
72+
static constexpr ck_tile::index_t K_Warp = 1;
73+
74+
static constexpr ck_tile::index_t M_Warp_Tile = 32;
75+
static constexpr ck_tile::index_t N_Warp_Tile = 32;
76+
static constexpr ck_tile::index_t K_Warp_Tile = 16;
77+
78+
static constexpr bool DoubleSmemBuffer = false;
79+
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
80+
};
81+
82+
template <typename PrecType>
83+
struct GemmConfigComputeV3 : public GemmConfigBase
84+
{
85+
// Compute V3 only support Intrawave scheduler
86+
static constexpr ck_tile::index_t M_Tile = 16;
87+
static constexpr ck_tile::index_t N_Tile = 64;
88+
static constexpr ck_tile::index_t K_Tile = 64;
89+
90+
static constexpr ck_tile::index_t M_Warp = 1;
91+
static constexpr ck_tile::index_t N_Warp = 4;
92+
static constexpr ck_tile::index_t K_Warp = 1;
93+
94+
static constexpr ck_tile::index_t M_Warp_Tile = 16;
95+
static constexpr ck_tile::index_t N_Warp_Tile = 16;
96+
static constexpr ck_tile::index_t K_Warp_Tile = 32;
97+
98+
static constexpr bool DoubleSmemBuffer = false;
99+
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
100+
};
101+
102+
template <typename PrecType>
103+
struct GemmConfigComputeV3_1 : public GemmConfigBase
104+
{
105+
static constexpr ck_tile::index_t M_Tile = 256;
106+
static constexpr ck_tile::index_t N_Tile = 256;
107+
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
108+
109+
static constexpr ck_tile::index_t M_Warp = 2;
110+
static constexpr ck_tile::index_t N_Warp = 2;
111+
static constexpr ck_tile::index_t K_Warp = 1;
112+
113+
static constexpr ck_tile::index_t M_Warp_Tile = 32;
114+
static constexpr ck_tile::index_t N_Warp_Tile = 32;
115+
static constexpr ck_tile::index_t K_Warp_Tile = 16;
116+
117+
static constexpr bool DoubleSmemBuffer = false;
118+
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
119+
};
120+
121+
template <typename PrecType>
122+
struct GemmConfigComputeV3_2 : public GemmConfigBase
123+
{
124+
static constexpr ck_tile::index_t M_Tile = 128;
125+
static constexpr ck_tile::index_t N_Tile = 128;
126+
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
127+
128+
static constexpr ck_tile::index_t M_Warp = 2;
129+
static constexpr ck_tile::index_t N_Warp = 2;
130+
static constexpr ck_tile::index_t K_Warp = 1;
131+
132+
static constexpr ck_tile::index_t M_Warp_Tile = 16;
133+
static constexpr ck_tile::index_t N_Warp_Tile = 16;
134+
static constexpr ck_tile::index_t K_Warp_Tile = 32;
135+
136+
static constexpr bool DoubleSmemBuffer = false;
137+
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
138+
139+
static constexpr int kBlockPerCu = 2;
140+
};
141+
142+
template <typename PrecType>
143+
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
144+
{
145+
static constexpr ck_tile::index_t M_Tile = 128;
146+
static constexpr ck_tile::index_t N_Tile = 128;
147+
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
148+
149+
static constexpr ck_tile::index_t M_Warp = 4;
150+
static constexpr ck_tile::index_t N_Warp = 2;
151+
static constexpr ck_tile::index_t K_Warp = 1;
152+
153+
static constexpr ck_tile::index_t M_Warp_Tile = 16;
154+
static constexpr ck_tile::index_t N_Warp_Tile = 16;
155+
static constexpr ck_tile::index_t K_Warp_Tile = 16;
156+
157+
static constexpr bool DoubleSmemBuffer = false;
158+
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
159+
160+
static constexpr int kBlockPerCu = 2;
161+
};
162+
163+
template <typename PrecType>
164+
struct GemmConfigComputeV4 : public GemmConfigBase
165+
{
166+
// Compute V4 only support Intrawave scheduler
167+
// Using the ping pong reader in the lds level
168+
static constexpr ck_tile::index_t M_Tile = 256;
169+
static constexpr ck_tile::index_t N_Tile = 256;
170+
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
171+
172+
static constexpr ck_tile::index_t M_Warp = 2;
173+
static constexpr ck_tile::index_t N_Warp = 2;
174+
static constexpr ck_tile::index_t K_Warp = 1;
175+
176+
static constexpr ck_tile::index_t M_Warp_Tile = 32;
177+
static constexpr ck_tile::index_t N_Warp_Tile = 32;
178+
static constexpr ck_tile::index_t K_Warp_Tile = 16;
179+
180+
static constexpr bool DoubleSmemBuffer = true;
181+
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
182+
};
183+
184+
template <typename PrecType>
185+
struct GemmConfigComputeV4_1 : public GemmConfigBase
186+
{
187+
static constexpr ck_tile::index_t M_Tile = 256;
188+
static constexpr ck_tile::index_t N_Tile = 256;
189+
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
190+
191+
static constexpr ck_tile::index_t M_Warp = 2;
192+
static constexpr ck_tile::index_t N_Warp = 2;
193+
static constexpr ck_tile::index_t K_Warp = 1;
194+
195+
static constexpr ck_tile::index_t M_Warp_Tile = 32;
196+
static constexpr ck_tile::index_t N_Warp_Tile = 32;
197+
static constexpr ck_tile::index_t K_Warp_Tile = 16;
198+
199+
static constexpr bool DoubleSmemBuffer = true;
200+
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
201+
};
202+
203+
template <typename PrecType>
204+
struct GemmConfigComputeV5 : public GemmConfigBase
205+
{
206+
static constexpr ck_tile::index_t M_Tile = 128;
207+
static constexpr ck_tile::index_t N_Tile = 128;
208+
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
209+
210+
static constexpr ck_tile::index_t M_Warp = 1;
211+
static constexpr ck_tile::index_t N_Warp = 1;
212+
static constexpr ck_tile::index_t K_Warp = 2;
213+
214+
static constexpr ck_tile::index_t M_Warp_Tile = 32;
215+
static constexpr ck_tile::index_t N_Warp_Tile = 32;
216+
static constexpr ck_tile::index_t K_Warp_Tile = 16;
217+
218+
static constexpr bool DoubleSmemBuffer = false;
219+
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
220+
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
221+
};
222+
223+
template <typename InDataType, typename WeiDataType = InDataType, typename OutDataType = InDataType>
224+
struct ConvTypeConfig;
225+
226+
template <>
227+
struct ConvTypeConfig<ck_tile::half_t>
228+
{
229+
using InDataType = ck_tile::half_t;
230+
using WeiDataType = ck_tile::half_t;
231+
using AccDataType = float;
232+
using OutDataType = ck_tile::half_t;
233+
// ToDo: Add more bias config to support different categories of GEMM.
234+
};
235+
236+
template <>
237+
struct ConvTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
238+
{
239+
using InDataType = ck_tile::bf16_t;
240+
using WeiDataType = ck_tile::bf16_t;
241+
using AccDataType = float;
242+
using OutDataType = ck_tile::bf16_t;
243+
};
244+
245+
template <typename T>
246+
struct DataTypeTraits;
247+
248+
template <>
249+
struct DataTypeTraits<float>
250+
{
251+
static constexpr const char* name = "fp32";
252+
};
253+
254+
template <>
255+
struct DataTypeTraits<ck_tile::half_t>
256+
{
257+
static constexpr const char* name = "fp16";
258+
};
259+
260+
template <>
261+
struct DataTypeTraits<ck_tile::bf16_t>
262+
{
263+
static constexpr const char* name = "bf16";
264+
};
265+
266+
template <ck_tile::index_t PipelineId>
267+
struct PipelineTypeTraits;
268+
269+
template <>
270+
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
271+
{
272+
template <typename PipelineProblem>
273+
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
274+
template <typename PipelineProblem>
275+
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
276+
};
277+
278+
template <>
279+
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
280+
{
281+
template <typename PipelineProblem>
282+
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
283+
template <typename PipelineProblem>
284+
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
285+
};
286+
287+
template <>
288+
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
289+
{
290+
template <typename PipelineProblem>
291+
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
292+
template <typename PipelineProblem>
293+
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
294+
};
295+
296+
template <>
297+
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
298+
{
299+
template <typename PipelineProblem>
300+
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
301+
template <typename PipelineProblem>
302+
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
303+
};

example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include "grouped_convolution_backward_data_invoker.hpp"
1515
#include "run_grouped_convolution_bwd_data_example.inc"
1616

17-
template <template <typename PrecType> typename ConvConfig>
17+
template <template <typename PrecType> typename GemmConfig>
1818
int run_grouped_conv_bwd_data_example(int argc, char* argv[])
1919
{
2020
using Invoker = GroupedConvolutionBackwardDataInvoker;
@@ -31,14 +31,14 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
3131
if(data_type == "fp16")
3232
{
3333
return run_grouped_conv_bwd_data_example_prec_type<Invoker,
34-
ConvConfig<ck_tile::half_t>,
34+
GemmConfig<ck_tile::half_t>,
3535
ck_tile::half_t>(
3636
in_layout, wei_layout, out_layout, argc, argv);
3737
}
3838
else if(data_type == "bf16")
3939
{
4040
return run_grouped_conv_bwd_data_example_prec_type<Invoker,
41-
ConvConfig<ck_tile::bf16_t>,
41+
GemmConfig<ck_tile::bf16_t>,
4242
ck_tile::bf16_t>(
4343
in_layout, wei_layout, out_layout, argc, argv);
4444
}
@@ -51,8 +51,8 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
5151
int main(int argc, char* argv[])
5252
{
5353
#if CK_TILE_USE_WMMA
54-
return !run_grouped_conv_bwd_data_example<ConvConfigComputeV3_WMMA>(argc, argv);
54+
return !run_grouped_conv_bwd_data_example<GemmConfigComputeV3_WMMA>(argc, argv);
5555
#else
56-
return !run_grouped_conv_bwd_data_example<ConvConfigComputeV3>(argc, argv);
56+
return !run_grouped_conv_bwd_data_example<GemmConfigComputeV3>(argc, argv);
5757
#endif
5858
}

0 commit comments

Comments
 (0)