Skip to content

feat: MTP speculative decoding — Phase 1 foundation#39

Merged
solderzzc merged 14 commits into
mainfrom
feat/mtp-speculative-decoding
May 12, 2026
Merged

feat: MTP speculative decoding — Phase 1 foundation#39
solderzzc merged 14 commits into
mainfrom
feat/mtp-speculative-decoding

Conversation

@solderzzc

@solderzzc solderzzc commented May 5, 2026

Copy link
Copy Markdown
Member
  • Force MLX evaluation of mtpLogits to prevent recursive compute graph explosion (OOM).
  • Apply dynamic KV cache quantization to all MTP draft heads during rewinds.
  • Optimize SwitchGLU SSD streaming with async pre-reads and fused buffers.
  • Add Qwen3.6-35B model definition and MTP speculation hooks.

Resolves SharpAI/SwiftLM#102

Implements the core infrastructure for Multi-Token Prediction (MTP)
speculative decoding, enabling 2x+ inference speedups using model-native
draft heads instead of an external draft model.

Closes ml-explore#102

Changes:

MTPConfig.swift (new): Add retainMTPWeights env flag gated on
SWIFTLM_MTP_ENABLE=1, providing a single source of truth for
all sanitizers.

LanguageModel.swift: Define MTPLanguageModel protocol inheriting
LanguageModel. Adds callMTP(_ inputs:cache:) standardized contract
for exposing internal MTP prediction heads.

Evaluate.swift: Implement MTPTokenIterator mirroring
SpeculativeTokenIterator structure but eliminating the external
draft model dependency. Flow: draft from cached MTP head logits,
verify in one batched forward pass, rejection-sample, trimPromptCache
for rejected suffix. Compatible with TurboKV, SSD streaming, and
existing LogitProcessor pipeline.

Model sanitizers (conditional weight retention): Updated sanitize()
in all affected models to skip MTP weight filtering when
MTPConfig.retainMTPWeights is true: Qwen35, Qwen3Next, DeepseekV4,
MiMo, MiMoV2Flash, and VLM Qwen35.

What is next (Phase 2): Conform Qwen3.5 and DeepSeek V4 to
MTPLanguageModel by exposing mtp weights as callable layers.
Wire --mtp flag in Server.swift to instantiate MTPTokenIterator.
Copilot AI review requested due to automatic review settings May 5, 2026 17:09

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR lays Phase 1 groundwork for MTP (Multi-Token Prediction) speculative decoding by introducing a shared MTP configuration flag, a protocol surface for models that can expose internal MTP heads, an MTP-based token iterator, and sanitizer changes to optionally retain mtp.* weights based on an environment variable.

Changes:

  • Add MTPConfig (env-gated via SWIFTLM_MTP_ENABLE=1) and thread it through model weight sanitizers to optionally keep MTP weights.
  • Introduce MTPLanguageModel and implement MTPTokenIterator to draft using internal MTP-head logits and verify with a single batched forward pass.
  • Update multiple model sanitizers (LLM + VLM) to conditionally drop or retain MTP-related checkpoint keys.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
Libraries/MLXLMCommon/MTPConfig.swift Adds a global env-driven switch for retaining MTP weights during sanitization.
Libraries/MLXLMCommon/LanguageModel.swift Introduces MTPLanguageModel protocol with callMTP API for MTP-head logits.
Libraries/MLXLMCommon/Evaluate.swift Adds MTPTokenIterator implementing MTP speculative decoding behavior.
Libraries/MLXLLM/Models/Qwen35.swift Conditionally strips mtp. weights based on MTPConfig.
Libraries/MLXLLM/Models/Qwen3Next.swift Conditionally strips mtp. weights based on MTPConfig.
Libraries/MLXLLM/Models/DeepseekV4.swift Conditionally strips “MTP layers” (layer index ≥ numMainLayers) based on MTPConfig.
Libraries/MLXLLM/Models/MiMo.swift Conditionally strips model.mtp_layers. weights based on MTPConfig.
Libraries/MLXLLM/Models/MiMoV2Flash.swift Conditionally strips model.mtp* weights based on MTPConfig.
Libraries/MLXVLM/Models/Qwen35.swift Conditionally strips mtp. weights based on MTPConfig.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1112 to +1140
if draftTokens.isEmpty {
let mtpResult = model.callMTP(y.tokens, cache: cache)
guard !mtpResult.isEmpty else { return }

let mainLogits = mtpResult[0]
var logits = mainLogits[0..., -1, 0...]
logits = processor?.process(logits: logits) ?? logits
let token = sampler.sample(logits: logits)
processor?.didSample(token: token)

pendingTokens.append(token.item(Int.self))
y = .init(tokens: token)

// Save future MTP logits for next iteration
self.mtpLogits = mtpResult.count > 1 ? Array(mtpResult.dropFirst()) : nil
return
}

// Verification: main model processes proposals in one pass
for layer in cache {
if let mamba = layer as? MambaCache { mamba.checkpoint() }
}

let verifyTokens = [y.tokens] + draftTokens
let verifyInput = LMInput.Text(tokens: concatenated(verifyTokens))
let verifyStart = verifyInput.tokens.dim(0) - (draftTokens.count + 1)

let mtpResult = model.callMTP(verifyInput.tokens, cache: cache)
guard !mtpResult.isEmpty else { return }
Comment thread Libraries/MLXLMCommon/Evaluate.swift Outdated
Comment on lines +1125 to +1126
// Save future MTP logits for next iteration
self.mtpLogits = mtpResult.count > 1 ? Array(mtpResult.dropFirst()) : nil
Comment thread Libraries/MLXLMCommon/Evaluate.swift Outdated
y = .init(tokens: token)

// Save future MTP logits for next iteration
self.mtpLogits = mtpResult.count > 1 ? Array(mtpResult.dropFirst()) : nil
Comment on lines +1004 to +1035
/// An iterator that generates tokens using Multi-Token Prediction (MTP) for speculative decoding.
/// It uses internal MTP heads of the main model instead of an external draft model.
public struct MTPTokenIterator: TokenIteratorProtocol {

var y: LMInput.Text
let model: any MTPLanguageModel

var state: LMOutput.State?
public let streamingError: SSDStreamingError? = nil
var cache: [KVCache]
let quantizeKVCache: (inout [KVCache]) -> Void

var processor: LogitProcessor?
let sampler: LogitSampler

var tokenCount = 0
let maxTokens: Int?

// Number of tokens the MTP heads predict (k)
let numMTPTokens: Int

// Logits from the previous step's MTP heads
var mtpLogits: [MLXArray]?

// Buffer of accepted tokens from the current speculation round
private var pendingTokens = [Int]()
private var pendingIndex = 0

// Internal metrics
var promptPrefillTime: TimeInterval = 0.0

/// Initialize a `MTPTokenIterator` with the given input.
Aegis-AI added 6 commits May 5, 2026 10:23
## Qwen35TextModel — full MTPLanguageModel conformance

- Add `num_nextn_predict_layers` field to `Qwen35TextConfiguration`
  (decoded from `config.json`; defaults to 0 for non-MTP checkpoints).
- Add `Qwen35MTPLayer` module with the reference architecture:
    enorm + hnorm (RMSNorm)  →  eh_proj (concat proj)
    →  Qwen35DecoderLayer  →  shared_head (vocab projection)
- Add `@ModuleInfo(key: "mtp") var mtp: [Qwen35MTPLayer]` to
  `Qwen35TextModel`; array is populated only when
  `MTPConfig.retainMTPWeights` is true, so non-MTP usage is zero-cost.
- Implement `callMTP(_ inputs:cache:) -> [MLXArray]`:
    1. Embed tokens (reused across all MTP heads)
    2. Run main forward pass → main logits
    3. Chain each Qwen35MTPLayer on the previous hidden state
    4. Return [main_logits, mtp_0_logits, mtp_1_logits, ...]

## DeepseekV4Model — MTPLanguageModel stub

- Conform `DeepseekV4Model` to `MTPLanguageModel`.
- callMTP() falls back to main logits until the `mtpLayers` array is
  allocated inside `DeepseekV4ModelInner` (tracked in Phase 3).

## Evaluate.swift — generateMTP() public API

- Add `generateMTP(input:cache:parameters:context:numMTPTokens:)`
  public function that:
    • Casts context.model to `any MTPLanguageModel`
    • Instantiates `MTPTokenIterator`
    • Pipes through `generateLoopTask` with full tool-call support
    • Falls back to standard `generate()` if cast fails (safe for
      non-MTP models, no caller-side guard needed)

## GenerationConfig — enableMTP / numMTPTokens flags

- `enableMTP: Bool = false` — per-request MTP toggle, Codable/persisted.
- `numMTPTokens: Int = 1`   — draft depth per speculation round.

## InferenceEngine — MTP routing

- `generate()` checks `config.enableMTP && model is MTPLanguageModel`
  and routes to `MLXLMCommon.generateMTP()` instead of the standard
  generate path.
…kenIterator

10 tests covering:
- MTPConfig.retainMTPWeights env-var gate
- MTPLanguageModel protocol hierarchy (compile-time + runtime cast)
- Qwen35TextConfiguration num_nextn_predict_layers decoding
- mtp array empty-when-unset guard
- callMTP fallback shape correctness
- callMTP main logits == callAsFunction determinism
- callMTP multi-batch shape [B, S, V]
- MTPTokenIterator exact maxTokens count
- MTPTokenIterator temperature=0 identity with TokenIterator
- KVCache offset advances after generation
- generateMTP fallback for models with no MTP heads

Also: make Qwen35MTPLayer and mtp array public for test
      visibility; make numNextnPredictLayers public in config;
      fix deprecated createAttentionMask call site.
…ersistence

- Added makeMTPCaches to MTPLanguageModel protocol for allocating persistent KV caches for MTP heads.
- Updated MTPTokenIterator to use persistent MTP caches across speculation rounds. Resolves the 'Recursive Depth Collapse' bug where depth 5 acceptance plummeted due to KV cache resetting.
- Updated Qwen35TextModel and DeepseekV4Model to support makeMTPCaches and accept mtpCaches in callMTP.
- Updated DeepseekV4ModelInner to conditionally allocate its mtpLayers array.
- Updated DeepseekV4 sanitize method to correctly remap model.layers.{numMain+i} to model.mtpLayers.{i}.
- Appended DeepseekV4-specific tests to MTPSpeculativeDecodingTests.
… support

- Force MLX evaluation of mtpLogits to prevent recursive compute graph explosion (OOM).
- Apply dynamic KV cache quantization to all MTP draft heads during rewinds.
- Optimize SwitchGLU SSD streaming with async pre-reads and fused buffers.
- Add Qwen3.6-35B model definition and MTP speculation hooks.
@solderzzc

Copy link
Copy Markdown
Member Author

Added OOM fixes for MTP Speculative Decoding (recursive graph collapse and missing quantization for draft heads). Code is pushed to this branch and the PR is automatically updated.

Aegis-AI added 6 commits May 6, 2026 23:09
- Add fp8_gather_gemv fused Metal kernel in SwitchLayers.swift with dynamic
  ROWS_PER_TG grid batching to prevent Metal dispatch overflows during 512-token
  prefill (grid limit was exceeded with fixed threadgroup dims)

- Refactor SwitchLinear.init() to allocate [1,1,1] UInt8 placeholder instead
  of full [numExperts, outputDims, inputDims] Float32 tensor — eliminating the
  ~2700 GB virtual memory graph that caused JetSam OOM pre-boot on 64GB UMA

- Add allocateExpertBuffers() using explicit integer dims (inputDims/outputDims)
  decoupled from weight.shape() to fix shape mismatch in SSD streaming path

- Qwen35MoE.sanitize(): when stream-experts is active, skip MLX.stacked() for
  both primary MoE layers and MTP layers — strip expert weight keys immediately
  to prevent 57 GB FP8 shard eager-mapping into UMA

- Step 5 dequantization: eagerly prune source tensor references before MLX.eval
  on each non-expert Linear projection to suppress peak RAM spikes during load

- Metal kernel: add explicit (device const uint8_t *) / (device const bfloat *)
  pointer casts to prevent compile-time type errors on dummy UInt8 buffers

- FP8Linear.swift: standalone FP8 dequantization helper for block-scaled
  weight_scale_inv tensors (128-element blocks, bfloat16 output)

- Load.swift: preserve weight_scale_inv stacking for SSD-streaming path;
  assign weightScaleInv on SwitchLinear from stackedScales even when
  stream-experts skips weight stacking

Fixes: connection-refused on 512-token prefill, JetSam pre-boot abort,
       Metal grid overflow for large expert projections

Validated: Qwen/Qwen3.6-35B-A3B FP8 loads and generates on M5 Pro 64GB
           (SSD-streaming path: 64.6 GB GPU_MEM, 9 tokens generated correctly)
…matches, implement dynamic KV slicing, and add target dimensionality padding for incompatible assistant models.
…inputs with main model predictions and invalidating mtpLogits on cache rewind

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 23 out of 23 changed files in this pull request and generated 14 comments.

Comments suppressed due to low confidence (1)

Libraries/MLXLMCommon/SwitchLayers.swift:206

  • In the fused stacked-buffer path, _combinedGateUpScales! is force-unwrapped and runFusedGateUpMatmul force-casts gateProj as! QuantizedSwitchLinear, but the current eligibility guard no longer requires gateProj/upProj to be QuantizedSwitchLinear. If useFusedGateUp is enabled with non-quantized projections, this will crash on cold init or at matmul. Restore a quantized-type guard for the fused path (or make fused mode gracefully fall back when scales/biases aren’t available).
        guard idx.size <= 32,
              let gateSSD = gateProj.resolveSSDInfo(),
              let upSSD = upProj.resolveSSDInfo(),
              let downSSD = downProj.resolveSSDInfo() else {
            return nil  // ineligible — fall through to legacy path
        }

        let CACHE_SLOTS = SwitchGLU.MAX_CACHE_SLOTS
        let isFused = SwitchGLU.useFusedGateUp

        if _stackedGate == nil && _stackedGateUp == nil {
            if isFused {
                // Combined gate+up buffer: shape [CACHE_SLOTS, 2*intermediate, hidden].
                _stackedGateUp = MLXArray.zeros(
                    [CACHE_SLOTS, 2 * gateProj.weight.dim(1), gateProj.weight.dim(2)]
                ).asType(gateProj.weight.dtype)
                _stackedDown = MLXArray.zeros(
                    [CACHE_SLOTS, downProj.weight.dim(1), downProj.weight.dim(2)]
                ).asType(downProj.weight.dtype)
                // Pre-concatenate gate+up scales/biases (one-time at cold init).
                if let qGate = gateProj as? QuantizedSwitchLinear, let qUp = upProj as? QuantizedSwitchLinear {
                    _combinedGateUpScales = MLX.concatenated([qGate.scales, qUp.scales], axis: 1)
                    if let gb = qGate.biases, let ub = qUp.biases {
                        _combinedGateUpBiases = MLX.concatenated([gb, ub], axis: 1)
                    }
                }
                _slotExpert = Array(repeating: nil, count: CACHE_SLOTS)
                _slotLastUsed = Array(repeating: 0, count: CACHE_SLOTS)
                _tokenCounter = 0
                var coldEvalList: [MLXArray] = [idx, _stackedGateUp!, _stackedDown!, _combinedGateUpScales!]
                if let cb = _combinedGateUpBiases { coldEvalList.append(cb) }
                MLX.eval(coldEvalList)
                _stackedGateUpBytesPerProj = _stackedGateUp!.nbytes / CACHE_SLOTS / 2
                _stackedBytesPerExpert = _stackedGateUpBytesPerProj

Comment on lines +281 to +287
/// Default: call the two-argument overload with no MTP caches.
/// Models that don't override `makeMTPCaches` get a zero-element array.
public func callMTP(_ inputs: MLXArray, cache: [KVCache]?, mtpCaches: [[KVCache]]?) -> [MLXArray] {
callMTP(inputs, cache: cache)
}

/// Shim for backward compat — calls the three-argument form with nil mtpCaches.
}

let acceptProb = Swift.min(1.0, pTargetX / Swift.max(pDraftX, 1e-9))
let u = Float.random(in: 0..<1)
Comment on lines +1856 to +1863
let iterator: any TokenIteratorProtocol
if let mtpModel = draftModel as? DualModelMTP {
// Set up the dual-model MTP reference
mtpModel.mainModelRef = context.model as? any BaseLanguageModel
iterator = try MTPTokenIterator(
input: input,
model: mtpModel,
cache: cache,
Comment on lines +380 to +387

// SYNCHRONIZATION POINT
// Ensure the GPU has finished reading the stacked buffers from the previous token's
// computeExpertsFused before we overwrite those slots with new expert weights from the SSD.
Stream.gpu.synchronize()
print("[SwitchLayers] SSD Sync: GPU drained. Misses=\(missesNeedingPread.count)")
fflush(stdout)

Comment on lines +72 to +81
// Extract weight_scale_inv for switch_mlp layers BEFORE update to avoid Unhandled Keys
var stackedScales = [String: MLXArray]()
for key in weights.keys {
if key.contains(".switch_mlp.") && key.hasSuffix(".weight_scale_inv") {
if let val = weights[key] {
stackedScales[key] = val
weights.removeValue(forKey: key)
}
}
}
let slotPerTokenArr: [Int32] = [0, 1, 2, 3, 4, 5, 6, 7]
let slotPerToken = MLXArray(slotPerTokenArr).asType(.uint32)
let slots = slotPerToken.asArray(Int32.self)
print(slots)
Comment thread Package.swift
Comment on lines 47 to 55
// This package uses a local path reference so the exact commit is
// controlled by WhichEver repo (SwiftLM) has both as submodules.
// In standalone CI, the checkout step clones SharpAI/mlx-swift
// into ../mlx-swift so this path resolves correctly.
// ─────────────────────────────────────────────────────────────────────────
.package(url: "https://github.com/SharpAI/mlx-swift.git", branch: "main"),
.package(path: "../mlx-swift"),

.package(url: "https://github.com/swiftlang/swift-syntax.git", from: "600.0.0-latest"),
],
let dequantized = scaled.reshaped([1, m + padBottom, n + padSide])[0..., 0 ..< m, 0 ..< n]
w = dequantized.asType(x.dtype)
} else {
if i == 0 { print("[SwitchLayers] computeExperts: NO weightScaleInv found! w shape=\(w.shape), dtype=\(w.dtype)") }
Comment on lines +1318 to +1319
print("[SwitchLayers] computeExpertsFused: FATAL ERROR: NO weightScaleInv found! w shape=\(w.shape), dtype=\(w.dtype)")
fflush(stdout)
Comment on lines +763 to +771
// Map community MTP checkpoint keys (e.g. language_model.mtp.fc) to array indices (language_model.mtp.0.fc)
// Some checkpoints use .mtp.fc instead of the array index .mtp.0.fc
let updatedKey = k.contains(".mtp.") && !k.contains(".mtp.0.") ? k.replacingOccurrences(of: ".mtp.", with: ".mtp.0.") : k
let updatedVal = v

if updatedKey != k {
weights.removeValue(forKey: k)
weights[updatedKey] = v
}
@solderzzc solderzzc force-pushed the feat/mtp-speculative-decoding branch from 28a931c to d0b7121 Compare May 12, 2026 19:13
@solderzzc solderzzc merged commit 3362bab into main May 12, 2026
6 checks passed
@solderzzc solderzzc deleted the feat/mtp-speculative-decoding branch May 12, 2026 19:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: Integrate MTP Speculative Decoding (MTPLX-style) for 2x+ Speedup

2 participants