2
2
// Licensed under the MIT License.
3
3
#pragma once
4
4
5
- #include " copy/mod.cuh"
5
+ #include " kernels/copy/mod.cuh"
6
+ #include " kernels/decode.cuh"
6
7
7
8
#include < cute/tensor.hpp>
8
9
@@ -16,8 +17,9 @@ namespace {
16
17
template <const int a, const int b>
17
18
static constexpr int divup = (a + b - 1 ) / b;
18
19
19
- template <typename DType, const int kTileSize , const int kVecLen ,
20
- const int kNumCentroids , const int kNumResCentroids >
20
+ template <typename DType, typename IdType, typename ResIdType,
21
+ const int kTileSize , const int kVecLen , const int kNumCentroids ,
22
+ const int kNumResCentroids >
21
23
struct SharedStorageImpl {
22
24
// /==== Shared memory for inputs ====///
23
25
static constexpr int kSizeCodebook = kNumCentroids * kVecLen ;
@@ -29,8 +31,11 @@ struct SharedStorageImpl {
29
31
static constexpr int kSizeInputs = 3 * kTileSize ;
30
32
array_aligned<DType, kSizeInputs , 128 > inputs;
31
33
32
- static constexpr int kSizeIndices = kTileSize * 2 ;
33
- array_aligned<uint16_t , kTileSize * 2 > indices;
34
+ // TODO(ying): Support residual indices are stored in uint8_t
35
+ static_assert (std::is_same_v<IdType, ResIdType>,
36
+ " The data type of indices for main and residual centroids must "
37
+ " be the same." );
38
+ array_aligned<IdType, kTileSize * 2 > indices;
34
39
35
40
// /==== Shared mempory for intermediate results ====///
36
41
static constexpr int kSizeWeights = kTileSize * kVecLen ;
@@ -42,7 +47,7 @@ struct SharedStorageImpl {
42
47
static constexpr int kSmemSize = ((kSizeCodebook + kSizeCodebookRes +
43
48
kSizeInputs + kSizeWeights + kSizeOut ) *
44
49
sizeof (DType)) +
45
- kSizeIndices * sizeof (uint16_t );
50
+ 2 * kTileSize * sizeof (IdType );
46
51
};
47
52
48
53
template <typename DType, const int kThreads , const int kNumCentroids ,
@@ -83,7 +88,8 @@ struct CodebookTraits : public Base {
83
88
84
89
} // namespace
85
90
86
- template <typename DType, const int kThreads , //
91
+ template <typename DType, typename IdType, typename ResIdType,
92
+ const int kThreads , //
87
93
const int kTileSize_ , const int kVecLen_ , //
88
94
const int kNumCentroids_ , const int kNumResCentroids_ ,
89
95
typename Base = copy::AccessInfo<DType>>
@@ -95,8 +101,9 @@ struct QuantGemvKeTraits : public Base {
95
101
static constexpr int kTileSize = kTileSize_ ;
96
102
97
103
// / allocate shared memory
98
- using SharedStorage = SharedStorageImpl<DType, kTileSize , kVecLen ,
99
- kNumCentroids , kNumResCentroids >;
104
+ using SharedStorage =
105
+ SharedStorageImpl<DType, IdType, ResIdType, kTileSize , kVecLen ,
106
+ kNumCentroids , kNumResCentroids >;
100
107
// / configurations for loading codebooks
101
108
using MainCentroidTraits =
102
109
CodebookTraits<DType, kThreads , kNumCentroids , kVecLen >;
@@ -131,14 +138,33 @@ struct QuantGemvKeTraits : public Base {
131
138
132
139
// / configurations for loading tiled indices
133
140
static constexpr int kThreadsIndex =
134
- kTileSize * sizeof (uint16_t ) / Base::kAccessInBytes ;
141
+ kTileSize * sizeof (IdType ) / Base::kAccessInBytes ;
135
142
static_assert (kThreadsIndex <= kThreads ,
136
143
" The current implementation requires that the number of "
137
144
" threads used to load a single index tile must be less than or "
138
145
" equal to the number of threads in the block." );
139
- using IndexLoader = copy::GlobalToSharedInputLoader<uint16_t , kTileSize >;
140
- // storer is defined for debugging purposes
141
- using IndexStorer = copy::SharedToGlobalInputStorer<uint16_t , kTileSize >;
146
+
147
+ // TODO(ying): The current implementation requires that the indices for both
148
+ // main and residual centroids are stored in the same data type. This will be
149
+ // addressed in the next version.
150
+ static_assert (std::is_same_v<IdType, ResIdType>,
151
+ " The data type of indices for main and residual centroids must "
152
+ " be the same." );
153
+ using IndexLoader = copy::GlobalToSharedInputLoader<IdType, 2 * kTileSize >;
154
+ using IndexStorer = copy::SharedToGlobalInputStorer<IdType, 2 * kTileSize >;
155
+
156
+ // / configurations for decoding indices
157
+ // Ensure the indices can be stored aligned with shared memory banks, and a
158
+ // single thread decode at least `kIdsPerBank` indices.
159
+ static constexpr int kBankBytes = 4 ;
160
+ static_assert (kBankBytes % sizeof (ResIdType) == 0 );
161
+ static constexpr int kIdsPerBank = kBankBytes / sizeof (ResIdType);
162
+ // how many indices are decoded by a single thread
163
+ static_assert (kTileSize % (kThreads * kIdsPerBank ) == 0 );
164
+ static constexpr int kDecodeNumPerThread = kTileSize / kThreads ;
165
+
166
+ using Decoder =
167
+ WeightDecoder<DType, IdType, ResIdType, kDecodeNumPerThread , kVecLen >;
142
168
};
143
169
144
170
} // namespace vptq::kernels
0 commit comments