@@ -51,6 +51,22 @@ void copy_kernel(TensorS S, TensorD D, ThreadLayout) {
5151 using namespace cute ;
5252
5353 // Slice the tiled tensors
54+ // This line slices the tiled tensor S to get the tile for the current work
55+ // group. S is a 3D tensor with layout ((M, N), m', n') where:
56+ // - (M, N) is the block/tile shape (first mode)
57+ // - m' is the number of tiles in the M dimension (second mode)
58+ // - n' is the number of tiles in the N dimension (third mode)
59+ //
60+ // The indexing S(make_coord(_, _), x, y) selects:
61+ // - make_coord(_, _): Takes all elements from the first mode (M, N), i.e.,
62+ // the entire tile
63+ // - compat::work_group_id::x(): Selects the x-th tile along the m'
64+ // dimension
65+ // - compat::work_group_id::y(): Selects the y-th tile along the n'
66+ // dimension
67+ //
68+ // Result: A 2D tensor of shape (BlockShape_M, BlockShape_N) corresponding to
69+ // the tile assigned to the current work group.
5470 Tensor tile_S = S (make_coord (_, _), compat::work_group_id::x (),
5571 compat::work_group_id::y ()); // (BlockShape_M, BlockShape_N)
5672 Tensor tile_D = D (make_coord (_, _), compat::work_group_id::x (),
@@ -64,78 +80,78 @@ void copy_kernel(TensorS S, TensorD D, ThreadLayout) {
6480 tile_S, ThreadLayout{}, compat::local_id::x ()); // (ThrValM, ThrValN)
6581 Tensor thr_tile_D = local_partition (
6682 tile_D, ThreadLayout{}, compat::local_id::x ()); // (ThrValM, ThrValN)
67- //
68-
69- // Construct a register-backed Tensor with the same shape as each thread's
70- // partition Use make_tensor to try to match the layout of thr_tile_S
71- Tensor fragment = make_tensor_like (thr_tile_S); // (ThrValM, ThrValN)
72-
73- // Copy from GMEM to RMEM and from RMEM to GMEM
74- copy (thr_tile_S, fragment);
75- copy (fragment, thr_tile_D);
76- }
77-
78- template <typename Element> void copy_direct (TransposeParams<Element> params) {
79- //
80- // Given a 2D shape, perform an efficient copy
81- //
82-
83- using namespace cute ;
84-
85- //
86- // Make tensors
87- //
88- auto tensor_shape = make_shape (params.M , params.N );
89- auto gmemLayoutS = make_layout (tensor_shape, LayoutRight{});
90- auto gmemLayoutD = make_layout (tensor_shape, LayoutRight{});
91- Tensor tensor_S = make_tensor (make_gmem_ptr (params.input ), gmemLayoutS);
92- Tensor tensor_D = make_tensor (make_gmem_ptr (params.output ), gmemLayoutD);
93-
94- //
95- // Tile tensors
96- //
9783
98- // Define a statically sized block (M, N).
99- // Note, by convention, capital letters are used to represent static modes.
100- auto block_shape = make_shape (Int< 1 >{}, Int< 16384 >{});
84+ // Construct a register-backed Tensor with the same shape as each thread's
85+ // partition Use make_tensor to try to match the layout of thr_tile_S
86+ Tensor fragment = make_tensor_like (thr_tile_S); // (ThrValM, ThrValN)
10187
102- if ((size<0 >(tensor_shape) % size<0 >(block_shape)) ||
103- (size<1 >(tensor_shape) % size<1 >(block_shape))) {
104- std::cerr << " The tensor shape must be divisible by the block shape."
105- << std::endl;
106- }
107- // Equivalent check to the above
108- if (not evenly_divides (tensor_shape, block_shape)) {
109- std::cerr << " Expected the block_shape to evenly divide the tensor shape."
110- << std::endl;
88+ // Copy from GMEM to RMEM and from RMEM to GMEM
89+ copy (thr_tile_S, fragment);
90+ copy (fragment, thr_tile_D);
11191 }
11292
113- // Tile the tensor (m, n) ==> ((M, N), m', n') where (M, N) is the static tile
114- // shape, and modes (m', n') correspond to the number of tiles.
115- //
116- // These will be used to determine the CUDA kernel grid dimensions.
117- Tensor tiled_tensor_S =
118- tiled_divide (tensor_S, block_shape); // ((M, N), m', n')
119- Tensor tiled_tensor_D =
120- tiled_divide (tensor_D, block_shape); // ((M, N), m', n')
121-
122- // Thread arrangement
123- Layout thr_layout =
124- make_layout (make_shape (Int<1 >{}, Int<1024 >{}), LayoutRight{});
125-
126- //
127- // Determine grid and block dimensions
128- //
129-
130- auto gridDim = compat::dim3 (
131- size<1 >(tiled_tensor_S),
132- size<2 >(tiled_tensor_S)); // Grid shape corresponds to modes m' and n'
133- auto blockDim = compat::dim3 (size (thr_layout));
134-
135- //
136- // Launch the kernel
137- //
138- compat::launch<copy_kernel<decltype (tiled_tensor_S), decltype (tiled_tensor_D),
139- decltype (thr_layout)>>(
140- gridDim, blockDim, tiled_tensor_S, tiled_tensor_D, thr_layout);
141- }
93+ template <typename Element>
94+ void copy_direct (TransposeParams<Element> params) {
95+ //
96+ // Given a 2D shape, perform an efficient copy
97+ //
98+
99+ using namespace cute ;
100+
101+ //
102+ // Make tensors
103+ //
104+ auto tensor_shape = make_shape (params.M , params.N );
105+ auto gmemLayoutS = make_layout (tensor_shape, LayoutRight{});
106+ auto gmemLayoutD = make_layout (tensor_shape, LayoutRight{});
107+ Tensor tensor_S = make_tensor (make_gmem_ptr (params.input ), gmemLayoutS);
108+ Tensor tensor_D = make_tensor (make_gmem_ptr (params.output ), gmemLayoutD);
109+
110+ //
111+ // Tile tensors
112+ //
113+
114+ // Define a statically sized block (M, N).
115+ // Note, by convention, capital letters are used to represent static modes.
116+ auto block_shape = make_shape (Int<1 >{}, Int<16384 >{});
117+
118+ if ((size<0 >(tensor_shape) % size<0 >(block_shape)) ||
119+ (size<1 >(tensor_shape) % size<1 >(block_shape))) {
120+ std::cerr << " The tensor shape must be divisible by the block shape."
121+ << std::endl;
122+ }
123+ // Equivalent check to the above
124+ if (not evenly_divides (tensor_shape, block_shape)) {
125+ std::cerr << " Expected the block_shape to evenly divide the tensor shape."
126+ << std::endl;
127+ }
128+
129+ // Tile the tensor (m, n) ==> ((M, N), m', n') where (M, N) is the static
130+ // tile shape, and modes (m', n') correspond to the number of tiles.
131+ //
132+ // These will be used to determine the CUDA kernel grid dimensions.
133+ Tensor tiled_tensor_S =
134+ tiled_divide (tensor_S, block_shape); // ((M, N), m', n')
135+ Tensor tiled_tensor_D =
136+ tiled_divide (tensor_D, block_shape); // ((M, N), m', n')
137+
138+ // Thread arrangement
139+ Layout thr_layout =
140+ make_layout (make_shape (Int<1 >{}, Int<1024 >{}), LayoutRight{});
141+
142+ //
143+ // Determine grid and block dimensions
144+ //
145+
146+ auto gridDim = compat::dim3 (
147+ size<1 >(tiled_tensor_S),
148+ size<2 >(tiled_tensor_S)); // Grid shape corresponds to modes m' and n'
149+ auto blockDim = compat::dim3 (size (thr_layout));
150+
151+ //
152+ // Launch the kernel
153+ //
154+ compat::launch<copy_kernel<decltype (tiled_tensor_S),
155+ decltype (tiled_tensor_D), decltype (thr_layout)>>(
156+ gridDim, blockDim, tiled_tensor_S, tiled_tensor_D, thr_layout);
157+ }
0 commit comments