1010#include " quant/qdq_6.cuh"
1111#include " quant/qdq_8.cuh"
1212
13- #define BLOCK_KN_SIZE 128
14- #define BLOCK_M_SIZE_MAX 8
15- #define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32 )
13+ #define GPTQ_BLOCK_KN_SIZE 128
14+ #define GPTQ_BLOCK_M_SIZE_MAX 8
15+ #define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32 )
16+
17+ #define EXL2_BLOCK_KN_SIZE 64
18+ #define EXL2_BLOCK_M_SIZE_MAX 8
19+ #define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32 )
20+
1621#define CLEAR_N_SIZE 256
1722
1823#include " q_gemm_kernel.cuh"
1924#include " q_gemm_kernel_gptq.cuh"
2025
21- #include " compat_gemm.cuh"
22-
2326void gemm_half_q_half_cuda_part
2427(
2528 const half* a,
@@ -29,20 +32,23 @@ void gemm_half_q_half_cuda_part
2932 int size_n,
3033 int size_k,
3134 int m_count,
32- bool clear
35+ bool clear,
36+ const half* r_weights,
37+ int r_weights_stride,
38+ bool mul_r_weights
3339)
3440{
3541 if (!b->is_gptq )
3642 {
3743 dim3 blockDim , gridDim ;
38- blockDim .x = BLOCK_KN_SIZE ;
44+ blockDim .x = EXL2_BLOCK_KN_SIZE ;
3945 blockDim .y = 1 ;
4046 blockDim .z = 1 ;
41- gridDim .x = DIVIDE (size_n, BLOCK_KN_SIZE * 4 );
47+ gridDim .x = DIVIDE (size_n, EXL2_BLOCK_KN_SIZE * 4 );
4248 gridDim .y = DIVIDE (size_m, m_count);
43- gridDim .z = DIVIDE (size_k, BLOCK_KN_SIZE );
49+ gridDim .z = DIVIDE (size_k, EXL2_BLOCK_KN_SIZE );
4450
45- fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel (true , m_count );
51+ fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel (m_count, r_weights != NULL , mul_r_weights );
4652
4753 kernel<<<gridDim , blockDim >>>
4854 (
@@ -55,32 +61,35 @@ void gemm_half_q_half_cuda_part
5561 size_n,
5662 size_k,
5763 b->groups ,
58- b->groupsize ,
64+ b->cuda_q_group_map ,
5965 b->cuda_q_perm ,
6066 b->rows_8 ,
6167 b->rows_6 ,
6268 b->rows_5 ,
6369 b->rows_4 ,
6470 b->rows_3 ,
6571 b->rows_2 ,
66- clear
72+ clear,
73+ r_weights,
74+ r_weights_stride
6775 );
6876 }
6977 else
7078 {
7179 dim3 blockDim , gridDim ;
72- blockDim .x = BLOCK_KN_SIZE ;
80+ blockDim .x = GPTQ_BLOCK_KN_SIZE ;
7381 blockDim .y = 1 ;
7482 blockDim .z = 1 ;
75- gridDim .x = DIVIDE (size_n, BLOCK_KN_SIZE * 4 );
83+ gridDim .x = DIVIDE (size_n, GPTQ_BLOCK_KN_SIZE * 4 );
7684 gridDim .y = DIVIDE (size_m, m_count);
77- gridDim .z = DIVIDE (size_k, BLOCK_KN_SIZE );
85+ gridDim .z = DIVIDE (size_k, GPTQ_BLOCK_KN_SIZE );
7886
79- fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel (true , m_count );
87+ fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel (m_count, r_weights != NULL , mul_r_weights );
8088
81- // DBGX((uint64_t) b->cuda_q_perm);
82- // DBGI(b->rows_4);
83- // DBGI(b->height);
89+ // DBGX((uint64_t) r_weights);
90+ // if (r_weights)
91+ // print_global_mem(r_weights, 1, 1, 1);
92+ // DBGI(r_weights_stride);
8493
8594 kernel<<<gridDim , blockDim >>>
8695 (
@@ -93,10 +102,12 @@ void gemm_half_q_half_cuda_part
93102 size_n,
94103 size_k,
95104 b->groups ,
96- b->groupsize ,
105+ b->gptq_groupsize ,
97106 b->cuda_q_perm ,
98107 b->rows_4 ,
99- clear
108+ clear,
109+ r_weights,
110+ r_weights_stride
100111 );
101112 }
102113}
@@ -112,13 +123,14 @@ void gemm_half_q_half_cuda
112123 int size_k,
113124 bool clear,
114125 half* temp_dq,
115- bool force_cuda
126+ bool force_cuda,
127+ const half* r_weights,
128+ const int r_weights_stride,
129+ bool mul_r_weights
116130)
117131{
118132 if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
119133 {
120- // printf("cublas\n");
121-
122134 // Reconstruct FP16 matrix, then cuBLAS
123135
124136 if (!temp_dq) temp_dq = b->temp_dq ;
@@ -139,12 +151,12 @@ void gemm_half_q_half_cuda
139151 // const float alpha = 1.0f;
140152 // const float beta = clear ? 0.0f : 1.0f;
141153 // cublasSgemmEx(cublas_handle,
142- // CUBLAS_OP_N,
143- // CUBLAS_OP_N,
144- // size_n, size_m, size_k,
145- // &alpha, temp_dq, CUDA_R_16F, size_n,
146- // a, CUDA_R_16F, size_k,
147- // &beta, c, CUDA_R_16F, size_n);
154+ // CUBLAS_OP_N,
155+ // CUBLAS_OP_N,
156+ // size_n, size_m, size_k,
157+ // &alpha, temp_dq, CUDA_R_16F, size_n,
158+ // a, CUDA_R_16F, size_k,
159+ // &beta, c, CUDA_R_16F, size_n);
148160
149161 // const float alpha = 1.0f;
150162 // const float beta = clear ? 0.0f : 1.0f;
@@ -158,24 +170,21 @@ void gemm_half_q_half_cuda
158170 }
159171 else
160172 {
161- // printf("cuda\n");
162-
163173 // Quantized matmul
164174
165- // if (clear) clear_tensor_cuda(c, size_m, size_n);
166-
167- int max_chunks = size_m / BLOCK_M_SIZE_MAX;
168- int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
175+ int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX;
176+ int max_chunks = size_m / block_m_size_max;
177+ int last_chunk = max_chunks * block_m_size_max;
169178 int last_chunk_size = size_m - last_chunk;
170179
171180 if (max_chunks)
172181 {
173- gemm_half_q_half_cuda_part (a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX , clear);
182+ gemm_half_q_half_cuda_part (a, b, c, last_chunk, size_n, size_k, block_m_size_max , clear, r_weights, r_weights_stride, mul_r_weights );
174183 }
175184
176185 if (last_chunk_size)
177186 {
178- gemm_half_q_half_cuda_part (a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
187+ gemm_half_q_half_cuda_part (a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights );
179188 }
180189 }
181190}
@@ -201,11 +210,10 @@ void clear_tensor_cuda
201210 int size_n
202211)
203212{
204- return ;
205- dim3 blockDim , gridDim ;
206- blockDim .x = CLEAR_N_SIZE;
207- blockDim .y = 1 ;
208- gridDim .x = DIVIDE (size_n / 8 , CLEAR_N_SIZE);
209- gridDim .y = size_m;
210- clear_kernel<<<gridDim , blockDim >>> (c, size_m, size_n);
213+ // dim3 blockDim, gridDim;
214+ // blockDim.x = CLEAR_N_SIZE;
215+ // blockDim.y = 1;
216+ // gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
217+ // gridDim.y = size_m;
218+ // clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
211219}
0 commit comments