Skip to content

Commit e0bdb2f

Browse files
committed
Merge branch 'upstream' into concedo_experimental
# Conflicts: # README.md # examples/imatrix/README.md # scripts/compare-llama-bench.py
2 parents 5a79dd5 + 6dde178 commit e0bdb2f

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

ggml/src/ggml-metal/ggml-metal.metal

+21-12
Original file line numberDiff line numberDiff line change
@@ -373,24 +373,33 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
373373
template <typename type4x4>
374374
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
375375
const half d_all = xb->d;
376-
device const uint8_t * ql = (device const uint8_t *)xb->ql;
377-
device const uint8_t * qh = (device const uint8_t *)xb->qh;
376+
device const uint16_t * ql = (device const uint16_t *)xb->ql;
377+
device const uint16_t * qh = (device const uint16_t *)xb->qh;
378378
device const int8_t * scales = (device const int8_t *)xb->scales;
379379

380-
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
381-
qh = qh + 32*(il/8) + 16*(il&1);
380+
ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
381+
qh = qh + 16*(il/8) + 8*(il&1);
382382
float sc = scales[(il%2) + 2 * ((il/2))];
383383
il = (il/2) & 3;
384384

385-
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
386-
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
387-
const float coef = il>1 ? 1.f/16.f : 1.f;
385+
const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
386+
const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F;
388387
const float ml = d_all * sc * 32.f;
389-
const float dl = d_all * sc * coef;
390-
for (int i = 0; i < 16; ++i) {
391-
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
392-
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
393-
reg[i/4][i%4] = dl * q - ml;
388+
const float dl0 = d_all * sc;
389+
const float dl1 = dl0 / 256.f;
390+
const float dl2 = dl0 / (256.f * 256.f);
391+
const float dl3 = dl0 / (256.f * 256.f * 256.f);
392+
const uint8_t shr_h = il>2 ? 2 : 0;
393+
const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
394+
const uint8_t shr_l = il>1 ? 4 : 0;
395+
for (int i = 0; i < 4; ++i) {
396+
const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
397+
const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
398+
const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
399+
reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml;
400+
reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml;
401+
reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml;
402+
reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
394403
}
395404
}
396405

0 commit comments

Comments
 (0)