|
2 | 2 | // Licensed under the MIT License.
|
3 | 3 | #pragma once
|
4 | 4 |
|
| 5 | +/// The Loader and Storer in this file use all collaborative threads in a thread |
| 6 | +/// block to transfer data tiles between global memory and shared memory. |
| 7 | + |
| 8 | +#include "kernels/copy/atom.cuh" |
5 | 9 | #include "kernels/copy/copy_traits.cuh"
|
| 10 | +#include "kernels/copy/warp.cuh" |
6 | 11 |
|
7 | 12 | #include <cute/tensor.hpp>
|
8 | 13 |
|
9 | 14 | namespace vptq::kernels::copy {
|
10 |
| - |
| 15 | +namespace tl = vptq::tile_layout; |
11 | 16 | using namespace cute;
|
12 | 17 |
|
13 |
| -namespace { |
14 |
| -/// ld.shared |
15 |
| -template <const int kBytes> |
16 |
| -DEVICE void ld_shared(void* dst, uint32_t src); |
17 |
| - |
18 |
| -/// ld.shared - 16b |
19 |
| -template <> |
20 |
| -DEVICE void ld_shared<2>(void* dst, uint32_t src) { |
21 |
| - asm volatile("ld.shared.u16 %0, [%1];\n" |
22 |
| - : "=h"(*reinterpret_cast<uint16_t*>(dst)) |
23 |
| - : "r"(src)); |
24 |
| -} |
25 |
| - |
26 |
| -/// ld.shared - 32b |
27 |
| -template <> |
28 |
| -DEVICE void ld_shared<4>(void* dst, uint32_t src) { |
29 |
| - asm volatile("ld.shared.u32 %0, [%1];\n" |
30 |
| - : "=r"(*reinterpret_cast<uint32_t*>(dst)) |
31 |
| - : "r"(src)); |
32 |
| -} |
33 |
| - |
34 |
| -/// ld.shared - 64b |
35 |
| -template <> |
36 |
| -DEVICE void ld_shared<8>(void* dst, uint32_t src) { |
37 |
| - uint2* dst_u64 = reinterpret_cast<uint2*>(dst); |
38 |
| - asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n" |
39 |
| - : "=r"(dst_u64->x), "=r"(dst_u64->y) |
40 |
| - : "r"(src)); |
41 |
| -} |
42 |
| - |
43 |
| -/// ld.shared - 128b |
44 |
| -template <> |
45 |
| -DEVICE void ld_shared<16>(void* dst, uint32_t src) { |
46 |
| - uint4* dst_u128 = reinterpret_cast<uint4*>(dst); |
47 |
| - asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n" |
48 |
| - : "=r"(dst_u128->x), "=r"(dst_u128->y), "=r"(dst_u128->z), |
49 |
| - "=r"(dst_u128->w) |
50 |
| - : "r"(src)); |
51 |
| -} |
52 |
| - |
53 |
| -/// st.shared |
54 |
| -template <int kBytes> |
55 |
| -DEVICE void st_shared(uint32_t dst, void const* src); |
56 |
| - |
57 |
| -/// st.shared - 16b |
58 |
| -template <> |
59 |
| -DEVICE void st_shared<2>(uint32_t dst, void const* src) { |
60 |
| - asm volatile("st.shared.u16 [%0], %1;\n" |
61 |
| - : |
62 |
| - : "r"(dst), "h"(*reinterpret_cast<uint16_t const*>(src))); |
63 |
| -} |
64 |
| - |
65 |
| -/// st.shared - 32b |
66 |
| -template <> |
67 |
| -DEVICE void st_shared<4>(uint32_t dst, void const* src) { |
68 |
| - asm volatile("st.shared.u32 [%0], %1;\n" |
69 |
| - : |
70 |
| - : "r"(dst), "r"(*reinterpret_cast<uint32_t const*>(src))); |
71 |
| -} |
72 |
| - |
73 |
| -/// st.shared - 64b |
74 |
| -template <> |
75 |
| -DEVICE void st_shared<8>(uint32_t dst, void const* src) { |
76 |
| - uint2 const* dst_u64 = reinterpret_cast<uint2 const*>(src); |
77 |
| - asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n" |
78 |
| - : |
79 |
| - : "r"(dst), "r"(dst_u64->x), "r"(dst_u64->y)); |
80 |
| -} |
81 |
| - |
82 |
| -/// st.shared - 128b |
83 |
| -template <> |
84 |
| -DEVICE void st_shared<16>(uint32_t dst, void const* src) { |
85 |
| - uint4 const* dst_u128 = reinterpret_cast<uint4 const*>(src); |
86 |
| - asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" |
87 |
| - : |
88 |
| - : "r"(dst), "r"(dst_u128->x), "r"(dst_u128->y), "r"(dst_u128->z), |
89 |
| - "r"(dst_u128->w)); |
90 |
| -} |
91 |
| - |
92 |
| -/// st.global |
93 |
| -template <int kBytes> |
94 |
| -DEVICE void st_global(void* dst, const void* src); |
95 |
| - |
96 |
| -template <> |
97 |
| -DEVICE void st_global<16>(void* dst, const void* src) { |
98 |
| - uint4 const* dst_u128 = reinterpret_cast<uint4 const*>(src); |
99 |
| - asm volatile("st.global.v4.b32 [%0], {%1, %2, %3, %4};\n" |
100 |
| - : |
101 |
| - : "l"(dst), "r"(dst_u128->x), "r"(dst_u128->y), "r"(dst_u128->z), |
102 |
| - "r"(dst_u128->w)); |
103 |
| -} |
104 |
| -} // namespace |
105 |
| - |
106 |
| -template <int kBytes> |
107 |
| -DEVICE void ld_shared_st_global(void* dst, uint32_t src); |
108 |
| - |
109 |
| -template <> |
110 |
| -DEVICE void ld_shared_st_global<16>(void* dst, uint32_t src) { |
111 |
| - unsigned tmp[4]; |
112 |
| - ld_shared<16>(tmp, src); |
113 |
| - st_global<16>(dst, tmp); |
114 |
| -} |
115 |
| - |
116 |
| -template <const int kBytes> |
117 |
| -DEVICE void ld_global_st_shared(uint32_t dst, void const* src) { |
118 |
| - static_assert(kBytes == 4 || kBytes == 8 || kBytes == 16); |
119 |
| - |
120 |
| -#if (__CUDA_ARCH__ >= 800) |
121 |
| - // SM90, hopper, SM80, SM86, ampere |
122 |
| - |
123 |
| - // TODO(ying): add a wrapper to allow choosing between different caching |
124 |
| - // policies (e.g. "cache all levels"). |
125 |
| - asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(dst), |
126 |
| - "l"(src), "n"(kBytes)); |
127 |
| -#else |
128 |
| - unsigned tmp[kBytes / 4]; |
129 |
| - if constexpr (kBytes == 16) { |
130 |
| - asm volatile("ld.global.v4.b32 {%0, %1, %2, %3}, [%4];\n" |
131 |
| - : "=r"(tmp[0]), "=r"(tmp[1]), "=r"(tmp[2]), "=r"(tmp[3]) |
132 |
| - : "l"(src)); |
133 |
| - asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" ::"r"(dst), |
134 |
| - "r"(tmp[0]), "r"(tmp[1]), "r"(tmp[2]), "r"(tmp[3])); |
135 |
| - } else if constexpr (kBytes == 8) { |
136 |
| - asm volatile("ld.global.v2.b32 {%0, %1}, [%2];\n" |
137 |
| - : "=r"(tmp[0]), "=r"(tmp[1]) |
138 |
| - : "l"(src)); |
139 |
| - asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" ::"r"(dst), "r"(tmp[0]), |
140 |
| - "r"(tmp[1])); |
141 |
| - } else if constexpr (kBytes == 4) { |
142 |
| - asm volatile("ld.global.b32 %0, [%1];\n" : "=r"(tmp[0]) : "l"(src)); |
143 |
| - asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(dst), "r"(tmp[0])); |
144 |
| - } |
145 |
| -#endif |
146 |
| -} |
147 |
| - |
148 | 18 | template <typename DType, const int kNumPerAccess, typename ThreadLayout,
|
149 | 19 | typename GlobalLayout /*src*/, typename SharedLayout /*dst*/>
|
150 | 20 | struct GlobalToSharedLoader {
|
@@ -213,4 +83,39 @@ private:
|
213 | 83 | TiledCopy tiled_copy_;
|
214 | 84 | };
|
215 | 85 |
|
| 86 | +/// NOTE: This configuration is specialized for copying a small tile whose size |
| 87 | +/// is smaller than the data size accessed by all threads in a CTA concurrently. |
| 88 | +template <typename DType, const int kNumel, typename Base = AccessInfo<DType>> |
| 89 | +struct GlobalToSharedInputLoader : public Base { |
| 90 | + static constexpr int kWarpTileShape = Base::kNumPerAccess * WARP_SIZE; |
| 91 | + static constexpr int kThreads = kNumel / kWarpTileShape * WARP_SIZE; |
| 92 | + |
| 93 | + DEVICE void operator()(const DType* src_, DType* dst_, int start_warp = 0) { |
| 94 | + int warp_id = threadIdx.x / WARP_SIZE - start_warp; |
| 95 | + int lane_id = threadIdx.x % WARP_SIZE; |
| 96 | + int offset = warp_id * kWarpTileShape + lane_id * Base::kNumPerAccess; |
| 97 | + |
| 98 | + ld_global_st_shared<Base::kAccessInBytes>( |
| 99 | + __cvta_generic_to_shared(dst_ + offset), src_ + offset); |
| 100 | + } |
| 101 | +}; |
| 102 | + |
| 103 | +/// NOTE: This configuration is specialized for copying a small tile whose size |
| 104 | +/// is smaller than the data size accessed by all threads in a CTA concurrently. |
| 105 | +template <typename DType, const int kNumel, typename Base = AccessInfo<DType>> |
| 106 | +struct SharedToGlobalInputStorer : public Base { |
| 107 | + static constexpr int kWarpTileShape = Base::kNumPerAccess * WARP_SIZE; |
| 108 | + static constexpr int kThreads = kNumel / kWarpTileShape * WARP_SIZE; |
| 109 | + |
| 110 | + DEVICE void operator()(const DType* src_, DType* dst_, int start_warp = 0) { |
| 111 | + int warp_id = threadIdx.x / WARP_SIZE - start_warp; |
| 112 | + int lane_id = threadIdx.x % WARP_SIZE; |
| 113 | + int offset = warp_id * kWarpTileShape + lane_id * Base::kNumPerAccess; |
| 114 | + |
| 115 | + ld_shared_st_global<Base::kAccessInBytes>( |
| 116 | + dst_ + offset, |
| 117 | + static_cast<uint32_t>(__cvta_generic_to_shared(src_ + offset))); |
| 118 | + } |
| 119 | +}; |
| 120 | + |
216 | 121 | } // namespace vptq::kernels::copy
|
0 commit comments