1
- #include " common.h"
2
- #include " dispatch_macros.h"
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
3
4
#include " quant_gemv_v2.cuh"
5
+ #include " util/common.h"
6
+ #include " util/cuda_utils.cuh"
7
+ #include " util/dispatch_macros.h"
4
8
5
9
namespace vptq {
6
10
11
+ /* *
12
+ * @brief Quantized GEMV kernel.
13
+ * @param act The input activations.
14
+ * @param bias The bias.
15
+ * @param indices The indices.
16
+ * @param centroids The codebook for the main vector quantized weights.
17
+ * Stored in row-major order. Element type: fp16, bf16.
18
+ * Shape: (num_codebooks, num_centroids, vec_len).
19
+ * @param residual_centroids The residual centroids.
20
+ * @param scale_weights The scale weights.
21
+ * @param scale_bias The scale bias.
22
+ * @param in_features The number of input features.
23
+ * @param out_features The number of output features.
24
+ */
7
25
torch::Tensor quant_gemv_v2 (
8
26
const torch::Tensor& act, const c10::optional<torch::Tensor>& bias,
9
27
const torch::Tensor& indices, const torch::Tensor& centroids,
@@ -44,6 +62,9 @@ torch::Tensor quant_gemv_v2(
44
62
const int64_t num_centroids = centroids.size (1 );
45
63
const int64_t vec_len = centroids.size (2 );
46
64
65
+ TORCH_CHECK_LT (batch, 16 )
66
+ << " In GEMV, the batch size is suggested to be less than 16." ;
67
+
47
68
TORCH_CHECK_EQ (num_codebooks, 1 ) << " Only support one codebook." ;
48
69
49
70
TORCH_CHECK (
@@ -60,7 +81,7 @@ torch::Tensor quant_gemv_v2(
60
81
dim3 blocks (batch, num_codebooks, block_z);
61
82
// FIXME(ying): refine the choice of threads in a thread block.
62
83
// For test at the moment.
63
- dim3 threads (256 , 1 , 1 );
84
+ dim3 threads (256 , 1 , 1 ); // four warps in a thread block.
64
85
65
86
std::cout << " num_codebooks: " << num_codebooks << std::endl
66
87
<< " num_centroids: " << num_centroids << std::endl
0 commit comments