Skip to content

Commit e2fde37

Browse files
Use newer version of copy_atom in epilogue collective (#573)
This PR introduces collective epilogue flow implemented with new Copy Atom. Example test used: 00_bmg_gemm.cpp
1 parent fee297d commit e2fde37

File tree

7 files changed

+627
-64
lines changed

7 files changed

+627
-64
lines changed

examples/00_bmg_gemm/00_bmg_gemm.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,10 @@ int main(int argc, const char** argv)
350350
// Refer https://github.com/intel/sycl-tla/blob/main/media/docs/cpp/xe_rearchitecture.md
351351
using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>;
352352
using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>;
353+
using GmemTiledCopyC = XE_LOAD_2D<32, 8, 16>;
354+
using GmemTiledCopyD = XE_STORE_2D<32, 8, 16>;
355+
356+
353357

354358
// Workgroup-level tile
355359
using TileShape = Shape<_256, _256, _32>;
@@ -369,9 +373,8 @@ int main(int argc, const char** argv)
369373

370374
// For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B.
371375
constexpr int PipelineStages = 2;
372-
// For older version of copy/mma atom, use cutlass::gemm::MainloopIntelXeXMX16 as dispatch policy
373376
using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged<PipelineStages>;
374-
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
377+
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric;
375378

376379
// This is the 'default' epilogue operation (Linear Combination) which performs everything in:
377380
// (D = alpha * (A*B) + beta * C)
@@ -394,9 +397,9 @@ int main(int argc, const char** argv)
394397
ElementOutput,
395398
cutlass::gemm::TagToStrideC_t<LayoutD>, // Converts CUTLASS 2.x to CUTLASS 3.x representation
396399
FusionCallBacks,
397-
XE_2D_U32x8x16_LD_N, // The copy atom used to load matrix C
400+
GmemTiledCopyC, // The copy atom used to load matrix C
398401
void, void,
399-
XE_2D_U32x8x16_ST_N, // The copy atom used to store matrix D
402+
GmemTiledCopyD, // The copy atom used to store matrix D
400403
void, void>;
401404

402405
// GEMM Mainloop - iteration over blocks in K dimension

include/cute/atom/copy_traits_xe_2d.hpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1297,12 +1297,22 @@ template <class CopyOp, class TiledMMA, class CTensor>
12971297
auto get_block_2d_copy_C(TiledMMA const& tiled_mma, CTensor const& c_tensor)
12981298
{
12991299
if constexpr (!std::is_void_v<CopyOp>) {
1300-
return make_block_2d_copy_C(CopyOp{}, tiled_mma, c_tensor);
1300+
return make_block_2d_copy_CD(CopyOp{}, tiled_mma, c_tensor);
13011301
} else {
13021302
return make_block_2d_copy_C(tiled_mma, c_tensor);
13031303
}
13041304
}
13051305

1306+
template <class CopyOp, class TiledMMA, class DTensor>
1307+
auto get_block_2d_copy_D(TiledMMA const& tiled_mma, DTensor const& d_tensor)
1308+
{
1309+
if constexpr (!std::is_void_v<CopyOp>) {
1310+
return make_block_2d_copy_CD(CopyOp{}, tiled_mma, d_tensor);
1311+
} else {
1312+
return make_block_2d_copy_D(tiled_mma, d_tensor);
1313+
}
1314+
}
1315+
13061316
//
13071317
// Display utilities
13081318
//

include/cutlass/epilogue/collective/collective_epilogue.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class CollectiveEpilogue {
7171
#include "sm100_epilogue_array_tma_warpspecialized.hpp"
7272
#if defined (SYCL_INTEL_TARGET)
7373
#include "xe_epilogue.hpp"
74+
#include "xe_epilogue_legacy.hpp"
7475
#include "xe_array_epilogue.hpp"
7576
#endif
7677
//

include/cutlass/epilogue/collective/xe_epilogue.hpp

Lines changed: 90 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ template <
6868
class CopyOpR2S_
6969
>
7070
class CollectiveEpilogue<
71-
IntelXeXMX16,
71+
IntelXeGeneric,
7272
CtaTileMNK_,
7373
ElementC_,
7474
StrideC_,
@@ -86,7 +86,7 @@ class CollectiveEpilogue<
8686
//
8787
// Type Aliases
8888
//
89-
using DispatchPolicy = IntelXeXMX16;
89+
using DispatchPolicy = IntelXeGeneric;
9090
using CtaTileMNK = CtaTileMNK_;
9191
using FusionCallbacks = FusionCallbacks_;
9292
using ElementC = ElementC_;
@@ -101,9 +101,13 @@ class CollectiveEpilogue<
101101
using CopyOpR2S = CopyOpR2S_;
102102

103103
using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits<FusionCallbacks>::Operation;
104-
using GmemTiledCopyC = conditional_t<cute::is_void_v<CopyOpG2R>, XE_2D_U32x8x16_LD_N, CopyOpG2R>;
104+
using GmemTiledCopyC = conditional_t<cute::is_void_v<CopyOpG2R>, XE_LOAD_2D<32, 8, 16>, CopyOpG2R>;
105105
using GmemTiledCopyD = cute::conditional_t<not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>,
106-
CopyOpR2G, XE_2D_U32x8x16_ST_N>;
106+
CopyOpR2G, XE_STORE_2D<32, 8, 16>>;
107+
static_assert(std::is_same_v<GmemTiledCopyC, XE_LOAD_2D<32, 8, 16>>,
108+
"Current epilogue implementation only support load op XE_LOAD_2D<32, 8, 16>");
109+
static_assert(std::is_same_v<GmemTiledCopyD, XE_STORE_2D<32, 8, 16>>,
110+
"Current epilogue implementation only support store op XE_STORE_2D<32, 8, 16>");
107111
using ElementOutput = ElementD;
108112
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
109113
using ElementAccumulator = ElementCompute;
@@ -118,22 +122,11 @@ class CollectiveEpilogue<
118122
static_assert(std::is_same_v<SmemLayoutAtomC, void>, "Copy operation to shared memory is not supported");
119123
static_assert(std::is_same_v<SmemLayoutAtomD, void>, "Copy operation to shared memory is not supported");
120124

121-
using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;
122-
123-
using Trait_D = Copy_Traits<GmemTiledCopyD, StrideD>;
124-
using val_layout_store_D = decltype(make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{})));
125-
using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom<Trait_D, ElementD>{}, Layout<CopyThreadShape>{}, val_layout_store_D{}));
126-
127125
private:
128126
constexpr static bool is_source_supported = not cute::is_void_v<ElementC> && not cute::is_void_v<CopyOpG2R>;
129127
constexpr static bool is_destination_supported = not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>;
130128

131129
using NonVoidElementC = conditional_t<is_source_supported, ElementC, ElementD>;
132-
using Trait_C = Copy_Traits<GmemTiledCopyC, StrideC>;
133-
using NonVoidTrait_C = conditional_t<is_source_supported, Trait_C, Trait_D>;
134-
using val_layout_load_C = decltype(make_layout(shape_div(typename NonVoidTrait_C::BlockShape{}, CopyThreadShape{})));
135-
using NonVoidValLayoutLoad_C = conditional_t<is_source_supported, val_layout_load_C, val_layout_store_D>;
136-
using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom<NonVoidTrait_C, NonVoidElementC>{}, Layout<CopyThreadShape>{}, NonVoidValLayoutLoad_C{}));
137130

138131
constexpr static bool is_m_major_C = detail::is_m_major<StrideC>();
139132
constexpr static bool is_m_major_D = detail::is_m_major<StrideD>();
@@ -156,6 +149,15 @@ class CollectiveEpilogue<
156149
};
157150
using TensorStorage = typename SharedStorage::TensorStorage;
158151

152+
// Helper to get tensor types
153+
template<class Element, class Stride>
154+
using TensorTypeC = decltype(make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)),
155+
make_layout(make_shape(int{}, int{}, int{}), Stride{})));
156+
157+
template<class Element, class Stride>
158+
using TensorTypeD = decltype(make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)),
159+
make_layout(make_shape(int{}, int{}, int{}), Stride{})));
160+
159161
// Host side epilogue arguments
160162
struct Arguments {
161163
typename FusionCallbacks::Arguments thread{};
@@ -168,8 +170,8 @@ class CollectiveEpilogue<
168170
// Device side epilogue params
169171
struct Params {
170172
typename FusionCallbacks::Params thread{};
171-
XE_Copy_C xe_load_c;
172-
XE_Copy_D xe_store_d;
173+
TensorTypeC<ElementC, StrideC> mC;
174+
TensorTypeD<ElementD, StrideD> mD;
173175
};
174176

175177
//
@@ -185,23 +187,13 @@ class CollectiveEpilogue<
185187
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
186188
auto problem_shape_MNKL = append<4>(problem_shape, 1);
187189
auto [M, N, K, L] = problem_shape_MNKL;
188-
189-
XE_Copy_C xe_load_c = {};
190-
if constexpr (is_source_supported) {
191-
auto mC = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M, N, L), args.dC));
192-
xe_load_c = {xe_load_c.with(mC)};
193-
}
194-
195-
XE_Copy_D xe_store_d = {};
196-
if constexpr (is_destination_supported) {
197-
auto mD = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M, N, L), args.dD));
198-
xe_store_d = {xe_store_d.with(mD)};
199-
}
190+
auto mC = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M, N, L), args.dC));
191+
auto mD = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M, N, L), args.dD));
200192

201193
return {
202194
FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace),
203-
xe_load_c,
204-
xe_store_d,
195+
mC,
196+
mD
205197
};
206198
}
207199

@@ -272,6 +264,37 @@ class CollectiveEpilogue<
272264
return fusion_callbacks.is_producer_load_needed();
273265
}
274266

267+
template<typename Tensor>
268+
CUTLASS_DEVICE auto reshape_with_unit_insertion(Tensor&& tensor) {
269+
using namespace cute;
270+
271+
auto orig_layout = tensor.layout();
272+
auto orig_shape = orig_layout.shape();
273+
auto orig_stride = orig_layout.stride();
274+
275+
auto first_dim = get<0>(orig_shape);
276+
auto outer_part = get<0>(first_dim);
277+
auto inner_part = get<1>(first_dim);
278+
279+
auto first_stride = get<0>(orig_stride);
280+
auto outer_stride = get<0>(first_stride);
281+
auto inner_stride = get<1>(first_stride);
282+
283+
auto target_shape = make_shape(
284+
make_shape(outer_part, _1{}),
285+
get<0>(inner_part),
286+
get<1>(inner_part)
287+
);
288+
289+
auto target_stride = make_stride(
290+
make_stride(outer_stride, _0{}),
291+
get<0>(inner_stride),
292+
get<1>(inner_stride)
293+
);
294+
295+
return make_tensor(tensor.data(), make_layout(target_shape, target_stride));
296+
}
297+
275298
template<
276299
class ProblemShapeMNKL,
277300
class TileShapeMNK,
@@ -288,7 +311,6 @@ class CollectiveEpilogue<
288311
TiledMma tiled_mma,
289312
int thread_idx) {
290313

291-
(void) tiled_mma;
292314
using namespace cute;
293315

294316
static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]");
@@ -299,12 +321,11 @@ class CollectiveEpilogue<
299321
static constexpr auto BLK_M = get<0>(CtaTileMNK{});
300322
static constexpr auto BLK_N = get<1>(CtaTileMNK{});
301323
static constexpr auto BLK_K = get<2>(CtaTileMNK{});
302-
// static_assert(is_same_v<typename TiledMma::ThrLayoutVMNK, int>, "assertation fail");
303324
static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape());
304325
static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape());
305326
static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape());
306-
307-
static_assert(
327+
328+
static_assert(
308329
BLK_M % ATOM_M == 0 &&
309330
BLK_N % ATOM_N == 0 &&
310331
BLK_K % ATOM_K == 0,
@@ -318,46 +339,53 @@ class CollectiveEpilogue<
318339
static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group
319340

320341
static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize;
321-
342+
322343
// Indexing variables
323344
auto [M, N, K, L] = problem_shape_mnkl;
324345
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
325-
auto m_sg = get_sub_group_id() / ATOM_N;
326-
auto n_sg = get_sub_group_id() % ATOM_N;
327-
328-
auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{});
329346

330347
auto sg_local_m_coord = get_sub_group_id() / ATOM_N;
331348
auto sg_local_n_coord = get_sub_group_id() % ATOM_N;
332349

333350
auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord;
334351
auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord;
335352
auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord);
336-
353+
354+
auto wg_coord = make_coord(m_coord, n_coord, k_coord, l_coord);
337355
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();
338356

357+
/*
358+
* NOTE: Automatic selection of load/store operations using make_block_2d_copy_C/make_block_2d_copy_D
359+
* is currently not supported. The current implementation is restricted to specific load/store
360+
* operations with dimensions 16x8, which are tightly coupled to the MMA atom size requirements.
361+
*
362+
* TODO: Future enhancement will include automatic selection of load/store operations
363+
* in collectiveEpilogue to provide more flexible dimension support.
364+
*/
365+
auto batch_idx = get<3>(wg_coord);
366+
auto copy_c = make_block_2d_copy_CD(GmemTiledCopyC{}, tiled_mma, params.mC(_,_,batch_idx));
367+
auto copy_d = make_block_2d_copy_CD(GmemTiledCopyD{}, tiled_mma, params.mD(_,_,batch_idx));
368+
339369
// Represent the full output tensor
340370
Tensor mD_mnl = cute::get_xe_tensor(make_shape(M,N,L));
341371

342-
// Tile the output tensor per WG and select the tile for current WG
343-
Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N)
344-
345-
// Tile the output tensor per SG and select tile for the current SG
346-
Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); // (SG_M,SG_N)
372+
// Tile the output tensor for the current workgroup
373+
Tensor gD = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), remove<2>(wg_coord)); // (BLK_M,BLK_N)
347374

348-
auto thread_xe_load_c = params.xe_load_c.get_thread_slice(thread_idx);
349-
Tensor tCgC = thread_xe_load_c.partition_S(gD);
375+
auto thread_xe_load_c = copy_c.get_thread_slice(thread_idx);
376+
Tensor tCgC = reshape_with_unit_insertion(thread_xe_load_c.partition_S(gD));
350377

351-
auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx);
352-
Tensor tCgD = thread_xe_store_d.partition_D(gD);
378+
auto thread_xe_store_d = copy_d.get_thread_slice(thread_idx);
379+
Tensor tCgD = reshape_with_unit_insertion(thread_xe_store_d.partition_D(gD));
353380

354381
Tensor trC = make_tensor<NonVoidElementC>(Shape<Int<FragmentSize>>{});
355382
Tensor trD_compute = make_tensor<ElementCompute>(Shape<Int<FragmentSize>>{});
356383

357384
// Because Sm90 uses shared memory, they are not tied to using the same accumulator values
358385
// for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be
359386
// sure that we are operating on the same values.
360-
ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx);
387+
ThrCopy thread_g2r = copy_c.get_slice(thread_idx);
388+
auto mn_shape = shape(typename decltype(copy_d)::Tiler_MN{});
361389

362390
// OOB predication for tile quantization "residue"
363391
// Absolute coordinate tensors (dynamic)
@@ -366,7 +394,7 @@ class CollectiveEpilogue<
366394
Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N)
367395
Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, mn_shape)); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N)
368396

369-
Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N)
397+
Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout());
370398

371399
// Get the fusion callbacks
372400
// Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles
@@ -378,7 +406,7 @@ class CollectiveEpilogue<
378406
sg_coord,
379407
tiled_mma,
380408
mn_shape,
381-
params.xe_store_d,
409+
copy_d,
382410
cD,
383411
residue_mn,
384412
tRS_cD,
@@ -403,19 +431,20 @@ class CollectiveEpilogue<
403431
FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K;
404432
constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{});
405433
static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" );
406-
434+
435+
407436
auto synchronize = [&] () {};
408437
CUTLASS_PRAGMA_UNROLL
409438
for (int epi_n = 0; epi_n < FragsN; epi_n++) {
410439
CUTLASS_PRAGMA_UNROLL
411440
for (int epi_m = 0; epi_m < FragsM; epi_m++) {
412441
cst_callbacks.begin_loop(epi_m, epi_n);
413-
442+
414443
//Instead of calling is_C_load_needed. We do heirachical check
415444
//so that runtime check not there when ElementC is void
416445
if constexpr (is_source_supported) {
417446
if (is_C_load_needed) {
418-
copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC);
447+
copy(copy_c, tCgC(_, epi_m, epi_n), trC);
419448
}
420449
}
421450

@@ -428,21 +457,23 @@ class CollectiveEpilogue<
428457
trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n);
429458
}
430459
cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag);
431-
460+
432461
if constexpr (is_destination_supported) {
433462
CUTLASS_PRAGMA_UNROLL
434463
for (int i = 0; i < size(trD_compute_frag); ++i) {
435464
trD_frag(i) = cutlass::NumericArrayConverter<ElementOutput, RegisterElementD, FragmentSize>{}(trD_compute_frag(i));
436465
}
437-
copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n));
466+
copy(copy_d, trD, tCgD(_, epi_m, epi_n));
438467
}
439468

440469
cst_callbacks.end_loop(epi_m, epi_n);
470+
441471
}
442472
}
443473

444474
cst_callbacks.end();
445-
}
475+
476+
}
446477

447478
private:
448479
Params const& params;

0 commit comments

Comments
 (0)