Skip to content

Commit 32eb4f2

Browse files
committed
Minor improvements to splitk related changes
1 parent 5a673d4 commit 32eb4f2

File tree

2 files changed

+50
-32
lines changed

2 files changed

+50
-32
lines changed

include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
4040
{
4141
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
4242

43+
// Full K needed for matrix B
44+
const index_t Kt = karg.K;
45+
4346
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
4447

45-
const index_t num_k_per_block = karg.K / (GridwiseGemm::KLane * GridwiseGemm::KPack);
48+
const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
4649
const index_t k_id = blockIdx.z * num_k_per_block;
4750

4851
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
@@ -51,7 +54,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
5154
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
5255
p_shared,
5356
karg,
54-
k_id);
57+
k_id,
58+
Kt);
5559
}
5660
#else
5761
ignore = karg;
@@ -78,8 +82,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
7882
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
7983
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
8084

81-
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
82-
const index_t num_k_per_block = karg.K / (GridwiseGemm::KLane * GridwiseGemm::KPack);
85+
// Full K needed for matrix B
86+
const index_t Kt = karg.K;
87+
88+
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
89+
90+
const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
8391
const index_t k_id = blockIdx.z * num_k_per_block;
8492

8593
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
@@ -89,7 +97,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
8997
p_shared_0,
9098
p_shared_1,
9199
karg,
92-
k_id);
100+
k_id,
101+
Kt);
93102
}
94103
#else
95104
ignore = karg;
@@ -1147,7 +1156,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
11471156
const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled,
11481157
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
11491158
c_grid_desc_mblock_mperblock_nblock_nperblock,
1150-
index_t k_id)
1159+
const index_t k_id)
11511160
{
11521161
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
11531162
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
@@ -1479,11 +1488,10 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
14791488
CDataType* p_c_grid,
14801489
void* p_shared,
14811490
const Problem& problem,
1482-
index_t k_id)
1491+
const index_t k_id,
1492+
const index_t Kt)
14831493
{
1484-
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1485-
// recompute K without splitK for matrix B
1486-
const index_t Kt = problem.K + problem.KRead * (problem.KBatch - 1);
1494+
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
14871495
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
14881496
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
14891497
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
@@ -1527,7 +1535,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
15271535
const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled,
15281536
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
15291537
c_grid_desc_mblock_mperblock_nblock_nperblock,
1530-
index_t k_id)
1538+
const index_t k_id)
15311539
{
15321540
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
15331541
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
@@ -1868,11 +1876,10 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
18681876
void* p_shared_0,
18691877
void* p_shared_1,
18701878
const Problem& problem,
1871-
index_t k_id)
1879+
const index_t k_id,
1880+
const index_t Kt)
18721881
{
1873-
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1874-
// recompute K without splitK for matrix B
1875-
const index_t Kt = problem.K + problem.KRead * (problem.KBatch - 1);
1882+
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
18761883
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
18771884
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
18781885
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);

include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
4343
{
4444
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
4545

46+
// Full K needed for matrix B
47+
const index_t Kt = karg.K;
48+
4649
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
47-
const index_t num_k_per_block =
48-
karg.K / (GridwiseGemm::KLane * GridwiseGemm::KPackPerGroup);
49-
const index_t k_id = blockIdx.z * num_k_per_block;
50+
51+
const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
52+
const index_t k_id = blockIdx.z * num_k_per_block;
5053

5154
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
5255
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
@@ -58,7 +61,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
5861
karg.a_element_op,
5962
karg.b_element_op,
6063
karg.c_element_op,
61-
k_id);
64+
k_id,
65+
Kt);
6266
}
6367
#else
6468
ignore = karg;
@@ -83,10 +87,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
8387
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
8488
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
8589

90+
// Full K needed for matrix B
91+
const index_t Kt = karg.K;
92+
8693
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
87-
const index_t num_k_per_block =
88-
karg.K / (GridwiseGemm::KLane * GridwiseGemm::KPackPerGroup);
89-
const index_t k_id = blockIdx.z * num_k_per_block;
94+
95+
const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
96+
const index_t k_id = blockIdx.z * num_k_per_block;
9097

9198
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
9299
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
@@ -99,7 +106,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
99106
karg.a_element_op,
100107
karg.b_element_op,
101108
karg.c_element_op,
102-
k_id);
109+
k_id,
110+
Kt);
103111
}
104112
#else
105113
ignore = karg;
@@ -1172,7 +1180,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
11721180
AElementwiseOperation a_element_op,
11731181
BElementwiseOperation b_element_op,
11741182
CElementwiseOperation c_element_op,
1175-
index_t k_id)
1183+
const index_t k_id,
1184+
const index_t Kt)
11761185
{
11771186
const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
11781187
Run<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
@@ -1186,7 +1195,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
11861195
b_element_op,
11871196
c_element_op,
11881197
block_2_ctile_map,
1189-
k_id);
1198+
k_id,
1199+
Kt);
11901200
}
11911201

11921202
template <typename Block2CTileMap,
@@ -1203,12 +1213,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
12031213
BElementwiseOperation b_element_op,
12041214
CElementwiseOperation c_element_op,
12051215
const Block2CTileMap& block_2_ctile_map,
1206-
index_t k_id)
1216+
const index_t k_id,
1217+
const index_t Kt)
12071218
{
12081219
ignore = b_element_op;
12091220
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1210-
// recompute K without splitK for matrix B
1211-
const index_t Kt = problem.K + problem.KRead * (problem.KBatch - 1);
12121221
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
12131222

12141223
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
@@ -1611,7 +1620,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
16111620
AElementwiseOperation a_element_op,
16121621
BElementwiseOperation b_element_op,
16131622
CElementwiseOperation c_element_op,
1614-
index_t k_id)
1623+
const index_t k_id,
1624+
const index_t Kt)
16151625
{
16161626
const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
16171627
Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
@@ -1626,7 +1636,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
16261636
b_element_op,
16271637
c_element_op,
16281638
block_2_ctile_map,
1629-
k_id);
1639+
k_id,
1640+
Kt);
16301641
}
16311642

16321643
template <typename Block2CTileMap,
@@ -1644,11 +1655,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
16441655
BElementwiseOperation b_element_op,
16451656
CElementwiseOperation c_element_op,
16461657
const Block2CTileMap& block_2_ctile_map,
1647-
index_t k_id)
1658+
const index_t k_id,
1659+
const index_t Kt)
16481660
{
16491661
ignore = b_element_op;
16501662
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1651-
const index_t Kt = problem.K + problem.KRead * (problem.KBatch - 1);
16521663
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
16531664
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
16541665
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);

0 commit comments

Comments
 (0)