Skip to content

Commit 03b4187

Browse files
authored
memory and bf16 (#23)
- shrink memory - support bf16
1 parent 41db6e0 commit 03b4187

File tree

5 files changed

+100
-63
lines changed

5 files changed

+100
-63
lines changed

csrc/dequant_impl_packed.cu

Lines changed: 72 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
#include <cuda_bf16.h>
54
#include <cmath>
65
#include <math_constants.h>
76
#include <ATen/cuda/CUDAContext.h>
@@ -35,15 +34,15 @@ __global__ void WqA16WithOutliers_PackIndice(
3534
tidx += bidz * cuda::kBlockSize * Do_Reduce;
3635
}
3736
int in_y = bidx;
38-
extern __shared__ scalar_t shared_memory[]; // 3xin_features, dynamic
39-
scalar_t* shared_input = shared_memory; // in_features, dynamic
37+
__shared__ scalar_t shared_memory[1]; // 3xin_features, dynamic
38+
scalar_t* shared_input = shared_memory; // in_features, dynamic
4039
// scalar_t* shared_w_scales = shared_memory+in_features;// in_features, dynamic
4140
scalar_t* shared_w_bias = shared_memory + in_features; // in_features, dynamic
4241
__shared__ float shared_output[GROUPSIZE][cuda::kBlockSize / 32 + 1];
43-
scalar_t tmp_output[GROUPSIZE] = {0};
42+
scalar_t tmp_output[GROUPSIZE];
4443
#pragma unroll
4544
for (int i = 0; i < GROUPSIZE; i++) {
46-
tmp_output[i] = scalar_t(0);
45+
tmp_output[i] = scalar_t(0.0f);
4746
}
4847
input_data = input_data + in_features * bidy;
4948
out = out + out_features * bidy * gridDim.z;
@@ -154,11 +153,7 @@ __global__ void WqA16WithOutliers_PackIndice(
154153
#pragma unroll
155154
for (int gi = 0; gi < GROUPSIZE; gi++) {
156155
float reduce_out = 0.f;
157-
if constexpr (!std::is_same_v<scalar_t, c10::BFloat16>) {
158-
reduce_out = __half2float(tmp_output[gi]);
159-
} else {
160-
reduce_out = __bfloat162float(tmp_output[gi]);
161-
}
156+
reduce_out = cuda::ConvertToFloat(tmp_output[gi]);
162157
reduce_out = cuda::warpReduceSum<32>(reduce_out);
163158
if (landid == 0) {
164159
shared_output[gi][warpid] = reduce_out;
@@ -181,10 +176,11 @@ __global__ void WqA16WithOutliers_PackIndice(
181176
reduce_out = cuda::warpReduceSum<cuda::kBlockSize / 32>(reduce_out);
182177
if (landid == 0 && (in_y * GROUPSIZE + wid) < out_features) {
183178
if constexpr (Do_Reduce) {
184-
out[(wid)*gridDim.z] =
185-
cuda::ConvertFromFloat<scalar_t>(reduce_out) + ((bidz == 0 && bias != 0) ? bias[wid] : scalar_t(0));
179+
out[(wid)*gridDim.z] = cuda::ConvertFromFloat<scalar_t>(reduce_out, scalar_t(0.0f)) +
180+
((bidz == 0 && bias != 0) ? bias[wid] : scalar_t(0.0f));
186181
} else {
187-
out[wid] = cuda::ConvertFromFloat<scalar_t>(reduce_out) + ((bias != 0) ? bias[wid] : scalar_t(0));
182+
out[wid] =
183+
cuda::ConvertFromFloat<scalar_t>(reduce_out, scalar_t(0.0f)) + ((bias != 0) ? bias[wid] : scalar_t(0.0f));
188184
}
189185
}
190186
}
@@ -204,6 +200,7 @@ __global__ void DequantizeWithOutliers_PackIndice(scalar_t* out, const int32_t*
204200
int tid = (bid * cuda::kBlockSize + threadIdx.x);
205201
int in_x = tid % in_features;
206202
int in_y = tid / in_features;
203+
using VecType = typename cuda::TypeVec2<scalar_t>::type;
207204

208205
uint16_t mapped_index_x = invert_perm ? invert_perm[in_x] : in_x;
209206
const scalar_t scale = weight_scale[in_x];
@@ -247,25 +244,25 @@ __global__ void DequantizeWithOutliers_PackIndice(scalar_t* out, const int32_t*
247244
cuda::iterator_packed_tensor<IDXBITS + ResidualBits>((const uint32_t*)q_indice, mappped_inx_in_a_codebook);
248245

249246
const uint16_t base_ind = merged_ind & ((1 << IDXBITS) - 1);
250-
__half2 base[GROUPSIZE / 2];
247+
VecType base[GROUPSIZE / 2];
251248
const scalar_t* centroids_start = centroids + base_ind * GROUPSIZE;
252249
cuda::ldg_vec_x<GROUPSIZE>((uint32_t*)(base), (const uint32_t*)(centroids_start));
253250

254251
if constexpr (ResidualBits > 0) {
255-
__half2 residual[GROUPSIZE / 2];
252+
VecType residual[GROUPSIZE / 2];
256253
merged_ind >>= IDXBITS;
257254
const uint16_t res_ind = merged_ind & ((1 << ResidualBits) - 1);
258255
const scalar_t* residual_centroids_start = residual_centroids + res_ind * GROUPSIZE;
259256
cuda::ldg_vec_x<GROUPSIZE>((uint32_t*)(residual), (const uint32_t*)(residual_centroids_start));
260257
#pragma unroll
261258
for (int i = 0; i < GROUPSIZE / 2; i++) {
262-
base[i] = __hadd2(*(((__half2*)base) + i), *(((__half2*)residual) + i));
259+
base[i] = __hadd2(*(((VecType*)base) + i), *(((VecType*)residual) + i));
263260
}
264261
}
265262

266-
__half2 hres[GROUPSIZE / 2];
267-
__half2 scale2 = __half2(scale, scale);
268-
__half2 bias2 = __half2(bias, bias);
263+
VecType hres[GROUPSIZE / 2];
264+
VecType scale2 = VecType(scale, scale);
265+
VecType bias2 = VecType(bias, bias);
269266
#pragma unroll
270267
for (int i = 0; i < GROUPSIZE / 2; i++) {
271268
hres[i] = __hfma2(base[i], scale2, bias2);
@@ -317,46 +314,61 @@ torch::Tensor lauch_deqantize_outliers_cuda_packkernel(
317314
}
318315
int outliers_indices_size_n1 = outliers_indices.has_value() ? outliers_indices.value().size(-1) : 0;
319316
int outliers_centroids_size_n1 = outliers_centroids.has_value() ? outliers_centroids.value().size(-1) : 1;
320-
using scalar_t = at::Half;
321317

322318
const uint16_t* perm_ptr = perm.has_value() ? (const uint16_t*)(perm.value().data_ptr<int16_t>()) : nullptr;
323319
const int16_t* outliers_indices_ptr =
324320
outliers_indices.has_value() ? outliers_indices.value().data_ptr<int16_t>() : nullptr;
325-
const scalar_t* residual_centroids_ptr =
326-
residual_centroids.has_value() ? residual_centroids.value().data_ptr<scalar_t>() : nullptr;
327-
const scalar_t* outliers_centroids_ptr =
328-
outliers_centroids.has_value() ? outliers_centroids.value().data_ptr<scalar_t>() : nullptr;
329321
auto stream = at::cuda::getCurrentCUDAStream().stream();
330-
#define callDequantWithOutliers(IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits) \
331-
DequantizeWithOutliers_PackIndice<scalar_t, IDXBITS, ResidualBits, BASEGROUP, OUT_OUF_INF> \
332-
<<<blocks, threads, 0, stream>>>(output.data_ptr<scalar_t>(), q_indice.data_ptr<int32_t>(), \
333-
outliers_indices_ptr, centroids.data_ptr<scalar_t>(), residual_centroids_ptr, \
334-
outliers_centroids_ptr, perm_ptr, weight_scale.data_ptr<scalar_t>(), \
335-
weight_bias.data_ptr<scalar_t>(), out_size[0], out_size[1], \
336-
outliers_indices_size_n1, outliers_centroids_size_n1, q_indice.stride(0), \
337-
q_indice.stride(1), centroids.stride(0), q_indice.size(0));
322+
#define callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits) \
323+
{ \
324+
using nv_type = typename C10ToNvType<scalar_t>::type; \
325+
DequantizeWithOutliers_PackIndice<nv_type, IDXBITS, ResidualBits, BASEGROUP, OUT_OUF_INF> \
326+
<<<blocks, threads, 0, stream>>>( \
327+
reinterpret_cast<nv_type*>(output.data_ptr<scalar_t>()), q_indice.data_ptr<int32_t>(), \
328+
outliers_indices_ptr, reinterpret_cast<const nv_type*>(centroids.data_ptr<scalar_t>()), \
329+
residual_centroids.has_value() \
330+
? reinterpret_cast<const nv_type*>(residual_centroids.value().data_ptr<scalar_t>()) \
331+
: nullptr, \
332+
outliers_centroids.has_value() \
333+
? reinterpret_cast<const nv_type*>(outliers_centroids.value().data_ptr<scalar_t>()) \
334+
: nullptr, \
335+
perm_ptr, reinterpret_cast<const nv_type*>(weight_scale.data_ptr<scalar_t>()), \
336+
reinterpret_cast<const nv_type*>(weight_bias.data_ptr<scalar_t>()), out_size[0], out_size[1], \
337+
outliers_indices_size_n1, outliers_centroids_size_n1, q_indice.stride(0), q_indice.stride(1), \
338+
centroids.stride(0), q_indice.size(0)); \
339+
}
340+
341+
#define callDequantWithOutliers_dtype(IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits) \
342+
if (centroids.dtype() == at::ScalarType::Half) { \
343+
using scalar_t = c10::Half; \
344+
callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \
345+
} else { \
346+
using scalar_t = c10::BFloat16; \
347+
callDequantWithOutliers(scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \
348+
}
349+
338350
#define callDequantWithOutliers_bits(BASEGROUP, OUT_OUF_INF, ResidualBits) \
339351
switch (index_bits) { \
340352
case 16: \
341-
callDequantWithOutliers(16, BASEGROUP, OUT_OUF_INF, ResidualBits); \
353+
callDequantWithOutliers_dtype(16, BASEGROUP, OUT_OUF_INF, ResidualBits); \
342354
break; \
343355
case 15: \
344-
callDequantWithOutliers(15, BASEGROUP, OUT_OUF_INF, ResidualBits); \
356+
callDequantWithOutliers_dtype(15, BASEGROUP, OUT_OUF_INF, ResidualBits); \
345357
break; \
346358
case 14: \
347-
callDequantWithOutliers(14, BASEGROUP, OUT_OUF_INF, ResidualBits); \
359+
callDequantWithOutliers_dtype(14, BASEGROUP, OUT_OUF_INF, ResidualBits); \
348360
break; \
349361
case 13: \
350-
callDequantWithOutliers(13, BASEGROUP, OUT_OUF_INF, ResidualBits); \
362+
callDequantWithOutliers_dtype(13, BASEGROUP, OUT_OUF_INF, ResidualBits); \
351363
break; \
352364
case 12: \
353-
callDequantWithOutliers(12, BASEGROUP, OUT_OUF_INF, ResidualBits); \
365+
callDequantWithOutliers_dtype(12, BASEGROUP, OUT_OUF_INF, ResidualBits); \
354366
break; \
355367
case 8: \
356-
callDequantWithOutliers(8, BASEGROUP, OUT_OUF_INF, ResidualBits); \
368+
callDequantWithOutliers_dtype(8, BASEGROUP, OUT_OUF_INF, ResidualBits); \
357369
break; \
358370
case 4: \
359-
callDequantWithOutliers(4, BASEGROUP, OUT_OUF_INF, ResidualBits); \
371+
callDequantWithOutliers_dtype(4, BASEGROUP, OUT_OUF_INF, ResidualBits); \
360372
break; \
361373
default: \
362374
TORCH_CHECK(false, "unspportetd index_bits:" + std::to_string(index_bits)); \
@@ -469,22 +481,32 @@ torch::Tensor lauch_gemv_outliers_cuda_packkernel(
469481
const uint16_t* outliers_indices_ptr =
470482
(const uint16_t*)(outliers_indices.has_value() ? outliers_indices.value().data_ptr<int16_t>() : nullptr);
471483
const uint16_t* perm_ptr = perm.has_value() ? (const uint16_t*)(perm.value().data_ptr<int16_t>()) : nullptr;
472-
const c10::Half* bias_ptr = bias.has_value() ? (bias.value().data_ptr<c10::Half>()) : nullptr;
473-
#define CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \
474-
WqA16WithOutliers_PackIndice<scalar_t, IDXBITS, ResidualBits, BASEGROUP, 4, Do_Reduce> \
475-
<<<blocks, threads, shared_memory_size, stream>>>( \
476-
out_buf.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), q_indice.data_ptr<int32_t>(), \
477-
outliers_indices_ptr, centroids.data_ptr<scalar_t>(), \
478-
residual_centroids.has_value() ? residual_centroids.value().data_ptr<scalar_t>() : nullptr, \
479-
outliers_centroids.has_value() ? outliers_centroids.value().data_ptr<scalar_t>() : nullptr, perm_ptr, \
480-
weight_scale.data_ptr<scalar_t>(), weight_bias.data_ptr<scalar_t>(), bias_ptr, out_features, in_features, \
481-
outliers_indices_size_n1, q_indice.stride(0), q_indice.stride(1), centroids.stride(0), q_indice.size(0));
484+
#define CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \
485+
{ \
486+
using nv_type = typename C10ToNvType<scalar_t>::type; \
487+
WqA16WithOutliers_PackIndice<nv_type, IDXBITS, ResidualBits, BASEGROUP, 4, Do_Reduce> \
488+
<<<blocks, threads, shared_memory_size, stream>>>( \
489+
reinterpret_cast<nv_type*>(out_buf.data_ptr<scalar_t>()), \
490+
reinterpret_cast<const nv_type*>(input.data_ptr<scalar_t>()), q_indice.data_ptr<int32_t>(), \
491+
outliers_indices_ptr, reinterpret_cast<const nv_type*>(centroids.data_ptr<scalar_t>()), \
492+
residual_centroids.has_value() \
493+
? reinterpret_cast<const nv_type*>(residual_centroids.value().data_ptr<scalar_t>()) \
494+
: nullptr, \
495+
outliers_centroids.has_value() \
496+
? reinterpret_cast<const nv_type*>(outliers_centroids.value().data_ptr<scalar_t>()) \
497+
: nullptr, \
498+
perm_ptr, reinterpret_cast<const nv_type*>(weight_scale.data_ptr<scalar_t>()), \
499+
reinterpret_cast<const nv_type*>(weight_bias.data_ptr<scalar_t>()), \
500+
bias.has_value() ? reinterpret_cast<const nv_type*>(bias.value().data_ptr<scalar_t>()) : nullptr, \
501+
out_features, in_features, outliers_indices_size_n1, q_indice.stride(0), q_indice.stride(1), \
502+
centroids.stride(0), q_indice.size(0)); \
503+
}
482504
#define CallWqA16kernel_dtype(out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits) \
483505
if (input.dtype() == at::ScalarType::Half) { \
484506
using scalar_t = c10::Half; \
485507
CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \
486508
} else { \
487-
using scalar_t = c10::Half; \
509+
using scalar_t = c10::BFloat16; \
488510
CallWqA16kernel(scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \
489511
}
490512
#define CallWqA16kernel_bits(out_buf, BASEGROUP, Do_Reduce, ResidualBits) \

csrc/utils.cuh

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,23 @@ struct TypeVec2<__nv_bfloat16> {
2121
};
2222

2323
template <typename T>
24-
T __device__ __forceinline__ ConvertFromFloat(float v) {
25-
if constexpr (std::is_same_v<T, __nv_bfloat16>) {
26-
return __float2bfloat16(v);
24+
T __device__ __forceinline__ ConvertFromFloat(float v, T vv) {
25+
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
26+
return vv = __float2bfloat16(v);
27+
} else {
28+
static_assert(std::is_same<T, __half>::value);
29+
return vv = __float2half(v);
2730
}
28-
return __float2half(v);
2931
}
3032

3133
template <typename T>
32-
T __device__ __forceinline__ ConvertToFloat(float v) {
33-
if constexpr (std::is_same_v<T, __nv_bfloat16>) {
34+
float __device__ __forceinline__ ConvertToFloat(T v) {
35+
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
3436
return __bfloat162float(v);
37+
} else {
38+
static_assert(std::is_same<T, __half>::value);
39+
return __half2float(v);
3540
}
36-
return __half2float(v);
3741
}
3842

3943
template <unsigned int WarpSize>
@@ -122,7 +126,7 @@ __device__ __forceinline__ uint32_t iterator_packed_tensor(const uint32_t* ptr,
122126
int second = end_bits / 32;
123127
start_bits = start_bits % 32;
124128
end_bits = end_bits % 32;
125-
uint32_t v = (ptr[first] >> (start_bits)) & ((1 << WBITS) - 1);
129+
uint32_t v = (ptr[first] >> (start_bits)) & (uint32_t(1 << WBITS) - 1);
126130
if (first == second || end_bits == 0) {
127131
return v;
128132
} else {

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_version():
2020

2121

2222
def build_cuda_extensions():
23-
compute_capabilities = [70, 75, 80, 86, 90]
23+
compute_capabilities = [80, 86, 90]
2424
arch_flags = []
2525
for cap in compute_capabilities:
2626
arch_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]

vptq/app_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,13 @@ def main():
122122
args = get_valid_args(parser)
123123
print(args)
124124

125+
#hf_args = {"dtype": torch.bfloat16}
125126
hf_args = {}
126127
token = os.getenv("HF_TOKEN", None)
127128
if token is not None:
128129
hf_args["token"] = token
129130

130-
model = VQAutoModelQuantization.from_pretrained(args.model, device_map="auto", **hf_args).half()
131+
model = VQAutoModelQuantization.from_pretrained(args.model, device_map="auto", **hf_args)
131132
tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer or args.model, **hf_args)
132133

133134
chat_loop(model, tokenizer, args)

0 commit comments

Comments
 (0)