Skip to content

Fix Gemma4 QAT (E-series) load: KV-shared layers have no k_proj/v_proj/k_norm#330

Open
Flo5k5 wants to merge 3 commits into
ml-explore:mainfrom
Flo5k5:fix/gemma4-kv-sharing-qat
Open

Fix Gemma4 QAT (E-series) load: KV-shared layers have no k_proj/v_proj/k_norm#330
Flo5k5 wants to merge 3 commits into
ml-explore:mainfrom
Flo5k5:fix/gemma4-kv-sharing-qat

Conversation

@Flo5k5

@Flo5k5 Flo5k5 commented Jun 5, 2026

Copy link
Copy Markdown

What

Gemma 4 E-series (E2B/E4B) shares K/V across the last num_kv_shared_layers layers.
Those layers reuse the K/V of an earlier layer of the same attention type — they still
compute their own queries (q_proj/q_norm) but own no k_proj, v_proj or
k_norm. Gemma4TextModelInner already routes their KV via previousKvs/sharedKV,
and the forward path never touches kProj/kNorm/vProj for them.

But Gemma4Attention declared k_proj/k_norm as non-optional and created
k_proj/v_proj/k_norm unconditionally for every layer, so the loader demanded
layers.N.self_attn.{k_proj,v_proj,k_norm} for the KV-shared layers too.

Full-precision / PTQ checkpoints still carry those (now-redundant) tensors, so they
happened to load. Quantization-aware (QAT) checkpoints prune them, so e.g.
mlx-community/gemma-4-E2B-it-qat-4bit fails at layer 15 (the first shared layer):

keyNotFound layers.15.self_attn.v_proj.weight    (then k_norm.weight)

Confirmed against the checkpoint's model.safetensors.index.json: an owner layer (14)
ships q/k/v_proj, q_norm, k_norm, o_proj; a shared layer (15) ships only
q_proj, q_norm, o_proj. v_norm is RMSNormNoScale (parameter-free) so it never
appears in any checkpoint and needs no change.

Change (Gemma4Text.swift)

  • Make kProj and kNorm optional (vProj was already optional).
  • Create k_proj/v_proj/k_norm only for the KV-owning layers, using the same
    layerIdx >= numHiddenLayers - numKvSharedLayers predicate already used for the
    double-wide-MLP gate.
  • Guard the compute path (only KV-owning layers reach it; KV-shared layers always
    receive sharedKV).
  • In sanitize, drop any redundant k_proj/v_proj/k_norm of KV-shared layers, so
    PTQ checkpoints that still ship them keep loading against the now-smaller module tree.

Net change is +50/−7 in one file; the forward/inference logic is untouched (it already
handled sharing correctly — only the module declarations over-claimed weights).

Relationship to #320

This builds on top of @nyteshade's #320 (make per_layer_model_projection
quantizable
) — that commit is included here to preserve attribution. #320 fixes the
first QAT load error (the PLE projection); this fixes the second (KV sharing).
Both are required for the QAT E-series checkpoints to load. Happy to rebase onto main
once #320 merges (then this becomes a single-commit follow-up).

Verification

macOS (Apple Silicon), mlx-community/gemma-4-E2B-it-qat-4bit:

  • Before: hard keyNotFound at layer 15 during load.

  • After: loads and generates. Sample (text-only prompt through a tool-enabled
    session):

    [mood: thinking] I can check your sleep, steps, and heart rate for you. I'll start with your health context.

PTQ checkpoints (mlx-community/gemma-4-e2b-it-4bit) continue to load unchanged (the
sanitize drop handles the redundant tensors they still carry).

nyteshade and others added 3 commits May 28, 2026 13:01
per_layer_model_projection was a custom ScaledLinear (a plain Module holding a
raw weight + scalar), which quantize() cannot see — it only converts
Linear/Embedding. With quantized E-series checkpoints (e.g. gemma-4-E4B 8-bit)
the projection stayed full-precision and crashed at load:

  mismatchedSize per_layer_model_projection.weight
  expected [num_layers*hidden_per_layer, hidden] got [..., hidden/pack_factor]

(the packed 8-bit QuantizedLinear weight [10752, 640] never matched the
full-precision expectation [10752, 2560]).

Mirror the Python reference (mlx_lm gemma3n): make per_layer_model_projection a
plain Linear (so it quantizes like every other Linear) and apply the
hidden_size**-0.5 scale (per_layer_projection_scale) in the forward pass rather
than baking it into a bespoke module.

Verified on macOS: gemma-4-E4B-it-MLX-8bit now loads and runs (prefill ~1670
t/s on Apple Silicon); previously a hard crash at load.
Gemma 4 E-series shares K/V across the last `num_kv_shared_layers` layers:
those layers reuse the K/V of an earlier layer of the same attention type. They
still compute their own queries (q_proj/q_norm) but own NO k_proj, v_proj or
k_norm — only o_proj/q_proj/q_norm. The forward pass already handles this
(KV-shared layers receive `sharedKV` and never touch kProj/kNorm/vProj), but
`Gemma4Attention` declared k_proj/k_norm as non-optional and created all of
k_proj/v_proj/k_norm unconditionally, so the loader demanded
`layers.N.self_attn.{k_proj,v_proj,k_norm}` for the shared layers too.

Full-precision/PTQ checkpoints still carry those redundant tensors, so they
loaded. Quantization-aware (QAT) checkpoints prune them, so loading
gemma-4-E2B-it-qat-4bit failed at layer 15 (first shared layer):

  keyNotFound layers.15.self_attn.v_proj.weight   (then k_norm.weight)

Confirmed against the checkpoint index: a shared layer ships only
o_proj/q_proj/q_norm; it lacks k_proj.{weight,scales,biases},
v_proj.{weight,scales,biases} and k_norm.weight. (v_norm is RMSNormNoScale —
parameter-free — so it never appears in any checkpoint.)

Make k_proj and k_norm optional and create k_proj/v_proj/k_norm only for the
KV-owning layers (same predicate as the existing double-wide-MLP gate), guard
the compute path, and drop any redundant k_proj/v_proj/k_norm of KV-shared
layers in `sanitize` so PTQ checkpoints that still ship them keep loading
against the now-smaller tree.

Together with the per_layer_model_projection fix (ml-explore#320), this lets the QAT
E-series checkpoints load and run.

Verified on macOS: gemma-4-E2B-it-qat-4bit now loads and responds.
Mirror the two MLXLLM/Gemma4Text fixes onto the MLXVLM Gemma4 text backbone so the
quantization-aware E2B checkpoint loads as a *vision* model:

- KV-sharing: kProj/kNorm are now optional and created only for the KV-owning
  layers (the last num_kv_shared_layers reuse an earlier layer's K/V and own no
  k_proj/v_proj/k_norm); guard the compute path; drop redundant keys in sanitize.
- PLE: per_layer_model_projection is a plain (quantizable) Linear with the
  hidden_size**-0.5 scale applied in the forward (mirrors ml-explore#320).

The sanitize drop is scoped to the text backbone — it must NOT match the vision
tower, which shares the layers.N.self_attn.{k,v}_proj naming. Without the
`!key.contains("vision_tower")` guard the drop amputated vision layers >= 15
and triggered keyNotFound on their clip bounds (output_max).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants