Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Libraries/MLXLLM/LLMModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ public enum LLMTypeRegistry {
"gemma3n": create(Gemma3nTextConfiguration.self, Gemma3nTextModel.init),
"gemma4": create(Gemma4Configuration.self, Gemma4Model.init),
"gemma4_text": create(Gemma4TextConfiguration.self, Gemma4TextModel.init),
"gemma4_assistant": { data in
let fullConfig = try JSONDecoder.json5().decode(Gemma4Configuration.self, from: data)
return Gemma4AssistantModel(fullConfig)
},
"qwen2": create(Qwen2Configuration.self, Qwen2Model.init),
"qwen3": create(Qwen3Configuration.self, Qwen3Model.init),
"qwen3_moe": create(Qwen3MoEConfiguration.self, Qwen3MoEModel.init),
Expand Down
100 changes: 74 additions & 26 deletions Libraries/MLXLLM/Models/DeepseekV4.swift
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ class DeepseekV4MoE: Module, UnaryLayer {

// MARK: - Decoder Block (with mHC Hyper-Connections)

class DeepseekV4Block: Module {
public class DeepseekV4Block: Module {
let config: DeepseekV4Configuration

// Key "attn" matches checkpoint path `layers.{l}.attn.*`
Expand Down Expand Up @@ -712,15 +712,15 @@ public class DeepseekV4ModelInner: Module, LayerPartitionable, StreamableMoE {

public var gpuLayerCount: Int? = nil
public var streamExperts: Bool = false
public var totalLayerCount: Int { layers.count }
public var totalLayerCount: Int { layers.count - (MTPConfig.retainMTPWeights ? config.numNextnPredictLayers : 0) }

init(config: DeepseekV4Configuration) {
self.config = config
self._embedTokens.wrappedValue = Embedding(
embeddingCount: config.vocabSize, dimensions: config.hiddenSize)
// Exclude MTP (multi-token prediction) layers from the main transformer stack
let mainLayerCount = config.numHiddenLayers - config.numNextnPredictLayers
self.layers = (0 ..< mainLayerCount).map {
let retainMTP = MTPConfig.retainMTPWeights && config.numNextnPredictLayers > 0
let totalCount = config.numHiddenLayers - (retainMTP ? 0 : config.numNextnPredictLayers)
self.layers = (0 ..< totalCount).map {
_ in DeepseekV4Block(config: config)
}
self._norm.wrappedValue = RMSNorm(dimensions: config.hiddenSize, eps: config.rmsNormEps)
Expand Down Expand Up @@ -750,7 +750,7 @@ public class DeepseekV4ModelInner: Module, LayerPartitionable, StreamableMoE {
let hForMask = h.reshaped([B, S, hc * config.hiddenSize]) // [B, S, hc*D]
let attentionMask = createAttentionMask(h: hForMask, cache: cache?.first)

for (i, layer) in layers.enumerated() {
for (i, layer) in layers.prefix(totalLayerCount).enumerated() {
h = partitionedLayerCall(
index: i, gpuLayerCount: gpuLayerCount, stream: streamExperts
) {
Expand All @@ -777,7 +777,7 @@ public class DeepseekV4Model: Module, LLMModel, KVCacheDimensionProvider, LoRAMo
public var model: DeepseekV4ModelInner
@ModuleInfo(key: "lm_head") var lmHead: Linear

init(_ args: DeepseekV4Configuration) {
public init(_ args: DeepseekV4Configuration) {
self.args = args
self.kvHeads = Array(repeating: 1, count: args.numHiddenLayers - args.numNextnPredictLayers)
self.model = DeepseekV4ModelInner(config: args)
Expand Down Expand Up @@ -841,35 +841,83 @@ public class DeepseekV4Model: Module, LLMModel, KVCacheDimensionProvider, LoRAMo
// 3. Filter out MTP (multi-token prediction) layers and rotary_emb keys
// Also drop compressor/indexer sub-module keys (not yet implemented)
let numMainLayers = args.numHiddenLayers - args.numNextnPredictLayers
return newWeights.filter { key, _ in
// Drop MTP layer weights (layers at index >= numMainLayers)
var finalWeights = [String: MLXArray]()
for (key, value) in newWeights {
// Drop rotary embedding precomputed frequencies
if key.contains("rotary_emb.inv_freq") { continue }
// Drop compressor/indexer sub-module weights
// TODO: implement DeepseekV4Compressor and DeepseekV4Indexer modules.
if key.contains(".attn.compressor.") || key.contains(".attn.indexer.") { continue }
// Drop gate.tid2eid
if key.contains(".ffn.gate.tid2eid") { continue }

if key.starts(with: "model.layers.") {
let parts = key.split(separator: ".")
if parts.count >= 3, let layerIdx = Int(parts[2]) {
if layerIdx >= numMainLayers {
return false
if layerIdx >= numMainLayers && !MTPConfig.retainMTPWeights {
continue
}
}
}
// Drop rotary embedding precomputed frequencies
if key.contains("rotary_emb.inv_freq") { return false }
// Drop compressor/indexer sub-module weights — these implement long-range
// compressed attention and are not yet implemented in this Swift port.
// Affected layers are those with compress_ratio != 0 (layers 2+).
// TODO: implement DeepseekV4Compressor and DeepseekV4Indexer modules.
if key.contains(".attn.compressor.") || key.contains(".attn.indexer.") {
return false
}
// Note: .attn.attn_sink is a valid model parameter — do NOT filter it.
// Drop gate.tid2eid — hash-layer token-to-expert lookup table (not yet implemented).
// Hash layers (0..numHashLayers-1) use deterministic routing; we fall back to
// the learned gate.weight for these layers instead.
if key.contains(".ffn.gate.tid2eid") { return false }
return true
finalWeights[key] = value
}
return finalWeights
}

public var loraLayers: [Module] {
model.layers
}
}

// MARK: - MTPLanguageModel Conformance for DeepseekV4Model

/// DeepSeek V4 uses a different MTP scheme: the MTP layers are the last
/// `numNextnPredictLayers` standard transformer blocks (`model.layers[numMainLayers...]`).
/// They share the same architecture as the main blocks but operate on the final hidden state.
/// The main `lm_head` is reused for all MTP depth projections.
extension DeepseekV4Model: MTPLanguageModel {
public func callMTP(_ inputs: MLXArray, cache: [KVCache]?, mtpCaches: [[KVCache]]?) -> [MLXArray] {
let mtpLayers = model.layers.suffix(args.numNextnPredictLayers)
guard MTPConfig.retainMTPWeights, !mtpLayers.isEmpty else {
return [callAsFunction(inputs, cache: cache)]
}

// Run the main model body (excludes MTP layers \u2014 DeepseekV4ModelInner only
// instantiates `numMain` blocks, so this is the standard forward pass)
let mainHidden = model(inputs, cache: cache)
let mainLogits = lmHead(mainHidden)
var result = [mainLogits]

// Chain MTP blocks stored in `model.mtpLayers`
var prevHidden = mainHidden
let B = prevHidden.dim(0), S = prevHidden.dim(1)
let hc = args.hcMult
for (i, mtpLayer) in mtpLayers.enumerated() {
let mtpCache = mtpCaches?[i]
// Expand [B, S, D] -> [B, S, hc, D]
var h = prevHidden.expandedDimensions(axis: 2)
h = repeated(h, count: hc, axis: 2)

let hForMask = h.reshaped([B, S, hc * args.hiddenSize])
let attentionMask = createAttentionMask(h: hForMask, cache: mtpCache?.first)

h = mtpLayer(h, mask: attentionMask, cache: mtpCache?.first)

// Reduce back to [B, S, D]
prevHidden = hcHead(
x: h, hcFn: model.hc_head.fn, hcScale: model.hc_head.scale,
hcBase: model.hc_head.base, eps: args.hcEps)

let mtpLogits = lmHead(model.norm(prevHidden))
result.append(mtpLogits)
}

return result
}

public func makeMTPCaches(parameters: GenerateParameters?) -> [[KVCache]] {
return (0 ..< args.numNextnPredictLayers).map { _ in
[KVCacheSimple()]
}
}
}
12 changes: 11 additions & 1 deletion Libraries/MLXLLM/Models/Gemma4.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,26 @@ public struct Gemma4Configuration: Codable, Sendable {
var modelType: String = "gemma4"
var textConfig: Gemma4TextConfiguration
var vocabSize: Int = 262144
var backboneHiddenSize: Int?
var numCentroids: Int?
var centroidIntermediateTopK: Int?

enum CodingKeys: String, CodingKey {
case modelType = "model_type"
case textConfig = "text_config"
case vocabSize = "vocab_size"
case backboneHiddenSize = "backbone_hidden_size"
case numCentroids = "num_centroids"
case centroidIntermediateTopK = "centroid_intermediate_top_k"
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
self.modelType = try container.decodeIfPresent(String.self, forKey: .modelType) ?? "gemma4"
self.vocabSize = try container.decodeIfPresent(Int.self, forKey: .vocabSize) ?? 262144
self.backboneHiddenSize = try container.decodeIfPresent(Int.self, forKey: .backboneHiddenSize)
self.numCentroids = try container.decodeIfPresent(Int.self, forKey: .numCentroids)
self.centroidIntermediateTopK = try container.decodeIfPresent(Int.self, forKey: .centroidIntermediateTopK)

// If text_config is present, decode from it; otherwise treat entire config as text config
if let textConfig = try container.decodeIfPresent(
Expand All @@ -49,7 +58,8 @@ public class Gemma4Model: Module, LLMModel, KVCacheDimensionProvider {
public var vocabularySize: Int { languageModel.vocabularySize }
public var kvHeads: [Int] { languageModel.kvHeads }

@ModuleInfo(key: "language_model") fileprivate var languageModel: Gemma4TextModel
@ModuleInfo(key: "language_model") public var languageModel: Gemma4TextModel
public var lastHiddenState: MLXArray? { return languageModel.lastHiddenState }

public init(_ config: Gemma4Configuration) {
self._languageModel.wrappedValue = Gemma4TextModel(config.textConfig)
Expand Down
Loading
Loading