Skip to content

Layr/promptpadding#3

Open
ronaldmannak wants to merge 4 commits into
Layr-Labs:mainfrom
PicoMLX:layr/promptpadding
Open

Layr/promptpadding#3
ronaldmannak wants to merge 4 commits into
Layr-Labs:mainfrom
PicoMLX:layr/promptpadding

Conversation

@ronaldmannak

Copy link
Copy Markdown

Proposed changes

Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Gajesh2007 added a commit that referenced this pull request Jun 20, 2026
…wins

Host-overhead reductions on the memory-bandwidth-bound autoregressive
decode path. No model output changes.

CHANGE #5 — penalty sampler runs on-device (highest value):
- makeRepetitionSampler now applies repetition/presence/frequency
  penalties with MLX gather (take) → elementwise transform (which) →
  scatter (indexed assignment), all on-device. Removes the per-token
  hard GPU sync (eval + asArray) and full-vocab GPU↔host round trip.
  Logits are upcast to float32 first so the math is bit-for-bit the
  original CPU formula (v>0 ? v/rp : v*rp; then -presence; then
  -frequency*count). The scatter is functional, so the input logits
  array is never mutated.
- TokenHistoryHolder maintains token counts incrementally on append()
  instead of rescanning the full history each step (was O(n) per token
  → O(n^2) per generation; now O(unique tokens) per call). Both history
  append sites route through the new append().

CHANGE #4 — cheap per-token wins:
- GenerationBatch.step() skips the full-vocab logSumExp normalization
  when no active row needs it (gated by skipLogprobNormalization, set by
  the Scheduler from the rows' sampling params, combined with AND on
  extend). Only top-p and the penalty sampler need normalized logprobs;
  temperature/top-k/min-p/categorical/argMax are shift-invariant, so
  results are unchanged. Defaults to false (normalize) for safety.
- step() reads the single sampled token via item() for B==1 instead of
  materializing a 1-element array; multi-row path unchanged.
- outputTokenIds streaming delta change left UNCHANGED: OutputCollector
  merge propagates new.outputTokenIds as the cumulative value and the
  field is a public documented "cumulative" contract, so emitting only a
  delta risks an external streaming regression. (See report.)

CHANGE #3 — resource debug off the hot path:
- DARKBLOOM_MLX_RESOURCE_DEBUG now defaults OFF (opt-in). The Memory.*
  resource/byte reads in the step loop are already gated behind the flag,
  so they no longer run per step by default. Capability preserved; enable
  with =1/true/yes/on.

Tests: CBRepetitionPenaltyDeviceTests asserts the on-device sampler
equals the original CPU formula within 1e-4 (incl. presence/frequency-
only and the empty / out-of-range guards) and that incremental counts
match a full rescan. Existing CBRepetitionPenaltyTests and
CBGenerationBatchShapeTests stay green.
Gajesh2007 added a commit that referenced this pull request Jun 20, 2026
* fix(kvcache): correct mask fill in quantizedScaledDotProductAttention

Masked positions used +Float.leastNormalMagnitude (~0) instead of a large
negative, leaking attention to future/padded tokens. Use -greatestFiniteMagnitude.
Local Darkbloom patch (DAR-313) to be upstreamed.

* feat(kvcache): QuantizedBatchKVCache — quantized batched/paged KV cache (DAR-314)

Quantized counterpart to BatchKVCache: BatchedCache + BatchPositionedKVCache +
QuantizedKVCacheProtocol. Incremental-append quantization (O(step)); update()
returns dequant fp16 fallback, updateQuantized() returns quantized tuples for the
kernel path; per-row left-padding + allocationStep growth. Numerically verified
(kernel_relL2=0 at 8-bit vs dequant-same-data) incl finalize/extend/filter.

* feat(scheduler): KVQuantizationConfig hook in cacheFactories (DAR-317)

SchedulerConfig.kvQuantization (nil = off, byte-identical). When set, full-attention
layers build QuantizedBatchKVCache; sliding/recurrent stay fp16. cacheFactoryKind
test seam added. Off-path preserves exact Mamba->Arrays->Rotating->default order.

* feat(kvcache): sink-safe DequantBatchKVCache for GPT-OSS (DAR-322)

Refactor QuantizedBatchKVCache into a shared base + kernel subclass
(QuantizedKVCacheProtocol, Gemma) + DequantBatchKVCache (NOT conforming, so
GPT-OSS attention takes the sink-safe regular path on dequantized fp16 while
storage stays quantized). KVQuantCacheKind selects the cache in cacheFactories.
Gemma kernel path unchanged (bit-exact).

* fix(kvcache): validate live KV quant attention paths

* fix(kvcache): address Codex review (#43) — dequant mask fast path + composite-cache guard

- DequantBatchKVCache.makeMask: return .none for single-token decode with no
  left padding (mirror BatchKVCache) so the dequant path takes SDPA's unmasked
  fast path. An explicit all-true mask routes through the divergent MLX #3384
  branch that flips top-1 logprobs on quantized Gemma 4 and traps continuous-
  batched decode in repetition loops. The kernel base class is left unchanged
  (its always-masked behavior is what the Gemma 4 g128 scheme was validated on).
- cacheFactoryKind: never replace a composite CacheList layer (e.g. BaichuanM1 /
  FalconH1 = MambaCache + KVCacheSimple) with a quantized KV cache; those models
  downcast the per-layer cache back to CacheList in forward(). Only plain
  full-attention layers are eligible for KV-quant.

* fix(kvcache): detach per-step batchOffset on quantized decode — IOGPU leak (#43)

QuantizedBatchKVCacheBase.updateQuantized chained lazy scalar additions to
batchOffset every decode step without detaching, unlike BatchKVCache.update.
For models that share one RoPE offset across layers (Gemma 4) most caches never
consume their own batchOffset, leaking a scalar buffer per step until the iogpu
numResources ceiling aborts the process. Mirror BatchKVCache's decode-only
asyncEval(batchOffset) detach — same DAR-325 class of leak, now fixed for the
quantized cache (covers both kernel and dequant paths via the shared
updateQuantized chokepoint).

* fix(kvcache): restrict KV-quant to plain KVCacheSimple + force cold path under quant

- cacheFactoryKind only quantizes plain full-attention KVCacheSimple layers;
  ChunkedKVCache and non-KVCacheSimple custom caches (e.g. DeepseekV4LayerCache)
  stay fp16 so model-side downcasts/semantics are preserved.
- admitWaiting forces the cold path when SchedulerConfig.kvQuantization is set so
  fp16 warm/restore rows never merge into a live quantized batch (extendBatched
  class-mismatch). Defense in depth; provider already disables prefix under quant.
- add non-Metal CBCacheFactorySelectionTests covering the selection matrix.

* perf(continuous-batching): on-GPU penalty sampler + per-token decode wins

Host-overhead reductions on the memory-bandwidth-bound autoregressive
decode path. No model output changes.

CHANGE #5 — penalty sampler runs on-device (highest value):
- makeRepetitionSampler now applies repetition/presence/frequency
  penalties with MLX gather (take) → elementwise transform (which) →
  scatter (indexed assignment), all on-device. Removes the per-token
  hard GPU sync (eval + asArray) and full-vocab GPU↔host round trip.
  Logits are upcast to float32 first so the math is bit-for-bit the
  original CPU formula (v>0 ? v/rp : v*rp; then -presence; then
  -frequency*count). The scatter is functional, so the input logits
  array is never mutated.
- TokenHistoryHolder maintains token counts incrementally on append()
  instead of rescanning the full history each step (was O(n) per token
  → O(n^2) per generation; now O(unique tokens) per call). Both history
  append sites route through the new append().

CHANGE #4 — cheap per-token wins:
- GenerationBatch.step() skips the full-vocab logSumExp normalization
  when no active row needs it (gated by skipLogprobNormalization, set by
  the Scheduler from the rows' sampling params, combined with AND on
  extend). Only top-p and the penalty sampler need normalized logprobs;
  temperature/top-k/min-p/categorical/argMax are shift-invariant, so
  results are unchanged. Defaults to false (normalize) for safety.
- step() reads the single sampled token via item() for B==1 instead of
  materializing a 1-element array; multi-row path unchanged.
- outputTokenIds streaming delta change left UNCHANGED: OutputCollector
  merge propagates new.outputTokenIds as the cumulative value and the
  field is a public documented "cumulative" contract, so emitting only a
  delta risks an external streaming regression. (See report.)

CHANGE #3 — resource debug off the hot path:
- DARKBLOOM_MLX_RESOURCE_DEBUG now defaults OFF (opt-in). The Memory.*
  resource/byte reads in the step loop are already gated behind the flag,
  so they no longer run per step by default. Capability preserved; enable
  with =1/true/yes/on.

Tests: CBRepetitionPenaltyDeviceTests asserts the on-device sampler
equals the original CPU formula within 1e-4 (incl. presence/frequency-
only and the empty / out-of-range guards) and that incremental counts
match a full rescan. Existing CBRepetitionPenaltyTests and
CBGenerationBatchShapeTests stay green.

* perf(kvcache): compile quantized-SDPA softmax + skip n==1 decode mask

Two decode-throughput levers for the hand-rolled quantized scaled-dot-product attention, aimed at fp16 decode parity (KV-quant is a memory-capacity feature, not a speedup).

Lever A: fuse the quantized-SDPA softmax (max/sub/exp/where/sum/div) via compile(shapeless:). It is the largest cluster of small elementwise kernels on the decode path; fusing cuts per-step launch overhead. shapeless: true handles the growing-context shape churn — the score tensor's kL axis grows every token, but a shapeless graph only recompiles on rank/dtype changes, not shape changes, so it compiles once and is reused. The cores are pure axis:-1 reductions/broadcasts (no batch/head/kL constants); the quantizedMM matmuls and GQA reshape stay outside (one large kernel each, and the reshape depends on the batch axis which shapeless would bake in stale). Both no-sink and sink variants are compiled. Numerics unchanged.

Lever B: skip the all-true causal mask for single-query decode in QuantizedBatchKVCacheBase.makeMask — for n==1 with no left padding and no sliding window the mask is a no-op, so return .none and skip the per-step createCausalMask build + where. Bit-identical to the old always-masked path; n>1 and windowed/left-padded cases unchanged. Mirrors BatchKVCache.makeMask.

Tests: Tests/MLXLMTests/QuantizedSDPATests.swift — GQA parity vs an independent dequantized reference for n==1 and n>1 (max|delta| ~2-4e-7), sink-softmax parity (~1e-7), n==1 .none-vs-all-true-mask parity (delta=0), and makeMask fast-path conditions.

* feat(kvcache): compose KV-quant with prefix-cache checkpoint restore (DAR-319 v2)

Checkpoint restore now rebuilds full-attention rows as QUANTIZED batched
caches (re-quantizing the captured fp16 prefix via the cold cache factory,
restoredFullAttentionCache) instead of fp16 BatchKVCache.merge, so a restored
row stays concrete-class-compatible with quantized cold rows under
extendBatched. Admission no longer forces the cold path under KV-quant for the
checkpoint-restore branch (engine-tier warm prefix stays fp16/cold).

Tests: restored-vs-cold parity under KV-quant + concurrent cold+restored
(the extendBatched type-compat regression) in CBCheckpointRestoreTests; all
92 CB tests + Quantized SDPA green.

* perf(kvcache): skip finish-time prompt-cache extract when no prefix cache

GenerationBatch only extracts (dequant + copy) a finished row's prompt cache
when a prefix cache will consume it (capturePromptCacheOnFinish, set by the
scheduler from prefixCache != nil). Avoids a per-finish fp16 copy of the KV
history on runs without prefix caching (e.g. KV-quant hybrids whose prefix
tier isn't the in-GPU block cache). Also removes internal tracker refs from
comments for readability. 92 CB tests green.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant