Fix Gemma4 QAT (E-series) load: KV-shared layers have no k_proj/v_proj/k_norm#330
Open
Flo5k5 wants to merge 3 commits into
Open
Fix Gemma4 QAT (E-series) load: KV-shared layers have no k_proj/v_proj/k_norm#330Flo5k5 wants to merge 3 commits into
Flo5k5 wants to merge 3 commits into
Conversation
per_layer_model_projection was a custom ScaledLinear (a plain Module holding a raw weight + scalar), which quantize() cannot see — it only converts Linear/Embedding. With quantized E-series checkpoints (e.g. gemma-4-E4B 8-bit) the projection stayed full-precision and crashed at load: mismatchedSize per_layer_model_projection.weight expected [num_layers*hidden_per_layer, hidden] got [..., hidden/pack_factor] (the packed 8-bit QuantizedLinear weight [10752, 640] never matched the full-precision expectation [10752, 2560]). Mirror the Python reference (mlx_lm gemma3n): make per_layer_model_projection a plain Linear (so it quantizes like every other Linear) and apply the hidden_size**-0.5 scale (per_layer_projection_scale) in the forward pass rather than baking it into a bespoke module. Verified on macOS: gemma-4-E4B-it-MLX-8bit now loads and runs (prefill ~1670 t/s on Apple Silicon); previously a hard crash at load.
Gemma 4 E-series shares K/V across the last `num_kv_shared_layers` layers:
those layers reuse the K/V of an earlier layer of the same attention type. They
still compute their own queries (q_proj/q_norm) but own NO k_proj, v_proj or
k_norm — only o_proj/q_proj/q_norm. The forward pass already handles this
(KV-shared layers receive `sharedKV` and never touch kProj/kNorm/vProj), but
`Gemma4Attention` declared k_proj/k_norm as non-optional and created all of
k_proj/v_proj/k_norm unconditionally, so the loader demanded
`layers.N.self_attn.{k_proj,v_proj,k_norm}` for the shared layers too.
Full-precision/PTQ checkpoints still carry those redundant tensors, so they
loaded. Quantization-aware (QAT) checkpoints prune them, so loading
gemma-4-E2B-it-qat-4bit failed at layer 15 (first shared layer):
keyNotFound layers.15.self_attn.v_proj.weight (then k_norm.weight)
Confirmed against the checkpoint index: a shared layer ships only
o_proj/q_proj/q_norm; it lacks k_proj.{weight,scales,biases},
v_proj.{weight,scales,biases} and k_norm.weight. (v_norm is RMSNormNoScale —
parameter-free — so it never appears in any checkpoint.)
Make k_proj and k_norm optional and create k_proj/v_proj/k_norm only for the
KV-owning layers (same predicate as the existing double-wide-MLP gate), guard
the compute path, and drop any redundant k_proj/v_proj/k_norm of KV-shared
layers in `sanitize` so PTQ checkpoints that still ship them keep loading
against the now-smaller tree.
Together with the per_layer_model_projection fix (ml-explore#320), this lets the QAT
E-series checkpoints load and run.
Verified on macOS: gemma-4-E2B-it-qat-4bit now loads and responds.
Mirror the two MLXLLM/Gemma4Text fixes onto the MLXVLM Gemma4 text backbone so the quantization-aware E2B checkpoint loads as a *vision* model: - KV-sharing: kProj/kNorm are now optional and created only for the KV-owning layers (the last num_kv_shared_layers reuse an earlier layer's K/V and own no k_proj/v_proj/k_norm); guard the compute path; drop redundant keys in sanitize. - PLE: per_layer_model_projection is a plain (quantizable) Linear with the hidden_size**-0.5 scale applied in the forward (mirrors ml-explore#320). The sanitize drop is scoped to the text backbone — it must NOT match the vision tower, which shares the layers.N.self_attn.{k,v}_proj naming. Without the `!key.contains("vision_tower")` guard the drop amputated vision layers >= 15 and triggered keyNotFound on their clip bounds (output_max).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Gemma 4 E-series (E2B/E4B) shares K/V across the last
num_kv_shared_layerslayers.Those layers reuse the K/V of an earlier layer of the same attention type — they still
compute their own queries (
q_proj/q_norm) but own nok_proj,v_projork_norm.Gemma4TextModelInneralready routes their KV viapreviousKvs/sharedKV,and the forward path never touches
kProj/kNorm/vProjfor them.But
Gemma4Attentiondeclaredk_proj/k_normas non-optional and createdk_proj/v_proj/k_normunconditionally for every layer, so the loader demandedlayers.N.self_attn.{k_proj,v_proj,k_norm}for the KV-shared layers too.Full-precision / PTQ checkpoints still carry those (now-redundant) tensors, so they
happened to load. Quantization-aware (QAT) checkpoints prune them, so e.g.
mlx-community/gemma-4-E2B-it-qat-4bitfails at layer 15 (the first shared layer):Confirmed against the checkpoint's
model.safetensors.index.json: an owner layer (14)ships
q/k/v_proj,q_norm,k_norm,o_proj; a shared layer (15) ships onlyq_proj,q_norm,o_proj.v_normisRMSNormNoScale(parameter-free) so it neverappears in any checkpoint and needs no change.
Change (
Gemma4Text.swift)kProjandkNormoptional (vProjwas already optional).k_proj/v_proj/k_normonly for the KV-owning layers, using the samelayerIdx >= numHiddenLayers - numKvSharedLayerspredicate already used for thedouble-wide-MLP gate.
receive
sharedKV).sanitize, drop any redundantk_proj/v_proj/k_normof KV-shared layers, soPTQ checkpoints that still ship them keep loading against the now-smaller module tree.
Net change is +50/−7 in one file; the forward/inference logic is untouched (it already
handled sharing correctly — only the module declarations over-claimed weights).
Relationship to #320
This builds on top of @nyteshade's #320 (make
per_layer_model_projectionquantizable) — that commit is included here to preserve attribution. #320 fixes the
first QAT load error (the PLE projection); this fixes the second (KV sharing).
Both are required for the QAT E-series checkpoints to load. Happy to rebase onto
mainonce #320 merges (then this becomes a single-commit follow-up).
Verification
macOS (Apple Silicon),
mlx-community/gemma-4-E2B-it-qat-4bit:Before: hard
keyNotFoundat layer 15 during load.After: loads and generates. Sample (text-only prompt through a tool-enabled
session):
PTQ checkpoints (
mlx-community/gemma-4-e2b-it-4bit) continue to load unchanged (thesanitizedrop handles the redundant tensors they still carry).