Skip to content

Gemma 3: chunked prompt prefill, skip lm_head on prompt positions#346

Open
beshkenadze wants to merge 1 commit into
ml-explore:mainfrom
beshkenadze:upstream-gemma3-prefill
Open

Gemma 3: chunked prompt prefill, skip lm_head on prompt positions#346
beshkenadze wants to merge 1 commit into
ml-explore:mainfrom
beshkenadze:upstream-gemma3-prefill

Conversation

@beshkenadze

Copy link
Copy Markdown

What

Speeds up prompt prefill (time-to-first-token) for all Gemma 3 text models by skipping the large lm_head on prompt positions and chunking the prefill for better CPU/GPU overlap.

Why

Gemma3TextModel.prepare returned the whole prompt to the TokenIterator, which then ran the 262k-vocabulary lm_head over every prompt position just to produce the first token. For Gemma 3's unusually large vocab this is a real TTFT tax that grows with prompt length, plus a large transient-logits memory spike.

How

  • Prefill all-but-the-last token through the inner Gemma3Model (no lm_head), updating only the KV cache; the discarded hidden states are never evaluated thanks to lazy eval. Only the final token goes through the full model to prime the iterator.
  • Chunk the prefill with asyncEval per chunk so chunk N+1's graph builds while the GPU evaluates chunk N. Chunk size honours an explicit GenerateParameters.prefillStepSize and otherwise defaults to 128, empirically tuned for this path on Apple Silicon. To make that default reachable, prefillStepSize becomes optional (nil = let the model choose); other models keep their 512 default via windowSize ?? 512 — no behaviour change for them.

Results

mlx-community/translategemma-4b-it-4bit (Gemma 3 text), Apple Silicon, 577-token prompt, greedy, median of 3:

prompt tok/s prefill time
before 177 3253 ms
after 463 1246 ms

2.6× faster prefill. Greedy output is byte-identical to before, including prompts longer than the 1024 sliding window where the rotating KV cache rotates.

Tuning (chunk-size sweep)

chunk tok/s
1024 (= sliding window) 288
256 384
128 (default) 463
64 185

128 wins: larger chunks lose asyncEval pipelining; smaller ones drown in per-launch overhead.

Scope / risk

  • Behaviour change limited to the Gemma 3 text path; other models unaffected.
  • Public API: GenerateParameters.prefillStepSize changes IntInt? (constructing with an explicit value is unchanged; only the default becomes "model decides").

Gemma3TextModel.prepare previously returned the whole prompt to the
TokenIterator, which then ran the 262k-vocabulary lm_head over every
prompt position just to produce the first token. Instead, prefill all
but the last token through the inner model (which has no lm_head),
updating only the KV cache, and hand just the final token to the
iterator.

The chunk size honours an explicit GenerateParameters.prefillStepSize
and otherwise defaults to 128, tuned for this path on Apple Silicon for
the best asyncEval CPU/GPU pipelining. To make that default reachable,
prefillStepSize becomes optional (nil = let the model choose); other
models keep their 512 default via `windowSize ?? 512`.

Measured on translategemma-4b-it-4bit (Apple Silicon, 577-token prompt,
greedy): prompt prefill 177 -> 463 tok/s (2.6x), 3253 -> 1246 ms. Greedy
output is byte-identical to before, including prompts longer than the
1024 sliding window where the rotating KV cache rotates.
@beshkenadze beshkenadze force-pushed the upstream-gemma3-prefill branch from 5faffc8 to c4aa0c4 Compare June 13, 2026 17:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant