Fix Gemma4 E-series load: make per_layer_model_projection quantizable#320
Closed
nyteshade wants to merge 1 commit into
Closed
Fix Gemma4 E-series load: make per_layer_model_projection quantizable#320nyteshade wants to merge 1 commit into
nyteshade wants to merge 1 commit into
Conversation
Flo5k5
added a commit
to Flo5k5/mlx-swift-lm
that referenced
this pull request
Jun 6, 2026
Mirror the two MLXLLM/Gemma4Text fixes onto the MLXVLM Gemma4 text backbone so the quantization-aware E2B checkpoint loads as a *vision* model: - KV-sharing: kProj/kNorm are now optional and created only for the KV-owning layers (the last num_kv_shared_layers reuse an earlier layer's K/V and own no k_proj/v_proj/k_norm); guard the compute path; drop redundant keys in sanitize. - PLE: per_layer_model_projection is a plain (quantizable) Linear with the hidden_size**-0.5 scale applied in the forward (mirrors ml-explore#320). The sanitize drop is scoped to the text backbone — it must NOT match the vision tower, which shares the layers.N.self_attn.{k,v}_proj naming. Without the `!key.contains("vision_tower")` guard the drop amputated vision layers >= 15 and triggered keyNotFound on their clip bounds (output_max).
4 tasks
per_layer_model_projection was a custom ScaledLinear (a plain Module holding a raw weight + scalar), which quantize() cannot see — it only converts Linear/Embedding. With quantized E-series checkpoints (e.g. gemma-4-E4B 8-bit) the projection stayed full-precision and crashed at load: mismatchedSize per_layer_model_projection.weight expected [num_layers*hidden_per_layer, hidden] got [..., hidden/pack_factor] (the packed 8-bit QuantizedLinear weight [10752, 640] never matched the full-precision expectation [10752, 2560]). Mirror the Python reference (mlx_lm gemma3n): make per_layer_model_projection a plain Linear (so it quantizes like every other Linear) and apply the hidden_size**-0.5 scale (per_layer_projection_scale) in the forward pass rather than baking it into a bespoke module. Verified on macOS: gemma-4-E4B-it-MLX-8bit now loads and runs (prefill ~1670 t/s on Apple Silicon); previously a hard crash at load.
5f550e9 to
774607e
Compare
Collaborator
|
Rebased on main to fix conflict. |
davidkoski
reviewed
Jun 15, 2026
| // checkpoints ship per_layer_model_projection as a packed | ||
| // QuantizedLinear; a custom non-Linear module is invisible to | ||
| // quantize() and would mismatch those packed weights at load. | ||
| let perLayerProjectionScale = pow(Float(config.hiddenSize), -0.5) |
Collaborator
There was a problem hiding this comment.
I think this becomes redundant with #309 -- that PR switches to Linear and has perLayerProjectionScale as a property of the layer. Same solution.
Collaborator
|
See #309 -- closing as a dup of that. Thank you! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Quantized Gemma 4 E-series (per-layer-embeddings) checkpoints — e.g.
gemma-4-E4B-it-MLX-8bit— crash at load:10752 = num_hidden_layers (42) × hidden_size_per_layer_input (256). The input dim is wrong: the model expects2560(=hidden_size, full precision) but the 8-bit checkpoint ships the packedQuantizedLinearweight640 = 2560 / 4(4 values peruint32), plus.scales/.biases.Cause
per_layer_model_projectionis declared as a customScaledLinear— a plainModulethat holds a rawweightand doesmatmul(x, weight.T) * scalar.quantize()only convertsLinear/Embedding, soScaledLinearis invisible to it: the projection stays full-precision while the checkpoint stores it quantized → shape mismatch at load. (Dense/bf16 E-series checkpoints load fine, which is why this went unnoticed.)Fix
Mirror the Python reference (
mlx_lmgemma3n), whereper_layer_model_projectionis a plainnn.Linearandper_layer_projection_scale = hidden_size**-0.5is applied separately in the forward pass:per_layer_model_projectiona plainLinear(so it quantizes like every otherLinear).hidden_size**-0.5scale after the projection incallAsFunction.ScaledLinear.Net: +14 / −22 lines.
Verification
On macOS / Apple Silicon,
gemma-4-E4B-it-MLX-8bitnow loads and runs — prefill ~1670 t/s on a 35.7k-token prompt. Previously a hard crash at load. Full-precision E-series checkpoints continue to load (the projection is still aLinear; only its quantization visibility changed).