Skip to content

Fix Gemma 4 QAT load: KV-shared layers carry no k_proj/v_proj/k_norm#342

Open
Flo5k5 wants to merge 1 commit into
ml-explore:mainfrom
Flo5k5:upstream-pr/gemma4-qat-kvshared-text
Open

Fix Gemma 4 QAT load: KV-shared layers carry no k_proj/v_proj/k_norm#342
Flo5k5 wants to merge 1 commit into
ml-explore:mainfrom
Flo5k5:upstream-pr/gemma4-qat-kvshared-text

Conversation

@Flo5k5

@Flo5k5 Flo5k5 commented Jun 12, 2026

Copy link
Copy Markdown

Proposed changes

Gemma 4 E-series shares K/V across the last num_kv_shared_layers layers: those layers reuse the K/V of an earlier layer of the same attention type. They compute their own queries (q_proj/q_norm) but own no k_proj, v_proj or k_norm — only o_proj/q_proj/q_norm.

Gemma4Attention declared k_proj/k_norm as non-optional and created k_proj/v_proj/k_norm unconditionally, so the loader demanded layers.N.self_attn.{k_proj,v_proj,k_norm} for the shared layers too. Full-precision / PTQ checkpoints still carry those redundant tensors, so they loaded — but quantization-aware (QAT) checkpoints prune them, so loading gemma-4-E2B-it-qat-4bit failed at layer 15 (the first shared layer):

keyNotFound layers.15.self_attn.v_proj.weight   (then k_norm.weight)

Confirmed against the checkpoint index: a shared layer ships only o_proj/q_proj/q_norm; it lacks k_proj.{weight,scales,biases}, v_proj.{weight,scales,biases} and k_norm.weight. (v_norm is RMSNormNoScale — parameter-free — so it never appears in any checkpoint.)

This change:

  • makes k_proj/k_norm optional and creates k_proj/v_proj/k_norm only for the KV-owning layers (same predicate as the existing double-wide-MLP gate);
  • guards the compute path — KV-shared layers always receive sharedKV, so the K/V-owning branch is unreachable for them (with a clear fatalError if a misconfiguration ever sends one there);
  • drops any redundant k_proj/v_proj/k_norm of KV-shared layers in sanitize, so PTQ checkpoints that still ship them keep loading against the now-smaller tree.

Together with the per_layer_model_projection fix (#320), this lets the QAT E-series checkpoints load and run. Verified on macOS: gemma-4-E2B-it-qat-4bit now loads and responds. The change has been running in a downstream app loading the Gemma 4 E2B/E4B QAT checkpoints on Apple Silicon.

A sibling fix exists for the VLM path (MLXVLM/Models/Gemma4.swift), which I'll send as a follow-up PR rebased on the recent KV-shared / drafter work in main.

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files — authored in-tree through the fork's formatting; happy to address any swift-format nits
  • I have added tests — this is a checkpoint-load-path fix, exercised via the integration model-loading tests rather than a unit test
  • I have updated the necessary documentation (n/a)

Gemma 4 E-series shares K/V across the last `num_kv_shared_layers` layers:
those layers reuse the K/V of an earlier layer of the same attention type. They
still compute their own queries (q_proj/q_norm) but own NO k_proj, v_proj or
k_norm — only o_proj/q_proj/q_norm. The forward pass already handles this
(KV-shared layers receive `sharedKV` and never touch kProj/kNorm/vProj), but
`Gemma4Attention` declared k_proj/k_norm as non-optional and created all of
k_proj/v_proj/k_norm unconditionally, so the loader demanded
`layers.N.self_attn.{k_proj,v_proj,k_norm}` for the shared layers too.

Full-precision/PTQ checkpoints still carry those redundant tensors, so they
loaded. Quantization-aware (QAT) checkpoints prune them, so loading
gemma-4-E2B-it-qat-4bit failed at layer 15 (first shared layer):

  keyNotFound layers.15.self_attn.v_proj.weight   (then k_norm.weight)

Confirmed against the checkpoint index: a shared layer ships only
o_proj/q_proj/q_norm; it lacks k_proj.{weight,scales,biases},
v_proj.{weight,scales,biases} and k_norm.weight. (v_norm is RMSNormNoScale —
parameter-free — so it never appears in any checkpoint.)

Make k_proj and k_norm optional and create k_proj/v_proj/k_norm only for the
KV-owning layers (same predicate as the existing double-wide-MLP gate), guard
the compute path, and drop any redundant k_proj/v_proj/k_norm of KV-shared
layers in `sanitize` so PTQ checkpoints that still ship them keep loading
against the now-smaller tree.

Together with the per_layer_model_projection fix (ml-explore#320), this lets the QAT
E-series checkpoints load and run.

Verified on macOS: gemma-4-E2B-it-qat-4bit now loads and responds.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant