1
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
2
// Licensed under the MIT License.
3
3
4
- #include < cuda_bf16.h>
5
4
#include < cmath>
6
5
#include < math_constants.h>
7
6
#include < ATen/cuda/CUDAContext.h>
@@ -35,15 +34,15 @@ __global__ void WqA16WithOutliers_PackIndice(
35
34
tidx += bidz * cuda::kBlockSize * Do_Reduce;
36
35
}
37
36
int in_y = bidx;
38
- extern __shared__ scalar_t shared_memory[]; // 3xin_features, dynamic
39
- scalar_t * shared_input = shared_memory; // in_features, dynamic
37
+ __shared__ scalar_t shared_memory[1 ]; // 3xin_features, dynamic
38
+ scalar_t * shared_input = shared_memory; // in_features, dynamic
40
39
// scalar_t* shared_w_scales = shared_memory+in_features;// in_features, dynamic
41
40
scalar_t * shared_w_bias = shared_memory + in_features; // in_features, dynamic
42
41
__shared__ float shared_output[GROUPSIZE][cuda::kBlockSize / 32 + 1 ];
43
- scalar_t tmp_output[GROUPSIZE] = { 0 } ;
42
+ scalar_t tmp_output[GROUPSIZE];
44
43
#pragma unroll
45
44
for (int i = 0 ; i < GROUPSIZE; i++) {
46
- tmp_output[i] = scalar_t (0 );
45
+ tmp_output[i] = scalar_t (0 . 0f );
47
46
}
48
47
input_data = input_data + in_features * bidy;
49
48
out = out + out_features * bidy * gridDim .z ;
@@ -154,11 +153,7 @@ __global__ void WqA16WithOutliers_PackIndice(
154
153
#pragma unroll
155
154
for (int gi = 0 ; gi < GROUPSIZE; gi++) {
156
155
float reduce_out = 0 .f ;
157
- if constexpr (!std::is_same_v<scalar_t , c10::BFloat16>) {
158
- reduce_out = __half2float (tmp_output[gi]);
159
- } else {
160
- reduce_out = __bfloat162float (tmp_output[gi]);
161
- }
156
+ reduce_out = cuda::ConvertToFloat (tmp_output[gi]);
162
157
reduce_out = cuda::warpReduceSum<32 >(reduce_out);
163
158
if (landid == 0 ) {
164
159
shared_output[gi][warpid] = reduce_out;
@@ -181,10 +176,11 @@ __global__ void WqA16WithOutliers_PackIndice(
181
176
reduce_out = cuda::warpReduceSum<cuda::kBlockSize / 32 >(reduce_out);
182
177
if (landid == 0 && (in_y * GROUPSIZE + wid) < out_features) {
183
178
if constexpr (Do_Reduce) {
184
- out[(wid)*gridDim .z ] =
185
- cuda::ConvertFromFloat< scalar_t >(reduce_out) + ((bidz == 0 && bias != 0 ) ? bias[wid] : scalar_t (0 ));
179
+ out[(wid)*gridDim .z ] = cuda::ConvertFromFloat< scalar_t >(reduce_out, scalar_t ( 0 . 0f )) +
180
+ ((bidz == 0 && bias != 0 ) ? bias[wid] : scalar_t (0 . 0f ));
186
181
} else {
187
- out[wid] = cuda::ConvertFromFloat<scalar_t >(reduce_out) + ((bias != 0 ) ? bias[wid] : scalar_t (0 ));
182
+ out[wid] =
183
+ cuda::ConvertFromFloat<scalar_t >(reduce_out, scalar_t (0 .0f )) + ((bias != 0 ) ? bias[wid] : scalar_t (0 .0f ));
188
184
}
189
185
}
190
186
}
@@ -204,6 +200,7 @@ __global__ void DequantizeWithOutliers_PackIndice(scalar_t* out, const int32_t*
204
200
int tid = (bid * cuda::kBlockSize + threadIdx .x );
205
201
int in_x = tid % in_features;
206
202
int in_y = tid / in_features;
203
+ using VecType = typename cuda::TypeVec2<scalar_t >::type;
207
204
208
205
uint16_t mapped_index_x = invert_perm ? invert_perm[in_x] : in_x;
209
206
const scalar_t scale = weight_scale[in_x];
@@ -247,25 +244,25 @@ __global__ void DequantizeWithOutliers_PackIndice(scalar_t* out, const int32_t*
247
244
cuda::iterator_packed_tensor<IDXBITS + ResidualBits>((const uint32_t *)q_indice, mappped_inx_in_a_codebook);
248
245
249
246
const uint16_t base_ind = merged_ind & ((1 << IDXBITS) - 1 );
250
- __half2 base[GROUPSIZE / 2 ];
247
+ VecType base[GROUPSIZE / 2 ];
251
248
const scalar_t * centroids_start = centroids + base_ind * GROUPSIZE;
252
249
cuda::ldg_vec_x<GROUPSIZE>((uint32_t *)(base), (const uint32_t *)(centroids_start));
253
250
254
251
if constexpr (ResidualBits > 0 ) {
255
- __half2 residual[GROUPSIZE / 2 ];
252
+ VecType residual[GROUPSIZE / 2 ];
256
253
merged_ind >>= IDXBITS;
257
254
const uint16_t res_ind = merged_ind & ((1 << ResidualBits) - 1 );
258
255
const scalar_t * residual_centroids_start = residual_centroids + res_ind * GROUPSIZE;
259
256
cuda::ldg_vec_x<GROUPSIZE>((uint32_t *)(residual), (const uint32_t *)(residual_centroids_start));
260
257
#pragma unroll
261
258
for (int i = 0 ; i < GROUPSIZE / 2 ; i++) {
262
- base[i] = __hadd2 (*(((__half2 *)base) + i), *(((__half2 *)residual) + i));
259
+ base[i] = __hadd2 (*(((VecType *)base) + i), *(((VecType *)residual) + i));
263
260
}
264
261
}
265
262
266
- __half2 hres[GROUPSIZE / 2 ];
267
- __half2 scale2 = __half2 (scale, scale);
268
- __half2 bias2 = __half2 (bias, bias);
263
+ VecType hres[GROUPSIZE / 2 ];
264
+ VecType scale2 = VecType (scale, scale);
265
+ VecType bias2 = VecType (bias, bias);
269
266
#pragma unroll
270
267
for (int i = 0 ; i < GROUPSIZE / 2 ; i++) {
271
268
hres[i] = __hfma2 (base[i], scale2, bias2);
@@ -317,46 +314,61 @@ torch::Tensor lauch_deqantize_outliers_cuda_packkernel(
317
314
}
318
315
int outliers_indices_size_n1 = outliers_indices.has_value () ? outliers_indices.value ().size (-1 ) : 0 ;
319
316
int outliers_centroids_size_n1 = outliers_centroids.has_value () ? outliers_centroids.value ().size (-1 ) : 1 ;
320
- using scalar_t = at::Half;
321
317
322
318
const uint16_t * perm_ptr = perm.has_value () ? (const uint16_t *)(perm.value ().data_ptr <int16_t >()) : nullptr ;
323
319
const int16_t * outliers_indices_ptr =
324
320
outliers_indices.has_value () ? outliers_indices.value ().data_ptr <int16_t >() : nullptr ;
325
- const scalar_t * residual_centroids_ptr =
326
- residual_centroids.has_value () ? residual_centroids.value ().data_ptr <scalar_t >() : nullptr ;
327
- const scalar_t * outliers_centroids_ptr =
328
- outliers_centroids.has_value () ? outliers_centroids.value ().data_ptr <scalar_t >() : nullptr ;
329
321
auto stream = at::cuda::getCurrentCUDAStream ().stream ();
330
- #define callDequantWithOutliers (IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits ) \
331
- DequantizeWithOutliers_PackIndice<scalar_t , IDXBITS, ResidualBits, BASEGROUP, OUT_OUF_INF> \
332
- <<<blocks, threads, 0 , stream>>> (output.data_ptr <scalar_t >(), q_indice.data_ptr <int32_t >(), \
333
- outliers_indices_ptr, centroids.data_ptr <scalar_t >(), residual_centroids_ptr, \
334
- outliers_centroids_ptr, perm_ptr, weight_scale.data_ptr <scalar_t >(), \
335
- weight_bias.data_ptr <scalar_t >(), out_size[0 ], out_size[1 ], \
336
- outliers_indices_size_n1, outliers_centroids_size_n1, q_indice.stride (0 ), \
337
- q_indice.stride (1 ), centroids.stride (0 ), q_indice.size (0 ));
322
+ #define callDequantWithOutliers (scalar_t, IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits ) \
323
+ { \
324
+ using nv_type = typename C10ToNvType<scalar_t >::type; \
325
+ DequantizeWithOutliers_PackIndice<nv_type, IDXBITS, ResidualBits, BASEGROUP, OUT_OUF_INF> \
326
+ <<<blocks, threads, 0 , stream>>> ( \
327
+ reinterpret_cast <nv_type*>(output.data_ptr <scalar_t >()), q_indice.data_ptr <int32_t >(), \
328
+ outliers_indices_ptr, reinterpret_cast <const nv_type*>(centroids.data_ptr <scalar_t >()), \
329
+ residual_centroids.has_value () \
330
+ ? reinterpret_cast <const nv_type*>(residual_centroids.value ().data_ptr <scalar_t >()) \
331
+ : nullptr , \
332
+ outliers_centroids.has_value () \
333
+ ? reinterpret_cast <const nv_type*>(outliers_centroids.value ().data_ptr <scalar_t >()) \
334
+ : nullptr , \
335
+ perm_ptr, reinterpret_cast <const nv_type*>(weight_scale.data_ptr <scalar_t >()), \
336
+ reinterpret_cast <const nv_type*>(weight_bias.data_ptr <scalar_t >()), out_size[0 ], out_size[1 ], \
337
+ outliers_indices_size_n1, outliers_centroids_size_n1, q_indice.stride (0 ), q_indice.stride (1 ), \
338
+ centroids.stride (0 ), q_indice.size (0 )); \
339
+ }
340
+
341
+ #define callDequantWithOutliers_dtype (IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits ) \
342
+ if (centroids.dtype () == at::ScalarType::Half) { \
343
+ using scalar_t = c10::Half; \
344
+ callDequantWithOutliers (scalar_t , IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \
345
+ } else { \
346
+ using scalar_t = c10::BFloat16; \
347
+ callDequantWithOutliers (scalar_t , IDXBITS, BASEGROUP, OUT_OUF_INF, ResidualBits); \
348
+ }
349
+
338
350
#define callDequantWithOutliers_bits (BASEGROUP, OUT_OUF_INF, ResidualBits ) \
339
351
switch (index_bits) { \
340
352
case 16 : \
341
- callDequantWithOutliers (16 , BASEGROUP, OUT_OUF_INF, ResidualBits); \
353
+ callDequantWithOutliers_dtype (16 , BASEGROUP, OUT_OUF_INF, ResidualBits); \
342
354
break ; \
343
355
case 15 : \
344
- callDequantWithOutliers (15 , BASEGROUP, OUT_OUF_INF, ResidualBits); \
356
+ callDequantWithOutliers_dtype (15 , BASEGROUP, OUT_OUF_INF, ResidualBits); \
345
357
break ; \
346
358
case 14 : \
347
- callDequantWithOutliers (14 , BASEGROUP, OUT_OUF_INF, ResidualBits); \
359
+ callDequantWithOutliers_dtype (14 , BASEGROUP, OUT_OUF_INF, ResidualBits); \
348
360
break ; \
349
361
case 13 : \
350
- callDequantWithOutliers (13 , BASEGROUP, OUT_OUF_INF, ResidualBits); \
362
+ callDequantWithOutliers_dtype (13 , BASEGROUP, OUT_OUF_INF, ResidualBits); \
351
363
break ; \
352
364
case 12 : \
353
- callDequantWithOutliers (12 , BASEGROUP, OUT_OUF_INF, ResidualBits); \
365
+ callDequantWithOutliers_dtype (12 , BASEGROUP, OUT_OUF_INF, ResidualBits); \
354
366
break ; \
355
367
case 8 : \
356
- callDequantWithOutliers (8 , BASEGROUP, OUT_OUF_INF, ResidualBits); \
368
+ callDequantWithOutliers_dtype (8 , BASEGROUP, OUT_OUF_INF, ResidualBits); \
357
369
break ; \
358
370
case 4 : \
359
- callDequantWithOutliers (4 , BASEGROUP, OUT_OUF_INF, ResidualBits); \
371
+ callDequantWithOutliers_dtype (4 , BASEGROUP, OUT_OUF_INF, ResidualBits); \
360
372
break ; \
361
373
default : \
362
374
TORCH_CHECK (false , " unspportetd index_bits:" + std::to_string (index_bits)); \
@@ -469,22 +481,32 @@ torch::Tensor lauch_gemv_outliers_cuda_packkernel(
469
481
const uint16_t * outliers_indices_ptr =
470
482
(const uint16_t *)(outliers_indices.has_value () ? outliers_indices.value ().data_ptr <int16_t >() : nullptr );
471
483
const uint16_t * perm_ptr = perm.has_value () ? (const uint16_t *)(perm.value ().data_ptr <int16_t >()) : nullptr ;
472
- const c10::Half* bias_ptr = bias.has_value () ? (bias.value ().data_ptr <c10::Half>()) : nullptr ;
473
- #define CallWqA16kernel (scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits ) \
474
- WqA16WithOutliers_PackIndice<scalar_t , IDXBITS, ResidualBits, BASEGROUP, 4 , Do_Reduce> \
475
- <<<blocks, threads, shared_memory_size, stream>>> ( \
476
- out_buf.data_ptr <scalar_t >(), input.data_ptr <scalar_t >(), q_indice.data_ptr <int32_t >(), \
477
- outliers_indices_ptr, centroids.data_ptr <scalar_t >(), \
478
- residual_centroids.has_value () ? residual_centroids.value ().data_ptr <scalar_t >() : nullptr , \
479
- outliers_centroids.has_value () ? outliers_centroids.value ().data_ptr <scalar_t >() : nullptr , perm_ptr, \
480
- weight_scale.data_ptr <scalar_t >(), weight_bias.data_ptr <scalar_t >(), bias_ptr, out_features, in_features, \
481
- outliers_indices_size_n1, q_indice.stride (0 ), q_indice.stride (1 ), centroids.stride (0 ), q_indice.size (0 ));
484
+ #define CallWqA16kernel (scalar_t, out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits ) \
485
+ { \
486
+ using nv_type = typename C10ToNvType<scalar_t >::type; \
487
+ WqA16WithOutliers_PackIndice<nv_type, IDXBITS, ResidualBits, BASEGROUP, 4 , Do_Reduce> \
488
+ <<<blocks, threads, shared_memory_size, stream>>> ( \
489
+ reinterpret_cast <nv_type*>(out_buf.data_ptr <scalar_t >()), \
490
+ reinterpret_cast <const nv_type*>(input.data_ptr <scalar_t >()), q_indice.data_ptr <int32_t >(), \
491
+ outliers_indices_ptr, reinterpret_cast <const nv_type*>(centroids.data_ptr <scalar_t >()), \
492
+ residual_centroids.has_value () \
493
+ ? reinterpret_cast <const nv_type*>(residual_centroids.value ().data_ptr <scalar_t >()) \
494
+ : nullptr , \
495
+ outliers_centroids.has_value () \
496
+ ? reinterpret_cast <const nv_type*>(outliers_centroids.value ().data_ptr <scalar_t >()) \
497
+ : nullptr , \
498
+ perm_ptr, reinterpret_cast <const nv_type*>(weight_scale.data_ptr <scalar_t >()), \
499
+ reinterpret_cast <const nv_type*>(weight_bias.data_ptr <scalar_t >()), \
500
+ bias.has_value () ? reinterpret_cast <const nv_type*>(bias.value ().data_ptr <scalar_t >()) : nullptr , \
501
+ out_features, in_features, outliers_indices_size_n1, q_indice.stride (0 ), q_indice.stride (1 ), \
502
+ centroids.stride (0 ), q_indice.size (0 )); \
503
+ }
482
504
#define CallWqA16kernel_dtype (out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits ) \
483
505
if (input.dtype () == at::ScalarType::Half) { \
484
506
using scalar_t = c10::Half; \
485
507
CallWqA16kernel (scalar_t , out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \
486
508
} else { \
487
- using scalar_t = c10::Half; \
509
+ using scalar_t = c10::BFloat16; \
488
510
CallWqA16kernel (scalar_t , out_buf, IDXBITS, BASEGROUP, Do_Reduce, ResidualBits); \
489
511
}
490
512
#define CallWqA16kernel_bits (out_buf, BASEGROUP, Do_Reduce, ResidualBits ) \
0 commit comments