Skip to content

Commit 3578f8e

Browse files
committed
Add transposed load kernel through the new Xe Copy Atoms
1 parent 1f53721 commit 3578f8e

File tree

7 files changed

+254
-83
lines changed

7 files changed

+254
-83
lines changed

examples/cute/tutorial/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ if (CUTLASS_ENABLE_SYCL)
4747

4848
cutlass_example_add_executable(
4949
cute_tutorial_tiled_transpose
50-
transpose/tiled_transpose_sycl.cpp
50+
transpose/main.cpp
5151
)
5252

5353
cutlass_example_add_executable(
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#pragma once
2+
/***************************************************************************************************
3+
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
4+
* reserved. Copyright (C) 2025 Intel Corporation, All rights reserved.
5+
* SPDX-License-Identifier: BSD-3-Clause
6+
*
7+
* Redistribution and use in source and binary forms, with or without
8+
* modification, are permitted provided that the following conditions are met:
9+
*
10+
* 1. Redistributions of source code must retain the above copyright notice,
11+
* this list of conditions and the following disclaimer.
12+
*
13+
* 2. Redistributions in binary form must reproduce the above copyright notice,
14+
* this list of conditions and the following disclaimer in the documentation
15+
* and/or other materials provided with the distribution.
16+
*
17+
* 3. Neither the name of the copyright holder nor the names of its
18+
* contributors may be used to endorse or promote products derived from
19+
* this software without specific prior written permission.
20+
*
21+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24+
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
25+
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
26+
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
27+
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
28+
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
29+
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30+
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31+
* POSSIBILITY OF SUCH DAMAGE.
32+
*
33+
**************************************************************************************************/
34+
#include <cute/tensor.hpp>
35+
#include <cute/util/compat.hpp>
36+
#include <sycl/ext/intel/experimental/grf_size_properties.hpp>
37+
#include <sycl/sycl.hpp>
38+
39+
#include "cutlass/util/print_error.hpp"
40+
#include "util.h"
41+
42+
template <class TensorS, class TensorD, class BlockShape, class BlockShapeTrans,
43+
class ThreadLayout>
44+
void block2DTransposedLoadKernel(TensorS const S, TensorD const DT,
45+
BlockShape const block_shape,
46+
BlockShapeTrans const block_shape_transposed,
47+
ThreadLayout const thread_layout) {
48+
using namespace cute;
49+
using Element = typename TensorS::value_type;
50+
51+
/* get workgroup and local ids */
52+
auto item = sycl::ext::oneapi::this_work_item::get_nd_item<2>();
53+
auto wg_m = int(item.get_group(0));
54+
auto wg_n = int(item.get_group(1));
55+
auto local_id = int(item.get_local_id(0));
56+
57+
/* proxy coordinate tensor */
58+
Tensor cS = make_identity_tensor(S.shape()); // (M,N)
59+
Tensor cDT = make_identity_tensor(DT.shape()); // (N,M)
60+
61+
auto wg_coord = make_coord(wg_m, wg_n);
62+
auto wg_coord_transposed = make_coord(wg_n, wg_m);
63+
64+
// Tensor data = ... // ( M, N) Tensor cta_data = local_tile(data,
65+
// Shape<16, 16>{}, make_coord(blockIdx.x,blockIdx.y)); // (_32,_64)
66+
Tensor gS = local_tile(cS, block_shape, wg_coord); // (BLK_M,BLK_N)
67+
Tensor gDT = local_tile(cDT, block_shape_transposed,
68+
wg_coord_transposed); // (BLK_N,BLK_M);
69+
70+
constexpr int CopyBits = sizeof_bits_v<Element>;
71+
auto transposed_load_op = XE_LOAD_2D_TRANSPOSE<CopyBits, 8, 8>{};
72+
auto store_op = XE_STORE_2D<CopyBits, 8, 8>{};
73+
74+
/* Slice TiledCopy operations to thread (work-item) level */
75+
auto transpose_S = make_block_2d_copy(transposed_load_op, S);
76+
auto thr_transpose_S = transpose_S.get_slice(local_id);
77+
78+
auto store_DT = make_block_2d_copy(store_op, DT);
79+
auto thr_copy_DT = store_DT.get_slice(local_id);
80+
81+
/* Register fragments for transposed copy */
82+
auto tSrS = thr_transpose_S.partition_sg_fragment_D(gS);
83+
auto tDrD = thr_copy_DT.partition_sg_fragment_D(gDT);
84+
85+
/* Partition global tensor (proxies) for copies */
86+
Tensor tSgS = thr_transpose_S.partition_S(gS);
87+
Tensor tDgD = thr_copy_DT.partition_D(gDT);
88+
89+
// if ( cute::thread(0, 0)){
90+
// print(tSgS);print("\n");
91+
// print(tSrS);print("\n");
92+
// print(tDgD);print("\n");
93+
// }
94+
95+
copy(transpose_S, tSgS, tSrS);
96+
// copy(tSrS, tDrD);
97+
copy(store_DT, tSrS, tDgD);
98+
}
99+
100+
class TransposeCuteName;
101+
template <typename Element>
102+
void block_2d_transposed_copy(TransposeParams<Element> params) {
103+
104+
using namespace cute;
105+
//
106+
// Make Tensors
107+
//
108+
auto tensor_shape = make_shape(params.M, params.N);
109+
auto tensor_shape_trans = make_shape(params.N, params.M);
110+
auto gmemLayoutS = make_layout(tensor_shape, LayoutRight{});
111+
auto gmemLayoutD = make_layout(tensor_shape_trans, LayoutRight{});
112+
Tensor tensor_S = make_tensor(make_gmem_ptr(params.input), gmemLayoutS);
113+
Tensor tensor_DT = make_tensor(make_gmem_ptr(params.output), gmemLayoutD);
114+
115+
// Make a transposed view of the output
116+
// auto gmemLayoutDT = make_layout(tensor_shape, GenColMajor{});
117+
// Tensor tensor_DT = make_tensor(make_gmem_ptr(params.output), gmemLayoutDT);
118+
119+
sycl::queue Q;
120+
121+
//
122+
// Tile tensors
123+
//
124+
125+
using bM = Int<32>;
126+
using bN = Int<8>;
127+
128+
auto block_shape = make_shape(bM{}, bN{}); // (bM, bN)
129+
auto block_shape_trans = make_shape(bN{}, bM{}); // (bN, bM)
130+
131+
sycl::range<2> local = {bM{}, 1};
132+
sycl::range<2> global = {local[0] * ceil_div(shape<0>(tensor_S), bM{}),
133+
local[1] * ceil_div(shape<1>(tensor_S), bN{})};
134+
135+
auto threadLayout = make_layout(make_shape(bM{}, Int<1>{}), LayoutRight{});
136+
137+
namespace syclex = sycl::ext::oneapi::experimental;
138+
namespace intelex = sycl::ext::intel::experimental;
139+
140+
syclex::properties kernel_props{syclex::sub_group_size<16>,
141+
intelex::grf_size<256>};
142+
143+
auto event = Q.parallel_for<TransposeCuteName>(
144+
sycl::nd_range<2>(global, local), kernel_props, [=](auto) {
145+
block2DTransposedLoadKernel(tensor_S, tensor_DT, block_shape,
146+
block_shape_trans, threadLayout);
147+
});
148+
};

examples/cute/tutorial/transpose/copy_direct.h

Lines changed: 87 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,22 @@ void copy_kernel(TensorS S, TensorD D, ThreadLayout) {
5151
using namespace cute;
5252

5353
// Slice the tiled tensors
54+
// This line slices the tiled tensor S to get the tile for the current work
55+
// group. S is a 3D tensor with layout ((M, N), m', n') where:
56+
// - (M, N) is the block/tile shape (first mode)
57+
// - m' is the number of tiles in the M dimension (second mode)
58+
// - n' is the number of tiles in the N dimension (third mode)
59+
//
60+
// The indexing S(make_coord(_, _), x, y) selects:
61+
// - make_coord(_, _): Takes all elements from the first mode (M, N), i.e.,
62+
// the entire tile
63+
// - compat::work_group_id::x(): Selects the x-th tile along the m'
64+
// dimension
65+
// - compat::work_group_id::y(): Selects the y-th tile along the n'
66+
// dimension
67+
//
68+
// Result: A 2D tensor of shape (BlockShape_M, BlockShape_N) corresponding to
69+
// the tile assigned to the current work group.
5470
Tensor tile_S = S(make_coord(_, _), compat::work_group_id::x(),
5571
compat::work_group_id::y()); // (BlockShape_M, BlockShape_N)
5672
Tensor tile_D = D(make_coord(_, _), compat::work_group_id::x(),
@@ -64,78 +80,78 @@ void copy_kernel(TensorS S, TensorD D, ThreadLayout) {
6480
tile_S, ThreadLayout{}, compat::local_id::x()); // (ThrValM, ThrValN)
6581
Tensor thr_tile_D = local_partition(
6682
tile_D, ThreadLayout{}, compat::local_id::x()); // (ThrValM, ThrValN)
67-
//
68-
69-
// Construct a register-backed Tensor with the same shape as each thread's
70-
// partition Use make_tensor to try to match the layout of thr_tile_S
71-
Tensor fragment = make_tensor_like(thr_tile_S); // (ThrValM, ThrValN)
72-
73-
// Copy from GMEM to RMEM and from RMEM to GMEM
74-
copy(thr_tile_S, fragment);
75-
copy(fragment, thr_tile_D);
76-
}
77-
78-
template <typename Element> void copy_direct(TransposeParams<Element> params) {
79-
//
80-
// Given a 2D shape, perform an efficient copy
81-
//
82-
83-
using namespace cute;
84-
85-
//
86-
// Make tensors
87-
//
88-
auto tensor_shape = make_shape(params.M, params.N);
89-
auto gmemLayoutS = make_layout(tensor_shape, LayoutRight{});
90-
auto gmemLayoutD = make_layout(tensor_shape, LayoutRight{});
91-
Tensor tensor_S = make_tensor(make_gmem_ptr(params.input), gmemLayoutS);
92-
Tensor tensor_D = make_tensor(make_gmem_ptr(params.output), gmemLayoutD);
93-
94-
//
95-
// Tile tensors
96-
//
9783

98-
// Define a statically sized block (M, N).
99-
// Note, by convention, capital letters are used to represent static modes.
100-
auto block_shape = make_shape(Int<1>{}, Int<16384>{});
84+
// Construct a register-backed Tensor with the same shape as each thread's
85+
// partition Use make_tensor to try to match the layout of thr_tile_S
86+
Tensor fragment = make_tensor_like(thr_tile_S); // (ThrValM, ThrValN)
10187

102-
if ((size<0>(tensor_shape) % size<0>(block_shape)) ||
103-
(size<1>(tensor_shape) % size<1>(block_shape))) {
104-
std::cerr << "The tensor shape must be divisible by the block shape."
105-
<< std::endl;
106-
}
107-
// Equivalent check to the above
108-
if (not evenly_divides(tensor_shape, block_shape)) {
109-
std::cerr << "Expected the block_shape to evenly divide the tensor shape."
110-
<< std::endl;
88+
// Copy from GMEM to RMEM and from RMEM to GMEM
89+
copy(thr_tile_S, fragment);
90+
copy(fragment, thr_tile_D);
11191
}
11292

113-
// Tile the tensor (m, n) ==> ((M, N), m', n') where (M, N) is the static tile
114-
// shape, and modes (m', n') correspond to the number of tiles.
115-
//
116-
// These will be used to determine the CUDA kernel grid dimensions.
117-
Tensor tiled_tensor_S =
118-
tiled_divide(tensor_S, block_shape); // ((M, N), m', n')
119-
Tensor tiled_tensor_D =
120-
tiled_divide(tensor_D, block_shape); // ((M, N), m', n')
121-
122-
// Thread arrangement
123-
Layout thr_layout =
124-
make_layout(make_shape(Int<1>{}, Int<1024>{}), LayoutRight{});
125-
126-
//
127-
// Determine grid and block dimensions
128-
//
129-
130-
auto gridDim = compat::dim3(
131-
size<1>(tiled_tensor_S),
132-
size<2>(tiled_tensor_S)); // Grid shape corresponds to modes m' and n'
133-
auto blockDim = compat::dim3(size(thr_layout));
134-
135-
//
136-
// Launch the kernel
137-
//
138-
compat::launch<copy_kernel<decltype(tiled_tensor_S), decltype(tiled_tensor_D),
139-
decltype(thr_layout)>>(
140-
gridDim, blockDim, tiled_tensor_S, tiled_tensor_D, thr_layout);
141-
}
93+
template <typename Element>
94+
void copy_direct(TransposeParams<Element> params) {
95+
//
96+
// Given a 2D shape, perform an efficient copy
97+
//
98+
99+
using namespace cute;
100+
101+
//
102+
// Make tensors
103+
//
104+
auto tensor_shape = make_shape(params.M, params.N);
105+
auto gmemLayoutS = make_layout(tensor_shape, LayoutRight{});
106+
auto gmemLayoutD = make_layout(tensor_shape, LayoutRight{});
107+
Tensor tensor_S = make_tensor(make_gmem_ptr(params.input), gmemLayoutS);
108+
Tensor tensor_D = make_tensor(make_gmem_ptr(params.output), gmemLayoutD);
109+
110+
//
111+
// Tile tensors
112+
//
113+
114+
// Define a statically sized block (M, N).
115+
// Note, by convention, capital letters are used to represent static modes.
116+
auto block_shape = make_shape(Int<1>{}, Int<16384>{});
117+
118+
if ((size<0>(tensor_shape) % size<0>(block_shape)) ||
119+
(size<1>(tensor_shape) % size<1>(block_shape))) {
120+
std::cerr << "The tensor shape must be divisible by the block shape."
121+
<< std::endl;
122+
}
123+
// Equivalent check to the above
124+
if (not evenly_divides(tensor_shape, block_shape)) {
125+
std::cerr << "Expected the block_shape to evenly divide the tensor shape."
126+
<< std::endl;
127+
}
128+
129+
// Tile the tensor (m, n) ==> ((M, N), m', n') where (M, N) is the static
130+
// tile shape, and modes (m', n') correspond to the number of tiles.
131+
//
132+
// These will be used to determine the CUDA kernel grid dimensions.
133+
Tensor tiled_tensor_S =
134+
tiled_divide(tensor_S, block_shape); // ((M, N), m', n')
135+
Tensor tiled_tensor_D =
136+
tiled_divide(tensor_D, block_shape); // ((M, N), m', n')
137+
138+
// Thread arrangement
139+
Layout thr_layout =
140+
make_layout(make_shape(Int<1>{}, Int<1024>{}), LayoutRight{});
141+
142+
//
143+
// Determine grid and block dimensions
144+
//
145+
146+
auto gridDim = compat::dim3(
147+
size<1>(tiled_tensor_S),
148+
size<2>(tiled_tensor_S)); // Grid shape corresponds to modes m' and n'
149+
auto blockDim = compat::dim3(size(thr_layout));
150+
151+
//
152+
// Launch the kernel
153+
//
154+
compat::launch<copy_kernel<decltype(tiled_tensor_S),
155+
decltype(tiled_tensor_D), decltype(thr_layout)>>(
156+
gridDim, blockDim, tiled_tensor_S, tiled_tensor_D, thr_layout);
157+
}

examples/cute/tutorial/transpose/copy_smem.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ void copySmemKernel(TensorS const S, TensorD const D, ThreadLayout,
6262
Tensor gS = S(make_coord(_, _), compat::work_group_id::x(),
6363
compat::work_group_id::y()); // (bM, bN)
6464
Tensor gD = D(make_coord(_, _), compat::work_group_id::x(),
65-
compat::work_group_id::y()); // (bN, bM)
65+
compat::work_group_id::y()); // (bM, bN)
6666

6767
Tensor sS = make_tensor(make_smem_ptr(shared_storage.smem.data()),
68-
SmemLayout{}); // (bN, bM)
68+
SmemLayout{}); // (bM, bN)
6969

7070
auto tiled_copy_load = make_tiled_copy(
7171
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
@@ -83,7 +83,7 @@ void copySmemKernel(TensorS const S, TensorD const D, ThreadLayout,
8383
Tensor tSgS = thr_copy_load.partition_S(gS);
8484
Tensor tSsS = thr_copy_load.partition_D(sS);
8585
//
86-
Tensor tDsS = thr_copy_store.partition_D(sS);
86+
Tensor tDsD = thr_copy_store.partition_S(sS);
8787
Tensor tDgD = thr_copy_store.partition_D(gD);
8888

8989
copy(tiled_copy_load, tSgS, tSsS);
@@ -92,7 +92,7 @@ void copySmemKernel(TensorS const S, TensorD const D, ThreadLayout,
9292
cp_async_wait<0>();
9393
syncthreads();
9494
//
95-
copy(tiled_copy_store, tDsS, tDgD);
95+
copy(tiled_copy_store, tDsD, tDgD);
9696
}
9797

9898
template <typename Element> void copy_smem(TransposeParams<Element> params) {

examples/cute/tutorial/transpose/main.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "block_2d_transposed_copy.h"
12
#include "copy_direct.h"
23
#include "copy_smem.h"
34
#include "transpose_naive.h"
@@ -25,5 +26,9 @@ int main(int argc, char const **argv) {
2526
printf("\nTranspose through SMEM.:\n");
2627
benchmark<Element>(transpose_smem<Element>, M, N, iterations);
2728

29+
printf("Block 2d Transposed load\n");
30+
benchmark<Element, true, false, false>(block_2d_transposed_copy<Element>, M,
31+
N, iterations);
32+
2833
return 0;
2934
}

examples/cute/tutorial/transpose/transpose_naive.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,8 @@ void transpose_naive(TransposeParams<Element> params) {
6969
// Make Tensors
7070
//
7171
auto tensor_shape = make_shape(params.M, params.N);
72-
auto tensor_shape_trans = make_shape(params.N, params.M);
7372
auto gmemLayoutS = make_layout(tensor_shape, LayoutRight{});
74-
auto gmemLayoutD = make_layout(tensor_shape_trans, LayoutRight{});
7573
Tensor tensor_S = make_tensor(make_gmem_ptr(params.input), gmemLayoutS);
76-
Tensor tensor_D = make_tensor(make_gmem_ptr(params.output), gmemLayoutD);
7774

7875
// Make a transposed view of the output
7976
auto gmemLayoutDT = make_layout(tensor_shape, GenColMajor{});
@@ -87,7 +84,6 @@ void transpose_naive(TransposeParams<Element> params) {
8784
using bN = Int<512>;
8885

8986
auto block_shape = make_shape(bM{}, bN{}); // (bM, bN)
90-
auto block_shape_trans = make_shape(bN{}, bM{}); // (bN, bM)
9187

9288
Tensor tiled_tensor_S =
9389
tiled_divide(tensor_S, block_shape); // ((bM, bN), m', n')

0 commit comments

Comments
 (0)