@@ -373,24 +373,33 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
373
373
template <typename type4x4>
374
374
void dequantize_q6_K (device const block_q6_K *xb, short il, thread type4x4 & reg) {
375
375
const half d_all = xb->d ;
376
- device const uint8_t * ql = (device const uint8_t *)xb->ql ;
377
- device const uint8_t * qh = (device const uint8_t *)xb->qh ;
376
+ device const uint16_t * ql = (device const uint16_t *)xb->ql ;
377
+ device const uint16_t * qh = (device const uint16_t *)xb->qh ;
378
378
device const int8_t * scales = (device const int8_t *)xb->scales ;
379
379
380
- ql = ql + 64 *(il/8 ) + 32 *((il/2 )&1 ) + 16 *(il&1 );
381
- qh = qh + 32 *(il/8 ) + 16 *(il&1 );
380
+ ql = ql + 32 *(il/8 ) + 16 *((il/2 )&1 ) + 8 *(il&1 );
381
+ qh = qh + 16 *(il/8 ) + 8 *(il&1 );
382
382
float sc = scales[(il%2 ) + 2 * ((il/2 ))];
383
383
il = (il/2 ) & 3 ;
384
384
385
- const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48 ) : (il>0 ? 12 : 3 );
386
- const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F ;
387
- const float coef = il>1 ? 1 .f /16 .f : 1 .f ;
385
+ const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030 ) : (il>0 ? 0x0C0C0C0C : 0x03030303 );
386
+ const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F ;
388
387
const float ml = d_all * sc * 32 .f ;
389
- const float dl = d_all * sc * coef;
390
- for (int i = 0 ; i < 16 ; ++i) {
391
- const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2 ))
392
- : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4 ));
393
- reg[i/4 ][i%4 ] = dl * q - ml;
388
+ const float dl0 = d_all * sc;
389
+ const float dl1 = dl0 / 256 .f ;
390
+ const float dl2 = dl0 / (256 .f * 256 .f );
391
+ const float dl3 = dl0 / (256 .f * 256 .f * 256 .f );
392
+ const uint8_t shr_h = il>2 ? 2 : 0 ;
393
+ const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4 );
394
+ const uint8_t shr_l = il>1 ? 4 : 0 ;
395
+ for (int i = 0 ; i < 4 ; ++i) {
396
+ const uint32_t low = (ql[2 *i] | (uint32_t )(ql[2 *i+1 ] << 16 )) & kmask2;
397
+ const uint32_t high = (qh[2 *i] | (uint32_t )(qh[2 *i+1 ] << 16 )) & kmask1;
398
+ const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
399
+ reg[i][0 ] = dl0 * ((half)(q & 0xFF )) - ml;
400
+ reg[i][1 ] = dl1 * ((float )(q & 0xFF00 )) - ml;
401
+ reg[i][2 ] = dl2 * ((float )(q & 0xFF0000 )) - ml;
402
+ reg[i][3 ] = dl3 * ((float )(q & 0xFF000000 )) - ml;
394
403
}
395
404
}
396
405
0 commit comments