Skip to content

Fix Gemma4 E-series load: make per_layer_model_projection quantizable#320

Closed
nyteshade wants to merge 1 commit into
ml-explore:mainfrom
nyteshade:fix/gemma4-ple-quantization
Closed

Fix Gemma4 E-series load: make per_layer_model_projection quantizable#320
nyteshade wants to merge 1 commit into
ml-explore:mainfrom
nyteshade:fix/gemma4-ple-quantization

Conversation

@nyteshade

Copy link
Copy Markdown

Problem

Quantized Gemma 4 E-series (per-layer-embeddings) checkpoints — e.g. gemma-4-E4B-it-MLX-8bit — crash at load:

mismatchedSize(path: ["language_model","model","per_layer_model_projection","weight"],
  modules: [..., "ScaledLinear"],
  expectedShape: [10752, 2560], actualShape: [10752, 640])

10752 = num_hidden_layers (42) × hidden_size_per_layer_input (256). The input dim is wrong: the model expects 2560 (= hidden_size, full precision) but the 8-bit checkpoint ships the packed QuantizedLinear weight 640 = 2560 / 4 (4 values per uint32), plus .scales/.biases.

Cause

per_layer_model_projection is declared as a custom ScaledLinear — a plain Module that holds a raw weight and does matmul(x, weight.T) * scalar. quantize() only converts Linear/Embedding, so ScaledLinear is invisible to it: the projection stays full-precision while the checkpoint stores it quantized → shape mismatch at load. (Dense/bf16 E-series checkpoints load fine, which is why this went unnoticed.)

Fix

Mirror the Python reference (mlx_lm gemma3n), where per_layer_model_projection is a plain nn.Linear and per_layer_projection_scale = hidden_size**-0.5 is applied separately in the forward pass:

  • Make per_layer_model_projection a plain Linear (so it quantizes like every other Linear).
  • Apply the hidden_size**-0.5 scale after the projection in callAsFunction.
  • Remove the now-unused ScaledLinear.

Net: +14 / −22 lines.

Verification

On macOS / Apple Silicon, gemma-4-E4B-it-MLX-8bit now loads and runs — prefill ~1670 t/s on a 35.7k-token prompt. Previously a hard crash at load. Full-precision E-series checkpoints continue to load (the projection is still a Linear; only its quantization visibility changed).

Flo5k5 added a commit to Flo5k5/mlx-swift-lm that referenced this pull request Jun 6, 2026
Mirror the two MLXLLM/Gemma4Text fixes onto the MLXVLM Gemma4 text backbone so the
quantization-aware E2B checkpoint loads as a *vision* model:

- KV-sharing: kProj/kNorm are now optional and created only for the KV-owning
  layers (the last num_kv_shared_layers reuse an earlier layer's K/V and own no
  k_proj/v_proj/k_norm); guard the compute path; drop redundant keys in sanitize.
- PLE: per_layer_model_projection is a plain (quantizable) Linear with the
  hidden_size**-0.5 scale applied in the forward (mirrors ml-explore#320).

The sanitize drop is scoped to the text backbone — it must NOT match the vision
tower, which shares the layers.N.self_attn.{k,v}_proj naming. Without the
`!key.contains("vision_tower")` guard the drop amputated vision layers >= 15
and triggered keyNotFound on their clip bounds (output_max).
per_layer_model_projection was a custom ScaledLinear (a plain Module holding a
raw weight + scalar), which quantize() cannot see — it only converts
Linear/Embedding. With quantized E-series checkpoints (e.g. gemma-4-E4B 8-bit)
the projection stayed full-precision and crashed at load:

  mismatchedSize per_layer_model_projection.weight
  expected [num_layers*hidden_per_layer, hidden] got [..., hidden/pack_factor]

(the packed 8-bit QuantizedLinear weight [10752, 640] never matched the
full-precision expectation [10752, 2560]).

Mirror the Python reference (mlx_lm gemma3n): make per_layer_model_projection a
plain Linear (so it quantizes like every other Linear) and apply the
hidden_size**-0.5 scale (per_layer_projection_scale) in the forward pass rather
than baking it into a bespoke module.

Verified on macOS: gemma-4-E4B-it-MLX-8bit now loads and runs (prefill ~1670
t/s on Apple Silicon); previously a hard crash at load.
@davidkoski davidkoski force-pushed the fix/gemma4-ple-quantization branch from 5f550e9 to 774607e Compare June 15, 2026 17:52
@davidkoski

Copy link
Copy Markdown
Collaborator

Rebased on main to fix conflict.

// checkpoints ship per_layer_model_projection as a packed
// QuantizedLinear; a custom non-Linear module is invisible to
// quantize() and would mismatch those packed weights at load.
let perLayerProjectionScale = pow(Float(config.hiddenSize), -0.5)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this becomes redundant with #309 -- that PR switches to Linear and has perLayerProjectionScale as a property of the layer. Same solution.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood.

@davidkoski

Copy link
Copy Markdown
Collaborator

See #309 -- closing as a dup of that. Thank you!

@davidkoski davidkoski closed this Jun 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants