Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ private func assertDraftBlockMatchesFixture(name: String) async throws {
lastToken: lastToken,
lastHidden: lastHidden,
sharedKV: sharedKV,
positionDeltas: nil,
queryOffset: queryOffset,
blockSize: blockSize,
sampler: ArgMaxSampler()
Expand Down
81 changes: 75 additions & 6 deletions Libraries/MLXLLM/Models/Qwen35.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ public struct Qwen35TextConfiguration: Codable, Sendable {
var headDim: Int?
var ropeScaling: [String: StringOrNumber]?
var fullAttentionInterval: Int = 4
var mtpNumHiddenLayers: Int = 0
var mtpUseDedicatedEmbeddings: Bool = false

// MoE fields
var numExperts: Int = 0
Expand Down Expand Up @@ -71,6 +73,8 @@ public struct Qwen35TextConfiguration: Codable, Sendable {
case headDim = "head_dim"
case ropeScaling = "rope_scaling"
case fullAttentionInterval = "full_attention_interval"
case mtpNumHiddenLayers = "mtp_num_hidden_layers"
case mtpUseDedicatedEmbeddings = "mtp_use_dedicated_embeddings"
case numExperts = "num_experts"
case numExpertsPerTok = "num_experts_per_tok"
case decoderSparseStep = "decoder_sparse_step"
Expand Down Expand Up @@ -117,6 +121,10 @@ public struct Qwen35TextConfiguration: Codable, Sendable {
self.headDim = try container.decodeIfPresent(Int.self, forKey: .headDim)
self.fullAttentionInterval =
try container.decodeIfPresent(Int.self, forKey: .fullAttentionInterval) ?? 4
self.mtpNumHiddenLayers =
try container.decodeIfPresent(Int.self, forKey: .mtpNumHiddenLayers) ?? 0
self.mtpUseDedicatedEmbeddings =
try container.decodeIfPresent(Bool.self, forKey: .mtpUseDedicatedEmbeddings) ?? false

// MoE fields
self.numExperts = try container.decodeIfPresent(Int.self, forKey: .numExperts) ?? 0
Expand Down Expand Up @@ -342,7 +350,8 @@ final class Qwen35Attention: Module {
}

func callAsFunction(
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?,
positionOffset: Int? = nil
) -> MLXArray {
let B = x.dim(0)
let L = x.dim(1)
Expand All @@ -359,7 +368,7 @@ final class Qwen35Attention: Module {
keys = kNorm(keys.reshaped(B, L, kvHeads, -1)).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, kvHeads, -1).transposed(0, 2, 1, 3)

let offset = cache?.ropeOffset
let offset = positionOffset.map { RoPEOffset.scalar($0) } ?? cache?.ropeOffset
queries = applyRotaryPosition(rope, to: queries, offset: offset)
keys = applyRotaryPosition(rope, to: keys, offset: offset)

Expand Down Expand Up @@ -445,8 +454,9 @@ final class Qwen35DecoderLayer: Module {

@ModuleInfo(key: "mlp") var mlp: Module

init(_ args: Qwen35TextConfiguration, layerIdx: Int) {
self.isLinear = (layerIdx + 1) % args.fullAttentionInterval != 0
init(_ args: Qwen35TextConfiguration, layerIdx: Int, forceFullAttention: Bool = false) {
self.isLinear =
forceFullAttention ? false : (layerIdx + 1) % args.fullAttentionInterval != 0

if isLinear {
_linearAttn.wrappedValue = Qwen35GatedDeltaNet(args)
Expand Down Expand Up @@ -479,13 +489,16 @@ final class Qwen35DecoderLayer: Module {
_ x: MLXArray,
attentionMask: MLXFast.ScaledDotProductAttentionMaskMode,
ssmMask: MLXArray?,
cache: KVCache?
cache: KVCache?,
positionOffset: Int? = nil
) -> MLXArray {
let r: MLXArray
if isLinear {
r = linearAttn!(inputLayerNorm(x), mask: ssmMask, cache: cache as? MambaCache)
} else {
r = selfAttn!(inputLayerNorm(x), mask: attentionMask, cache: cache)
r = selfAttn!(
inputLayerNorm(x), mask: attentionMask, cache: cache,
positionOffset: positionOffset)
}

let h = x + r
Expand Down Expand Up @@ -578,6 +591,32 @@ public class Qwen35TextModel: Module, LLMModel, KVCacheDimensionProvider {
return out
}

public func callAsFunction(
_ input: LMInput.Text, cache: [KVCache]?, state: LMOutput.State?
) -> LMOutput {
let emitDrafterState = state?[mtpEmitFlagKey] ?? false
let hiddenStates = model(input.tokens, cache: cache)

let logits: MLXArray
if let lmHead {
logits = lmHead(hiddenStates)
} else {
logits = model.embedTokens.asLinear(hiddenStates)
}

guard emitDrafterState else {
return LMOutput(logits: logits)
}

var outState = state ?? LMOutput.State()
outState[mtpLastHiddenStatesKey] = hiddenStates
outState[mtpSharedKVStatesKey] = qwen35SharedKVState(
cache: cache, fullAttentionIndex: model.faIdx)
outState[mtpSharedKVOffsetsKey] = qwen35SharedKVOffsets(
cache: cache, fullAttentionIndex: model.faIdx)
return LMOutput(logits: logits, state: outState)
}

public func newCache(parameters: GenerateParameters?) -> [KVCache] {
return model.layers.map { layer in
if layer.isLinear {
Expand Down Expand Up @@ -626,6 +665,30 @@ public class Qwen35TextModel: Module, LLMModel, KVCacheDimensionProvider {
}
}

private func qwen35SharedKVState(
cache: [KVCache]?,
fullAttentionIndex: Int
) -> [String: (MLXArray, MLXArray)] {
guard let cache, fullAttentionIndex < cache.count else {
return [:]
}
let state = cache[fullAttentionIndex].state
guard state.count == 2 else {
return [:]
}
return ["full_attention": (state[0], state[1])]
}

private func qwen35SharedKVOffsets(
cache: [KVCache]?,
fullAttentionIndex: Int
) -> [String: Int]? {
guard let cache, fullAttentionIndex < cache.count else {
return nil
}
return ["full_attention": cache[fullAttentionIndex].offset]
}

extension Qwen35TextModel: LoRAModel {
public var loraLayers: [Module] {
model.layers
Expand All @@ -651,6 +714,12 @@ public class Qwen35Model: Module, LLMModel, KVCacheDimensionProvider {
languageModel(inputs, cache: cache)
}

public func callAsFunction(
_ input: LMInput.Text, cache: [KVCache]?, state: LMOutput.State?
) -> LMOutput {
languageModel(input, cache: cache, state: state)
}

public func newCache(parameters: GenerateParameters?) -> [KVCache] {
languageModel.newCache(parameters: parameters)
}
Expand Down
133 changes: 133 additions & 0 deletions Libraries/MLXLLM/Models/Qwen35MTP.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright © 2026 Apple Inc.

import Foundation
import MLX
import MLXLMCommon
import MLXNN

final class Qwen35MTPPredictor: Module {
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding?
@ModuleInfo(key: "fc") var fc: Linear
@ModuleInfo(key: "layers") var layers: [Qwen35DecoderLayer]
@ModuleInfo(key: "norm") var norm: RMSNorm
@ModuleInfo(key: "pre_fc_norm_embedding") var preFCNormEmbedding: RMSNorm
@ModuleInfo(key: "pre_fc_norm_hidden") var preFCNormHidden: RMSNorm

init(_ args: Qwen35TextConfiguration) {
var mtpArgs = args
mtpArgs.hiddenLayers = max(args.mtpNumHiddenLayers, 1)
mtpArgs.fullAttentionInterval = 1

if args.mtpUseDedicatedEmbeddings {
_embedTokens.wrappedValue = Embedding(
embeddingCount: args.vocabularySize,
dimensions: args.hiddenSize
)
}
_fc.wrappedValue = Linear(args.hiddenSize * 2, args.hiddenSize, bias: false)
_layers.wrappedValue = (0 ..< mtpArgs.hiddenLayers).map {
Qwen35DecoderLayer(mtpArgs, layerIdx: $0, forceFullAttention: true)
}
_norm.wrappedValue = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
_preFCNormEmbedding.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
_preFCNormHidden.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
super.init()
}

func newCache() -> [KVCache] {
layers.map { _ in KVCacheSimple() }
}

func callAsFunction(
inputsEmbeds: MLXArray,
hiddenStates previousHidden: MLXArray,
cache: KVCache?,
stepIndex: Int,
positionOffset: Int
) -> MLXArray {
var hiddenStates = concatenated(
[preFCNormEmbedding(inputsEmbeds), preFCNormHidden(previousHidden)], axis: -1)
hiddenStates = fc(hiddenStates)

let faMask = createAttentionMask(h: hiddenStates, cache: cache)
let layer = layers[stepIndex % layers.count]

hiddenStates = layer(
hiddenStates,
attentionMask: faMask,
ssmMask: nil,
cache: cache,
positionOffset: positionOffset)

return norm(hiddenStates)
}
}

public final class Qwen35MTPDraftModel: Module, MTPDrafterModel {
public let configuration: Qwen35TextConfiguration

@ModuleInfo(key: "mtp") var mtp: Qwen35MTPPredictor

public init(_ configuration: Qwen35TextConfiguration) {
self.configuration = configuration
_mtp.wrappedValue = Qwen35MTPPredictor(configuration)
super.init()
}

public convenience init(_ configuration: Qwen35Configuration) {
self.init(configuration.textConfig)
}

public func draftBlock(
target: any LanguageModel,
lastToken: MLXArray,
lastHidden: MLXArray,
sharedKV _: [String: (MLXArray, MLXArray)],
positionDeltas _: MLXArray?,
queryOffset: Int,
blockSize: Int,
sampler: any LogitSampler
) -> MLXArray {
let (targetEmbedTokens, lmHead) = targetEmbeddingAndHead(target)
let inputEmbedding = mtp.embedTokens ?? targetEmbedTokens
return draftMTPTokenBlock(
targetEmbedTokens: targetEmbedTokens,
lmHead: lmHead,
inputEmbedding: inputEmbedding,
lastToken: lastToken,
lastHidden: lastHidden,
queryOffset: queryOffset,
blockSize: blockSize,
sampler: sampler,
newCache: mtp.newCache
) { inputsEmbeds, hiddenStates, cache, stepIndex, positionOffset in
mtp(
inputsEmbeds: inputsEmbeds,
hiddenStates: hiddenStates,
cache: cache,
stepIndex: stepIndex,
positionOffset: positionOffset)
}
}

public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
qwenMTPSanitizeWeights(
weights: weights,
mtpNumHiddenLayers: configuration.mtpNumHiddenLayers,
numExperts: configuration.numExperts
)
}

private func targetEmbeddingAndHead(_ target: any LanguageModel) -> (Embedding, Linear?) {
if let model = target as? Qwen35Model {
return (model.languageModel.model.embedTokens, model.languageModel.lmHead)
}
if let model = target as? Qwen35TextModel {
return (model.model.embedTokens, model.lmHead)
}
fatalError(
"Qwen35MTPDraftModel requires a Qwen35 target, got \(type(of: target))")
}
}
57 changes: 57 additions & 0 deletions Libraries/MLXLLM/Qwen35TextMTPRegistration.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright © 2026 Apple Inc.

import Foundation
import MLXLMCommon

/// Registers Qwen3.5/Qwen3.6 text MTP drafter model types.
///
/// Callers should invoke this once before loading a Qwen text drafter through
/// ``MTPDrafterModelFactory``.
public enum Qwen35TextMTPRegistration {
public static func register() async {
await MTPDrafterTypeRegistry.shared.registerModelType(
"qwen3_5_text",
creator: { data in
let config = try JSONDecoder.json5().decode(
Qwen35TextConfiguration.self, from: data)
return Qwen35MTPDraftModel(config)
}
)
await MTPDrafterTypeRegistry.shared.registerModelType(
"qwen3_5",
matches: qwen35TextMTPConfiguration,
creator: { data in
let config = try JSONDecoder.json5().decode(
Qwen35Configuration.self, from: data)
return Qwen35MTPDraftModel(config)
}
)
await MTPDrafterTypeRegistry.shared.registerModelType(
"qwen3_5_moe",
matches: qwen35TextMTPConfiguration,
creator: { data in
let config = try JSONDecoder.json5().decode(
Qwen35Configuration.self, from: data)
return Qwen35MTPDraftModel(config)
}
)
}
}

private func qwen35TextMTPConfiguration(_ data: Data) -> Bool {
guard let shape = try? JSONDecoder.json5().decode(Qwen35MTPConfigurationShape.self, from: data)
else {
return true
}
return shape.visionConfig == nil
}

private struct Qwen35MTPConfigurationShape: Decodable {
var visionConfig: VisionConfig?

enum CodingKeys: String, CodingKey {
case visionConfig = "vision_config"
}

struct VisionConfig: Decodable {}
}
16 changes: 16 additions & 0 deletions Libraries/MLXLMCommon/MTPDrafterModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ public protocol MTPDrafterModel: BaseLanguageModel {
/// - sharedKV: Dict keyed by `layer_type` (`"full_attention"` /
/// `"sliding_attention"`) mapping to `(keys, values)` `MLXArray`s for
/// the last layer of that layer-type in the target.
/// - positionDeltas: Optional target-emitted position delta state for
/// multimodal RoPE models whose continuation positions are not
/// derivable from cache length alone.
/// - queryOffset: Constant absolute position for the round (the
/// position the bonus token sits at in the target's KV cache).
/// Passed as a Swift `Int` rather than an `MLXArray` to avoid the
Expand All @@ -53,6 +56,7 @@ public protocol MTPDrafterModel: BaseLanguageModel {
lastToken: MLXArray,
lastHidden: MLXArray,
sharedKV: [String: (MLXArray, MLXArray)],
positionDeltas: MLXArray?,
queryOffset: Int,
blockSize: Int,
sampler: any LogitSampler
Expand Down Expand Up @@ -125,6 +129,18 @@ public let mtpLastHiddenStatesKey =
public let mtpSharedKVStatesKey =
LMOutput.Key<[String: (MLXArray, MLXArray)]>("mtp.sharedKVStates")

/// Target writes the absolute cache offset for each emitted shared-K/V
/// layer-type here. The K/V tensors themselves may have a capped sequence
/// axis when backed by a rotating cache, so their shape is not always the
/// absolute RoPE/query position.
public let mtpSharedKVOffsetsKey =
LMOutput.Key<[String: Int]>("mtp.sharedKVOffsets")

/// Target writes optional position delta state here for MTP drafters that
/// need to reproduce target-specific RoPE continuation positions.
public let mtpPositionDeltasKey =
LMOutput.Key<MLXArray>("mtp.positionDeltas")

/// The MTP iterator sets this key on the ``LMOutput/State`` it passes into
/// the main model on each call to opt the target into emitting
/// ``mtpLastHiddenStatesKey`` and ``mtpSharedKVStatesKey``. An absent key
Expand Down
Loading