diff --git a/Libraries/MLXLLM/LLMModelFactory.swift b/Libraries/MLXLLM/LLMModelFactory.swift index fcb236043..7e9781507 100644 --- a/Libraries/MLXLLM/LLMModelFactory.swift +++ b/Libraries/MLXLLM/LLMModelFactory.swift @@ -33,6 +33,10 @@ public enum LLMTypeRegistry { "gemma3n": create(Gemma3nTextConfiguration.self, Gemma3nTextModel.init), "gemma4": create(Gemma4Configuration.self, Gemma4Model.init), "gemma4_text": create(Gemma4TextConfiguration.self, Gemma4TextModel.init), + "gemma4_assistant": { data in + let fullConfig = try JSONDecoder.json5().decode(Gemma4Configuration.self, from: data) + return Gemma4AssistantModel(fullConfig) + }, "qwen2": create(Qwen2Configuration.self, Qwen2Model.init), "qwen3": create(Qwen3Configuration.self, Qwen3Model.init), "qwen3_moe": create(Qwen3MoEConfiguration.self, Qwen3MoEModel.init), diff --git a/Libraries/MLXLLM/Models/DeepseekV4.swift b/Libraries/MLXLLM/Models/DeepseekV4.swift index a058248ee..5dd903d55 100644 --- a/Libraries/MLXLLM/Models/DeepseekV4.swift +++ b/Libraries/MLXLLM/Models/DeepseekV4.swift @@ -612,7 +612,7 @@ class DeepseekV4MoE: Module, UnaryLayer { // MARK: - Decoder Block (with mHC Hyper-Connections) -class DeepseekV4Block: Module { +public class DeepseekV4Block: Module { let config: DeepseekV4Configuration // Key "attn" matches checkpoint path `layers.{l}.attn.*` @@ -712,15 +712,15 @@ public class DeepseekV4ModelInner: Module, LayerPartitionable, StreamableMoE { public var gpuLayerCount: Int? = nil public var streamExperts: Bool = false - public var totalLayerCount: Int { layers.count } + public var totalLayerCount: Int { layers.count - (MTPConfig.retainMTPWeights ? config.numNextnPredictLayers : 0) } init(config: DeepseekV4Configuration) { self.config = config self._embedTokens.wrappedValue = Embedding( embeddingCount: config.vocabSize, dimensions: config.hiddenSize) - // Exclude MTP (multi-token prediction) layers from the main transformer stack - let mainLayerCount = config.numHiddenLayers - config.numNextnPredictLayers - self.layers = (0 ..< mainLayerCount).map { + let retainMTP = MTPConfig.retainMTPWeights && config.numNextnPredictLayers > 0 + let totalCount = config.numHiddenLayers - (retainMTP ? 0 : config.numNextnPredictLayers) + self.layers = (0 ..< totalCount).map { _ in DeepseekV4Block(config: config) } self._norm.wrappedValue = RMSNorm(dimensions: config.hiddenSize, eps: config.rmsNormEps) @@ -750,7 +750,7 @@ public class DeepseekV4ModelInner: Module, LayerPartitionable, StreamableMoE { let hForMask = h.reshaped([B, S, hc * config.hiddenSize]) // [B, S, hc*D] let attentionMask = createAttentionMask(h: hForMask, cache: cache?.first) - for (i, layer) in layers.enumerated() { + for (i, layer) in layers.prefix(totalLayerCount).enumerated() { h = partitionedLayerCall( index: i, gpuLayerCount: gpuLayerCount, stream: streamExperts ) { @@ -777,7 +777,7 @@ public class DeepseekV4Model: Module, LLMModel, KVCacheDimensionProvider, LoRAMo public var model: DeepseekV4ModelInner @ModuleInfo(key: "lm_head") var lmHead: Linear - init(_ args: DeepseekV4Configuration) { + public init(_ args: DeepseekV4Configuration) { self.args = args self.kvHeads = Array(repeating: 1, count: args.numHiddenLayers - args.numNextnPredictLayers) self.model = DeepseekV4ModelInner(config: args) @@ -841,35 +841,83 @@ public class DeepseekV4Model: Module, LLMModel, KVCacheDimensionProvider, LoRAMo // 3. Filter out MTP (multi-token prediction) layers and rotary_emb keys // Also drop compressor/indexer sub-module keys (not yet implemented) let numMainLayers = args.numHiddenLayers - args.numNextnPredictLayers - return newWeights.filter { key, _ in - // Drop MTP layer weights (layers at index >= numMainLayers) + var finalWeights = [String: MLXArray]() + for (key, value) in newWeights { + // Drop rotary embedding precomputed frequencies + if key.contains("rotary_emb.inv_freq") { continue } + // Drop compressor/indexer sub-module weights + // TODO: implement DeepseekV4Compressor and DeepseekV4Indexer modules. + if key.contains(".attn.compressor.") || key.contains(".attn.indexer.") { continue } + // Drop gate.tid2eid + if key.contains(".ffn.gate.tid2eid") { continue } + if key.starts(with: "model.layers.") { let parts = key.split(separator: ".") if parts.count >= 3, let layerIdx = Int(parts[2]) { - if layerIdx >= numMainLayers { - return false + if layerIdx >= numMainLayers && !MTPConfig.retainMTPWeights { + continue } } } - // Drop rotary embedding precomputed frequencies - if key.contains("rotary_emb.inv_freq") { return false } - // Drop compressor/indexer sub-module weights — these implement long-range - // compressed attention and are not yet implemented in this Swift port. - // Affected layers are those with compress_ratio != 0 (layers 2+). - // TODO: implement DeepseekV4Compressor and DeepseekV4Indexer modules. - if key.contains(".attn.compressor.") || key.contains(".attn.indexer.") { - return false - } - // Note: .attn.attn_sink is a valid model parameter — do NOT filter it. - // Drop gate.tid2eid — hash-layer token-to-expert lookup table (not yet implemented). - // Hash layers (0..numHashLayers-1) use deterministic routing; we fall back to - // the learned gate.weight for these layers instead. - if key.contains(".ffn.gate.tid2eid") { return false } - return true + finalWeights[key] = value } + return finalWeights } public var loraLayers: [Module] { model.layers } } + +// MARK: - MTPLanguageModel Conformance for DeepseekV4Model + +/// DeepSeek V4 uses a different MTP scheme: the MTP layers are the last +/// `numNextnPredictLayers` standard transformer blocks (`model.layers[numMainLayers...]`). +/// They share the same architecture as the main blocks but operate on the final hidden state. +/// The main `lm_head` is reused for all MTP depth projections. +extension DeepseekV4Model: MTPLanguageModel { + public func callMTP(_ inputs: MLXArray, cache: [KVCache]?, mtpCaches: [[KVCache]]?) -> [MLXArray] { + let mtpLayers = model.layers.suffix(args.numNextnPredictLayers) + guard MTPConfig.retainMTPWeights, !mtpLayers.isEmpty else { + return [callAsFunction(inputs, cache: cache)] + } + + // Run the main model body (excludes MTP layers \u2014 DeepseekV4ModelInner only + // instantiates `numMain` blocks, so this is the standard forward pass) + let mainHidden = model(inputs, cache: cache) + let mainLogits = lmHead(mainHidden) + var result = [mainLogits] + + // Chain MTP blocks stored in `model.mtpLayers` + var prevHidden = mainHidden + let B = prevHidden.dim(0), S = prevHidden.dim(1) + let hc = args.hcMult + for (i, mtpLayer) in mtpLayers.enumerated() { + let mtpCache = mtpCaches?[i] + // Expand [B, S, D] -> [B, S, hc, D] + var h = prevHidden.expandedDimensions(axis: 2) + h = repeated(h, count: hc, axis: 2) + + let hForMask = h.reshaped([B, S, hc * args.hiddenSize]) + let attentionMask = createAttentionMask(h: hForMask, cache: mtpCache?.first) + + h = mtpLayer(h, mask: attentionMask, cache: mtpCache?.first) + + // Reduce back to [B, S, D] + prevHidden = hcHead( + x: h, hcFn: model.hc_head.fn, hcScale: model.hc_head.scale, + hcBase: model.hc_head.base, eps: args.hcEps) + + let mtpLogits = lmHead(model.norm(prevHidden)) + result.append(mtpLogits) + } + + return result + } + + public func makeMTPCaches(parameters: GenerateParameters?) -> [[KVCache]] { + return (0 ..< args.numNextnPredictLayers).map { _ in + [KVCacheSimple()] + } + } +} diff --git a/Libraries/MLXLLM/Models/Gemma4.swift b/Libraries/MLXLLM/Models/Gemma4.swift index ea0c1c3db..5b65a30a4 100644 --- a/Libraries/MLXLLM/Models/Gemma4.swift +++ b/Libraries/MLXLLM/Models/Gemma4.swift @@ -18,17 +18,26 @@ public struct Gemma4Configuration: Codable, Sendable { var modelType: String = "gemma4" var textConfig: Gemma4TextConfiguration var vocabSize: Int = 262144 + var backboneHiddenSize: Int? + var numCentroids: Int? + var centroidIntermediateTopK: Int? enum CodingKeys: String, CodingKey { case modelType = "model_type" case textConfig = "text_config" case vocabSize = "vocab_size" + case backboneHiddenSize = "backbone_hidden_size" + case numCentroids = "num_centroids" + case centroidIntermediateTopK = "centroid_intermediate_top_k" } public init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: CodingKeys.self) self.modelType = try container.decodeIfPresent(String.self, forKey: .modelType) ?? "gemma4" self.vocabSize = try container.decodeIfPresent(Int.self, forKey: .vocabSize) ?? 262144 + self.backboneHiddenSize = try container.decodeIfPresent(Int.self, forKey: .backboneHiddenSize) + self.numCentroids = try container.decodeIfPresent(Int.self, forKey: .numCentroids) + self.centroidIntermediateTopK = try container.decodeIfPresent(Int.self, forKey: .centroidIntermediateTopK) // If text_config is present, decode from it; otherwise treat entire config as text config if let textConfig = try container.decodeIfPresent( @@ -49,7 +58,8 @@ public class Gemma4Model: Module, LLMModel, KVCacheDimensionProvider { public var vocabularySize: Int { languageModel.vocabularySize } public var kvHeads: [Int] { languageModel.kvHeads } - @ModuleInfo(key: "language_model") fileprivate var languageModel: Gemma4TextModel + @ModuleInfo(key: "language_model") public var languageModel: Gemma4TextModel + public var lastHiddenState: MLXArray? { return languageModel.lastHiddenState } public init(_ config: Gemma4Configuration) { self._languageModel.wrappedValue = Gemma4TextModel(config.textConfig) diff --git a/Libraries/MLXLLM/Models/Gemma4Text.swift b/Libraries/MLXLLM/Models/Gemma4Text.swift index 57ee20c35..6fe3fbcd0 100644 --- a/Libraries/MLXLLM/Models/Gemma4Text.swift +++ b/Libraries/MLXLLM/Models/Gemma4Text.swift @@ -53,7 +53,7 @@ public struct Gemma4TextConfiguration: Codable, Sendable { var slidingWindowPattern: Int = 5 var maxPositionEmbeddings: Int = 131072 var attentionKeqV: Bool = false - var finalLogitSoftcapping: Float = 30.0 + var finalLogitSoftcapping: Float? = 30.0 var useDoubleWideMlp: Bool = true var enableMoEBlock: Bool = false var numExperts: Int? @@ -137,7 +137,7 @@ public struct Gemma4TextConfiguration: Codable, Sendable { self.attentionKeqV = try container.decodeIfPresent(Bool.self, forKey: .attentionKeqV) ?? false self.finalLogitSoftcapping = - try container.decodeIfPresent(Float.self, forKey: .finalLogitSoftcapping) ?? 30.0 + try container.decodeIfPresent(Float.self, forKey: .finalLogitSoftcapping) self.useDoubleWideMlp = try container.decodeIfPresent(Bool.self, forKey: .useDoubleWideMlp) ?? true self.enableMoEBlock = @@ -254,13 +254,13 @@ private class Gemma4Attention: Module { let scale: Float @ModuleInfo(key: "q_proj") var qProj: Linear - @ModuleInfo(key: "k_proj") var kProj: Linear + @ModuleInfo(key: "k_proj") var kProj: Linear? @ModuleInfo(key: "v_proj") var vProj: Linear? @ModuleInfo(key: "o_proj") var oProj: Linear @ModuleInfo(key: "q_norm") var qNorm: RMSNorm - @ModuleInfo(key: "k_norm") var kNorm: RMSNorm - @ModuleInfo(key: "v_norm") var vNorm: RMSNormNoScale + @ModuleInfo(key: "k_norm") var kNorm: RMSNorm? + @ModuleInfo(key: "v_norm") var vNorm: RMSNormNoScale? @ModuleInfo var rope: RoPELayer @@ -288,15 +288,25 @@ private class Gemma4Attention: Module { self.scale = 1.0 self._qProj.wrappedValue = Linear(dim, nHeads * effectiveHeadDim, bias: false) - self._kProj.wrappedValue = Linear(dim, nKvHeads * effectiveHeadDim, bias: false) - if !useKeqV { - self._vProj.wrappedValue = Linear(dim, nKvHeads * effectiveHeadDim, bias: false) + + // A layer owns its own K/V if it is NOT a KV-shared layer. + // In the Gemma 4 architecture, the main model has K/V weights for all layers even if num_kv_shared_layers > 0. + // However, the assistant model has numHiddenLayers == numKvSharedLayers and NO K/V weights at all. + let isAssistant = config.numHiddenLayers == config.numKvSharedLayers + let hasKv = !isAssistant + + if hasKv { + self._kProj.wrappedValue = Linear(dim, nKvHeads * effectiveHeadDim, bias: false) + if !useKeqV { + self._vProj.wrappedValue = Linear(dim, nKvHeads * effectiveHeadDim, bias: false) + } + self._kNorm.wrappedValue = RMSNorm(dimensions: effectiveHeadDim, eps: config.rmsNormEps) + self._vNorm.wrappedValue = RMSNormNoScale(eps: config.rmsNormEps) } + self._oProj.wrappedValue = Linear(nHeads * effectiveHeadDim, dim, bias: false) self._qNorm.wrappedValue = RMSNorm(dimensions: effectiveHeadDim, eps: config.rmsNormEps) - self._kNorm.wrappedValue = RMSNorm(dimensions: effectiveHeadDim, eps: config.rmsNormEps) - self._vNorm.wrappedValue = RMSNormNoScale(eps: config.rmsNormEps) // RoPE: sliding uses default, full uses proportional with partial rotation if isSliding { @@ -328,15 +338,26 @@ private class Gemma4Attention: Module { var queries = qProj(x).reshaped(B, L, nHeads, effectiveHeadDim) queries = qNorm(queries) - let keys: MLXArray - let values: MLXArray let activePositionOffset = positionOffset ?? gemma4CapturePositionOffset(from: cache) + var adjustedMask = mask + let kvState: Gemma4LLMKVState if let (sharedK, sharedV) = sharedKV { // KV-shared layers use pre-computed KV from an earlier layer - keys = sharedK - values = sharedV + kvState = .regular(keys: sharedK, values: sharedV) + + // For sharedKV, we still need to adjust the mask if cache is shorter than mask + if case .array(let maskArray) = mask { + let keysSeqLen = kvState.seqLen + if maskArray.dim(-1) > keysSeqLen { + adjustedMask = .array(maskArray[.ellipsis, 0 ..< keysSeqLen]) + } + } + } else { + guard let kProj = kProj, let kNorm = kNorm, let vNorm = vNorm else { + fatalError("Layer \(layerIdx) is a KV-shared layer but received no sharedKV") + } var k = kProj(x).reshaped(B, L, nKvHeads, effectiveHeadDim) k = kNorm(k) k = k.transposed(0, 2, 1, 3) @@ -348,18 +369,9 @@ private class Gemma4Attention: Module { v = vNorm(v) v = v.transposed(0, 2, 1, 3) } else { - // When K-eq-V, k is already transposed to [B, nKvHeads, L, D]. - // Applying vNorm (last-axis, layout-agnostic) and then transposing - // again would yield [B, L, nKvHeads, D] — the wrong layout. - // Skip the extra transpose; the norm is still applied correctly. v = vNorm(k) } - // Dispatch to the correct KV-cache update based on concrete cache type. - // QuantizedKVCache traps on `.update(keys:values:)` — we must call - // `.updateQuantized(keys:values:)` and then route to - // `quantizedScaledDotProductAttention` below. - let kvState: Gemma4LLMKVState if let quantizedCache = cache as? QuantizedKVCacheProtocol { let (qKeys, qValues) = quantizedCache.updateQuantized(keys: k, values: v) kvState = .quantized( @@ -375,21 +387,20 @@ private class Gemma4Attention: Module { } else { kvState = .regular(keys: k, values: v) } - - queries = queries.transposed(0, 2, 1, 3) - queries = gemma4ApplyRotaryPosition(rope, to: queries, offset: activePositionOffset) - - // Adjust mask if cache is shorter than mask (mask was built for a longer sequence). - // Only slice — never pad: if mask is already shorter we leave it alone. - var adjustedMask = mask + + // Adjust mask if cache is shorter than mask if case .array(let maskArray) = mask { let keysSeqLen = kvState.seqLen if maskArray.dim(-1) > keysSeqLen { adjustedMask = .array(maskArray[.ellipsis, 0 ..< keysSeqLen]) } } + } - let output: MLXArray = + queries = queries.transposed(0, 2, 1, 3) + queries = gemma4ApplyRotaryPosition(rope, to: queries, offset: activePositionOffset) + + let output: MLXArray = switch kvState { case .regular(let rKeys, let rValues): MLXFast.scaledDotProductAttention( @@ -446,31 +457,6 @@ private class Gemma4Attention: Module { ) } - // ── sharedKV path ── - // (queries already computed above; keys/values come from an earlier layer) - queries = queries.transposed(0, 2, 1, 3) - queries = gemma4ApplyRotaryPosition(rope, to: queries, offset: activePositionOffset) - - var adjustedMask = mask - if case .array(let maskArray) = mask { - let keysSeqLen = keys.dim(2) - if maskArray.dim(-1) > keysSeqLen { - adjustedMask = .array(maskArray[.ellipsis, 0 ..< keysSeqLen]) - } - } - - let output = MLXFast.scaledDotProductAttention( - queries: queries, - keys: keys, - values: values, - scale: scale, - mask: adjustedMask ?? .none - ) - .transposed(0, 2, 1, 3) - .reshaped(B, L, -1) - - return (oProj(output), (keys, values), activePositionOffset) - } } // MARK: - MLP @@ -732,6 +718,9 @@ private class Gemma4TextModelInner: Module { // KV sharing mapping: for each layer, which earlier layer provides KVs let previousKvs: [Int] let firstKvSharedLayerIdx: Int + + public var lastHiddenState: MLXArray? + public var hiddenStateBeforeNorm: MLXArray? init(_ config: Gemma4TextConfiguration) { self.config = config @@ -849,10 +838,26 @@ private class Gemma4TextModelInner: Module { var intermediates = [(kv: (MLXArray, MLXArray)?, positionOffset: Gemma4PositionOffset?)]( repeating: (nil, nil), count: config.numHiddenLayers) + let isAssistant = (config.numKvSharedLayers == config.numHiddenLayers) + for (idx, layer) in layers.enumerated() { - let prevIdx = previousKvs[idx] - let sharedKV = intermediates[prevIdx].kv - let sharedPositionOffset = intermediates[prevIdx].positionOffset + var sharedKV: (MLXArray, MLXArray)? = nil + var sharedPositionOffset: Gemma4PositionOffset? = nil + + if isAssistant, let fullCache = cache, fullCache.count > config.numHiddenLayers { + // Determine which layer of the main model to share KV from + let mainIdx = layer.layerType == "sliding_attention" ? fullCache.count - 2 : fullCache.count - 1 + let cacheElement = fullCache[mainIdx] + if let c = cacheElement as? KVCacheSimple, let k = c.keys, let v = c.values { + sharedKV = (k, v) + } else if let c = cacheElement as? RotatingKVCache, let k = c.keys, let v = c.values { + sharedKV = (k, v) + } + } else { + let prevIdx = previousKvs[idx] + sharedKV = intermediates[prevIdx].kv + sharedPositionOffset = intermediates[prevIdx].positionOffset + } let mask = maskByType[layer.layerType] let (out, kvPair, positionOffset) = layer( @@ -867,7 +872,10 @@ private class Gemma4TextModelInner: Module { intermediates[idx] = (kvPair, positionOffset) } - return norm(h) + self.hiddenStateBeforeNorm = h + h = norm(h) + self.lastHiddenState = h + return h } } @@ -877,6 +885,9 @@ public class Gemma4TextModel: Module, LLMModel, KVCacheDimensionProvider { public let vocabularySize: Int public let kvHeads: [Int] + public var lastHiddenState: MLXArray? { return model.lastHiddenState } + public var hiddenStateBeforeNorm: MLXArray? { return model.hiddenStateBeforeNorm } + fileprivate let config: Gemma4TextConfiguration fileprivate let model: Gemma4TextModelInner @@ -900,19 +911,24 @@ public class Gemma4TextModel: Module, LLMModel, KVCacheDimensionProvider { } else { out = model.embedTokens.asLinear(out) } - out = tanh(out / config.finalLogitSoftcapping) * config.finalLogitSoftcapping + if let cap = config.finalLogitSoftcapping { + out = tanh(out / cap) * cap + } return out } public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { var sanitized = [String: MLXArray]() for (k, v) in weights { - // Skip vision/audio/rotary weights + // Skip vision/audio/rotary weights and unsupported MTP keys if k.contains("self_attn.rotary_emb") || k.contains("input_max") || k.contains("input_min") || k.contains("output_max") || k.contains("output_min") + || k.hasPrefix("pre_projection") + || k.hasPrefix("post_projection") + || k.hasPrefix("masked_embedding") { continue } @@ -971,3 +987,338 @@ extension Gemma4TextModel: LoRAModel { model.layers.map { $0.selfAttn } } } + +// MARK: - Assistant + +public class Gemma4AssistantModel: Module, LLMModel, DualModelMTP, KVCacheDimensionProvider { + public let vocabularySize: Int + public let kvHeads: [Int] + + public let config: Gemma4TextConfiguration + fileprivate let model: Gemma4TextModelInner + + @ModuleInfo(key: "lm_head") var lmHead: Linear? + + public var _preProjectionWeight: MLXArray? + public var _postProjectionWeight: MLXArray? + + public var preProjectionWeight: MLXArray? { _preProjectionWeight } + public var postProjectionWeight: MLXArray? { _postProjectionWeight } + + // Masked embedder state (centroid-based sparse logit projection) + var _centroidWeight: MLXArray? // [num_centroids, hidden] — centroids linear weight + var _tokenOrdering: MLXArray? // [vocab_size] int32 — canonical token ordering (ordered->canonical) + var _invTokenOrdering: MLXArray? // [vocab_size] int32 — inverse token ordering (canonical->ordered) + var numCentroids: Int = 2048 + var centroidTopK: Int = 32 + var vocabSizePerCentroid: Int = 128 // vocab_size / num_centroids + + // Reference to the main model so we can call it inside callMTP + public var mainModelRef: (any BaseLanguageModel)? = nil + + public init(_ fullConfig: Gemma4Configuration) { + let config = fullConfig.textConfig + self.config = config + self.vocabularySize = config.vocabSize + self.kvHeads = (0 ..< config.numHiddenLayers).map { _ in config.numKeyValueHeads } + self.model = Gemma4TextModelInner(config) + + self.numCentroids = fullConfig.numCentroids ?? 2048 + self.centroidTopK = fullConfig.centroidIntermediateTopK ?? 32 + self.vocabSizePerCentroid = config.vocabSize / self.numCentroids + + if !config.tieWordEmbeddings { + self._lmHead.wrappedValue = Linear(config.hiddenSize, config.vocabSize, bias: false) + } + super.init() + } + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + var sanitized = weights + if let w = weights["pre_projection.weight"] { + self._preProjectionWeight = w + sanitized.removeValue(forKey: "pre_projection.weight") + } + if let w = weights["post_projection.weight"] { + self._postProjectionWeight = w + sanitized.removeValue(forKey: "post_projection.weight") + } + + // Load masked embedder weights for centroid-based sparse logit projection + if let w = weights["masked_embedding.centroids.weight"] { + self._centroidWeight = w + sanitized.removeValue(forKey: "masked_embedding.centroids.weight") + } + if let w = weights["masked_embedding.token_ordering"] { + self._tokenOrdering = w.asType(.int32) + // Precompute inverse ordering: inv[canonical_id] = ordered_position + // This enables O(1) conversion from ordered logits to canonical logits + self._invTokenOrdering = argSort(w.asType(.int32), axis: 0) + sanitized.removeValue(forKey: "masked_embedding.token_ordering") + } + + return sanitized + } + + /// Compute logits using the centroid-based sparse masked embedder. + /// Matches HF Gemma4AssistantMaskedEmbedder.forward(). + /// - hNormed: [B, 1, hidden=256] + /// Returns [B, 1, vocab] + func maskedEmbedderLogits(_ hNormed: MLXArray) -> MLXArray { + guard let centroidW = _centroidWeight, let tokenOrdering = _tokenOrdering else { + // Fallback to full projection + return model.embedTokens.asLinear(hNormed) + } + + let B = hNormed.dim(0) + let S = hNormed.dim(1) + let vocabSize = config.vocabSize + + // centroid_logits = hNormed @ centroidW.T → [B, S, num_centroids] + let centroidLogits = matmul(hNormed, centroidW.T) + + // top_k_indices = argTopK(centroid_logits, k=centroidTopK) → [B, S, topK] + // MLX doesn't have argTopK directly; use argSort descending and take first topK + let sortedCentroidIdx = argSort(centroidLogits, axis: -1) // ascending + let reversedIdx = sortedCentroidIdx[.ellipsis, (sortedCentroidIdx.dim(-1) - centroidTopK)...] + // reversedIdx is [B, S, topK] — indices of top-K centroids + + // token_ordering reshaped: [num_centroids, vocabSizePerCentroid] + let tokenOrderingReshaped = tokenOrdering.reshaped([numCentroids, vocabSizePerCentroid]) + + // Gather canonical positions for each selected centroid + // For each of the topK centroid indices, gather its vocabSizePerCentroid token positions + // selected_canonical: [B, S, topK, vocabSizePerCentroid] + let topKFlat = reversedIdx.reshaped([-1]) // [B*S*topK] + let selectedCanonical = tokenOrderingReshaped[topKFlat] // [B*S*topK, vocabSizePerCentroid] + let selectedCanonicalShaped = selectedCanonical.reshaped([B, S, centroidTopK, vocabSizePerCentroid]) + + // Gather embeddings at those positions: embed_tokens.weight[canonical] → [B*S*topK*K, hidden] + let embedWeight = model.embedTokens.weight // [vocab, 256] + let selectedFlat = selectedCanonicalShaped.reshaped([-1]).asType(.int32) // [B*S*topK*K] + let selectedEmbeds = embedWeight[selectedFlat] // [B*S*topK*K, 256] + let totalCandidates = centroidTopK * vocabSizePerCentroid + let selectedEmbedsShaped = selectedEmbeds.reshaped([B, S, totalCandidates, config.hiddenSize]) + + // dot products: [B, S, 1, hidden] @ [B, S, hidden, topK*K] → [B, S, topK*K] + let hExpanded = hNormed.expandedDimensions(axis: -2) // [B, S, 1, hidden] + let selectedLogits = matmul(hExpanded, selectedEmbedsShaped.transposed(0, 1, 3, 2)).squeezed(axis: -2) + // selectedLogits: [B, S, topK*K] + + // Build output tensor: fill with min - 1.0, scatter selectedLogits to canonical positions + let minVal = selectedLogits.min(axes: [-1], keepDims: true) // [B, S, 1] + var output = broadcast(minVal - 1.0, to: [B, S, vocabSize]) // [B, S, vocab] + + // Scatter selectedLogits into output at scatterIdx positions. + // We use a workaround: create an index array and use scatter-add pattern. + // selectedLogits: [B, S, topK*K], scatterIdx: [B, S, topK*K] (token indices) + // For each (b,s,k): output[b, s, scatterIdx[b,s,k]] = selectedLogits[b,s,k] + // Use mlx scatter via the __setitem__ approach: + let scatterIdx2D = selectedCanonicalShaped.reshaped([B * S, totalCandidates]).asType(.int32) + let selectedLogits2D = selectedLogits.reshaped([B * S, totalCandidates]) + var output2D = output.reshaped([B * S, vocabSize]) + for bsIdx in 0 ..< B * S { + let idxRow = scatterIdx2D[bsIdx] // [totalCandidates] + let valRow = selectedLogits2D[bsIdx] // [totalCandidates] + output2D[bsIdx, idxRow] = valRow + } + output = output2D.reshaped([B, S, vocabSize]) + + return output + } + + + public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray { + // Fallback for standard autoregressive call, though not used in MTP flow + let h = model(inputs, cache: cache) + if let lmHead { + return lmHead(h) + } + return model.embedTokens.asLinear(h) + } + + public func callMTP(_ inputs: MLXArray, cache: [KVCache]?, mtpCaches: [[KVCache]]?) -> [MLXArray] { + guard let mainModel = mainModelRef else { + fatalError("mainModelRef must be set on Gemma4AssistantModel before calling callMTP") + } + + let posOffset = cache?.first.map { gemma4CapturePositionOffset(from: $0) } + + // 1. Run the main model to get main logits and backbone hidden state + guard let llmMain = mainModel as? any LLMModel else { + fatalError("mainModelRef must be an LLMModel") + } + let mainLogits = llmMain(inputs, cache: cache) + + // Extract the NORMALIZED hidden state from the backbone + var hBackbone: MLXArray + if let g4m = mainModel as? Gemma4Model, let lhs = g4m.lastHiddenState { + hBackbone = lhs + } else if let g4tm = mainModel as? Gemma4TextModel, let lhs = g4tm.lastHiddenState { + hBackbone = lhs + } else { + fatalError("[MTP] Could not extract normalized hidden state from main model") + } + + var allLogits = [mainLogits] + + // pre_projection: [256, 3072] — expects concat(hBackbone, embedToken) both 1536-dim → 3072 + // post_projection: [1536, 256] — maps assistant 256-dim state back to 1536 backbone dim + + // For depth=0, we don't have a draft token yet — we use the LAST token from inputs as the "current" token. + // hBackbone[..., -1:, ...] is the hidden state after the last real token. + // We embed the last input token to form the first concatenation. + let backboneDim = hBackbone.dim(-1) // 1536 + + // Get the last hidden state (the one that will predict the next token) + let seqLen = hBackbone.dim(1) + var hLast = hBackbone[0..., (seqLen-1)..