Skip to content

Commit e141627

Browse files
committed
feat(turbo-kv): add turbo_decode_k/v — batch dequantize for SDPA attention
Implements decode path: packed uint8 compressed KV history → float32 for concatenation with hot window before passing to standard SDPA. Supports D=128 (68B/50B records) and D=256 (136B/100B records).
1 parent 957f763 commit e141627

6 files changed

Lines changed: 184 additions & 0 deletions

File tree

LocalPackages/mlx-swift/Source/Cmlx/include/mlx/c/fast.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,16 @@ int mlx_fast_turbo_encode(
216216
int k_bits,
217217
const mlx_stream s);
218218

219+
int mlx_fast_turbo_decode_k(
220+
mlx_array* res,
221+
const mlx_array packed,
222+
const mlx_stream s);
223+
224+
int mlx_fast_turbo_decode_v(
225+
mlx_array* res,
226+
const mlx_array packed,
227+
const mlx_stream s);
228+
219229
int mlx_fast_prefault(mlx_array x);
220230

221231
int mlx_fast_pread_into(

LocalPackages/mlx-swift/Source/Cmlx/mlx-c/mlx/c/fast.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,8 +814,44 @@ extern "C" int mlx_fast_turbo_encode(
814814
return 0;
815815
}
816816

817+
extern "C" int mlx_fast_turbo_decode_k(
818+
mlx_array* res,
819+
const mlx_array packed,
820+
const mlx_stream s) {
821+
try {
822+
mlx_array_set_(
823+
*res,
824+
mlx::core::fast::turbo_decode_k(
825+
mlx_array_get_(packed),
826+
mlx_stream_get_(s)));
827+
} catch (std::exception& e) {
828+
mlx_error(e.what());
829+
return 1;
830+
}
831+
return 0;
832+
}
833+
834+
extern "C" int mlx_fast_turbo_decode_v(
835+
mlx_array* res,
836+
const mlx_array packed,
837+
const mlx_stream s) {
838+
try {
839+
mlx_array_set_(
840+
*res,
841+
mlx::core::fast::turbo_decode_v(
842+
mlx_array_get_(packed),
843+
mlx_stream_get_(s)));
844+
} catch (std::exception& e) {
845+
mlx_error(e.what());
846+
return 1;
847+
}
848+
return 0;
849+
}
850+
851+
817852
extern "C" int mlx_fast_prefault(
818853
mlx_array x) {
854+
819855
try {
820856
mlx::core::prefault(mlx_array_get_(x));
821857
} catch (std::exception& e) {

LocalPackages/mlx-swift/Source/Cmlx/mlx-c/mlx/c/fast.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,19 @@ int mlx_fast_turbo_encode(
215215
int k_bits,
216216
const mlx_stream s);
217217

218+
int mlx_fast_turbo_decode_k(
219+
mlx_array* res,
220+
const mlx_array packed,
221+
const mlx_stream s);
222+
223+
int mlx_fast_turbo_decode_v(
224+
mlx_array* res,
225+
const mlx_array packed,
226+
const mlx_stream s);
227+
218228
int mlx_fast_prefault(mlx_array x);
219229

230+
220231
// pread() directly into the already-evaluated MLX array's unified memory buffer.
221232
// This gives full NVMe sequential throughput without OS page-fault overhead.
222233
// The array MUST already be evaluated (concrete pointer exists).

LocalPackages/mlx-swift/Source/Cmlx/mlx/mlx/fast.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,4 +1054,91 @@ array turbo_encode_v(const array& values, StreamOrDevice s_) {
10541054
return array(buf.data(), out_shape, uint8);
10551055
}
10561056

1057+
1058+
// ── TurboQuant Decode ─────────────────────────────────────────────────────────
1059+
// Batch-decode packed uint8 compressed history back to float32 tensors.
1060+
// Used by KVCacheSimple when routing compressed history through standard SDPA.
1061+
// Supports head_dim=128 (record=68B K / 50B V) and head_dim=256 (2 sub-groups).
1062+
1063+
array turbo_decode_k(const array& packed, StreamOrDevice s_) {
1064+
auto s = to_stream(s_);
1065+
1066+
const int record_bytes = static_cast<int>(packed.shape(-1));
1067+
if (record_bytes != TURBO_K_RECORD && record_bytes != TURBO_K_RECORD * 2) {
1068+
throw std::invalid_argument(
1069+
"[turbo_decode_k] last dim must be 68 (D=128) or 136 (D=256), got " +
1070+
std::to_string(record_bytes));
1071+
}
1072+
const int n_subgroups = record_bytes / TURBO_K_RECORD;
1073+
const int head_dim = n_subgroups * ::mlx::core::fast::TURBO_D;
1074+
1075+
// Materialise packed buffer on CPU
1076+
auto packed_u8 = astype(packed, uint8, s);
1077+
eval(packed_u8);
1078+
const uint8_t* src = packed_u8.data<uint8_t>();
1079+
1080+
const int N = static_cast<int>(packed_u8.size() / record_bytes);
1081+
std::vector<float> buf(static_cast<size_t>(N) * head_dim);
1082+
1083+
for (int i = 0; i < N; ++i) {
1084+
for (int g = 0; g < n_subgroups; ++g) {
1085+
const uint8_t* sub_src = src + i * record_bytes + g * TURBO_K_RECORD;
1086+
::mlx::core::fast::TurboQuantK rec;
1087+
std::memset(&rec, 0, sizeof(rec));
1088+
std::memcpy(rec.indices, sub_src, 48);
1089+
std::memcpy(rec.qjl_signs, sub_src + 48, 16);
1090+
std::memcpy(&rec.norm_fp16, sub_src + 64, 2);
1091+
std::memcpy(&rec.rnorm_fp16, sub_src + 66, 2);
1092+
::mlx::core::fast::turbo_dequantize_k(
1093+
rec,
1094+
buf.data() + i * head_dim + g * ::mlx::core::fast::TURBO_D,
1095+
::mlx::core::fast::TURBO_D);
1096+
}
1097+
}
1098+
1099+
Shape out_shape = packed.shape();
1100+
out_shape.back() = head_dim;
1101+
// Return float32; Swift caller casts to model dtype (fp16/bf16) as needed
1102+
return array(buf.data(), out_shape, float32);
1103+
}
1104+
1105+
array turbo_decode_v(const array& packed, StreamOrDevice s_) {
1106+
auto s = to_stream(s_);
1107+
1108+
const int record_bytes = static_cast<int>(packed.shape(-1));
1109+
if (record_bytes != TURBO_V_RECORD && record_bytes != TURBO_V_RECORD * 2) {
1110+
throw std::invalid_argument(
1111+
"[turbo_decode_v] last dim must be 50 (D=128) or 100 (D=256), got " +
1112+
std::to_string(record_bytes));
1113+
}
1114+
const int n_subgroups = record_bytes / TURBO_V_RECORD;
1115+
const int head_dim = n_subgroups * ::mlx::core::fast::TURBO_D;
1116+
1117+
auto packed_u8 = astype(packed, uint8, s);
1118+
eval(packed_u8);
1119+
const uint8_t* src = packed_u8.data<uint8_t>();
1120+
1121+
const int N = static_cast<int>(packed_u8.size() / record_bytes);
1122+
std::vector<float> buf(static_cast<size_t>(N) * head_dim);
1123+
1124+
for (int i = 0; i < N; ++i) {
1125+
for (int g = 0; g < n_subgroups; ++g) {
1126+
const uint8_t* sub_src = src + i * record_bytes + g * TURBO_V_RECORD;
1127+
::mlx::core::fast::TurboQuantV rec;
1128+
std::memset(&rec, 0, sizeof(rec));
1129+
std::memcpy(rec.indices, sub_src, 48);
1130+
std::memcpy(&rec.norm_fp16, sub_src + 48, 2);
1131+
::mlx::core::fast::turbo_dequantize_v(
1132+
rec,
1133+
buf.data() + i * head_dim + g * ::mlx::core::fast::TURBO_D,
1134+
::mlx::core::fast::TURBO_D);
1135+
}
1136+
}
1137+
1138+
Shape out_shape = packed.shape();
1139+
out_shape.back() = head_dim;
1140+
return array(buf.data(), out_shape, float32);
1141+
}
1142+
10571143
} // namespace mlx::core::fast
1144+

LocalPackages/mlx-swift/Source/Cmlx/mlx/mlx/fast.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,20 @@ MLX_API array turbo_encode_k(const array& keys, StreamOrDevice s = {});
118118
*/
119119
MLX_API array turbo_encode_v(const array& values, StreamOrDevice s = {});
120120

121+
/**
122+
* Decode TurboKV compressed K-cache back to float32.
123+
*
124+
* packed: uint8 with last dim 68 (D=128) or 136 (D=256)
125+
* returns: float32 array with last dim = head_dim (128 or 256)
126+
*/
127+
MLX_API array turbo_decode_k(const array& packed, StreamOrDevice s = {});
128+
129+
/**
130+
* Decode TurboKV compressed V-cache back to float32.
131+
*
132+
* packed: uint8 with last dim 50 (D=128) or 100 (D=256)
133+
* returns: float32 array with last dim = head_dim (128 or 256)
134+
*/
135+
MLX_API array turbo_decode_v(const array& packed, StreamOrDevice s = {});
136+
121137
} // namespace mlx::core::fast

LocalPackages/mlx-swift/Source/MLX/MLXFast.swift

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,30 @@ public enum MLXFast {
278278
return (kTuple, vTuple)
279279
}
280280

281+
/// Batch-decode TurboKV compressed key history (packed uint8) back to float32.
282+
///
283+
/// - Parameter packed: `[..., 68]` uint8 for D=128, or `[..., 136]` for D=256
284+
/// - Returns: `[..., headDim]` float32 — caller casts to model dtype as needed
285+
public static func turboDecodeK(
286+
packed: MLXArray, stream: StreamOrDevice = .default
287+
) -> MLXArray {
288+
var result = mlx_array_new()
289+
mlx_fast_turbo_decode_k(&result, packed.ctx, stream.ctx)
290+
return MLXArray(result)
291+
}
292+
293+
/// Batch-decode TurboKV compressed value history (packed uint8) back to float32.
294+
///
295+
/// - Parameter packed: `[..., 50]` uint8 for D=128, or `[..., 100]` for D=256
296+
/// - Returns: `[..., headDim]` float32 — caller casts to model dtype as needed
297+
public static func turboDecodeV(
298+
packed: MLXArray, stream: StreamOrDevice = .default
299+
) -> MLXArray {
300+
var result = mlx_array_new()
301+
mlx_fast_turbo_decode_v(&result, packed.ctx, stream.ctx)
302+
return MLXArray(result)
303+
}
304+
281305
// ── SSD Flash-Stream Metrics ──────────────────────────────────────────────
282306

283307
/// Snapshot of cumulative SSD streaming throughput stats.

0 commit comments

Comments
 (0)