Fix Gemma 4 QAT load: KV-shared layers carry no k_proj/v_proj/k_norm#342
Open
Flo5k5 wants to merge 1 commit into
Open
Fix Gemma 4 QAT load: KV-shared layers carry no k_proj/v_proj/k_norm#342Flo5k5 wants to merge 1 commit into
Flo5k5 wants to merge 1 commit into
Conversation
Gemma 4 E-series shares K/V across the last `num_kv_shared_layers` layers:
those layers reuse the K/V of an earlier layer of the same attention type. They
still compute their own queries (q_proj/q_norm) but own NO k_proj, v_proj or
k_norm — only o_proj/q_proj/q_norm. The forward pass already handles this
(KV-shared layers receive `sharedKV` and never touch kProj/kNorm/vProj), but
`Gemma4Attention` declared k_proj/k_norm as non-optional and created all of
k_proj/v_proj/k_norm unconditionally, so the loader demanded
`layers.N.self_attn.{k_proj,v_proj,k_norm}` for the shared layers too.
Full-precision/PTQ checkpoints still carry those redundant tensors, so they
loaded. Quantization-aware (QAT) checkpoints prune them, so loading
gemma-4-E2B-it-qat-4bit failed at layer 15 (first shared layer):
keyNotFound layers.15.self_attn.v_proj.weight (then k_norm.weight)
Confirmed against the checkpoint index: a shared layer ships only
o_proj/q_proj/q_norm; it lacks k_proj.{weight,scales,biases},
v_proj.{weight,scales,biases} and k_norm.weight. (v_norm is RMSNormNoScale —
parameter-free — so it never appears in any checkpoint.)
Make k_proj and k_norm optional and create k_proj/v_proj/k_norm only for the
KV-owning layers (same predicate as the existing double-wide-MLP gate), guard
the compute path, and drop any redundant k_proj/v_proj/k_norm of KV-shared
layers in `sanitize` so PTQ checkpoints that still ship them keep loading
against the now-smaller tree.
Together with the per_layer_model_projection fix (ml-explore#320), this lets the QAT
E-series checkpoints load and run.
Verified on macOS: gemma-4-E2B-it-qat-4bit now loads and responds.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Proposed changes
Gemma 4 E-series shares K/V across the last
num_kv_shared_layerslayers: those layers reuse the K/V of an earlier layer of the same attention type. They compute their own queries (q_proj/q_norm) but own nok_proj,v_projork_norm— onlyo_proj/q_proj/q_norm.Gemma4Attentiondeclaredk_proj/k_normas non-optional and createdk_proj/v_proj/k_normunconditionally, so the loader demandedlayers.N.self_attn.{k_proj,v_proj,k_norm}for the shared layers too. Full-precision / PTQ checkpoints still carry those redundant tensors, so they loaded — but quantization-aware (QAT) checkpoints prune them, so loadinggemma-4-E2B-it-qat-4bitfailed at layer 15 (the first shared layer):Confirmed against the checkpoint index: a shared layer ships only
o_proj/q_proj/q_norm; it lacksk_proj.{weight,scales,biases},v_proj.{weight,scales,biases}andk_norm.weight. (v_normisRMSNormNoScale— parameter-free — so it never appears in any checkpoint.)This change:
k_proj/k_normoptional and createsk_proj/v_proj/k_normonly for the KV-owning layers (same predicate as the existing double-wide-MLP gate);sharedKV, so the K/V-owning branch is unreachable for them (with a clearfatalErrorif a misconfiguration ever sends one there);k_proj/v_proj/k_normof KV-shared layers insanitize, so PTQ checkpoints that still ship them keep loading against the now-smaller tree.Together with the
per_layer_model_projectionfix (#320), this lets the QAT E-series checkpoints load and run. Verified on macOS:gemma-4-E2B-it-qat-4bitnow loads and responds. The change has been running in a downstream app loading the Gemma 4 E2B/E4B QAT checkpoints on Apple Silicon.A sibling fix exists for the VLM path (
MLXVLM/Models/Gemma4.swift), which I'll send as a follow-up PR rebased on the recent KV-shared / drafter work inmain.Checklist
pre-commit run --all-files— authored in-tree through the fork's formatting; happy to address anyswift-formatnits