@@ -43,10 +43,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
4343 {
4444 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte ()];
4545
46+ // Full K needed for matrix B
47+ const index_t Kt = karg.K ;
48+
4649 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset (karg, blockIdx.z );
47- const index_t num_k_per_block =
48- karg. K / ( GridwiseGemm::KLane * GridwiseGemm::KPackPerGroup );
49- const index_t k_id = blockIdx.z * num_k_per_block;
50+
51+ const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled (karg. K );
52+ const index_t k_id = blockIdx.z * num_k_per_block;
5053
5154 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
5255 karg.p_a_grid + splitk_batch_offset.a_k_split_offset ,
@@ -58,7 +61,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
5861 karg.a_element_op ,
5962 karg.b_element_op ,
6063 karg.c_element_op ,
61- k_id);
64+ k_id,
65+ Kt);
6266 }
6367#else
6468 ignore = karg;
@@ -83,10 +87,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
8387 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte ()];
8488 __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte ()];
8589
90+ // Full K needed for matrix B
91+ const index_t Kt = karg.K ;
92+
8693 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset (karg, blockIdx.z );
87- const index_t num_k_per_block =
88- karg. K / ( GridwiseGemm::KLane * GridwiseGemm::KPackPerGroup );
89- const index_t k_id = blockIdx.z * num_k_per_block;
94+
95+ const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled (karg. K );
96+ const index_t k_id = blockIdx.z * num_k_per_block;
9097
9198 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
9299 karg.p_a_grid + splitk_batch_offset.a_k_split_offset ,
@@ -99,7 +106,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
99106 karg.a_element_op ,
100107 karg.b_element_op ,
101108 karg.c_element_op ,
102- k_id);
109+ k_id,
110+ Kt);
103111 }
104112#else
105113 ignore = karg;
@@ -1172,7 +1180,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
11721180 AElementwiseOperation a_element_op,
11731181 BElementwiseOperation b_element_op,
11741182 CElementwiseOperation c_element_op,
1175- index_t k_id)
1183+ const index_t k_id,
1184+ const index_t Kt)
11761185 {
11771186 const auto block_2_ctile_map = Block2CTileMapDefault{problem.M , problem.N , 4 };
11781187 Run<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
@@ -1186,7 +1195,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
11861195 b_element_op,
11871196 c_element_op,
11881197 block_2_ctile_map,
1189- k_id);
1198+ k_id,
1199+ Kt);
11901200 }
11911201
11921202 template <typename Block2CTileMap,
@@ -1203,12 +1213,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
12031213 BElementwiseOperation b_element_op,
12041214 CElementwiseOperation c_element_op,
12051215 const Block2CTileMap& block_2_ctile_map,
1206- index_t k_id)
1216+ const index_t k_id,
1217+ const index_t Kt)
12071218 {
12081219 ignore = b_element_op;
12091220 index_t BN0Shuffled = CalculateBN0Shuffled (problem.N );
1210- // recompute K without splitK for matrix B
1211- const index_t Kt = problem.K + problem.KRead * (problem.KBatch - 1 );
12121221 index_t BK0Shuffled = CalculateBK0Shuffled (Kt);
12131222
12141223 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1 (
@@ -1611,7 +1620,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
16111620 AElementwiseOperation a_element_op,
16121621 BElementwiseOperation b_element_op,
16131622 CElementwiseOperation c_element_op,
1614- index_t k_id)
1623+ const index_t k_id,
1624+ const index_t Kt)
16151625 {
16161626 const auto block_2_ctile_map = Block2CTileMapDefault{problem.M , problem.N , 4 };
16171627 Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
@@ -1626,7 +1636,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
16261636 b_element_op,
16271637 c_element_op,
16281638 block_2_ctile_map,
1629- k_id);
1639+ k_id,
1640+ Kt);
16301641 }
16311642
16321643 template <typename Block2CTileMap,
@@ -1644,11 +1655,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
16441655 BElementwiseOperation b_element_op,
16451656 CElementwiseOperation c_element_op,
16461657 const Block2CTileMap& block_2_ctile_map,
1647- index_t k_id)
1658+ const index_t k_id,
1659+ const index_t Kt)
16481660 {
16491661 ignore = b_element_op;
16501662 index_t BN0Shuffled = CalculateBN0Shuffled (problem.N );
1651- const index_t Kt = problem.K + problem.KRead * (problem.KBatch - 1 );
16521663 index_t BK0Shuffled = CalculateBK0Shuffled (Kt);
16531664 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1 (
16541665 problem.M , problem.MPadded , problem.K , problem.KPadded , problem.StrideA , problem.AK0 );
0 commit comments