Validate BatchGenerator inputs#4
Open
ronaldmannak wants to merge 4 commits into
Open
Conversation
Gajesh2007
added a commit
that referenced
this pull request
Jun 20, 2026
…wins Host-overhead reductions on the memory-bandwidth-bound autoregressive decode path. No model output changes. CHANGE #5 — penalty sampler runs on-device (highest value): - makeRepetitionSampler now applies repetition/presence/frequency penalties with MLX gather (take) → elementwise transform (which) → scatter (indexed assignment), all on-device. Removes the per-token hard GPU sync (eval + asArray) and full-vocab GPU↔host round trip. Logits are upcast to float32 first so the math is bit-for-bit the original CPU formula (v>0 ? v/rp : v*rp; then -presence; then -frequency*count). The scatter is functional, so the input logits array is never mutated. - TokenHistoryHolder maintains token counts incrementally on append() instead of rescanning the full history each step (was O(n) per token → O(n^2) per generation; now O(unique tokens) per call). Both history append sites route through the new append(). CHANGE #4 — cheap per-token wins: - GenerationBatch.step() skips the full-vocab logSumExp normalization when no active row needs it (gated by skipLogprobNormalization, set by the Scheduler from the rows' sampling params, combined with AND on extend). Only top-p and the penalty sampler need normalized logprobs; temperature/top-k/min-p/categorical/argMax are shift-invariant, so results are unchanged. Defaults to false (normalize) for safety. - step() reads the single sampled token via item() for B==1 instead of materializing a 1-element array; multi-row path unchanged. - outputTokenIds streaming delta change left UNCHANGED: OutputCollector merge propagates new.outputTokenIds as the cumulative value and the field is a public documented "cumulative" contract, so emitting only a delta risks an external streaming regression. (See report.) CHANGE #3 — resource debug off the hot path: - DARKBLOOM_MLX_RESOURCE_DEBUG now defaults OFF (opt-in). The Memory.* resource/byte reads in the step loop are already gated behind the flag, so they no longer run per step by default. Capability preserved; enable with =1/true/yes/on. Tests: CBRepetitionPenaltyDeviceTests asserts the on-device sampler equals the original CPU formula within 1e-4 (incl. presence/frequency- only and the empty / out-of-range guards) and that incremental counts match a full rescan. Existing CBRepetitionPenaltyTests and CBGenerationBatchShapeTests stay green.
Gajesh2007
added a commit
that referenced
this pull request
Jun 20, 2026
* fix(kvcache): correct mask fill in quantizedScaledDotProductAttention Masked positions used +Float.leastNormalMagnitude (~0) instead of a large negative, leaking attention to future/padded tokens. Use -greatestFiniteMagnitude. Local Darkbloom patch (DAR-313) to be upstreamed. * feat(kvcache): QuantizedBatchKVCache — quantized batched/paged KV cache (DAR-314) Quantized counterpart to BatchKVCache: BatchedCache + BatchPositionedKVCache + QuantizedKVCacheProtocol. Incremental-append quantization (O(step)); update() returns dequant fp16 fallback, updateQuantized() returns quantized tuples for the kernel path; per-row left-padding + allocationStep growth. Numerically verified (kernel_relL2=0 at 8-bit vs dequant-same-data) incl finalize/extend/filter. * feat(scheduler): KVQuantizationConfig hook in cacheFactories (DAR-317) SchedulerConfig.kvQuantization (nil = off, byte-identical). When set, full-attention layers build QuantizedBatchKVCache; sliding/recurrent stay fp16. cacheFactoryKind test seam added. Off-path preserves exact Mamba->Arrays->Rotating->default order. * feat(kvcache): sink-safe DequantBatchKVCache for GPT-OSS (DAR-322) Refactor QuantizedBatchKVCache into a shared base + kernel subclass (QuantizedKVCacheProtocol, Gemma) + DequantBatchKVCache (NOT conforming, so GPT-OSS attention takes the sink-safe regular path on dequantized fp16 while storage stays quantized). KVQuantCacheKind selects the cache in cacheFactories. Gemma kernel path unchanged (bit-exact). * fix(kvcache): validate live KV quant attention paths * fix(kvcache): address Codex review (#43) — dequant mask fast path + composite-cache guard - DequantBatchKVCache.makeMask: return .none for single-token decode with no left padding (mirror BatchKVCache) so the dequant path takes SDPA's unmasked fast path. An explicit all-true mask routes through the divergent MLX #3384 branch that flips top-1 logprobs on quantized Gemma 4 and traps continuous- batched decode in repetition loops. The kernel base class is left unchanged (its always-masked behavior is what the Gemma 4 g128 scheme was validated on). - cacheFactoryKind: never replace a composite CacheList layer (e.g. BaichuanM1 / FalconH1 = MambaCache + KVCacheSimple) with a quantized KV cache; those models downcast the per-layer cache back to CacheList in forward(). Only plain full-attention layers are eligible for KV-quant. * fix(kvcache): detach per-step batchOffset on quantized decode — IOGPU leak (#43) QuantizedBatchKVCacheBase.updateQuantized chained lazy scalar additions to batchOffset every decode step without detaching, unlike BatchKVCache.update. For models that share one RoPE offset across layers (Gemma 4) most caches never consume their own batchOffset, leaking a scalar buffer per step until the iogpu numResources ceiling aborts the process. Mirror BatchKVCache's decode-only asyncEval(batchOffset) detach — same DAR-325 class of leak, now fixed for the quantized cache (covers both kernel and dequant paths via the shared updateQuantized chokepoint). * fix(kvcache): restrict KV-quant to plain KVCacheSimple + force cold path under quant - cacheFactoryKind only quantizes plain full-attention KVCacheSimple layers; ChunkedKVCache and non-KVCacheSimple custom caches (e.g. DeepseekV4LayerCache) stay fp16 so model-side downcasts/semantics are preserved. - admitWaiting forces the cold path when SchedulerConfig.kvQuantization is set so fp16 warm/restore rows never merge into a live quantized batch (extendBatched class-mismatch). Defense in depth; provider already disables prefix under quant. - add non-Metal CBCacheFactorySelectionTests covering the selection matrix. * perf(continuous-batching): on-GPU penalty sampler + per-token decode wins Host-overhead reductions on the memory-bandwidth-bound autoregressive decode path. No model output changes. CHANGE #5 — penalty sampler runs on-device (highest value): - makeRepetitionSampler now applies repetition/presence/frequency penalties with MLX gather (take) → elementwise transform (which) → scatter (indexed assignment), all on-device. Removes the per-token hard GPU sync (eval + asArray) and full-vocab GPU↔host round trip. Logits are upcast to float32 first so the math is bit-for-bit the original CPU formula (v>0 ? v/rp : v*rp; then -presence; then -frequency*count). The scatter is functional, so the input logits array is never mutated. - TokenHistoryHolder maintains token counts incrementally on append() instead of rescanning the full history each step (was O(n) per token → O(n^2) per generation; now O(unique tokens) per call). Both history append sites route through the new append(). CHANGE #4 — cheap per-token wins: - GenerationBatch.step() skips the full-vocab logSumExp normalization when no active row needs it (gated by skipLogprobNormalization, set by the Scheduler from the rows' sampling params, combined with AND on extend). Only top-p and the penalty sampler need normalized logprobs; temperature/top-k/min-p/categorical/argMax are shift-invariant, so results are unchanged. Defaults to false (normalize) for safety. - step() reads the single sampled token via item() for B==1 instead of materializing a 1-element array; multi-row path unchanged. - outputTokenIds streaming delta change left UNCHANGED: OutputCollector merge propagates new.outputTokenIds as the cumulative value and the field is a public documented "cumulative" contract, so emitting only a delta risks an external streaming regression. (See report.) CHANGE #3 — resource debug off the hot path: - DARKBLOOM_MLX_RESOURCE_DEBUG now defaults OFF (opt-in). The Memory.* resource/byte reads in the step loop are already gated behind the flag, so they no longer run per step by default. Capability preserved; enable with =1/true/yes/on. Tests: CBRepetitionPenaltyDeviceTests asserts the on-device sampler equals the original CPU formula within 1e-4 (incl. presence/frequency- only and the empty / out-of-range guards) and that incremental counts match a full rescan. Existing CBRepetitionPenaltyTests and CBGenerationBatchShapeTests stay green. * perf(kvcache): compile quantized-SDPA softmax + skip n==1 decode mask Two decode-throughput levers for the hand-rolled quantized scaled-dot-product attention, aimed at fp16 decode parity (KV-quant is a memory-capacity feature, not a speedup). Lever A: fuse the quantized-SDPA softmax (max/sub/exp/where/sum/div) via compile(shapeless:). It is the largest cluster of small elementwise kernels on the decode path; fusing cuts per-step launch overhead. shapeless: true handles the growing-context shape churn — the score tensor's kL axis grows every token, but a shapeless graph only recompiles on rank/dtype changes, not shape changes, so it compiles once and is reused. The cores are pure axis:-1 reductions/broadcasts (no batch/head/kL constants); the quantizedMM matmuls and GQA reshape stay outside (one large kernel each, and the reshape depends on the batch axis which shapeless would bake in stale). Both no-sink and sink variants are compiled. Numerics unchanged. Lever B: skip the all-true causal mask for single-query decode in QuantizedBatchKVCacheBase.makeMask — for n==1 with no left padding and no sliding window the mask is a no-op, so return .none and skip the per-step createCausalMask build + where. Bit-identical to the old always-masked path; n>1 and windowed/left-padded cases unchanged. Mirrors BatchKVCache.makeMask. Tests: Tests/MLXLMTests/QuantizedSDPATests.swift — GQA parity vs an independent dequantized reference for n==1 and n>1 (max|delta| ~2-4e-7), sink-softmax parity (~1e-7), n==1 .none-vs-all-true-mask parity (delta=0), and makeMask fast-path conditions. * feat(kvcache): compose KV-quant with prefix-cache checkpoint restore (DAR-319 v2) Checkpoint restore now rebuilds full-attention rows as QUANTIZED batched caches (re-quantizing the captured fp16 prefix via the cold cache factory, restoredFullAttentionCache) instead of fp16 BatchKVCache.merge, so a restored row stays concrete-class-compatible with quantized cold rows under extendBatched. Admission no longer forces the cold path under KV-quant for the checkpoint-restore branch (engine-tier warm prefix stays fp16/cold). Tests: restored-vs-cold parity under KV-quant + concurrent cold+restored (the extendBatched type-compat regression) in CBCheckpointRestoreTests; all 92 CB tests + Quantized SDPA green. * perf(kvcache): skip finish-time prompt-cache extract when no prefix cache GenerationBatch only extracts (dequant + copy) a finished row's prompt cache when a prefix cache will consume it (capturePromptCacheOnFinish, set by the scheduler from prefixCache != nil). Avoids a per-finish fp16 copy of the KV history on runs without prefix caching (e.g. KV-quant hybrids whose prefix tier isn't the in-GPU block cache). Also removes internal tracker refs from comments for readability. 92 CB tests green.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
BatchGeneratorErrorcases for invalid configuration and request inputsBatchGenerator.insert(...)throwing so invalid requests fail before mutating generator state0try insert(...)Rationale
BatchGenerator.insert(...)previously validated onlymaxTokens.count. MismatchedsamplersorstateMachinescould trap later during indexing, nonpositive batch/prefill settings could prevent progress, and empty prompt rows silently fell back to token0.This PR makes those failures explicit and deterministic. Validation happens before
uidCounterorunprocessedare mutated, so failed inserts do not leak UIDs or partially enqueue requests.API Note
This is source-breaking for callers of the new continuous batching API:
BatchGenerator.insert(...)now throwsBatchGeneratorError.