@@ -68,7 +68,7 @@ template <
6868 class CopyOpR2S_
6969>
7070class CollectiveEpilogue <
71- IntelXeXMX16 ,
71+ IntelXeGeneric ,
7272 CtaTileMNK_,
7373 ElementC_,
7474 StrideC_,
@@ -86,7 +86,7 @@ class CollectiveEpilogue<
8686 //
8787 // Type Aliases
8888 //
89- using DispatchPolicy = IntelXeXMX16 ;
89+ using DispatchPolicy = IntelXeGeneric ;
9090 using CtaTileMNK = CtaTileMNK_;
9191 using FusionCallbacks = FusionCallbacks_;
9292 using ElementC = ElementC_;
@@ -101,9 +101,13 @@ class CollectiveEpilogue<
101101 using CopyOpR2S = CopyOpR2S_;
102102
103103 using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits<FusionCallbacks>::Operation;
104- using GmemTiledCopyC = conditional_t <cute::is_void_v<CopyOpG2R>, XE_2D_U32x8x16_LD_N , CopyOpG2R>;
104+ using GmemTiledCopyC = conditional_t <cute::is_void_v<CopyOpG2R>, XE_LOAD_2D< 32 , 8 , 16 > , CopyOpG2R>;
105105 using GmemTiledCopyD = cute::conditional_t <not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>,
106- CopyOpR2G, XE_2D_U32x8x16_ST_N>;
106+ CopyOpR2G, XE_STORE_2D<32 , 8 , 16 >>;
107+ static_assert (std::is_same_v<GmemTiledCopyC, XE_LOAD_2D<32 , 8 , 16 >>,
108+ " Current epilogue implementation only support load op XE_LOAD_2D<32, 8, 16>" );
109+ static_assert (std::is_same_v<GmemTiledCopyD, XE_STORE_2D<32 , 8 , 16 >>,
110+ " Current epilogue implementation only support store op XE_STORE_2D<32, 8, 16>" );
107111 using ElementOutput = ElementD;
108112 using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
109113 using ElementAccumulator = ElementCompute;
@@ -118,22 +122,11 @@ class CollectiveEpilogue<
118122 static_assert (std::is_same_v<SmemLayoutAtomC, void >, " Copy operation to shared memory is not supported" );
119123 static_assert (std::is_same_v<SmemLayoutAtomD, void >, " Copy operation to shared memory is not supported" );
120124
121- using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;
122-
123- using Trait_D = Copy_Traits<GmemTiledCopyD, StrideD>;
124- using val_layout_store_D = decltype (make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{})));
125- using XE_Copy_D = decltype (make_tiled_copy(Copy_Atom<Trait_D, ElementD>{}, Layout<CopyThreadShape>{}, val_layout_store_D{}));
126-
127125private:
128126 constexpr static bool is_source_supported = not cute::is_void_v<ElementC> && not cute::is_void_v<CopyOpG2R>;
129127 constexpr static bool is_destination_supported = not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>;
130128
131129 using NonVoidElementC = conditional_t <is_source_supported, ElementC, ElementD>;
132- using Trait_C = Copy_Traits<GmemTiledCopyC, StrideC>;
133- using NonVoidTrait_C = conditional_t <is_source_supported, Trait_C, Trait_D>;
134- using val_layout_load_C = decltype (make_layout(shape_div(typename NonVoidTrait_C::BlockShape{}, CopyThreadShape{})));
135- using NonVoidValLayoutLoad_C = conditional_t <is_source_supported, val_layout_load_C, val_layout_store_D>;
136- using XE_Copy_C = decltype (make_tiled_copy(Copy_Atom<NonVoidTrait_C, NonVoidElementC>{}, Layout<CopyThreadShape>{}, NonVoidValLayoutLoad_C{}));
137130
138131 constexpr static bool is_m_major_C = detail::is_m_major<StrideC>();
139132 constexpr static bool is_m_major_D = detail::is_m_major<StrideD>();
@@ -156,6 +149,15 @@ class CollectiveEpilogue<
156149 };
157150 using TensorStorage = typename SharedStorage::TensorStorage;
158151
152+ // Helper to get tensor types
153+ template <class Element , class Stride >
154+ using TensorTypeC = decltype (make_tensor(make_gmem_ptr(static_cast <Element const *>(nullptr )),
155+ make_layout (make_shape(int {}, int {}, int {}), Stride{})));
156+
157+ template <class Element , class Stride >
158+ using TensorTypeD = decltype (make_tensor(make_gmem_ptr(static_cast <Element*>(nullptr )),
159+ make_layout (make_shape(int {}, int {}, int {}), Stride{})));
160+
159161 // Host side epilogue arguments
160162 struct Arguments {
161163 typename FusionCallbacks::Arguments thread{};
@@ -168,8 +170,8 @@ class CollectiveEpilogue<
168170 // Device side epilogue params
169171 struct Params {
170172 typename FusionCallbacks::Params thread{};
171- XE_Copy_C xe_load_c ;
172- XE_Copy_D xe_store_d ;
173+ TensorTypeC<ElementC, StrideC> mC ;
174+ TensorTypeD<ElementD, StrideD> mD ;
173175 };
174176
175177 //
@@ -185,23 +187,13 @@ class CollectiveEpilogue<
185187 // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
186188 auto problem_shape_MNKL = append<4 >(problem_shape, 1 );
187189 auto [M, N, K, L] = problem_shape_MNKL;
188-
189- XE_Copy_C xe_load_c = {};
190- if constexpr (is_source_supported) {
191- auto mC = make_tensor (make_gmem_ptr (args.ptr_C ), make_layout (make_shape (M, N, L), args.dC ));
192- xe_load_c = {xe_load_c.with (mC )};
193- }
194-
195- XE_Copy_D xe_store_d = {};
196- if constexpr (is_destination_supported) {
197- auto mD = make_tensor (make_gmem_ptr (args.ptr_D ), make_layout (make_shape (M, N, L), args.dD ));
198- xe_store_d = {xe_store_d.with (mD )};
199- }
190+ auto mC = make_tensor (make_gmem_ptr (args.ptr_C ), make_layout (make_shape (M, N, L), args.dC ));
191+ auto mD = make_tensor (make_gmem_ptr (args.ptr_D ), make_layout (make_shape (M, N, L), args.dD ));
200192
201193 return {
202194 FusionCallbacks::to_underlying_arguments (problem_shape, args.thread , workspace),
203- xe_load_c ,
204- xe_store_d,
195+ mC ,
196+ mD
205197 };
206198 }
207199
@@ -272,6 +264,37 @@ class CollectiveEpilogue<
272264 return fusion_callbacks.is_producer_load_needed ();
273265 }
274266
267+ template <typename Tensor>
268+ CUTLASS_DEVICE auto reshape_with_unit_insertion (Tensor&& tensor) {
269+ using namespace cute ;
270+
271+ auto orig_layout = tensor.layout ();
272+ auto orig_shape = orig_layout.shape ();
273+ auto orig_stride = orig_layout.stride ();
274+
275+ auto first_dim = get<0 >(orig_shape);
276+ auto outer_part = get<0 >(first_dim);
277+ auto inner_part = get<1 >(first_dim);
278+
279+ auto first_stride = get<0 >(orig_stride);
280+ auto outer_stride = get<0 >(first_stride);
281+ auto inner_stride = get<1 >(first_stride);
282+
283+ auto target_shape = make_shape (
284+ make_shape (outer_part, _1{}),
285+ get<0 >(inner_part),
286+ get<1 >(inner_part)
287+ );
288+
289+ auto target_stride = make_stride (
290+ make_stride (outer_stride, _0{}),
291+ get<0 >(inner_stride),
292+ get<1 >(inner_stride)
293+ );
294+
295+ return make_tensor (tensor.data (), make_layout (target_shape, target_stride));
296+ }
297+
275298 template <
276299 class ProblemShapeMNKL ,
277300 class TileShapeMNK ,
@@ -288,7 +311,6 @@ class CollectiveEpilogue<
288311 TiledMma tiled_mma,
289312 int thread_idx) {
290313
291- (void ) tiled_mma;
292314 using namespace cute ;
293315
294316 static_assert (cute::rank (CtaTileMNK{}) == 3 , " CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]" );
@@ -299,12 +321,11 @@ class CollectiveEpilogue<
299321 static constexpr auto BLK_M = get<0 >(CtaTileMNK{});
300322 static constexpr auto BLK_N = get<1 >(CtaTileMNK{});
301323 static constexpr auto BLK_K = get<2 >(CtaTileMNK{});
302- // static_assert(is_same_v<typename TiledMma::ThrLayoutVMNK, int>, "assertation fail");
303324 static constexpr auto ATOM_M = get<1 >(typename TiledMma::ThrLayoutVMNK{}.shape ());
304325 static constexpr auto ATOM_N = get<2 >(typename TiledMma::ThrLayoutVMNK{}.shape ());
305326 static constexpr auto ATOM_K = get<3 >(typename TiledMma::ThrLayoutVMNK{}.shape ());
306-
307- static_assert (
327+
328+ static_assert (
308329 BLK_M % ATOM_M == 0 &&
309330 BLK_N % ATOM_N == 0 &&
310331 BLK_K % ATOM_K == 0 ,
@@ -318,46 +339,53 @@ class CollectiveEpilogue<
318339 static constexpr int FragsN = get<1 >(SubgroupTileShape{}) / get<1 >(MmaAtomShape ()); // B frags per sub_group
319340
320341 static constexpr int FragmentSize = (get<0 >(MmaAtomShape ()) * get<1 >(MmaAtomShape ())) / SubgroupSize;
321-
342+
322343 // Indexing variables
323344 auto [M, N, K, L] = problem_shape_mnkl;
324345 auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
325- auto m_sg = get_sub_group_id () / ATOM_N;
326- auto n_sg = get_sub_group_id () % ATOM_N;
327-
328- auto mn_shape = shape (typename decltype (params.xe_store_d )::Tiler_MN{});
329346
330347 auto sg_local_m_coord = get_sub_group_id () / ATOM_N;
331348 auto sg_local_n_coord = get_sub_group_id () % ATOM_N;
332349
333350 auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord;
334351 auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord;
335352 auto sg_coord = make_coord (sg_m_coord, sg_n_coord, k_coord, l_coord);
336-
353+
354+ auto wg_coord = make_coord (m_coord, n_coord, k_coord, l_coord);
337355 bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed ();
338356
357+ /*
358+ * NOTE: Automatic selection of load/store operations using make_block_2d_copy_C/make_block_2d_copy_D
359+ * is currently not supported. The current implementation is restricted to specific load/store
360+ * operations with dimensions 16x8, which are tightly coupled to the MMA atom size requirements.
361+ *
362+ * TODO: Future enhancement will include automatic selection of load/store operations
363+ * in collectiveEpilogue to provide more flexible dimension support.
364+ */
365+ auto batch_idx = get<3 >(wg_coord);
366+ auto copy_c = make_block_2d_copy_CD (GmemTiledCopyC{}, tiled_mma, params.mC (_,_,batch_idx));
367+ auto copy_d = make_block_2d_copy_CD (GmemTiledCopyD{}, tiled_mma, params.mD (_,_,batch_idx));
368+
339369 // Represent the full output tensor
340370 Tensor mD_mnl = cute::get_xe_tensor (make_shape (M,N,L));
341371
342- // Tile the output tensor per WG and select the tile for current WG
343- Tensor g_wg_D = local_tile (mD_mnl , take<0 ,2 >(CtaTileMNK{}), make_coord (m_coord,n_coord,l_coord)); // (BLK_M,BLK_N)
344-
345- // Tile the output tensor per SG and select tile for the current SG
346- Tensor gD = local_tile (g_wg_D, take<0 ,2 >(SubgroupTileShape{}), make_coord (m_sg,n_sg)); // (SG_M,SG_N)
372+ // Tile the output tensor for the current workgroup
373+ Tensor gD = local_tile (mD_mnl , take<0 ,2 >(CtaTileMNK{}), remove<2 >(wg_coord)); // (BLK_M,BLK_N)
347374
348- auto thread_xe_load_c = params. xe_load_c .get_thread_slice (thread_idx);
349- Tensor tCgC = thread_xe_load_c.partition_S (gD );
375+ auto thread_xe_load_c = copy_c .get_thread_slice (thread_idx);
376+ Tensor tCgC = reshape_with_unit_insertion ( thread_xe_load_c.partition_S (gD ) );
350377
351- auto thread_xe_store_d = params. xe_store_d .get_thread_slice (thread_idx);
352- Tensor tCgD = thread_xe_store_d.partition_D (gD );
378+ auto thread_xe_store_d = copy_d .get_thread_slice (thread_idx);
379+ Tensor tCgD = reshape_with_unit_insertion ( thread_xe_store_d.partition_D (gD ) );
353380
354381 Tensor trC = make_tensor<NonVoidElementC>(Shape<Int<FragmentSize>>{});
355382 Tensor trD_compute = make_tensor<ElementCompute>(Shape<Int<FragmentSize>>{});
356383
357384 // Because Sm90 uses shared memory, they are not tied to using the same accumulator values
358385 // for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be
359386 // sure that we are operating on the same values.
360- ThrCopy thread_g2r = params.xe_load_c .get_slice (thread_idx);
387+ ThrCopy thread_g2r = copy_c.get_slice (thread_idx);
388+ auto mn_shape = shape (typename decltype (copy_d)::Tiler_MN{});
361389
362390 // OOB predication for tile quantization "residue"
363391 // Absolute coordinate tensors (dynamic)
@@ -366,7 +394,7 @@ class CollectiveEpilogue<
366394 Tensor cD_mn = local_tile (mD_crd , take<0 ,2 >(CtaTileMNK{}), make_coord (m_coord, n_coord)); // (CTA_M,CTA_N)
367395 Tensor tRS_cD_mn = thread_g2r.partition_S (flat_divide (cD_mn, mn_shape)); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N)
368396
369- Tensor tRS_cD = make_coord_tensor (tRS_cD_mn.layout ()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N)
397+ Tensor tRS_cD = make_coord_tensor (tRS_cD_mn.layout ());
370398
371399 // Get the fusion callbacks
372400 // Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles
@@ -378,7 +406,7 @@ class CollectiveEpilogue<
378406 sg_coord,
379407 tiled_mma,
380408 mn_shape,
381- params. xe_store_d ,
409+ copy_d ,
382410 cD,
383411 residue_mn,
384412 tRS_cD,
@@ -403,19 +431,20 @@ class CollectiveEpilogue<
403431 FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K;
404432 constexpr int MN = get<0 >(CtaTileMNK{}) * get<1 >(CtaTileMNK{});
405433 static_assert (ValuesLoaded == MN, " the total elements loaded by all threads should be the same as MxN" );
406-
434+
435+
407436 auto synchronize = [&] () {};
408437 CUTLASS_PRAGMA_UNROLL
409438 for (int epi_n = 0 ; epi_n < FragsN; epi_n++) {
410439 CUTLASS_PRAGMA_UNROLL
411440 for (int epi_m = 0 ; epi_m < FragsM; epi_m++) {
412441 cst_callbacks.begin_loop (epi_m, epi_n);
413-
442+
414443 // Instead of calling is_C_load_needed. We do heirachical check
415444 // so that runtime check not there when ElementC is void
416445 if constexpr (is_source_supported) {
417446 if (is_C_load_needed) {
418- copy (params. xe_load_c , tCgC (_, epi_m, epi_n), trC);
447+ copy (copy_c , tCgC (_, epi_m, epi_n), trC);
419448 }
420449 }
421450
@@ -428,21 +457,23 @@ class CollectiveEpilogue<
428457 trD_compute_frag (epi_v) = cst_callbacks.visit (acc_frag_mn (epi_v), epi_v, epi_m, epi_n);
429458 }
430459 cst_callbacks.reduce (nullptr , synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1 ), trD_compute_frag);
431-
460+
432461 if constexpr (is_destination_supported) {
433462 CUTLASS_PRAGMA_UNROLL
434463 for (int i = 0 ; i < size (trD_compute_frag); ++i) {
435464 trD_frag (i) = cutlass::NumericArrayConverter<ElementOutput, RegisterElementD, FragmentSize>{}(trD_compute_frag (i));
436465 }
437- copy (params. xe_store_d , trD, tCgD (_, epi_m, epi_n));
466+ copy (copy_d , trD, tCgD (_, epi_m, epi_n));
438467 }
439468
440469 cst_callbacks.end_loop (epi_m, epi_n);
470+
441471 }
442472 }
443473
444474 cst_callbacks.end ();
445- }
475+
476+ }
446477
447478private:
448479 Params const & params;
0 commit comments