Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions Libraries/MLXLLM/LLMModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,23 @@ extension LLMModel {
let prefillStepSize = windowSize ?? 512
var y = input.text

// Prepare the prompt in chunks if larger than the prefill size.
// asyncEval lets the CPU build chunk N+1's graph while the GPU evaluates
// chunk N.
var state: LMOutput.State?
while y.tokens.size > prefillStepSize {
let input = y[.newAxis, ..<prefillStepSize]
let output = self(input, cache: cache.isEmpty ? nil : cache, state: state)
state = output.state
asyncEval(cache)
y = y[prefillStepSize...]
withPreparedCache(cache, lengths: y.sequenceLengths) {
// Prepare the prompt in chunks if larger than the prefill size.
// asyncEval lets the CPU build chunk N+1's graph while the GPU evaluates
// chunk N.
var state: LMOutput.State?
while y.tokens.size > prefillStepSize {
let input = y[.newAxis, ..<prefillStepSize]
let output = self(input, cache: cache.isEmpty ? nil : cache, state: state)
state = output.state
asyncEval(cache)
y = y[prefillStepSize...]
}

// Single sync after the loop to flush any remaining async work.
eval(cache)
}

// Single sync after the loop to flush any remaining async work.
eval(cache)

return .tokens(y)
}

Expand Down
4 changes: 2 additions & 2 deletions Libraries/MLXLLM/Models/FalconH1.swift
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,9 @@ class FalconH1Mixer: Module {
let t = paddedInput.dim(1)
let ends = clip(lengths, min: 0, max: t - nKeep)
let positions = (ends[0..., .newAxis] + MLXArray(0 ..< nKeep))[.ellipsis, .newAxis]
cache[0] = MLX.takeAlong(paddedInput, positions, axis: 1)
cache[0] = contiguous(MLX.takeAlong(paddedInput, positions, axis: 1))
} else {
cache[0] = paddedInput[0..., (-nKeep)...]
cache[0] = contiguous(paddedInput[0..., (-nKeep)..., 0...])
}
}

Expand Down
3 changes: 2 additions & 1 deletion Libraries/MLXLLM/Models/GraniteMoeHybrid.swift
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class GraniteMoeHybridMamba2Mixer: Module {
if let cache {
let end = padded.dim(1)
let start = max(0, end - (convKernelSize - 1))
cache[0] = padded[0..., start ..< end, 0...]
cache[0] = contiguous(padded[0..., start ..< end, 0...])
}

let convOutput = conv1d(padded)
Expand Down Expand Up @@ -181,6 +181,7 @@ class GraniteMoeHybridMamba2Mixer: Module {

if let cache {
cache[1] = nextState
cache.advance(hiddenStates.dim(1))
}

let flattenedY = y.flattened(start: 2)
Expand Down
3 changes: 2 additions & 1 deletion Libraries/MLXLLM/Models/Jamba.swift
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,9 @@ class JambaMambaMixer: Module {
x, convState: convState, ssmState: ssmState)

if let cache = cache {
cache[0] = newConvState
cache[0] = contiguous(newConvState)
cache[1] = newSsmState
cache.advance(x.dim(1))
}

return output
Expand Down
3 changes: 2 additions & 1 deletion Libraries/MLXLLM/Models/LFM2.swift
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ class LFM2ShortConv: Module {

Bx = concatenated([state!, Bx], axis: -2)
if let cache {
cache[0] = Bx[0..., (Bx.dim(1) - (lCache - 1))..., 0...]
cache[0] = contiguous(Bx[0..., (Bx.dim(1) - (lCache - 1))..., 0...])
cache.advance(x.dim(1))
}

let convOut = conv(Bx)
Expand Down
3 changes: 2 additions & 1 deletion Libraries/MLXLLM/Models/LFM2MoE.swift
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ class LFM2MoEShortConv: Module {
Bx = concatenated([state!, Bx], axis: -2)
if let cache {
let start = Bx.dim(1) - (lCache - 1)
cache[0] = Bx[0..., start..., 0...]
cache[0] = contiguous(Bx[0..., start..., 0...])
cache.advance(x.dim(1))
}

let convOut = conv(Bx)
Expand Down
3 changes: 2 additions & 1 deletion Libraries/MLXLLM/Models/NemotronH.swift
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ private class NemotronHMamba2Mixer: Module, NemotronHMixer {
if let cache {
let end = padded.dim(1)
let start = max(0, end - (convKernelSize - 1))
cache[0] = padded[0..., start ..< end, 0...]
cache[0] = contiguous(padded[0..., start ..< end, 0...])
}

let convOutput = conv1d(padded)
Expand Down Expand Up @@ -232,6 +232,7 @@ private class NemotronHMamba2Mixer: Module, NemotronHMixer {

if let cache {
cache[1] = nextState
cache.advance(hiddenStates.dim(1))
}

let flattenedY = y.flattened(start: 2)
Expand Down
25 changes: 8 additions & 17 deletions Libraries/MLXLLM/Models/Qwen3.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Qwen3Attention: Module {
@ModuleInfo(key: "q_norm") var qNorm: RMSNorm
@ModuleInfo(key: "k_norm") var kNorm: RMSNorm

let rope: RoPE
let rope: RoPELayer

public init(_ args: Qwen3Configuration) {
self.args = args
Expand All @@ -44,22 +44,13 @@ class Qwen3Attention: Module {
_qNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps)
_kNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps)

let ropeScale: Float
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
let factor = ropeScaling["factor"]
{
if let v = factor.asFloat() {
ropeScale = 1 / v
} else {
fatalError("ropeScaling.factor must be a float")
}
} else {
ropeScale = 1
}

self.rope = RoPE(
dimensions: headDim, traditional: false, base: args.ropeTheta,
scale: ropeScale)
self.rope = initializeRope(
dims: headDim,
base: args.ropeTheta,
traditional: false,
scalingConfig: args.ropeScaling,
maxPositionEmbeddings: args.maxPositionEmbeddings
)
}

public func callAsFunction(
Expand Down
3 changes: 2 additions & 1 deletion Libraries/MLXLLM/Models/Qwen35.swift
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ final class Qwen35GatedDeltaNet: Module {

let convInput = concatenated([convState, qkv], axis: 1)
if let cache {
cache[0] = convInput[0..., (-(convKernelSize - 1))...]
cache[0] = contiguous(convInput[0..., (-(convKernelSize - 1))..., 0...])
}

let convOut = silu(conv1d(convInput))
Expand Down Expand Up @@ -287,6 +287,7 @@ final class Qwen35GatedDeltaNet: Module {

if let cache {
cache[1] = state
cache.advance(S)
Comment thread
aleroot marked this conversation as resolved.
}

out = norm(out, gate: z)
Expand Down
9 changes: 7 additions & 2 deletions Libraries/MLXLLM/Models/Qwen3Next.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ func sigmoidMultiply(_ x: MLXArray, _ gate: MLXArray) -> MLXArray {
x * sigmoid(gate)
}

private func preciseSwiGLU(_ hiddenStates: MLXArray, gate: MLXArray, x: MLXArray) -> MLXArray {
(silu(gate.asType(.float32)) * x.asType(.float32)).asType(hiddenStates.dtype)
}

// MARK: - Model Components

final class Qwen3NextRMSNormGated: Module {
Expand All @@ -31,7 +35,7 @@ final class Qwen3NextRMSNormGated: Module {
func callAsFunction(_ hiddenStates: MLXArray, gate: MLXArray? = nil) -> MLXArray {
var x = MLXFast.rmsNorm(hiddenStates, weight: weight, eps: eps)
if let gate {
x = x * silu(gate)
x = preciseSwiGLU(hiddenStates, gate: gate, x: x)
}
return x
}
Expand Down Expand Up @@ -258,7 +262,7 @@ public final class Qwen3NextGatedDeltaNet: Module {

let convInput = concatenated([convState, mixedQKV], axis: 1)
if let cache {
cache[0] = convInput[0..., (1 - convKernelSize)..., 0...]
cache[0] = contiguous(convInput[0..., (1 - convKernelSize)..., 0...])
}

let convOut = silu(conv1d(convInput))
Expand Down Expand Up @@ -290,6 +294,7 @@ public final class Qwen3NextGatedDeltaNet: Module {

if let cache {
cache[1] = newState
cache.advance(S)
}

let normalized = norm(out, gate: z)
Expand Down
5 changes: 3 additions & 2 deletions Libraries/MLXLMCommon/Evaluate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -671,8 +671,9 @@ public struct TokenIterator: TokenIteratorProtocol {

/// Evaluate the next token and return the new token (y), updating cache state
mutating func step(previous: LMInput.Text) -> MLXArray {
let result = model(
previous[text: .newAxis], cache: cache.isEmpty ? nil : cache, state: state)
let result = withPreparedCache(cache, lengths: previous.sequenceLengths) {
model(previous[text: .newAxis], cache: cache.isEmpty ? nil : cache, state: state)
}
self.state = result.state

// Apply dynamic cache quantization after each step
Expand Down
80 changes: 56 additions & 24 deletions Libraries/MLXLMCommon/KVCache.swift
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,46 @@ public protocol KVCache: Evaluatable {

/// Create an independent deep copy of this cache.
func copy() -> any KVCache

/// Prepare cache metadata for a batched sequence.
func prepare(lengths: [Int]?)

/// Prepare cache metadata for a batched sequence.
func prepare(lengths: MLXArray?)

/// Clear transient cache metadata after generation.
func finalize()
}

extension KVCache {
public var ropeOffset: RoPEOffset {
.scalar(offset)
}

public func prepare(lengths: [Int]?) {}

public func prepare(lengths: MLXArray?) {}

public func finalize() {}
}

public func withPreparedCache<Result>(
_ cache: [any KVCache],
lengths: [Int]?,
_ body: () throws -> Result
) rethrows -> Result {
guard let lengths else {
return try body()
}
for cache in cache {
cache.prepare(lengths: lengths)
}
defer {
for cache in cache {
cache.finalize()
}
}
return try body()
}

/// Protocol for caches that support efficient quantized operations
Expand Down Expand Up @@ -173,6 +207,12 @@ open class BaseKVCache: KVCache {
fatalError("copy() must be implemented by subclass")
}

open func prepare(lengths: [Int]?) {}

open func prepare(lengths: MLXArray?) {}

open func finalize() {}

/// Default implementation for caches without special mask requirements
open func makeMask(
n: Int, windowSize: Int?, returnArray: Bool
Expand Down Expand Up @@ -1163,7 +1203,7 @@ public class ChunkedKVCache: KVCacheSimple {

/// Base cache for array-based state storage
public class ArraysCache: BaseKVCache {
private var cache: [MLXArray?]
fileprivate var cache: [MLXArray?]
internal var leftPadding: MLXArray?
internal var lengths: MLXArray?

Expand Down Expand Up @@ -1198,10 +1238,7 @@ public class ArraysCache: BaseKVCache {
}

internal func copyContents(to new: ArraysCache) {
let s = self.state
if !s.isEmpty {
new.state = s.map { $0[.ellipsis] }
}
new.cache = cache.map { $0?[.ellipsis] }
new.offset = self.offset
new.leftPadding = self.leftPadding
new.lengths = self.lengths
Expand Down Expand Up @@ -1244,15 +1281,15 @@ public class ArraysCache: BaseKVCache {
lengths = concatenate(lengths, other.lengths)
}

public func prepare(lengths: [Int]?) {
public override func prepare(lengths: [Int]?) {
self.lengths = lengths.map { MLXArray($0) }
}

public func prepare(lengths: MLXArray?) {
public override func prepare(lengths: MLXArray?) {
self.lengths = lengths
}

public func finalize() {
public override func finalize() {
lengths = nil
leftPadding = nil
}
Expand All @@ -1275,6 +1312,11 @@ public class ArraysCache: BaseKVCache {
return leftPadding.asArray(Int.self)
}

internal var lengthsValues: [Int]? {
guard let lengths else { return nil }
return lengths.asArray(Int.self)
}

internal var presentSlotIndices: [Int] {
cache.enumerated().compactMap { (i, v) in v != nil ? i : nil }
}
Expand Down Expand Up @@ -1412,26 +1454,16 @@ public class CacheList: BaseKVCache {
return new
}

public func prepare(lengths: [Int]?) {
forEachArraysCache { $0.prepare(lengths: lengths) }
}

public func prepare(lengths: MLXArray?) {
forEachArraysCache { $0.prepare(lengths: lengths) }
public override func prepare(lengths: [Int]?) {
caches.forEach { $0.prepare(lengths: lengths) }
}

public func finalize() {
forEachArraysCache { $0.finalize() }
public override func prepare(lengths: MLXArray?) {
caches.forEach { $0.prepare(lengths: lengths) }
}

private func forEachArraysCache(_ body: (ArraysCache) -> Void) {
for cache in caches {
if let arrays = cache as? ArraysCache {
body(arrays)
} else if let list = cache as? CacheList {
list.forEachArraysCache(body)
}
}
public override func finalize() {
caches.forEach { $0.finalize() }
}

public override var isTrimmable: Bool {
Expand Down
9 changes: 9 additions & 0 deletions Libraries/MLXLMCommon/LanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ public struct LMInput {
) -> Text {
Text(tokens: tokens[indices, stream: stream], mask: mask)
}

/// Per-batch sequence lengths derived from the optional attention mask.
public var sequenceLengths: [Int]? {
if let mask {
return mask.asType(.int32).sum(axis: -1).asArray(Int.self)
}
guard tokens.ndim == 2 else { return nil }
return Array(repeating: tokens.dim(1), count: tokens.dim(0))
}
}

/// Representation of prepared input image(s).
Expand Down
Loading