Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion Libraries/MLXLMCommon/AttentionUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,23 @@ public func attentionWithCacheUpdate(
mask: mask
)
}
if let quantizedKVCache = cache as? QuantizedKVCacheProtocol {
if let turboCache = cache as? TurboQuantKVCache {
let L = queries.dim(2)
if L > 1 {
// Prefill (L>1): raw update + standard SDPA. Zero overhead.
let (cachedKeys, cachedValues) = turboCache.update(keys: keys, values: values)
return MLXFast.scaledDotProductAttention(
queries: queries, keys: cachedKeys, values: cachedValues,
scale: scale, mask: mask
)
}
// Decode (L=1): compressed cache path. First call triggers
// compressRawCache() inside compressedAttention.
return turboCache.compressedAttention(
queries: queries, keys: keys, values: values,
scale: scale, mask: mask
)
} else if let quantizedKVCache = cache as? QuantizedKVCacheProtocol {
let (quantizedKeys, quantizedValues) = quantizedKVCache.updateQuantized(
keys: keys, values: values)
return quantizedScaledDotProductAttention(
Expand Down
17 changes: 15 additions & 2 deletions Libraries/MLXLMCommon/Evaluate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ public struct GenerateParameters: Sendable {
/// Step to begin using a quantized KV cache when kvBits is non-nil (default: 0)
public var quantizedKVStart: Int

/// KV cache compression scheme. Overrides kvBits when set.
/// Built-in: "affine4", "affine8" (equivalent to kvBits 4/8), plus the
/// TurboQuant family ("turbo4", "turbo3", "turbo2", "turbo4v2", "turbo4v3").
public var kvScheme: String?

/// Sampling temperature
public var temperature: Float

Expand Down Expand Up @@ -108,6 +113,7 @@ public struct GenerateParameters: Sendable {
kvBits: Int? = nil,
kvGroupSize: Int = 64,
quantizedKVStart: Int = 0,
kvScheme: String? = nil,
temperature: Float = 0.6,
topP: Float = 1.0,
topK: Int = 0,
Expand All @@ -125,6 +131,7 @@ public struct GenerateParameters: Sendable {
self.kvBits = kvBits
self.kvGroupSize = kvGroupSize
self.quantizedKVStart = quantizedKVStart
self.kvScheme = kvScheme
self.temperature = temperature
self.topP = topP
self.topK = topK
Expand Down Expand Up @@ -536,6 +543,7 @@ public struct TokenIterator: TokenIteratorProtocol {
let kvBits: Int?
let kvGroupSize: Int
let quantizedKVStart: Int
let kvScheme: String?

// Internal metrics
public var promptPrefillTime: TimeInterval = 0.0
Expand Down Expand Up @@ -564,6 +572,7 @@ public struct TokenIterator: TokenIteratorProtocol {
self.kvBits = parameters.kvBits
self.kvGroupSize = parameters.kvGroupSize
self.quantizedKVStart = parameters.quantizedKVStart
self.kvScheme = parameters.kvScheme

self.promptPrefillTime = try measure {
try prepare(input: .init(text: y), windowSize: parameters.prefillStepSize)
Expand Down Expand Up @@ -597,6 +606,7 @@ public struct TokenIterator: TokenIteratorProtocol {
self.kvBits = parameters.kvBits
self.kvGroupSize = parameters.kvGroupSize
self.quantizedKVStart = parameters.quantizedKVStart
self.kvScheme = parameters.kvScheme

self.promptPrefillTime = try measure {
try prepare(input: input, windowSize: parameters.prefillStepSize)
Expand Down Expand Up @@ -630,6 +640,7 @@ public struct TokenIterator: TokenIteratorProtocol {
self.kvBits = nil
self.kvGroupSize = 64
self.quantizedKVStart = 0
self.kvScheme = nil

self.promptPrefillTime = try measure {
try prepare(input: input, windowSize: prefillStepSize)
Expand Down Expand Up @@ -680,7 +691,8 @@ public struct TokenIterator: TokenIteratorProtocol {
cache: &cache,
kvBits: kvBits,
kvGroupSize: kvGroupSize,
quantizedKVStart: quantizedKVStart
quantizedKVStart: quantizedKVStart,
kvScheme: kvScheme
)

return convertToToken(logits: result.logits)
Expand Down Expand Up @@ -798,7 +810,8 @@ public struct SpeculativeTokenIterator: TokenIteratorProtocol {
cache: &cache,
kvBits: parameters.kvBits,
kvGroupSize: parameters.kvGroupSize,
quantizedKVStart: parameters.quantizedKVStart
quantizedKVStart: parameters.quantizedKVStart,
kvScheme: parameters.kvScheme
)
}

Expand Down
68 changes: 60 additions & 8 deletions Libraries/MLXLMCommon/KVCache.swift

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TurboQuant cache cannot be restored correctly from prompt cache. TurboQuantKVCache is not represented in cacheClassName, so it is serialized as KVCache. On load it is restored as KVCacheSimple, but TurboQuant compressed state carries 3/4 arrays rather than KVCacheSimple’s required 2 arrays, leading to invalid restoration behavior. Add explicit TurboQuant class name mapping and restore path.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in nex push. Added TurboQuantKVCache to cacheClassName and restoreCacheFromMetaState. metaState now carries bits, keyBits, valueBits, and seed so the cache can be reconstructed correctly on load.

Original file line number Diff line number Diff line change
Expand Up @@ -1529,6 +1529,7 @@ private func cacheClassName(_ cache: KVCache) -> String {
case is ArraysCache: return "ArraysCache"
case is RotatingKVCache: return "RotatingKVCache"
case is QuantizedKVCache: return "QuantizedKVCache"
case is TurboQuantKVCache: return "TurboQuantKVCache"
case is KVCacheSimple: return "KVCache"
case is CacheList: return "CacheList"
default: return "KVCache"
Expand Down Expand Up @@ -1682,6 +1683,22 @@ private func restoreCacheFromMetaState(
cache.restoreFromMetaState(state: state, savedMetaState: metaState)
return cache

case "TurboQuantKVCache":
guard metaState.count >= 5,
let bits = Int(metaState[1]),
let keyBits = Int(metaState[2]),
let valueBits = Int(metaState[3]),
let seed = UInt64(metaState[4])
else {
throw KVCacheError(
message: "Invalid TurboQuantKVCache metaState")
}
let cache = TurboQuantKVCache(
bits: bits, keyBits: keyBits, valueBits: valueBits, seed: seed)
cache.state = state
cache.metaState = metaState
return cache

case "CacheList":
return try CacheList.fromState(state: state, metaState: metaState)

Expand Down Expand Up @@ -1927,8 +1944,18 @@ public func quantizedScaledDotProductAttention(

/// Dynamically quantize KV caches during generation if conditions are met
///
/// Resolve a kvScheme string to (bits, groupSize) for affine quantization.
/// Returns nil for unrecognized schemes (custom schemes handle their own caches).
public func resolveAffineScheme(_ scheme: String?) -> (bits: Int, groupSize: Int)? {
switch scheme {
case "affine4": return (4, 64)
case "affine8": return (8, 64)
default: return nil
}
}

/// Converts regular caches to quantized caches when:
/// - kvBits is specified
/// - kvBits is specified (or kvScheme resolves to a built-in affine scheme)
/// - The cache is not already quantized
/// - The cache offset is greater than quantizedKVStart
///
Expand All @@ -1937,13 +1964,41 @@ public func quantizedScaledDotProductAttention(
/// - kvBits: Number of bits for quantization (nil = no quantization)
/// - kvGroupSize: Group size for quantization
/// - quantizedKVStart: Token count threshold to begin quantizing
/// - kvScheme: Scheme selector; overrides kvBits when it names a built-in
/// affine scheme ("affine4", "affine8") or a TurboQuant scheme
/// ("turbo4", "turbo4v2", ...). Unrecognized schemes are left to custom
/// cache implementations and do not quantize here.
public func maybeQuantizeKVCache(
cache: inout [KVCache],
kvBits: Int?,
kvGroupSize: Int = 64,
quantizedKVStart: Int = 0
quantizedKVStart: Int = 0,
kvScheme: String? = nil
) {
guard let kvBits = kvBits, !cache.isEmpty else { return }
// TurboQuant schemes convert eligible layers to TurboQuantKVCache
// (handled in TurboQuantKVCache.swift to keep this file scheme-agnostic).
if let scheme = kvScheme, let turbo = resolveTurboScheme(scheme) {
maybeTurboQuantizeKVCache(
cache: &cache,
keyBits: turbo.keyBits,
valueBits: turbo.valueBits,
quantizedKVStart: quantizedKVStart
)
return
}

// Resolve effective bits: kvScheme overrides kvBits.
let effectiveBits: Int
let effectiveGroupSize: Int
if let scheme = kvScheme, let resolved = resolveAffineScheme(scheme) {
effectiveBits = resolved.bits
effectiveGroupSize = resolved.groupSize
} else if let kvBits {
effectiveBits = kvBits
effectiveGroupSize = kvGroupSize
} else {
return
}

// Find the first quantizable (non-Mamba, non-already-quantized) cache entry
guard let firstQuantizable = cache.first(where: { $0 is KVCacheSimple }),
Expand All @@ -1954,26 +2009,23 @@ public func maybeQuantizeKVCache(
}

for i in 0 ..< cache.count {
// Handle cache types that support quantization
if let simpleCache = cache[i] as? KVCacheSimple {
let state = simpleCache.state
if state.count == 2 {
let keyHeadDim = state[0].dim(3)
let valueHeadDim = state[1].dim(3)
guard
resolvedKVQuantizationGroupSize(
requested: kvGroupSize,
requested: effectiveGroupSize,
keyHeadDim: keyHeadDim,
valueHeadDim: valueHeadDim
) != nil
else {
continue
}
}
cache[i] = simpleCache.toQuantized(groupSize: kvGroupSize, bits: kvBits)
cache[i] = simpleCache.toQuantized(groupSize: effectiveGroupSize, bits: effectiveBits)
}
// TODO: RotatingKVCache.toQuantized() is not implemented yet, like in Python.
// When implemented, add: else if let rotatingCache = cache[i] as? RotatingKVCache { ... }
// MambaCache and CacheList don't use traditional KV quantization
}
}
Loading