Skip to content

Commit 8de304d

Browse files
committed
Vectorized load from shared to register.
1 parent 84978a6 commit 8de304d

12 files changed

+266
-59
lines changed

.vscode/settings.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"--style={based_on_s'tyle: google, column_limit: 80, indent_width: 4}"
44
],
55
"files.associations": {
6-
"optional": "cpp"
6+
"optional": "cpp",
7+
"atomic": "cpp"
78
}
89
}

csrc/kernels/copy/copy_traits.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
// Licensed under the MIT License.
33
#pragma once
44

5-
#include "kernels/copy/layout.cuh"
5+
#include "kernels/layout.cuh"
66

77
namespace vptq::kernels::copy {
88
namespace tl = vptq::tile_layout;
99

1010
template <typename DType>
1111
struct AccessInfo {
12-
// the maximal width of vectorized access.
12+
// the maximal width of vectorized access in bits and bytes
1313
static constexpr int kAccessInBits = 128;
1414
static constexpr int kAccessInBytes = 16;
1515

File renamed without changes.

csrc/kernels/copy/mod.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
#include "config.cuh"
66
#include "kernels/copy/atom.cuh"
7-
#include "kernels/copy/copy.cuh"
87
#include "kernels/copy/copy_traits.cuh"
9-
#include "kernels/copy/layout.cuh"
8+
#include "kernels/copy/global_to_shared.cuh"
109
#include "kernels/copy/sync.cuh"
10+
#include "kernels/copy/vectorized.cuh"
1111
#include "kernels/copy/warp.cuh"

csrc/kernels/copy/vectorized.cuh

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
#pragma once
4+
5+
namespace vptq::kernels::copy {
6+
7+
// TODO(ying): Define additional user-defined vectorized types if necessary
8+
template <typename DType, int kN>
9+
struct GetPackType;
10+
11+
template <>
12+
struct GetPackType<__half, 2> {
13+
using type = __half2;
14+
};
15+
16+
template <>
17+
struct GetPackType<__bfloat16, 2> {
18+
using type = __bfloat162;
19+
};
20+
21+
template <>
22+
struct GetPackType<uint8_t, 4> {
23+
using type = int;
24+
};
25+
26+
template <>
27+
struct GetPackType<uint16_t, 2> {
28+
using type = int;
29+
};
30+
31+
template <>
32+
struct GetPackType<uint16_t, 4> {
33+
using type = int2;
34+
};
35+
36+
template <>
37+
struct GetPackType<uint, 4> {
38+
using type = uint4; // uint4 has native 128 bits load/store support
39+
};
40+
41+
template <>
42+
struct GetPackType<float4, 4> {
43+
using type = float4; // float4 has native 128 bits load/store support
44+
};
45+
46+
template <typename DType, int kN>
47+
using PackType = typename GetPackType<DType, kN>::type;
48+
49+
/// Vectorized copy for a single access.
50+
/// @param DType_ The data type of the elements to copy.
51+
/// @param kN The number of elements to pack into a vectorized copy. This
52+
/// should be no more than 128 bits.
53+
template <typename DType_, int kN>
54+
struct PackedCopy {
55+
using DType = DType_;
56+
using Packed = PackType<DType, kN>;
57+
58+
// the maximum read/write transaction size in bytes for a thread
59+
static constexpr int kMaxVecBytes = 16;
60+
61+
static_assert(sizeof(DType) * kN <= kMaxVecBytes,
62+
"The total number of bytes must be less than or equal to the "
63+
"maximum width of a vectorized instruction.");
64+
65+
// This ctor does nothing but ensures the object is created in device memory
66+
DEVICE PackedCopy() {}
67+
68+
DEVICE void operator()(const DType* src_, DType* dst_) {
69+
const Packed* src = reinterpret_cast<const Packed*>(src_);
70+
Packed* dst = reinterpret_cast<Packed*>(dst_);
71+
*dst = *src;
72+
}
73+
};
74+
75+
} // namespace vptq::kernels::copy

csrc/kernels/copy/warp.cuh

+16-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Licensed under the MIT License.
33
#pragma once
44

5-
#include "kernels/copy/layout.cuh"
5+
#include "kernels/layout.cuh"
66

77
namespace vptq::kernels::copy {
88

@@ -15,16 +15,30 @@ struct WarpCounter {
1515
next_warp_ = kNumWarpsPerTile;
1616
}
1717

18+
// TODO(ying): simplify these calculations
1819
HOST_DEVICE int cur() const { return cur_warp_; }
1920

2021
HOST_DEVICE int next() const { return next_warp_; }
2122

22-
HOST_DEVICE void operator++() { // TODO(ying): simplify these calculations
23+
HOST_DEVICE int next(int i) const {
24+
int wid = next_warp_ + i * kNumWarpsPerTile;
25+
wid = wid > kNumWarps ? wid % kNumWarps : wid;
26+
return wid;
27+
}
28+
29+
HOST_DEVICE void operator++() {
2330
cur_warp_ = next_warp_ % kNumWarps;
2431
next_warp_ += kNumWarpsPerTile;
2532
next_warp_ = next_warp_ > kNumWarps ? next_warp_ % kNumWarps : next_warp_;
2633
}
2734

35+
HOST_DEVICE WarpCounter& operator+=(int n) {
36+
cur_warp_ = next_warp_ % kNumWarps;
37+
next_warp_ += (n * kNumWarpsPerTile);
38+
next_warp_ = next_warp_ > kNumWarps ? next_warp_ % kNumWarps : next_warp_;
39+
return *this;
40+
}
41+
2842
private:
2943
int cur_warp_;
3044
int next_warp_;

csrc/kernels/decode.cuh

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
#pragma once
4+
5+
#include "kernels/copy/copy_traits.cuh"
6+
#include "kernels/copy/vectorized.cuh"
7+
#include "util/debug.cuh"
8+
9+
namespace vptq::kernels {
10+
using namespace copy;
11+
12+
template <typename DType_, typename IdType_, typename ResIdType_,
13+
const int kNumPerThread_, const int kVecLen_,
14+
typename Base = AccessInfo<DType_>>
15+
struct WeightDecoder {
16+
using DType = DType_;
17+
using IdType = IdType_;
18+
using ResIdType = ResIdType_;
19+
20+
// TODO(ying): The current implementation requires that the indices for both
21+
// main and residual centroids are stored in the same data type, such as both
22+
// being uint16_t. If the main indices are in uint16_t and the residual
23+
// indices are in uint8_t, additional handling will be required. This will be
24+
// addressed in the next version.
25+
static_assert(std::is_same_v<IdType, ResIdType>,
26+
"The data type of indices for main and residual centroids must "
27+
"be the same.");
28+
29+
static constexpr int kNumPerThread = kNumPerThread_;
30+
static constexpr int kVecLen = kVecLen_;
31+
32+
DEVICE void operator()(DType* output, // output
33+
const DType* codebook, // codebook for main centroids
34+
const DType* codebook_res, // codebook for residual
35+
const IdType* ids, // indices for main centroids
36+
const ResIdType* res_ids, // indices for residual
37+
const DType* alpha, const DType* beta) {
38+
// threads in a CTA are laid out in 1-D fashion.
39+
int offset = threadIdx.x * kNumPerThread;
40+
const IdType* ids_ = ids + offset; // indices for the current thread
41+
// residual indices for the current thread
42+
const ResIdType* res_ids_ = res_ids + offset;
43+
44+
// load indices and residual indice into registers
45+
// indices on thread local registers
46+
IdType reg_ids[kNumPerThread];
47+
ResIdType reg_residual_ids[kNumPerThread];
48+
49+
#pragma unroll
50+
for (int i = 0; i < kNumPerThread; i += kPackedNum) {
51+
copy_ids(&ids_[i] /*src*/, &reg_ids[i] /*dst*/);
52+
copy_ids(&res_ids_[i] /*src*/, &reg_residual_ids[i] /*dst*/);
53+
}
54+
}
55+
56+
private:
57+
// Indices are packed into 4 bytes in the current implementation, stored in a
58+
// shared memory bank. This can be tuned if needed.
59+
static constexpr int kPackedIdsBytes = 4;
60+
static constexpr int kPackedNum = kPackedIdsBytes / sizeof(IdType);
61+
static_assert(kPackedNum, "kPackedNum must be greater than 0");
62+
using VecCopy = PackedCopy<IdType, kPackedNum>;
63+
VecCopy copy_ids;
64+
};
65+
66+
} // namespace vptq::kernels
File renamed without changes.

csrc/kernels/quant_gemv_traits.cuh

+39-13
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
// Licensed under the MIT License.
33
#pragma once
44

5-
#include "copy/mod.cuh"
5+
#include "kernels/copy/mod.cuh"
6+
#include "kernels/decode.cuh"
67

78
#include <cute/tensor.hpp>
89

@@ -16,8 +17,9 @@ namespace {
1617
template <const int a, const int b>
1718
static constexpr int divup = (a + b - 1) / b;
1819

19-
template <typename DType, const int kTileSize, const int kVecLen,
20-
const int kNumCentroids, const int kNumResCentroids>
20+
template <typename DType, typename IdType, typename ResIdType,
21+
const int kTileSize, const int kVecLen, const int kNumCentroids,
22+
const int kNumResCentroids>
2123
struct SharedStorageImpl {
2224
///==== Shared memory for inputs ====///
2325
static constexpr int kSizeCodebook = kNumCentroids * kVecLen;
@@ -29,8 +31,11 @@ struct SharedStorageImpl {
2931
static constexpr int kSizeInputs = 3 * kTileSize;
3032
array_aligned<DType, kSizeInputs, 128> inputs;
3133

32-
static constexpr int kSizeIndices = kTileSize * 2;
33-
array_aligned<uint16_t, kTileSize * 2> indices;
34+
// TODO(ying): Support residual indices are stored in uint8_t
35+
static_assert(std::is_same_v<IdType, ResIdType>,
36+
"The data type of indices for main and residual centroids must "
37+
"be the same.");
38+
array_aligned<IdType, kTileSize * 2> indices;
3439

3540
///==== Shared mempory for intermediate results ====///
3641
static constexpr int kSizeWeights = kTileSize * kVecLen;
@@ -42,7 +47,7 @@ struct SharedStorageImpl {
4247
static constexpr int kSmemSize = ((kSizeCodebook + kSizeCodebookRes +
4348
kSizeInputs + kSizeWeights + kSizeOut) *
4449
sizeof(DType)) +
45-
kSizeIndices * sizeof(uint16_t);
50+
2 * kTileSize * sizeof(IdType);
4651
};
4752

4853
template <typename DType, const int kThreads, const int kNumCentroids,
@@ -83,7 +88,8 @@ struct CodebookTraits : public Base {
8388

8489
} // namespace
8590

86-
template <typename DType, const int kThreads, //
91+
template <typename DType, typename IdType, typename ResIdType,
92+
const int kThreads, //
8793
const int kTileSize_, const int kVecLen_, //
8894
const int kNumCentroids_, const int kNumResCentroids_,
8995
typename Base = copy::AccessInfo<DType>>
@@ -95,8 +101,9 @@ struct QuantGemvKeTraits : public Base {
95101
static constexpr int kTileSize = kTileSize_;
96102

97103
/// allocate shared memory
98-
using SharedStorage = SharedStorageImpl<DType, kTileSize, kVecLen,
99-
kNumCentroids, kNumResCentroids>;
104+
using SharedStorage =
105+
SharedStorageImpl<DType, IdType, ResIdType, kTileSize, kVecLen,
106+
kNumCentroids, kNumResCentroids>;
100107
/// configurations for loading codebooks
101108
using MainCentroidTraits =
102109
CodebookTraits<DType, kThreads, kNumCentroids, kVecLen>;
@@ -131,14 +138,33 @@ struct QuantGemvKeTraits : public Base {
131138

132139
/// configurations for loading tiled indices
133140
static constexpr int kThreadsIndex =
134-
kTileSize * sizeof(uint16_t) / Base::kAccessInBytes;
141+
kTileSize * sizeof(IdType) / Base::kAccessInBytes;
135142
static_assert(kThreadsIndex <= kThreads,
136143
"The current implementation requires that the number of "
137144
"threads used to load a single index tile must be less than or "
138145
"equal to the number of threads in the block.");
139-
using IndexLoader = copy::GlobalToSharedInputLoader<uint16_t, kTileSize>;
140-
// storer is defined for debugging purposes
141-
using IndexStorer = copy::SharedToGlobalInputStorer<uint16_t, kTileSize>;
146+
147+
// TODO(ying): The current implementation requires that the indices for both
148+
// main and residual centroids are stored in the same data type. This will be
149+
// addressed in the next version.
150+
static_assert(std::is_same_v<IdType, ResIdType>,
151+
"The data type of indices for main and residual centroids must "
152+
"be the same.");
153+
using IndexLoader = copy::GlobalToSharedInputLoader<IdType, 2 * kTileSize>;
154+
using IndexStorer = copy::SharedToGlobalInputStorer<IdType, 2 * kTileSize>;
155+
156+
/// configurations for decoding indices
157+
// Ensure the indices can be stored aligned with shared memory banks, and a
158+
// single thread decode at least `kIdsPerBank` indices.
159+
static constexpr int kBankBytes = 4;
160+
static_assert(kBankBytes % sizeof(ResIdType) == 0);
161+
static constexpr int kIdsPerBank = kBankBytes / sizeof(ResIdType);
162+
// how many indices are decoded by a single thread
163+
static_assert(kTileSize % (kThreads * kIdsPerBank) == 0);
164+
static constexpr int kDecodeNumPerThread = kTileSize / kThreads;
165+
166+
using Decoder =
167+
WeightDecoder<DType, IdType, ResIdType, kDecodeNumPerThread, kVecLen>;
142168
};
143169

144170
} // namespace vptq::kernels

0 commit comments

Comments
 (0)