add TurboQuant KV cache compression#232
Conversation
| ], | ||
| dependencies: [ | ||
| .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.3")), | ||
| .package(url: "https://github.com/ekryski/mlx-swift", branch: "alpha"), |
There was a problem hiding this comment.
You cannot change this library, I mean this needs to reference the official ml-explore mlx-swift, if you need to add there, you need to first add a pull request there and get it merged, this is no sense.
There was a problem hiding this comment.
This PR is still in draft, pointing to the wrong library. I would hold off on review. Defiant pointing at the wrong library
There was a problem hiding this comment.
This will be fixed in the next push. No intention of switching the dependencies.
There was a problem hiding this comment.
Out-of-bounds threadgroup memory in fused encode kernels (can corrupt results/crash for head_dim > 128).
In both dense and WHT encode kernels, shared_norm is hardcoded to 4 entries, but indexed by sg_id = d/32 and read up to num_groups = (Dim+31)/32. For Dim=256, this accesses indices 0...7.
There was a problem hiding this comment.
Good catch, fixed in next push. shared_norm is now sized [(Dim + 31) / 32] instead of hardcoded 4. Current models all use head_dim=128 so this wasn't hit in practice but would break on 256.
| } | ||
|
|
||
| // KV head mapping (use first query's head — same assumption as non-causal NR0) | ||
| uint q_head_idx_0 = (query_group * NR0) / L; |
There was a problem hiding this comment.
NR0 path derives kv_idx from only the first query row in the group. If a group spans multiple query heads (possible when grouping is not aligned to L), some rows will read/write against the wrong KV head. Guarding only on totalQ % nr0 == 0 is insufficient for correctness.
There was a problem hiding this comment.
Added a queryChunkLength % nr0 == 0 in, next push, guard so the NR0 path only activates when L aligns with the group size. Falls back to per-row dispatch otherwise.
There was a problem hiding this comment.
TurboQuant cache cannot be restored correctly from prompt cache. TurboQuantKVCache is not represented in cacheClassName, so it is serialized as KVCache. On load it is restored as KVCacheSimple, but TurboQuant compressed state carries 3/4 arrays rather than KVCacheSimple’s required 2 arrays, leading to invalid restoration behavior. Add explicit TurboQuant class name mapping and restore path.
There was a problem hiding this comment.
Fixed in nex push. Added TurboQuantKVCache to cacheClassName and restoreCacheFromMetaState. metaState now carries bits, keyBits, valueBits, and seed so the cache can be reconstructed correctly on load.
|
@aleroot I appreciate the early reviews but this is still in draft. |
| self.maxTokens = parameters.maxTokens | ||
| self.numDraftTokens = numDraftTokens | ||
|
|
||
| self.quantizeKVCache = { cache in |
There was a problem hiding this comment.
kvScheme dropped in speculative decoding quantization.TurboQuant scheme selection is ignored for speculative generation despite being part of GenerateParameters...
There was a problem hiding this comment.
Yep, missed that. Threaded kvScheme through the speculative quantization closure now.
aleroot
left a comment
There was a problem hiding this comment.
The implementation is ambitious and has strong ideas, and honestly is something I wanted to work on as well... At the moment this pull request is just a prototype as it does not look like a full inner-product TurboQuant formulation ...
There are no tests as well that are really proving is working.
Appreciate your early feedback. Still a work in progress and know that this is an ambitious PR as is. Thank you for the draft comments! |
Sorry, I did not notice it was a draft . |
All good let's work together on this. It's a big lift and it's hard to separate into individual PRs. Lots of work ahead. |
Follow up as i was on the road for my last comment: Still polishing this, the PPL data in the description shows it working across Llama, Qwen, Mistral, and Qwen3.5 MoE. Happy to collaborate on getting this to a state that works for the project. If you have ideas on the formulation or testing approach I'm all ears. (we can also do email or discord or whatever works) |
4921784 to
2864b83
Compare
Adds kvScheme to GenerateParameters alongside existing kvBits. Provides a string-based scheme selector for KV cache compression strategies. Built-in: "affine4", "affine8" (equivalent to kvBits 4/8). kvScheme overrides kvBits when set. Unrecognized schemes pass through for custom KVCache implementations. Plumbing only, no new compression algorithms. Prepares the API for WHT-based and other non-affine KV compression schemes.
2864b83 to
d4ef67a
Compare
WHT rotation + group-scaled Lloyd-Max quantization for the KV cache, routed through the kvScheme parameter. Two-phase design: raw FP16 during prefill (zero overhead), compress on the first decode step, incremental encode afterwards. Codec: each group of 16 post-rotation elements is amax-normalized, quantized against an N(0,1)-optimal codebook, and stored with a matched-norm scale (||group|| / ||centroids||) so dequantization preserves per-group L2 norms — per-group scales are what survive the channel outliers of real K/V distributions. Validated schemes (teacher-forced PPL/KLD vs the FP16 cache, 214 forced steps on real text): Qwen2.5-1.5B-4bit turbo0v4 PPL +0.001 KLD 0.006 top-1 211/214 Qwen2.5-1.5B-4bit turbo0v2 PPL +0.45 KLD 0.29 top-1 186/214 Llama-3.2-3B-4bit turbo0v4 PPL +0.004 KLD 0.003 top-1 215/215 Key-quantizing schemes (turbo4 etc.) are withheld: 4-bit grouped key quantization collapses real-text quality (PPL 89 vs 1.18 baseline) even though every component passes unit parity — keys need finer treatment before those schemes are exposed. A raw-buffer bypass run through the identical bookkeeping reproduces FP16 exactly (KLD 0.0), isolating the gap to key quantization alone. The decode path is pure MLX ops (dequantize rotated keys/values, score against rotated queries — the rotation cancels in the dot product). The fused Metal kernels from the earlier draft predate the group-scale format and are removed; they can return re-templated as a perf follow-up. Tests: rotation orthogonality (WHT + dense), codec round trip, cache offsets/trim, compressed-vs-FP16 attention parity (incl. GQA head mapping with head-distinguishable values and the kvScheme conversion route), and an end-to-end generation test on a tiny random-weight Llama. No checkpoint needed; CI-runnable.
d4ef67a to
61ca2a1
Compare
Summary
Implements TurboQuant-style KV cache compression — WHT rotation + group-scaled Lloyd-Max quantization — routed through the
kvSchemeparameter.TurboQuantKVCache— two-phase cache: raw FP16 during prefill (zero overhead), compress on the first decode step, incremental encode afterwardsKVCache.swift,AttentionUtils.swift,Evaluate.swift— scheme routing throughkvSchemeStacked on #230 (the
kvSchemeplumbing) — review this PR for the second commit only until #230 lands.Codec
Each group of 16 post-rotation elements is amax-normalized, quantized against an N(0,1)-optimal codebook, and stored with a matched-norm scale (
‖group‖ / ‖centroids‖) so dequantization preserves per-group L2 norms. Per-group scales are what survive the channel outliers of real K/V distributions — a single per-vector norm lets one hot channel destroy the rest of the vector. The decode path is pure MLX ops: keys/values dequantize in rotated space and score against rotated queries (the rotation cancels in the dot product).Schemes
turbo0v4turbo0v2Validation (M5 Max, teacher-forced PPL/KLD vs the FP16 cache, 214 forced steps on real text)
turbo0v4turbo0v2turbo0v4A raw-buffer bypass through the identical cache bookkeeping reproduces FP16 exactly (KLD 0.0), so the remaining deltas above are purely quantization.
Why no key quantization (yet)
4-bit grouped key quantization collapses real-text quality (PPL 89 vs 1.18 baseline on Qwen2.5-1.5B) even though every component passes unit parity — small key-direction errors reshuffle peaked attention and compound across layers. The
turbo4/turbo4v2-style schemes are withheld fromkvSchemeuntil keys get finer treatment; that work also brings back fused Metal kernels (the earlier draft's kernels predate the group-scale format and were removed — encode/decode are pure MLX ops in this PR).Tests
CI-runnable, no checkpoint: rotation orthogonality (WHT + dense), codec round trip, cache offset/trim behavior, compressed-vs-FP16 attention parity (including GQA head-mapping with head-distinguishable values and the
kvSchemeconversion route), and an end-to-end generation test on a tiny random-weight Llama.