Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 64 additions & 28 deletions Libraries/MLXLLM/Models/Gemma4Text.swift
Original file line number Diff line number Diff line change
Expand Up @@ -192,21 +192,6 @@ private class RMSNormNoScale: Module {
}
}

private class ScaledLinear: Module {
let weight: MLXArray
let scalar: Float

init(inFeatures: Int, outFeatures: Int, scalar: Float) {
self.weight = MLXArray.zeros([outFeatures, inFeatures])
self.scalar = scalar
super.init()
}

func callAsFunction(_ x: MLXArray) -> MLXArray {
matmul(x, weight.T) * scalar
}
}

// MARK: - Attention

private class Gemma4Attention: Module {
Expand All @@ -221,12 +206,15 @@ private class Gemma4Attention: Module {
let scale: Float

@ModuleInfo(key: "q_proj") var qProj: Linear
@ModuleInfo(key: "k_proj") var kProj: Linear
// Optional: KV-shared layers reuse an earlier layer's K/V and own no k_proj/v_proj.
@ModuleInfo(key: "k_proj") var kProj: Linear?
@ModuleInfo(key: "v_proj") var vProj: Linear?
@ModuleInfo(key: "o_proj") var oProj: Linear

@ModuleInfo(key: "q_norm") var qNorm: RMSNorm
@ModuleInfo(key: "k_norm") var kNorm: RMSNorm
// Optional: KV-shared layers don't compute K, so they carry no k_norm weight.
// (v_norm is RMSNormNoScale — parameter-free — so it never appears in checkpoints.)
@ModuleInfo(key: "k_norm") var kNorm: RMSNorm?
@ModuleInfo(key: "v_norm") var vNorm: RMSNormNoScale

@ModuleInfo var rope: RoPELayer
Expand Down Expand Up @@ -255,14 +243,26 @@ private class Gemma4Attention: Module {
self.scale = 1.0

self._qProj.wrappedValue = Linear(dim, nHeads * effectiveHeadDim, bias: false)
self._kProj.wrappedValue = Linear(dim, nKvHeads * effectiveHeadDim, bias: false)
if !useKeqV {
self._vProj.wrappedValue = Linear(dim, nKvHeads * effectiveHeadDim, bias: false)
// KV-shared layers (the last `num_kv_shared_layers`) reuse the K/V of an earlier
// layer of the same attention type, so they own no k_proj/v_proj. Quantized
// (QAT) checkpoints prune those tensors; create them only for the KV-owning
// layers so the module tree matches the checkpoint. (Older/PTQ checkpoints that
// still ship the redundant tensors are dropped in `sanitize`.) Same predicate as
// the double-wide MLP gate.
let firstKvSharedLayerIdx = config.numHiddenLayers - config.numKvSharedLayers
let isKvSharedLayer = layerIdx >= firstKvSharedLayerIdx && firstKvSharedLayerIdx > 0
if !isKvSharedLayer {
self._kProj.wrappedValue = Linear(dim, nKvHeads * effectiveHeadDim, bias: false)
if !useKeqV {
self._vProj.wrappedValue = Linear(dim, nKvHeads * effectiveHeadDim, bias: false)
}
}
self._oProj.wrappedValue = Linear(nHeads * effectiveHeadDim, dim, bias: false)

self._qNorm.wrappedValue = RMSNorm(dimensions: effectiveHeadDim, eps: config.rmsNormEps)
self._kNorm.wrappedValue = RMSNorm(dimensions: effectiveHeadDim, eps: config.rmsNormEps)
if !isKvSharedLayer {
self._kNorm.wrappedValue = RMSNorm(dimensions: effectiveHeadDim, eps: config.rmsNormEps)
}
self._vNorm.wrappedValue = RMSNormNoScale(eps: config.rmsNormEps)

// RoPE: sliding uses default, full uses proportional with partial rotation
Expand Down Expand Up @@ -304,6 +304,13 @@ private class Gemma4Attention: Module {
keys = sharedK
values = sharedV
} else {
// Only KV-owning layers fall here (KV-shared layers always receive `sharedKV`),
// so k_proj and k_norm are guaranteed to exist.
guard let kProj, let kNorm else {
fatalError(
"Gemma4Attention layer \(layerIdx) computed its own K/V but has no k_proj/k_norm; "
+ "KV-shared layers must be passed `sharedKV`.")
}
var k = kProj(x).reshaped(B, L, nKvHeads, effectiveHeadDim)
k = kNorm(k)
k = k.transposed(0, 2, 1, 3)
Expand Down Expand Up @@ -500,7 +507,7 @@ private class Gemma4TextModelInner: Module {

// Per-layer embeddings (PLE)
@ModuleInfo(key: "embed_tokens_per_layer") var embedTokensPerLayer: Embedding?
@ModuleInfo(key: "per_layer_model_projection") var perLayerModelProjection: ScaledLinear?
@ModuleInfo(key: "per_layer_model_projection") var perLayerModelProjection: Linear?
@ModuleInfo(key: "per_layer_projection_norm") var perLayerProjectionNorm: RMSNorm?

// KV sharing mapping: for each layer, which earlier layer provides KVs
Expand All @@ -524,10 +531,10 @@ private class Gemma4TextModelInner: Module {
self._embedTokensPerLayer.wrappedValue = Embedding(
embeddingCount: config.vocabSizePerLayerInput,
dimensions: config.numHiddenLayers * config.hiddenSizePerLayerInput)
self._perLayerModelProjection.wrappedValue = ScaledLinear(
inFeatures: config.hiddenSize,
outFeatures: config.numHiddenLayers * config.hiddenSizePerLayerInput,
scalar: pow(Float(config.hiddenSize), -0.5))
self._perLayerModelProjection.wrappedValue = Linear(
config.hiddenSize,
config.numHiddenLayers * config.hiddenSizePerLayerInput,
bias: false)
self._perLayerProjectionNorm.wrappedValue = RMSNorm(
dimensions: config.hiddenSizePerLayerInput, eps: config.rmsNormEps)
}
Expand Down Expand Up @@ -577,8 +584,15 @@ private class Gemma4TextModelInner: Module {
tokenPLE.dim(0), tokenPLE.dim(1),
config.numHiddenLayers, config.hiddenSizePerLayerInput)

// Model projection PLE
let modelPLE = modelProj(h).reshaped(
// Model projection PLE. The hidden_size**-0.5 scale (the reference
// impl's `per_layer_projection_scale`) is applied here, AFTER the
// projection, so the projection itself stays a plain Linear and
// quantizes like every other Linear in the model. The 8-bit E-series
// checkpoints ship per_layer_model_projection as a packed
// QuantizedLinear; a custom non-Linear module is invisible to
// quantize() and would mismatch those packed weights at load.
let perLayerProjectionScale = pow(Float(config.hiddenSize), -0.5)
let modelPLE = (modelProj(h) * perLayerProjectionScale).reshaped(
h.dim(0), h.dim(1),
config.numHiddenLayers, config.hiddenSizePerLayerInput)
let normedModelPLE = projNorm(modelPLE)
Expand Down Expand Up @@ -679,6 +693,7 @@ public class Gemma4TextModel: Module, LLMModel, KVCacheDimensionProvider {
}

public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
let firstKvSharedLayerIdx = config.numHiddenLayers - config.numKvSharedLayers
var sanitized = [String: MLXArray]()
for (k, v) in weights {
// Skip vision/audio/rotary weights
Expand All @@ -690,11 +705,32 @@ public class Gemma4TextModel: Module, LLMModel, KVCacheDimensionProvider {
{
continue
}
// Drop redundant k_proj/v_proj/k_norm for KV-shared layers: they reuse an
// earlier layer's K/V and own no K projection or K norm, so the module tree
// has none. QAT checkpoints already omit these; some (PTQ) checkpoints still
// ship them, and keeping them would be an unexpected weight. Dropping here
// makes both load against the same tree. (v_norm is parameter-free.)
if firstKvSharedLayerIdx > 0,
k.contains("self_attn.k_proj")
|| k.contains("self_attn.v_proj")
|| k.contains("self_attn.k_norm"),
let layerIdx = Self.decoderLayerIndex(in: k),
layerIdx >= firstKvSharedLayerIdx
{
continue
}
sanitized[k] = v
}
return sanitized
}

/// Extract `N` from a weight key shaped like `…layers.N.…`, else nil.
private static func decoderLayerIndex(in key: String) -> Int? {
guard let range = key.range(of: "layers.") else { return nil }
let digits = key[range.upperBound...].prefix { $0.isNumber }
return Int(digits)
}

public func newCache(parameters: GenerateParameters?) -> [any KVCache] {
let firstKvShared = config.numHiddenLayers - config.numKvSharedLayers

Expand Down
65 changes: 52 additions & 13 deletions Libraries/MLXVLM/Models/Gemma4.swift
Original file line number Diff line number Diff line change
Expand Up @@ -620,11 +620,12 @@ private final class Gemma4TextAttention: Module {
let useKEqV: Bool

@ModuleInfo(key: "q_proj") var qProj: Linear
@ModuleInfo(key: "k_proj") var kProj: Linear
// Optional: KV-shared layers reuse an earlier layer's K/V and own no k_proj/v_proj/k_norm.
@ModuleInfo(key: "k_proj") var kProj: Linear?
@ModuleInfo(key: "v_proj") var vProj: Linear?
@ModuleInfo(key: "o_proj") var oProj: Linear
@ModuleInfo(key: "q_norm") var qNorm: Gemma4RMSNormZeroShift
@ModuleInfo(key: "k_norm") var kNorm: Gemma4RMSNormZeroShift
@ModuleInfo(key: "k_norm") var kNorm: Gemma4RMSNormZeroShift?
@ModuleInfo(key: "v_norm") var vNorm: Gemma4RMSNormNoScale
@ModuleInfo var rope: OffsetLayer

Expand All @@ -646,16 +647,22 @@ private final class Gemma4TextAttention: Module {
self.isKVSharedLayer = layerIdx >= firstKVSharedLayer && firstKVSharedLayer > 0

self._qProj.wrappedValue = Linear(config.hiddenSize, numHeads * headDim, bias: false)
self._kProj.wrappedValue = Linear(config.hiddenSize, numKVHeads * headDim, bias: false)
if !useKEqV {
self._vProj.wrappedValue = Linear(
config.hiddenSize, numKVHeads * headDim, bias: false)
// KV-shared layers (the last `num_kv_shared_layers`) reuse an earlier layer's K/V,
// so they own no k_proj/v_proj/k_norm. QAT checkpoints prune those; create them only
// for the KV-owning layers so the module tree matches the checkpoint. (Redundant
// ones in PTQ checkpoints are dropped in `sanitize`.)
if !isKVSharedLayer {
self._kProj.wrappedValue = Linear(config.hiddenSize, numKVHeads * headDim, bias: false)
if !useKEqV {
self._vProj.wrappedValue = Linear(
config.hiddenSize, numKVHeads * headDim, bias: false)
}
self._kNorm.wrappedValue = Gemma4RMSNormZeroShift(
dimensions: headDim, eps: config.rmsNormEps)
}
self._oProj.wrappedValue = Linear(numHeads * headDim, config.hiddenSize, bias: false)
self._qNorm.wrappedValue = Gemma4RMSNormZeroShift(
dimensions: headDim, eps: config.rmsNormEps)
self._kNorm.wrappedValue = Gemma4RMSNormZeroShift(
dimensions: headDim, eps: config.rmsNormEps)
self._vNorm.wrappedValue = Gemma4RMSNormNoScale(eps: config.rmsNormEps)

let ropeKey = isSliding ? "sliding_attention" : "full_attention"
Expand Down Expand Up @@ -690,6 +697,13 @@ private final class Gemma4TextAttention: Module {
currentOffset = offset ?? 0
kvState = sharedKV
} else {
// Only KV-owning layers reach here (KV-shared layers always get `sharedKV`),
// so k_proj and k_norm are guaranteed to exist.
guard let kProj, let kNorm else {
fatalError(
"Gemma4 layer \(layerIdx) computed its own K/V but has no k_proj/k_norm; "
+ "KV-shared layers must be passed sharedKV.")
}
currentOffset = cache?.offset ?? 0
var keys = kProj(x).reshaped(batch, length, numKVHeads, headDim)
var values =
Expand Down Expand Up @@ -889,7 +903,9 @@ private final class Gemma4TextBackbone: Module {
@ModuleInfo(key: "layers") var layers: [Gemma4TextDecoderLayer]
@ModuleInfo(key: "norm") var norm: Gemma4RMSNormZeroShift
@ModuleInfo(key: "embed_tokens_per_layer") var embedTokensPerLayer: Embedding?
@ModuleInfo(key: "per_layer_model_projection") var perLayerModelProjection: Gemma4ScaledLinear?
// A plain Linear (not a custom ScaledLinear) so `quantize()` can convert it for QAT
// checkpoints; the hidden_size**-0.5 scale is applied in the forward. (Mirrors #320.)
@ModuleInfo(key: "per_layer_model_projection") var perLayerModelProjection: Linear?
@ModuleInfo(key: "per_layer_projection_norm") var perLayerProjectionNorm:
Gemma4RMSNormZeroShift?

Expand Down Expand Up @@ -929,10 +945,10 @@ private final class Gemma4TextBackbone: Module {
embeddingCount: config.vocabularySizePerLayerInput,
dimensions: config.hiddenLayers * config.hiddenSizePerLayerInput
)
self._perLayerModelProjection.wrappedValue = Gemma4ScaledLinear(
inFeatures: config.hiddenSize,
outFeatures: config.hiddenLayers * config.hiddenSizePerLayerInput,
scalar: pow(Float(config.hiddenSize), -0.5)
self._perLayerModelProjection.wrappedValue = Linear(
config.hiddenSize,
config.hiddenLayers * config.hiddenSizePerLayerInput,
bias: false
)
self._perLayerProjectionNorm.wrappedValue = Gemma4RMSNormZeroShift(
dimensions: config.hiddenSizePerLayerInput, eps: config.rmsNormEps)
Expand Down Expand Up @@ -963,7 +979,10 @@ private final class Gemma4TextBackbone: Module {
return nil
}

// hidden_size**-0.5 scale that the old Gemma4ScaledLinear baked in (now applied here
// so the projection can be a plain, quantizable Linear — mirrors #320).
var perLayerProjection = perLayerModelProjection(inputsEmbeds)
* pow(Float(config.hiddenSize), -0.5)
perLayerProjection = perLayerProjection.reshaped(
Array(inputsEmbeds.shape.dropLast()) + [
config.hiddenLayers, config.hiddenSizePerLayerInput,
Expand Down Expand Up @@ -1151,10 +1170,30 @@ private final class Gemma4TextLanguageModel: Module, KVCacheDimensionProvider {
var sanitized: [String: MLXArray] = [:]
sanitized.reserveCapacity(weights.count + 1)

let firstKVSharedLayer = config.hiddenLayers - config.numKVSharedLayers
for (key, value) in weights {
if key.contains("rotary_emb") {
continue
}
// Drop redundant k_proj/v_proj/k_norm for KV-shared layers (they reuse an earlier
// layer's K/V and own none). QAT checkpoints already omit these; some (PTQ) ship
// them — dropping makes both load against the now-smaller module tree.
// SCOPE: text backbone only. The vision tower (`vision_tower.encoder.layers.N`)
// shares the `layers.N.self_attn.{k,v}_proj` naming, so without this guard the
// drop would amputate vision layers ≥ firstKVSharedLayer (15) and trigger a
// `keyNotFound` for their clip bounds / projections.
if firstKVSharedLayer > 0,
!key.contains("vision_tower"),
!key.contains("audio_tower"),
key.contains("self_attn.k_proj")
|| key.contains("self_attn.v_proj")
|| key.contains("self_attn.k_norm"),
let r = key.range(of: "layers."),
let li = Int(key[r.upperBound...].prefix { $0.isNumber }),
li >= firstKVSharedLayer
{
continue
}

var newKey = key
if newKey.hasPrefix("model.") {
Expand Down