feat: MTP speculative decoding — Phase 1 foundation#39
Conversation
Implements the core infrastructure for Multi-Token Prediction (MTP) speculative decoding, enabling 2x+ inference speedups using model-native draft heads instead of an external draft model. Closes ml-explore#102 Changes: MTPConfig.swift (new): Add retainMTPWeights env flag gated on SWIFTLM_MTP_ENABLE=1, providing a single source of truth for all sanitizers. LanguageModel.swift: Define MTPLanguageModel protocol inheriting LanguageModel. Adds callMTP(_ inputs:cache:) standardized contract for exposing internal MTP prediction heads. Evaluate.swift: Implement MTPTokenIterator mirroring SpeculativeTokenIterator structure but eliminating the external draft model dependency. Flow: draft from cached MTP head logits, verify in one batched forward pass, rejection-sample, trimPromptCache for rejected suffix. Compatible with TurboKV, SSD streaming, and existing LogitProcessor pipeline. Model sanitizers (conditional weight retention): Updated sanitize() in all affected models to skip MTP weight filtering when MTPConfig.retainMTPWeights is true: Qwen35, Qwen3Next, DeepseekV4, MiMo, MiMoV2Flash, and VLM Qwen35. What is next (Phase 2): Conform Qwen3.5 and DeepSeek V4 to MTPLanguageModel by exposing mtp weights as callable layers. Wire --mtp flag in Server.swift to instantiate MTPTokenIterator.
There was a problem hiding this comment.
Pull request overview
This PR lays Phase 1 groundwork for MTP (Multi-Token Prediction) speculative decoding by introducing a shared MTP configuration flag, a protocol surface for models that can expose internal MTP heads, an MTP-based token iterator, and sanitizer changes to optionally retain mtp.* weights based on an environment variable.
Changes:
- Add
MTPConfig(env-gated viaSWIFTLM_MTP_ENABLE=1) and thread it through model weight sanitizers to optionally keep MTP weights. - Introduce
MTPLanguageModeland implementMTPTokenIteratorto draft using internal MTP-head logits and verify with a single batched forward pass. - Update multiple model sanitizers (LLM + VLM) to conditionally drop or retain MTP-related checkpoint keys.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| Libraries/MLXLMCommon/MTPConfig.swift | Adds a global env-driven switch for retaining MTP weights during sanitization. |
| Libraries/MLXLMCommon/LanguageModel.swift | Introduces MTPLanguageModel protocol with callMTP API for MTP-head logits. |
| Libraries/MLXLMCommon/Evaluate.swift | Adds MTPTokenIterator implementing MTP speculative decoding behavior. |
| Libraries/MLXLLM/Models/Qwen35.swift | Conditionally strips mtp. weights based on MTPConfig. |
| Libraries/MLXLLM/Models/Qwen3Next.swift | Conditionally strips mtp. weights based on MTPConfig. |
| Libraries/MLXLLM/Models/DeepseekV4.swift | Conditionally strips “MTP layers” (layer index ≥ numMainLayers) based on MTPConfig. |
| Libraries/MLXLLM/Models/MiMo.swift | Conditionally strips model.mtp_layers. weights based on MTPConfig. |
| Libraries/MLXLLM/Models/MiMoV2Flash.swift | Conditionally strips model.mtp* weights based on MTPConfig. |
| Libraries/MLXVLM/Models/Qwen35.swift | Conditionally strips mtp. weights based on MTPConfig. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if draftTokens.isEmpty { | ||
| let mtpResult = model.callMTP(y.tokens, cache: cache) | ||
| guard !mtpResult.isEmpty else { return } | ||
|
|
||
| let mainLogits = mtpResult[0] | ||
| var logits = mainLogits[0..., -1, 0...] | ||
| logits = processor?.process(logits: logits) ?? logits | ||
| let token = sampler.sample(logits: logits) | ||
| processor?.didSample(token: token) | ||
|
|
||
| pendingTokens.append(token.item(Int.self)) | ||
| y = .init(tokens: token) | ||
|
|
||
| // Save future MTP logits for next iteration | ||
| self.mtpLogits = mtpResult.count > 1 ? Array(mtpResult.dropFirst()) : nil | ||
| return | ||
| } | ||
|
|
||
| // Verification: main model processes proposals in one pass | ||
| for layer in cache { | ||
| if let mamba = layer as? MambaCache { mamba.checkpoint() } | ||
| } | ||
|
|
||
| let verifyTokens = [y.tokens] + draftTokens | ||
| let verifyInput = LMInput.Text(tokens: concatenated(verifyTokens)) | ||
| let verifyStart = verifyInput.tokens.dim(0) - (draftTokens.count + 1) | ||
|
|
||
| let mtpResult = model.callMTP(verifyInput.tokens, cache: cache) | ||
| guard !mtpResult.isEmpty else { return } |
| // Save future MTP logits for next iteration | ||
| self.mtpLogits = mtpResult.count > 1 ? Array(mtpResult.dropFirst()) : nil |
| y = .init(tokens: token) | ||
|
|
||
| // Save future MTP logits for next iteration | ||
| self.mtpLogits = mtpResult.count > 1 ? Array(mtpResult.dropFirst()) : nil |
| /// An iterator that generates tokens using Multi-Token Prediction (MTP) for speculative decoding. | ||
| /// It uses internal MTP heads of the main model instead of an external draft model. | ||
| public struct MTPTokenIterator: TokenIteratorProtocol { | ||
|
|
||
| var y: LMInput.Text | ||
| let model: any MTPLanguageModel | ||
|
|
||
| var state: LMOutput.State? | ||
| public let streamingError: SSDStreamingError? = nil | ||
| var cache: [KVCache] | ||
| let quantizeKVCache: (inout [KVCache]) -> Void | ||
|
|
||
| var processor: LogitProcessor? | ||
| let sampler: LogitSampler | ||
|
|
||
| var tokenCount = 0 | ||
| let maxTokens: Int? | ||
|
|
||
| // Number of tokens the MTP heads predict (k) | ||
| let numMTPTokens: Int | ||
|
|
||
| // Logits from the previous step's MTP heads | ||
| var mtpLogits: [MLXArray]? | ||
|
|
||
| // Buffer of accepted tokens from the current speculation round | ||
| private var pendingTokens = [Int]() | ||
| private var pendingIndex = 0 | ||
|
|
||
| // Internal metrics | ||
| var promptPrefillTime: TimeInterval = 0.0 | ||
|
|
||
| /// Initialize a `MTPTokenIterator` with the given input. |
## Qwen35TextModel — full MTPLanguageModel conformance
- Add `num_nextn_predict_layers` field to `Qwen35TextConfiguration`
(decoded from `config.json`; defaults to 0 for non-MTP checkpoints).
- Add `Qwen35MTPLayer` module with the reference architecture:
enorm + hnorm (RMSNorm) → eh_proj (concat proj)
→ Qwen35DecoderLayer → shared_head (vocab projection)
- Add `@ModuleInfo(key: "mtp") var mtp: [Qwen35MTPLayer]` to
`Qwen35TextModel`; array is populated only when
`MTPConfig.retainMTPWeights` is true, so non-MTP usage is zero-cost.
- Implement `callMTP(_ inputs:cache:) -> [MLXArray]`:
1. Embed tokens (reused across all MTP heads)
2. Run main forward pass → main logits
3. Chain each Qwen35MTPLayer on the previous hidden state
4. Return [main_logits, mtp_0_logits, mtp_1_logits, ...]
## DeepseekV4Model — MTPLanguageModel stub
- Conform `DeepseekV4Model` to `MTPLanguageModel`.
- callMTP() falls back to main logits until the `mtpLayers` array is
allocated inside `DeepseekV4ModelInner` (tracked in Phase 3).
## Evaluate.swift — generateMTP() public API
- Add `generateMTP(input:cache:parameters:context:numMTPTokens:)`
public function that:
• Casts context.model to `any MTPLanguageModel`
• Instantiates `MTPTokenIterator`
• Pipes through `generateLoopTask` with full tool-call support
• Falls back to standard `generate()` if cast fails (safe for
non-MTP models, no caller-side guard needed)
## GenerationConfig — enableMTP / numMTPTokens flags
- `enableMTP: Bool = false` — per-request MTP toggle, Codable/persisted.
- `numMTPTokens: Int = 1` — draft depth per speculation round.
## InferenceEngine — MTP routing
- `generate()` checks `config.enableMTP && model is MTPLanguageModel`
and routes to `MLXLMCommon.generateMTP()` instead of the standard
generate path.
…kenIterator
10 tests covering:
- MTPConfig.retainMTPWeights env-var gate
- MTPLanguageModel protocol hierarchy (compile-time + runtime cast)
- Qwen35TextConfiguration num_nextn_predict_layers decoding
- mtp array empty-when-unset guard
- callMTP fallback shape correctness
- callMTP main logits == callAsFunction determinism
- callMTP multi-batch shape [B, S, V]
- MTPTokenIterator exact maxTokens count
- MTPTokenIterator temperature=0 identity with TokenIterator
- KVCache offset advances after generation
- generateMTP fallback for models with no MTP heads
Also: make Qwen35MTPLayer and mtp array public for test
visibility; make numNextnPredictLayers public in config;
fix deprecated createAttentionMask call site.
…ersistence
- Added makeMTPCaches to MTPLanguageModel protocol for allocating persistent KV caches for MTP heads.
- Updated MTPTokenIterator to use persistent MTP caches across speculation rounds. Resolves the 'Recursive Depth Collapse' bug where depth 5 acceptance plummeted due to KV cache resetting.
- Updated Qwen35TextModel and DeepseekV4Model to support makeMTPCaches and accept mtpCaches in callMTP.
- Updated DeepseekV4ModelInner to conditionally allocate its mtpLayers array.
- Updated DeepseekV4 sanitize method to correctly remap model.layers.{numMain+i} to model.mtpLayers.{i}.
- Appended DeepseekV4-specific tests to MTPSpeculativeDecodingTests.
… support - Force MLX evaluation of mtpLogits to prevent recursive compute graph explosion (OOM). - Apply dynamic KV cache quantization to all MTP draft heads during rewinds. - Optimize SwitchGLU SSD streaming with async pre-reads and fused buffers. - Add Qwen3.6-35B model definition and MTP speculation hooks.
|
Added OOM fixes for MTP Speculative Decoding (recursive graph collapse and missing quantization for draft heads). Code is pushed to this branch and the PR is automatically updated. |
- Add fp8_gather_gemv fused Metal kernel in SwitchLayers.swift with dynamic
ROWS_PER_TG grid batching to prevent Metal dispatch overflows during 512-token
prefill (grid limit was exceeded with fixed threadgroup dims)
- Refactor SwitchLinear.init() to allocate [1,1,1] UInt8 placeholder instead
of full [numExperts, outputDims, inputDims] Float32 tensor — eliminating the
~2700 GB virtual memory graph that caused JetSam OOM pre-boot on 64GB UMA
- Add allocateExpertBuffers() using explicit integer dims (inputDims/outputDims)
decoupled from weight.shape() to fix shape mismatch in SSD streaming path
- Qwen35MoE.sanitize(): when stream-experts is active, skip MLX.stacked() for
both primary MoE layers and MTP layers — strip expert weight keys immediately
to prevent 57 GB FP8 shard eager-mapping into UMA
- Step 5 dequantization: eagerly prune source tensor references before MLX.eval
on each non-expert Linear projection to suppress peak RAM spikes during load
- Metal kernel: add explicit (device const uint8_t *) / (device const bfloat *)
pointer casts to prevent compile-time type errors on dummy UInt8 buffers
- FP8Linear.swift: standalone FP8 dequantization helper for block-scaled
weight_scale_inv tensors (128-element blocks, bfloat16 output)
- Load.swift: preserve weight_scale_inv stacking for SSD-streaming path;
assign weightScaleInv on SwitchLinear from stackedScales even when
stream-experts skips weight stacking
Fixes: connection-refused on 512-token prefill, JetSam pre-boot abort,
Metal grid overflow for large expert projections
Validated: Qwen/Qwen3.6-35B-A3B FP8 loads and generates on M5 Pro 64GB
(SSD-streaming path: 64.6 GB GPU_MEM, 9 tokens generated correctly)
…fine MoE load verification
…matches, implement dynamic KV slicing, and add target dimensionality padding for incompatible assistant models.
…inputs with main model predictions and invalidating mtpLogits on cache rewind
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 23 out of 23 changed files in this pull request and generated 14 comments.
Comments suppressed due to low confidence (1)
Libraries/MLXLMCommon/SwitchLayers.swift:206
- In the fused stacked-buffer path,
_combinedGateUpScales!is force-unwrapped andrunFusedGateUpMatmulforce-castsgateProj as! QuantizedSwitchLinear, but the current eligibility guard no longer requiresgateProj/upProjto beQuantizedSwitchLinear. IfuseFusedGateUpis enabled with non-quantized projections, this will crash on cold init or at matmul. Restore a quantized-type guard for the fused path (or make fused mode gracefully fall back when scales/biases aren’t available).
guard idx.size <= 32,
let gateSSD = gateProj.resolveSSDInfo(),
let upSSD = upProj.resolveSSDInfo(),
let downSSD = downProj.resolveSSDInfo() else {
return nil // ineligible — fall through to legacy path
}
let CACHE_SLOTS = SwitchGLU.MAX_CACHE_SLOTS
let isFused = SwitchGLU.useFusedGateUp
if _stackedGate == nil && _stackedGateUp == nil {
if isFused {
// Combined gate+up buffer: shape [CACHE_SLOTS, 2*intermediate, hidden].
_stackedGateUp = MLXArray.zeros(
[CACHE_SLOTS, 2 * gateProj.weight.dim(1), gateProj.weight.dim(2)]
).asType(gateProj.weight.dtype)
_stackedDown = MLXArray.zeros(
[CACHE_SLOTS, downProj.weight.dim(1), downProj.weight.dim(2)]
).asType(downProj.weight.dtype)
// Pre-concatenate gate+up scales/biases (one-time at cold init).
if let qGate = gateProj as? QuantizedSwitchLinear, let qUp = upProj as? QuantizedSwitchLinear {
_combinedGateUpScales = MLX.concatenated([qGate.scales, qUp.scales], axis: 1)
if let gb = qGate.biases, let ub = qUp.biases {
_combinedGateUpBiases = MLX.concatenated([gb, ub], axis: 1)
}
}
_slotExpert = Array(repeating: nil, count: CACHE_SLOTS)
_slotLastUsed = Array(repeating: 0, count: CACHE_SLOTS)
_tokenCounter = 0
var coldEvalList: [MLXArray] = [idx, _stackedGateUp!, _stackedDown!, _combinedGateUpScales!]
if let cb = _combinedGateUpBiases { coldEvalList.append(cb) }
MLX.eval(coldEvalList)
_stackedGateUpBytesPerProj = _stackedGateUp!.nbytes / CACHE_SLOTS / 2
_stackedBytesPerExpert = _stackedGateUpBytesPerProj
| /// Default: call the two-argument overload with no MTP caches. | ||
| /// Models that don't override `makeMTPCaches` get a zero-element array. | ||
| public func callMTP(_ inputs: MLXArray, cache: [KVCache]?, mtpCaches: [[KVCache]]?) -> [MLXArray] { | ||
| callMTP(inputs, cache: cache) | ||
| } | ||
|
|
||
| /// Shim for backward compat — calls the three-argument form with nil mtpCaches. |
| } | ||
|
|
||
| let acceptProb = Swift.min(1.0, pTargetX / Swift.max(pDraftX, 1e-9)) | ||
| let u = Float.random(in: 0..<1) |
| let iterator: any TokenIteratorProtocol | ||
| if let mtpModel = draftModel as? DualModelMTP { | ||
| // Set up the dual-model MTP reference | ||
| mtpModel.mainModelRef = context.model as? any BaseLanguageModel | ||
| iterator = try MTPTokenIterator( | ||
| input: input, | ||
| model: mtpModel, | ||
| cache: cache, |
|
|
||
| // SYNCHRONIZATION POINT | ||
| // Ensure the GPU has finished reading the stacked buffers from the previous token's | ||
| // computeExpertsFused before we overwrite those slots with new expert weights from the SSD. | ||
| Stream.gpu.synchronize() | ||
| print("[SwitchLayers] SSD Sync: GPU drained. Misses=\(missesNeedingPread.count)") | ||
| fflush(stdout) | ||
|
|
| // Extract weight_scale_inv for switch_mlp layers BEFORE update to avoid Unhandled Keys | ||
| var stackedScales = [String: MLXArray]() | ||
| for key in weights.keys { | ||
| if key.contains(".switch_mlp.") && key.hasSuffix(".weight_scale_inv") { | ||
| if let val = weights[key] { | ||
| stackedScales[key] = val | ||
| weights.removeValue(forKey: key) | ||
| } | ||
| } | ||
| } |
| let slotPerTokenArr: [Int32] = [0, 1, 2, 3, 4, 5, 6, 7] | ||
| let slotPerToken = MLXArray(slotPerTokenArr).asType(.uint32) | ||
| let slots = slotPerToken.asArray(Int32.self) | ||
| print(slots) |
| // This package uses a local path reference so the exact commit is | ||
| // controlled by WhichEver repo (SwiftLM) has both as submodules. | ||
| // In standalone CI, the checkout step clones SharpAI/mlx-swift | ||
| // into ../mlx-swift so this path resolves correctly. | ||
| // ───────────────────────────────────────────────────────────────────────── | ||
| .package(url: "https://github.com/SharpAI/mlx-swift.git", branch: "main"), | ||
| .package(path: "../mlx-swift"), | ||
|
|
||
| .package(url: "https://github.com/swiftlang/swift-syntax.git", from: "600.0.0-latest"), | ||
| ], |
| let dequantized = scaled.reshaped([1, m + padBottom, n + padSide])[0..., 0 ..< m, 0 ..< n] | ||
| w = dequantized.asType(x.dtype) | ||
| } else { | ||
| if i == 0 { print("[SwitchLayers] computeExperts: NO weightScaleInv found! w shape=\(w.shape), dtype=\(w.dtype)") } |
| print("[SwitchLayers] computeExpertsFused: FATAL ERROR: NO weightScaleInv found! w shape=\(w.shape), dtype=\(w.dtype)") | ||
| fflush(stdout) |
| // Map community MTP checkpoint keys (e.g. language_model.mtp.fc) to array indices (language_model.mtp.0.fc) | ||
| // Some checkpoints use .mtp.fc instead of the array index .mtp.0.fc | ||
| let updatedKey = k.contains(".mtp.") && !k.contains(".mtp.0.") ? k.replacingOccurrences(of: ".mtp.", with: ".mtp.0.") : k | ||
| let updatedVal = v | ||
|
|
||
| if updatedKey != k { | ||
| weights.removeValue(forKey: k) | ||
| weights[updatedKey] = v | ||
| } |
… MTPTokenIterator
28a931c to
d0b7121
Compare
mtpLogitsto prevent recursive compute graph explosion (OOM).SwitchGLUSSD streaming with async pre-reads and fused buffers.Resolves SharpAI/SwiftLM#102