@@ -10,9 +10,12 @@ namespace vptq::copy {
10
10
using namespace cute ;
11
11
12
12
// / TODO(ying); the current implementation supports load row-major data only.
13
- template <typename DType, const int kThreads , const int64_t kRows ,
14
- const int64_t kCols , typename Base = AccessInfo<DType>>
13
+ template <typename DType, const int kThreads , const int64_t kRows_ ,
14
+ const int64_t kCols_ , typename Base = AccessInfo<DType>>
15
15
struct GlobalToSharedLoader : public Base {
16
+ static constexpr int kRows = kRows_ ;
17
+ static constexpr int kCols = kCols_ ;
18
+
16
19
DEVICE void operator ()(const DType* src_, DType* dst_) {
17
20
int tid = threadIdx .x ;
18
21
@@ -32,12 +35,10 @@ struct GlobalToSharedLoader : public Base {
32
35
}
33
36
34
37
private:
35
- // source
36
38
using GlobalLayout =
37
39
cute::Layout<Shape<Int<kRows >, Int<kCols >>, Stride<Int<kCols >, _1>>;
38
40
GlobalLayout src_layout_;
39
41
40
- // destination
41
42
using SharedLayout =
42
43
cute::Layout<Shape<Int<kRows >, Int<kCols >>, Stride<Int<kCols >, _1>>;
43
44
@@ -69,4 +70,60 @@ private:
69
70
TiledCopy tiled_copy_;
70
71
};
71
72
73
+ // / TODO(ying); the current implementation supports load row-major data only.
74
+ template <typename DType, const int kThreads , const int64_t kRows_ ,
75
+ const int64_t kCols_ , typename Base = AccessInfo<DType>>
76
+ struct SharedToGlobalStorer : public Base {
77
+ static constexpr int kRows = kRows_ ;
78
+ static constexpr int kCols = kCols_ ;
79
+
80
+ DEVICE void operator ()(const DType* src_, DType* dst_) {
81
+ int tid = threadIdx .x ;
82
+
83
+ auto stile = make_tensor (make_smem_ptr (src_), src_layout_);
84
+ auto gtile = make_tensor (make_gmem_ptr (dst_), dst_layout_);
85
+
86
+ auto loader = tiled_copy_.get_thread_slice (tid);
87
+
88
+ auto src = loader.partition_S (stile);
89
+ auto dst = loader.partition_D (gtile);
90
+
91
+ #pragma unroll
92
+ for (int i = 0 ; i < int (size<1 >(src)); ++i)
93
+ #pragma unroll
94
+ for (int j = 0 ; j < int (size<2 >(src)); ++j)
95
+ cute::copy (tiled_copy_, src (cute::_, i, j), dst (cute::_, i, j));
96
+ }
97
+
98
+ private:
99
+ using SharedLayout =
100
+ cute::Layout<Shape<Int<kRows >, Int<kCols >>, Stride<Int<kCols >, _1>>;
101
+ // using LayoutAtom =
102
+ // decltype(composition(cute::Swizzle<2, 3, 3>{},
103
+ // cute::Layout<Shape<_4, _64>, Stride<_64, _1>>{}));
104
+ // using SharedLayout = decltype(tile_to_shape(
105
+ // LayoutAtom{}, Shape<Int<kRows>, Int<kCols>>{}, cute::Step<_2, _1>{}));
106
+ SharedLayout src_layout_;
107
+
108
+ using GlobalLayout =
109
+ cute::Layout<Shape<Int<kRows >, Int<kCols >>, Stride<Int<kCols >, _1>>;
110
+ GlobalLayout dst_layout_;
111
+
112
+ // tiled copy
113
+ static constexpr int kThreadCols =
114
+ kCols * Base::kElementBits / Base::kAccessInBits ;
115
+ static_assert (kThreadCols > 0 );
116
+ static constexpr int kThreadRows = kThreads / kThreadCols ;
117
+
118
+ using ThreadLayout = cute::Layout<Shape<Int<kThreadRows >, Int<kThreadCols >>,
119
+ Stride<Int<kThreadCols >, _1>>;
120
+ using ValueLayout = cute::Layout<Shape<_1, _8>>;
121
+
122
+ using CopyInst = Copy_Atom<DefaultCopy, DType>;
123
+
124
+ using TiledCopy =
125
+ decltype (make_tiled_copy(CopyInst{}, ThreadLayout{}, ValueLayout{}));
126
+ TiledCopy tiled_copy_;
127
+ };
128
+
72
129
} // namespace vptq::copy
0 commit comments