Skip to content

Commit ee981fc

Browse files
committed
Add storer and verify loaded codebook.
1 parent 24651e8 commit ee981fc

9 files changed

+193
-70
lines changed

csrc/config.cuh

+10-12
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Licensed under the MIT License.
33
#pragma once
44

5-
#if defined(__CUDA_ARCH__)
5+
#if (defined(__CUDA_ARCH__) || defined(USE_ROCM))
66
#define HOST_DEVICE __forceinline__ __host__ __device__
77
#define DEVICE __forceinline__ __device__
88
#define HOST __forceinline__ __host__
@@ -12,30 +12,28 @@
1212
#define HOST inline
1313
#endif
1414

15-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
16-
#define CP_ASYNC_SM80_ENABLED
17-
#endif
18-
1915
#if defined(USE_ROCM)
2016
#include <hip/hip_bf16.h>
2117
#include <hip/hip_fp16.h>
2218

23-
#define VPTQ_LDG(arg) __ldg(arg)
24-
#define SHFL_DOWN(val, offset) __shfl_down(val, offset)
25-
#define WARP_SIZE warpSize
26-
2719
typedef __hip_bfloat162 __bfloat162;
2820
typedef __hip_bfloat16 __bfloat16;
2921

22+
#define VPTQ_LDG(arg) __ldg(arg)
23+
#define SHFL_DOWN(val, offset) __shfl_down(val, offset)
24+
#define WARP_SIZE warpSize
3025
#else
3126
#include <cuda_bf16.h>
3227
#include <cuda_fp16.h>
3328

29+
typedef __nv_bfloat162 __bfloat162;
30+
typedef __nv_bfloat16 __bfloat16;
31+
3432
#define WARP_SIZE 32
3533
#define VPTQ_LDG(arg) *(arg)
3634
#define SHFL_DOWN(val, offset) __shfl_down_sync(0xffffffff, val, offset)
35+
#endif
3736

38-
typedef __nv_bfloat162 __bfloat162;
39-
typedef __nv_bfloat16 __bfloat16;
40-
37+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
38+
#define CP_ASYNC_SM80_ENABLED
4139
#endif

csrc/copy/copy.cuh

+61-4
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@ namespace vptq::copy {
1010
using namespace cute;
1111

1212
/// TODO(ying); the current implementation supports load row-major data only.
13-
template <typename DType, const int kThreads, const int64_t kRows,
14-
const int64_t kCols, typename Base = AccessInfo<DType>>
13+
template <typename DType, const int kThreads, const int64_t kRows_,
14+
const int64_t kCols_, typename Base = AccessInfo<DType>>
1515
struct GlobalToSharedLoader : public Base {
16+
static constexpr int kRows = kRows_;
17+
static constexpr int kCols = kCols_;
18+
1619
DEVICE void operator()(const DType* src_, DType* dst_) {
1720
int tid = threadIdx.x;
1821

@@ -32,12 +35,10 @@ struct GlobalToSharedLoader : public Base {
3235
}
3336

3437
private:
35-
// source
3638
using GlobalLayout =
3739
cute::Layout<Shape<Int<kRows>, Int<kCols>>, Stride<Int<kCols>, _1>>;
3840
GlobalLayout src_layout_;
3941

40-
// destination
4142
using SharedLayout =
4243
cute::Layout<Shape<Int<kRows>, Int<kCols>>, Stride<Int<kCols>, _1>>;
4344

@@ -69,4 +70,60 @@ private:
6970
TiledCopy tiled_copy_;
7071
};
7172

73+
/// TODO(ying); the current implementation supports load row-major data only.
74+
template <typename DType, const int kThreads, const int64_t kRows_,
75+
const int64_t kCols_, typename Base = AccessInfo<DType>>
76+
struct SharedToGlobalStorer : public Base {
77+
static constexpr int kRows = kRows_;
78+
static constexpr int kCols = kCols_;
79+
80+
DEVICE void operator()(const DType* src_, DType* dst_) {
81+
int tid = threadIdx.x;
82+
83+
auto stile = make_tensor(make_smem_ptr(src_), src_layout_);
84+
auto gtile = make_tensor(make_gmem_ptr(dst_), dst_layout_);
85+
86+
auto loader = tiled_copy_.get_thread_slice(tid);
87+
88+
auto src = loader.partition_S(stile);
89+
auto dst = loader.partition_D(gtile);
90+
91+
#pragma unroll
92+
for (int i = 0; i < int(size<1>(src)); ++i)
93+
#pragma unroll
94+
for (int j = 0; j < int(size<2>(src)); ++j)
95+
cute::copy(tiled_copy_, src(cute::_, i, j), dst(cute::_, i, j));
96+
}
97+
98+
private:
99+
using SharedLayout =
100+
cute::Layout<Shape<Int<kRows>, Int<kCols>>, Stride<Int<kCols>, _1>>;
101+
// using LayoutAtom =
102+
// decltype(composition(cute::Swizzle<2, 3, 3>{},
103+
// cute::Layout<Shape<_4, _64>, Stride<_64, _1>>{}));
104+
// using SharedLayout = decltype(tile_to_shape(
105+
// LayoutAtom{}, Shape<Int<kRows>, Int<kCols>>{}, cute::Step<_2, _1>{}));
106+
SharedLayout src_layout_;
107+
108+
using GlobalLayout =
109+
cute::Layout<Shape<Int<kRows>, Int<kCols>>, Stride<Int<kCols>, _1>>;
110+
GlobalLayout dst_layout_;
111+
112+
// tiled copy
113+
static constexpr int kThreadCols =
114+
kCols * Base::kElementBits / Base::kAccessInBits;
115+
static_assert(kThreadCols > 0);
116+
static constexpr int kThreadRows = kThreads / kThreadCols;
117+
118+
using ThreadLayout = cute::Layout<Shape<Int<kThreadRows>, Int<kThreadCols>>,
119+
Stride<Int<kThreadCols>, _1>>;
120+
using ValueLayout = cute::Layout<Shape<_1, _8>>;
121+
122+
using CopyInst = Copy_Atom<DefaultCopy, DType>;
123+
124+
using TiledCopy =
125+
decltype(make_tiled_copy(CopyInst{}, ThreadLayout{}, ValueLayout{}));
126+
TiledCopy tiled_copy_;
127+
};
128+
72129
} // namespace vptq::copy

csrc/quant_gemv.cuh

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33
#pragma once
44

5+
#include "util/convert.cuh"
56
#include "util/cuda_utils.cuh"
67

78
namespace vptq {
@@ -146,7 +147,7 @@ __global__ void WqA16WithOutliers_PackIndice(
146147
#pragma unroll
147148
for (int gi = 0; gi < GROUPSIZE; ++gi) {
148149
float reduce_out = 0.f;
149-
reduce_out = cuda::ConvertToFloat(tmp_output[gi]);
150+
reduce_out = to_float(tmp_output[gi]);
150151
reduce_out = cuda::warpReduceSum<WARP_SIZE>(reduce_out);
151152
if (landid == 0) {
152153
shared_output[gi][warpid] = reduce_out;
@@ -172,10 +173,10 @@ __global__ void WqA16WithOutliers_PackIndice(
172173
if (landid == 0 && (in_y * GROUPSIZE + wid) < out_features) {
173174
if constexpr (Do_Reduce) {
174175
out[(wid)*gridDim.z] =
175-
cuda::ConvertFromFloat<scalar_t>(reduce_out, zero_value) +
176+
from_float<scalar_t>(reduce_out, zero_value) +
176177
((bidz == 0 && bias != 0) ? bias[wid] : zero_value);
177178
} else {
178-
out[wid] = cuda::ConvertFromFloat<scalar_t>(reduce_out, zero_value) +
179+
out[wid] = from_float<scalar_t>(reduce_out, zero_value) +
179180
((bias != 0) ? bias[wid] : zero_value);
180181
}
181182
}

csrc/quant_gemv_v2.cu

+11-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ struct QuantGemvKeTraits : public Base {
1717
using LoaderG2S =
1818
copy::GlobalToSharedLoader<DType, kThreads, kNumCentroids / kPackedVecs,
1919
kVecLen * kPackedVecs>;
20+
using StorerS2G =
21+
copy::SharedToGlobalStorer<DType, kThreads, kNumCentroids / kPackedVecs,
22+
kVecLen * kPackedVecs>;
2023
};
2124

2225
/**
@@ -83,7 +86,11 @@ torch::Tensor quant_gemv_v2(
8386
"Supported vector length in vectorized quantization: 4, 8, 12, or 16.");
8487

8588
torch::Tensor output;
86-
output = at::empty({in_features, out_features}, centroids.options());
89+
// output = at::empty({in_features, out_features}, centroids.options());
90+
91+
// NOTE: this is for test!!!
92+
output =
93+
at::empty({num_codebooks, num_centroids, vec_len}, centroids.options());
8794

8895
auto stream = at::cuda::getCurrentCUDAStream().stream();
8996

@@ -117,7 +124,9 @@ torch::Tensor quant_gemv_v2(
117124

118125
std::cout << "centroid number: " << kNumCentroids
119126
<< "; vector length: " << kVecLen
120-
<< "; smem_size: " << smem_size / 1024 << "KB" << std::endl;
127+
<< "; smem_size: " << smem_size / 1024
128+
<< "KB; max smem size: " << kMaxSmemPerBlock / 1024 << "KB"
129+
<< std::endl;
121130

122131
using Config =
123132
QuantGemvKeTraits<nv_type, kThreads, kNumCentroids, kVecLen>;

csrc/quant_gemv_v2.cuh

+4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#pragma once
44

55
#include "copy/sync.cuh"
6+
#include "util/convert.cuh"
67
#include "util/debug.cuh"
78

89
namespace vptq {
@@ -22,10 +23,13 @@ __global__ void quant_gemv_v2_kernel(
2223
auto* buf = reinterpret_cast<DType*>(buf_);
2324

2425
typename KeTraits::LoaderG2S loader;
26+
typename KeTraits::StorerS2G storer;
27+
2528
loader(centroids, buf);
2629
__copy_async();
2730
__syncthreads();
2831

32+
storer(buf, output);
2933
return;
3034
}
3135

csrc/util/convert.cuh

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
#pragma once
4+
5+
#include "config.cuh"
6+
7+
namespace vptq {
8+
template <typename T>
9+
T DEVICE from_float(float v, T vv) {
10+
(void)(vv);
11+
if constexpr (std::is_same<T, __bfloat16>::value) {
12+
return vv = __float2bfloat16(v);
13+
} else if constexpr (std::is_same<T, float>::value) {
14+
return vv = v;
15+
} else {
16+
static_assert(std::is_same<T, __half>::value);
17+
return vv = __float2half(v);
18+
}
19+
}
20+
21+
template <typename T>
22+
float DEVICE to_float(T v) {
23+
if constexpr (std::is_same<T, __bfloat16>::value) {
24+
return __bfloat162float(v);
25+
} else if constexpr (std::is_same<T, float>::value) {
26+
return v;
27+
} else {
28+
static_assert(std::is_same<T, __half>::value);
29+
return __half2float(v);
30+
}
31+
}
32+
} // namespace vptq

csrc/util/cuda_utils.cuh

+7-33
Original file line numberDiff line numberDiff line change
@@ -45,33 +45,8 @@ struct TypeVec2<float> {
4545
typedef float2 type;
4646
};
4747

48-
template <typename T>
49-
T __device__ __forceinline__ ConvertFromFloat(float v, T vv) {
50-
(void)(vv);
51-
if constexpr (std::is_same<T, __bfloat16>::value) {
52-
return vv = __float2bfloat16(v);
53-
} else if constexpr (std::is_same<T, float>::value) {
54-
return vv = v;
55-
} else {
56-
static_assert(std::is_same<T, __half>::value);
57-
return vv = __float2half(v);
58-
}
59-
}
60-
61-
template <typename T>
62-
float __device__ __forceinline__ ConvertToFloat(T v) {
63-
if constexpr (std::is_same<T, __bfloat16>::value) {
64-
return __bfloat162float(v);
65-
} else if constexpr (std::is_same<T, float>::value) {
66-
return v;
67-
} else {
68-
static_assert(std::is_same<T, __half>::value);
69-
return __half2float(v);
70-
}
71-
}
72-
7348
template <unsigned int WarpSize>
74-
__device__ __forceinline__ float warpReduceSum(float sum) {
49+
DEVICE float warpReduceSum(float sum) {
7550
if constexpr (WarpSize >= 64)
7651
sum += SHFL_DOWN(sum, 32); // 0-16, 1-17, 2-18, etc.
7752
if constexpr (WarpSize >= 32)
@@ -86,8 +61,8 @@ __device__ __forceinline__ float warpReduceSum(float sum) {
8661
}
8762

8863
template <int GROUPSIZE, typename T>
89-
__device__ __forceinline__ void ldg_vec_x(
90-
T* __restrict__ dst_t32, const uint32_t* __restrict__ src_u32) {
64+
DEVICE void ldg_vec_x(T* __restrict__ dst_t32,
65+
const uint32_t* __restrict__ src_u32) {
9166
uint32_t* dst_u32 = (uint32_t*)dst_t32;
9267
if constexpr (std::is_same<T, float>::value ||
9368
std::is_same<T, float2>::value) {
@@ -133,8 +108,7 @@ __device__ __forceinline__ void ldg_vec_x(
133108
}
134109

135110
template <int WBITS>
136-
__device__ __forceinline__ uint32_t iterator_packed_tensor(const uint32_t* ptr,
137-
int idx) {
111+
DEVICE uint32_t iterator_packed_tensor(const uint32_t* ptr, int idx) {
138112
if constexpr (WBITS == 32) {
139113
return ptr[idx];
140114
} else if constexpr (WBITS == 16) {
@@ -160,7 +134,7 @@ __device__ __forceinline__ uint32_t iterator_packed_tensor(const uint32_t* ptr,
160134
} // namespace cuda
161135

162136
template <typename T>
163-
T __device__ __forceinline__ FMA2(T a, T b, T c) {
137+
T DEVICE FMA2(T a, T b, T c) {
164138
if constexpr (std::is_same<T, __bfloat162>::value) {
165139
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
166140
float x =
@@ -197,7 +171,7 @@ T __device__ __forceinline__ FMA(T a, T b, T c) {
197171
}
198172

199173
template <typename T>
200-
T __device__ __forceinline__ ADD2(T a, T b) {
174+
T DEVICE ADD2(T a, T b) {
201175
if constexpr (std::is_same<T, __bfloat162>::value) {
202176
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(USE_ROCM)
203177
float x = __bfloat162float(a.x) + __bfloat162float(b.x);
@@ -215,7 +189,7 @@ T __device__ __forceinline__ ADD2(T a, T b) {
215189
}
216190

217191
template <typename T>
218-
T __device__ __forceinline__ ZERO_VALUE(T a) {
192+
T DEVICE ZERO_VALUE(T a) {
219193
if constexpr (std::is_same<T, __bfloat16>::value) {
220194
#if defined(USE_ROCM)
221195
return __float2bfloat16(0.0f);

0 commit comments

Comments
 (0)