diff --git a/Cargo.lock b/Cargo.lock index 6d21384d..d340e506 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3897,6 +3897,23 @@ dependencies = [ "vllm-text", ] +[[package]] +name = "openinfer-qwen3-4b-dflash" +version = "0.1.0" +dependencies = [ + "anyhow", + "crossbeam-channel", + "cudarc", + "half", + "log", + "memmap2", + "openinfer-core", + "openinfer-kernels", + "safetensors", + "serde", + "serde_json", +] + [[package]] name = "openinfer-qwen35-4b" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 16ef9c71..b325ea5e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ members = [ "openinfer-deepseek-v2-lite", "openinfer-kimi-k2", "openinfer-qwen3-4b", + "openinfer-qwen3-4b-dflash", "openinfer-qwen35-4b", "openinfer-sample", "openinfer-kv-cache", @@ -129,6 +130,7 @@ openinfer-engine = { path = "openinfer-engine" } openinfer-kernels = { path = "openinfer-kernels" } openinfer-kimi-k2 = { path = "openinfer-kimi-k2" } openinfer-qwen3-4b = { path = "openinfer-qwen3-4b" } +openinfer-qwen3-4b-dflash = { path = "openinfer-qwen3-4b-dflash" } openinfer-qwen35-4b = { path = "openinfer-qwen35-4b" } openinfer-sample = { path = "openinfer-sample" } openinfer-deepseek-v2-lite = { path = "openinfer-deepseek-v2-lite" } diff --git a/docs/index.md b/docs/index.md index 8f41a7f6..37678285 100644 --- a/docs/index.md +++ b/docs/index.md @@ -29,6 +29,7 @@ Organized by domain (model line / subsystem / playbook / lesson) instead of by l | `models/qwen3/green-ctx-sm-partition.md` | Green Context SM partition (`OPENINFER_SM_PARTITION=20`) runs prefill/decode on disjoint SMs so decode stops stalling behind co-scheduled prefill: 5090 mid-band ITL p99 ~halved, TPOT down (−22% @QPS12), but TTFT 2–4× worse (prefill deferred + fewer SMs) — a TTFT↔ITL/TPOT trade, not a free win. Two-graph change (decode CUDA graph captured on the green decode stream) adds ~5% ITL p99 / 1–4% TPOT on top. Mechanism, A/B table, Xid-31/gemm_lt pitfalls. | | `models/qwen3/roadmap.md` | Qwen3-4B roadmap (2026-06 review): line is the maturity bar; #220 RoPE OOB, batched greedy sampling (#307), mixed greedy/non-greedy sampling (#284), and pegaflow KV offload (#316) are landed; open set is zero TP coverage, zero-adapter-only LoRA gate, dropped prefix-cache observability, stale docs, and YaRN #8 follow-up. | | `models/qwen3/model-crate.md` | `openinfer-qwen3-4b` owns Qwen3 config/weights/executor/scheduler/tests/kernel plan; root sees generic `EngineHandle`; split-K retuned to `256/64`, with 4k/64 serving TPOT p50 at `6.46ms` on RTX 5090. | +| `models/qwen3/dflash.md` | `openinfer-qwen3-4b-dflash` supports only `z-lab/Qwen3-4B-DFlash-b16`: standalone model config/weights/forward plus transformers remote-code parity, with no generic DFlash framework or Qwen3 server/controller changes in this task. | | `models/qwen3/prefix-cache.md` | Prefix caching on by default for Qwen3-4B: full-block kvbm radix matching at the executor, suffix-only prefill. Repeated ~1900-token prompt TTFT 141.8 → 16.3ms p50 (8.7×); warm TTFT ≈ TPOT + ~5ms setup. Includes the RoPE scalar-path corruption fix and the drain-the-stream TTFT measurement pitfall. | | `models/qwen3/accuracy-gate.md` | Qwen3-4B instance of the logits golden gate (`tests/hf_golden_gate.rs`): 48 teacher-forced sequences / 816 positions vs a stored HF bf16 golden, replayed over bs=1 / batched eager / CUDA-graph. Strict guards: regret check + mean ≤ 0.06 + p99 ≤ 0.20; absolute max printed but not asserted (coverage-unstable). Methodology in `subsystems/correctness/`. | | `models/qwen3/kernels-crate.md` | Phase 1 split implemented and 5090-verified: Qwen3-4B kernel surface lives in `openinfer-kernels`; release build, test-target compile, accuracy gate, and bench snapshot pass. | diff --git a/docs/models/qwen3/dflash.md b/docs/models/qwen3/dflash.md new file mode 100644 index 00000000..01d82580 --- /dev/null +++ b/docs/models/qwen3/dflash.md @@ -0,0 +1,448 @@ +# Qwen3-4B-DFlash model + +**TL;DR**: `openinfer-qwen3-4b-dflash` supports only the `z-lab/Qwen3-4B-DFlash-b16` model. It now has two draft-only execution surfaces: the original bs1 transformers-parity forward path, and an internal exact-shape batch runner/scheduler that batches already-prepared `noise_embedding`, selected target hidden states, and `position_ids`. The forward gate currently measures mean delta `0.034243`, p99 `0.125000`, max `0.500000` over 7,680 output values for uncached, unified-cache one-shot, and first-step draft-cache paths; batch-vs-single and executor request-tag smoke extend that gate. Cache control APIs are fail-closed for unknown request ids. The scheduler thread now joins on handle drop (mirrors `EngineHandle`) and resident draft caches are bounded by `max_caches` with an explicit `drop_cache` retirement path (mirrors Qwen3 `drop_request`); over-cap admission fails closed. The batch K/V concatenation now uses a fused `strided_segment_copy` kernel instead of a per-request `memcpy_dtod` loop, lifting bs32 draft throughput from ~42K to ~63K tok/s (1.5x) with zero accuracy drift. Target verification, acceptance, fallback token selection, and OpenAI serving remain out of scope. + +Last touched: 2026-06 + +## Boundary + +This task is model-specific. The boundary is: + +| Crate | Owns | +| --- | --- | +| `openinfer-qwen3-4b-dflash` | `Qwen3-4B-DFlash-b16` config, weights, draft forward, draft-only batch executor/scheduler, model-specific kernels/wrappers, and transformers parity tests | +| `openinfer-qwen3-4b` | Unchanged existing Qwen3 target serving, scheduler, KV, LoRA/offload/TP policy, and HF logits gate | + +Out of scope for this task: generic speculative decoding, a generic DFlash abstraction, OpenAI/server flags, LoRA/TP/KV-offload interactions, target verification, acceptance-length calculation, fallback token selection, and target hidden extraction from Qwen3. + +## Reference Model + +The authoritative reference is the Hugging Face repo `z-lab/Qwen3-4B-DFlash-b16`, not an inferred architecture from the target Qwen3 crate. The model card uses: + +```python +transformers==4.57.3 +AutoModel.from_pretrained(..., trust_remote_code=True) +draft.spec_generate(target, input_ids, ...) +``` + +The local checkpoint at `/home/hezhaozhao/models/Qwen3-4B-DFlash-b16` contains the same remote-code shape: + +| Field | Value | +| --- | --- | +| `architectures` | `DFlashDraftModel` | +| draft layers | `5` | +| target layers | `36` | +| hidden size | `2560` | +| intermediate size | `9728` | +| attention heads / KV heads | `32 / 8` | +| head dim | `128` | +| block size | `16` | +| mask token | `151669` | +| target hidden layers | `[1, 9, 17, 25, 33]` | +| vocab size | `151936` | + +Checkpoint keys are unprefixed relative to a target `model.` namespace: `layers.*`, `fc.weight`, `hidden_norm.weight`, and `norm.weight`. `fc.weight` is `[2560, 12800]`, i.e. one hidden-sized projection from five concatenated target hidden states. + +## Draft Forward + +The draft forward is not target Qwen3 attention with a different checkpoint. Its attention is dense and non-causal: + +1. `target_hidden = hidden_norm(fc(concat(selected target hidden states)))` +2. `hidden_states = noise_embedding` +3. for each of the five draft layers: + - RMSNorm `hidden_states` + - Q comes from normalized noise hidden + - K/V come from `cat(target_hidden, hidden_states)` + - Q/K get Qwen3 head RMSNorm and RoPE + - attention is non-causal over the whole `target_hidden + noise_hidden` span + - residual add + - post-attention RMSNorm + Qwen3 MLP + residual add +4. final `norm(hidden_states)` + +The crate should expose draft-model primitives, not speculative serving: + +```rust +pub struct DFlashDraftModel { ... } + +impl DFlashDraftModel { + pub fn load(model_path: &Path, device_ordinal: usize) -> anyhow::Result; + pub fn config(&self) -> &DFlashConfig; + pub fn target_layer_ids(&self) -> &[usize]; + pub fn forward( + &self, + noise_embedding: &HiddenStates, + selected_target_hidden: &DFlashTargetHidden, + position_ids: &[i32], + ) -> anyhow::Result; +} +``` + +The first version takes already-selected target hidden states as input and returns the final draft hidden states. Extracting those hidden states from `openinfer-qwen3-4b`, target verification, acceptance length calculation, and KV cropping are not part of this model implementation. + +## Draft-Only Batch Runner + +The batch path is intentionally internal. It is not an OpenAI-compatible text +generation surface because the DFlash draft model does not consume prompt token +ids and does not own a language-model head. Callers must provide device +`HiddenStates` for: + +| Input | Shape | +| --- | --- | +| `noise_embedding` | `[q_len, hidden_size]` | +| `target_hidden` | `[ctx_len, target_layer_count * hidden_size]` | +| `position_ids` | `ctx_len + q_len` host positions | + +The runner groups only exact-shape requests. The batch key is +`(q_len, ctx_len, past_len, cache_mode)`. `NoCache` requests use the real +batched path: compact D2D input staging, batched FC/context projection, batched +per-layer Q/K/V and MLP GEMMs, and FlashInfer +`BatchPrefillWithRaggedKVCache` in non-causal mode for attention. `DraftCache` +requests keep the same `DFlashDraftCache` lifecycle and are executed serially +inside the GPU owner thread in this step; cross-request draft-cache batching +needs a compact past-K/V layout and should be added with the target +verification loop. + +The public Rust surface is crate-local serving infrastructure, not server API: + +```rust +pub struct DFlashDraftHostRequest { ... } +pub struct DFlashDraftHostResponse { ... } +pub struct DFlashExecutor { ... } +pub struct DFlashSchedulerHandle { ... } +``` + +`DFlashSchedulerHandle` is a single-thread GPU owner with FCFS exact-shape +batching, a small `max_wait` coalescing window, and `max_total_tokens` +admission over `(ctx_len + q_len + past_len)` for each candidate batch. Its +public `submit` boundary uses host bf16 buffers and returns host bf16 output so +CUDA device tensors do not cross thread/context ownership boundaries. It also +owns per-request draft cache state through `reset_cache`, `crop_cache`, +`cache_seq_len`, and `drop_cache`, and the cache-reading calls error on unknown +request ids instead of silently treating them as empty state; `drop_cache` is +idempotent (a missing cache is not an error) so callers can retire a request +from any lifecycle state. Resident caches are bounded by `max_caches` +(`DFlashExecutorOptions`, default 64); exceeding it fails closed until a +retired request's cache is dropped — this mirrors Qwen3's per-request block +accounting under the fixed `KvCacheManager` pool and prevents the unbounded +GPU-memory leak the old grow-only `HashMap` had. The handle joins the scheduler +thread on drop (the last clone closes the channel and joins, mirroring +`EngineHandle`), so dropping the handle without an explicit shutdown no longer +leaks the GPU-owner thread. `NoCache` requests use the real batched path, while +host `DraftCache` requests run serially until compact past-K/V batching lands. +The executor also exposes a borrowed compact batch view for same-thread +controller experiments. + +## Draft Cache + +Do not maintain separate public cache concepts for this crate. The reference +Python uses one `past_key_values_draft = DynamicCache()` in `spec_generate`, +then calls the drafter with: + +```python +position_ids=position_ids[:, past_key_values_draft.get_seq_length(): start + block_size] +past_key_values=past_key_values_draft +use_cache=True +past_key_values_draft.crop(start) +``` + +OpenInfer mirrors that boundary with one `DFlashDraftCache`: + +| State | Meaning | +| --- | --- | +| `prepare_step_context(...)` | Projects the current selected target hidden states and prepares per-layer context `K/V`; this replaces the old standalone `prepare_context_cache(...)` wording. | +| `forward_with_draft_cache(...)` | Runs one draft block, appends step context `K/V` and noise-token `K/V` to each layer's draft past state, and advances `seq_len`. | +| `crop(seq_len)` / `reset()` | Matches the reference `DynamicCache.crop(start)` lifecycle after target verification decides how far the draft state remains valid. | + +The first-step cached path is numerically identical to the standalone HF +remote-code forward because there is no existing past yet. Cross-step cached +parity must be validated only after the target verification/controller is added; +without the target loop, a second cached draft step is not the same numerical +problem as the old no-draft-cache substitution probe. + +## Correctness Gate + +The accuracy bar is transformers parity. For the draft crate that means: + +| Gate | Purpose | +| --- | --- | +| config/loader shape test | Reject wrong checkpoint layout early: `target_layer_ids`, `block_size`, `mask_token_id`, `fc.weight`, layer count, and attention/MLP shapes | +| draft-forward smoke | Load `/home/hezhaozhao/models/Qwen3-4B-DFlash-b16`, run a tiny GPU block with synthetic `noise_embedding`, selected target hidden states, and position ids, and catch shape/kernel failures | +| transformers forward parity | Compare the standalone draft forward against the HF remote-code model for fixed synthetic `noise_embedding`, selected target hidden states, and position ids | +| batch-vs-single parity | Compare two exact-shape batched rows against the bs1 forward output under the same DFlash tolerance | +| executor smoke | Submit request-tagged exact-shape `NoCache` requests and assert output shape/request ids | +| scheduler cache smoke | Submit host `DraftCache` request, then assert scheduler-owned `cache_seq_len`, `crop_cache`, and `reset_cache` behavior; also checks control messages preserve FIFO ordering behind pending submits | +| cache control rejection | `reset_cache` / `crop_cache` / `cache_seq_len` fail closed on unknown request ids; `drop_cache` is idempotent (retiring an unknown id is not an error) | +| drafter generation parity | Run a greedy bs1 transformers target loop twice, once with the HF drafter and once with the OpenInfer drafter, then compare generated token ids/text and acceptance lengths | + +Do not use `Qwen3-4B-Instruct-2507` as a correctness baseline for this model. The checkpoint is documented for `Qwen/Qwen3-4B`, but this task's gate is the DFlash draft model's own transformers forward, not target acceptance rate. + +## Kernel Notes + +Existing Qwen3 target attention is causal/paged and does not match `Qwen3-4B-DFlash-b16` draft attention. The draft kernel path should follow vLLM/FlashAttention semantics where possible: Q/K/V in head-major logical shape, GQA expansion by `q_head / (num_q_heads / num_kv_heads)`, RoPE on Q and K, softmax over all context+draft keys, and no causal mask. + +The reference implementation to mirror is vLLM's attention stack, especially `vllm.v1.attention.backends.flash_attn.FlashAttentionBackend` and `vllm.v1.attention.backends.flashinfer.FlashInferBackend`: both explicitly support `supports_non_causal()`, and their prefill/decode planners expose the causal flag and varlen context shape that DFlash needs. + +The batch runner uses FlashInfer `BatchPrefillWithRaggedKVCache` with +`MaskMode::kNone` for compact non-causal attention. That keeps the DFlash batch +path close to vLLM's varlen/non-causal attention semantics instead of looping +over single-request prefill. + +## Accuracy Scripts + +The DFlash scripts intentionally mirror the rest of the repository: + +| Script | Output | Use | +| --- | --- | --- | +| `tools/accuracy/dump_qwen3_4b_dflash_hf_golden.py` | `test_data/qwen3-4b-dflash-hf-golden.safetensors` | Offline transformers remote-code forward oracle for the Rust gate | +| `openinfer-qwen3-4b-dflash/tests/hf_golden_gate.rs` | test pass/fail plus delta distribution | Release Rust gate that replays the stored oracle without Python | +| `tools/accuracy/compare_qwen3_4b_dflash_drafter_generation.py` | `target/accuracy/qwen3-dflash/drafter-generation.json` | End-to-end drafter-substitution evidence: same transformers target loop, HF drafter vs OpenInfer drafter | +| `tools/accuracy/bench_qwen3_4b_dflash_forward.py` + `qwen3_dflash_forward_bench` | `target/benchmarks/qwen3-dflash/forward.json` | Standalone forward latency comparison: transformers remote-code vs OpenInfer forward on the same synthetic fixture | +| `qwen3_dflash_batch_bench` | stdout JSON / redirected benchmark artifact | Draft-only batch sweep over bs `1,2,4,8,16,32`, reporting req/s, draft tok/s, and latency percentiles | +| `openinfer-qwen3-4b-dflash/src/bin/qwen3_dflash_forward_fixture.rs` | safetensors with `openinfer_output` | Bridge used by the generation comparison script to call the Rust drafter from Python | + +The forward golden is generated by: + +```bash +.venv/bin/python tools/accuracy/dump_qwen3_4b_dflash_hf_golden.py \ + --model-path /home/hezhaozhao/models/Qwen3-4B-DFlash-b16 \ + --out test_data/qwen3-4b-dflash-hf-golden.safetensors +``` + +The Rust gate is: + +```bash +OPENINFER_DFLASH_TEST_MODEL_PATH=/home/hezhaozhao/models/Qwen3-4B-DFlash-b16 \ +cargo test --release -p openinfer-qwen3-4b-dflash --test hf_golden_gate -- --nocapture +``` + +The DFlash gate intentionally uses `OPENINFER_DFLASH_TEST_MODEL_PATH` rather +than the generic `OPENINFER_TEST_MODEL_PATH`, because the latter usually points +at the normal Qwen3 target checkpoint. The test also checks that +`config.json.architectures` contains `DFlashDraftModel` before running. + +The batch throughput probe is: + +```bash +cargo run --release -p openinfer-qwen3-4b-dflash --bin qwen3_dflash_batch_bench -- \ + --model-path /home/hezhaozhao/models/Qwen3-4B-DFlash-b16 \ + --ctx-len 2 \ + --q-len 16 \ + --batch-sizes 1,2,4,8,16,32 \ + --warmup 5 \ + --iters 30 +``` + +Observed local batch runner sweep on the same WSL/CUDA `sm_120` setup, +`ctx_len=2`, `q_len=16`, warmup `5`, iters `30`: + +| Batch | mean ms | p50 ms | p90 ms | p99 ms | draft tok/s | req/s | +| ---: | ---: | ---: | ---: | ---: | ---: | ---: | +| 1 | 2.065 | — | — | — | 7,748 | — | +| 2 | 2.154 | — | — | — | 14,856 | — | +| 4 | 3.118 | — | — | — | 20,525 | — | +| 8 | 3.335 | — | — | — | 38,382 | — | +| 16 | 4.699 | — | — | — | 54,476 | — | +| 32 | 8.178 | — | — | — | 62,611 | — | + +The batch path now improves draft-token throughput by `8.1x` from bs1 to bs32. +The bs16/bs32 step gained ~1.5x after replacing the per-request `compact_kv` +memcpy loop (`2 * batch_size` `memcpy_dtod` calls per K/V tensor per layer) +with a single fused `strided_segment_copy` CUDA kernel — one launch copies the +entire batch's ctx segment, another the noise segment, collapsing 128 +launches/layer at bs32 into 4. This is draft-model throughput only; it does not +include target hidden production, verification, acceptance, or fallback-token +work. + +On the local WSL setup used for the first run, the workspace-level vLLM git dependency and empty FlashInfer submodule required a narrower temporary workspace plus: + +```bash +LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda/lib64:/usr/local/cuda/targets/x86_64-linux/lib \ +OPENINFER_FLASHINFER_INCLUDE=/home/hezhaozhao/openinfer/.venv/lib/python3.12/site-packages/flashinfer/data/include \ +cargo test --release -p openinfer-qwen3-4b-dflash --test hf_golden_gate -- --nocapture +``` + +Observed result after the unified cache change: + +```text +dflash HF golden deltas: mean=0.034243, p99=0.125000, max=0.500000, n=7680 +dflash unified-cache one-shot HF golden deltas: mean=0.034243, p99=0.125000, max=0.500000, n=7680 +dflash draft-cache HF golden deltas: mean=0.034243, p99=0.125000, max=0.500000, n=7680 +test dflash_forward_matches_hf_remote_code ... ok +``` + +The drafter-substitution generation probe is: + +```bash +cargo build --release -p openinfer-qwen3-4b-dflash --bin qwen3_dflash_forward_fixture + +.venv/bin/python tools/accuracy/compare_qwen3_4b_dflash_drafter_generation.py \ + --target-model-path /path/to/Qwen3-4B \ + --draft-model-path /home/hezhaozhao/models/Qwen3-4B-DFlash-b16 \ + --openinfer-bin target/release/qwen3_dflash_forward_fixture \ + --out target/accuracy/qwen3-dflash/drafter-generation.json +``` + +The JSON report records each prompt's generated token ids/text, token/text hashes, +first mismatch if any, acceptance lengths, and optional OpenInfer-vs-HF draft +hidden deltas. It exits non-zero unless every case is `all_token_text_exact`. +This is the DFlash analogue of the DeepSeek-V2-Lite same-host generation +comparison, but scoped to the current standalone drafter boundary. + +For performance, use the same synthetic fixture on both sides: + +```bash +cargo build --release -p openinfer-qwen3-4b-dflash --bin qwen3_dflash_forward_bench + +.venv/bin/python tools/accuracy/bench_qwen3_4b_dflash_forward.py \ + --draft-model-path /home/hezhaozhao/models/Qwen3-4B-DFlash-b16 \ + --openinfer-bin target/release/qwen3_dflash_forward_bench \ + --out target/benchmarks/qwen3-dflash/forward.json +``` + +The benchmark report includes transformers latency stats and OpenInfer latency +stats for the same bf16 fixture. It is a standalone draft-forward measurement, +not a full speculative-decoding throughput claim. + +Observed local benchmark on RTX 5070 Ti, WSL, CUDA `sm_120`, `ctx_len=2`, +`q_len=16`, warmup `5`, iters `30`, same generated bf16 fixture: + +| Engine | mean ms | p50 ms | p90 ms | p99 ms | +| --- | ---: | ---: | ---: | ---: | +| transformers remote-code | 4.294 | 3.612 | 5.067 | 15.360 | +| OpenInfer DFlash | 2.285 | 2.195 | 2.659 | 2.895 | + +OpenInfer is `1.65x` faster at p50 and `1.88x` faster by mean for this +standalone forward shape. The transformers p99 includes a single 15.36 ms tail +in this short run, so p99 should not be over-interpreted without a longer sweep. +The measured artifact is `target/benchmarks/qwen3-dflash/forward.json`. + +First optimization pass: `DFlashForwardScratch` reuses the forward buffer set +across repeated calls. The HF forward gate stayed identical: +`mean=0.034243`, `p99=0.125000`, `max=0.500000`, `n=7680`. The same forward +benchmark wrote `target/benchmarks/qwen3-dflash/forward-final.json`: + +| OpenInfer path | mean ms | p50 ms | p90 ms | p99 ms | +| --- | ---: | ---: | ---: | ---: | +| allocate buffers per forward | 2.285 | 2.195 | 2.659 | 2.895 | +| reuse `DFlashForwardScratch` | 2.125 | 2.035 | 2.410 | 2.936 | + +This pass improved OpenInfer p50 by `1.08x`. It is a necessary cleanup for the +future decode loop, but not enough by itself to prove DFlash value. + +A follow-up attempt to move the cloned input hidden state into reusable scratch +was not kept: the current fused residual+RMSNorm op mutates the residual hidden +state in place, so separating input/output ping-pong buffers correctly requires +reworking that layer boundary rather than a local buffer-only patch. + +Second optimization pass: `DFlashForwardScratch` gained an explicit draft-side +target-hidden context K/V cache. `prepare_context_cache(...)` computes +`target_normed` plus each layer's context `K/V` and K norm+RoPE once; repeated +`forward_with_context_cache(...)` calls then only compute the noise-token K/V and +concat cached context with the current draft block. The HF gate now checks both +uncached and cached paths, and both stayed identical: +`mean=0.034243`, `p99=0.125000`, `max=0.500000`, `n=7680`. + +Cached benchmark artifact: `target/benchmarks/qwen3-dflash/forward-context-cache.json`. +The reported latency excludes the one-time `prepare_context_cache(...)` call, +matching the intended loop shape where context cache is updated explicitly when +target hidden changes. + +| OpenInfer path | mean ms | p50 ms | p90 ms | p99 ms | +| --- | ---: | ---: | ---: | ---: | +| allocate buffers per forward | 2.285 | 2.195 | 2.659 | 2.895 | +| reuse `DFlashForwardScratch` | 2.125 | 2.035 | 2.410 | 2.936 | +| reuse scratch + context K/V cache | 1.863 | 1.831 | 2.001 | 2.301 | + +The context cache improves p50 by `1.11x` over scratch-only and `1.20x` over the +initial implementation for this small `ctx_len=2`, `q_len=16` fixture. + +Third pass: the public cache shape was unified as `DFlashDraftCache`. The old +"context cache" is now just the step-context part of the same object, and the +cache also owns per-layer draft past K/V buffers plus `seq_len`, `crop`, and +`reset` state. The HF gate checks uncached, unified-cache one-shot, and first-step +draft-cache paths; all three retain the same delta distribution: +`mean=0.034243`, `p99=0.125000`, `max=0.500000`, `n=7680`. + +The cache internals now follow the `openinfer-kv-cache` separation more closely +without directly adopting its paged block manager: `DFlashDraftState` owns the +long-lived draft past K/V and sequence length, `DFlashStepContext` owns the +current target-hidden context K/V, and `ForwardBuffers` remains transient +scratch. The public object is still a single `DFlashDraftCache`, but a prepared +step is consumed by `forward_with_draft_cache(...)`; callers must prepare the +next step explicitly after `crop(start)`, mirroring the reference `DynamicCache` +lifecycle. + +The corresponding benchmark artifact is +`target/benchmarks/qwen3-dflash/forward-draft-cache.json`. This benchmark uses +the more honest `prepare_step_context + forward_with_draft_cache` timing inside +each measured iteration, so it should not be compared directly against the +previous context-cache number that excluded prepare time: + +| Engine/path | mean ms | p50 ms | p90 ms | p99 ms | +| --- | ---: | ---: | ---: | ---: | +| transformers remote-code | 5.564 | 4.429 | 9.078 | 18.713 | +| OpenInfer `DFlashDraftCache` first-step path | 2.311 | 2.209 | 2.479 | 3.519 | + +After the internal state/step/scratch refactor, the same benchmark wrote +`target/benchmarks/qwen3-dflash/forward-draft-cache-refactor.json` with no +accuracy change and no performance regression: + +| Engine/path | mean ms | p50 ms | p90 ms | p99 ms | +| --- | ---: | ---: | ---: | ---: | +| transformers remote-code | 4.242 | 3.861 | 5.616 | 6.922 | +| OpenInfer `DFlashDraftCache` refactor path | 2.228 | 2.155 | 2.454 | 2.541 | + +## Current Implementation + +The crate now exists as a standalone model implementation with config parsing, exact-key safetensor loading, a block draft forward, unified draft cache state, a tiny local GPU smoke test, and a HF remote-code golden gate. The attention path uses the existing Qwen3 Q/K RMSNorm+RoPE kernel and a FlashInfer single-prefill wrapper with `MaskMode::kNone`; context K currently reuses the Q/K kernel with a throwaway Q scratch buffer, so a future cleanup can split a K-only norm+RoPE helper without changing semantics. + +The local `.venv` uses `torch==2.9.0+cu129`, `transformers==4.57.3`, `safetensors`, `accelerate`, and `datasets` because the HF remote code imports `datasets` via `utils.py`. The generated fixture stores seed-pinned synthetic `noise_embedding`, selected `target_hidden`, `position_ids`, and HF final `output`; `openinfer-qwen3-4b-dflash/tests/hf_golden_gate.rs` replays those tensors through the Rust forward and compares deltas. + +An additional end-to-end generation probe used the same transformers target +model for verification and swapped only the drafter: + +| Prompt | Result | +| --- | --- | +| `Hello, my name is` | identical token ids/text; acceptance `[1, 2, 1, 2, 1, 1]` | +| `The capital of France is` | identical token ids/text; acceptance `[2, 1, 2, 2, 2]` | +| `Qwen is a language model that` | identical token ids/text; acceptance `[2, 2, 1, 1, 1, 1]` | +| `1, 1, 2, 3, 5,` | identical token ids/text; acceptance `[4, 1, 2, 2]` | + +The probe intentionally used a no-draft-cache loop on both sides because it +predates `DFlashDraftCache` and because `openinfer-qwen3-4b-dflash` still does +not own the target verification/controller. Within that older boundary, +OpenInfer DFlash produces the same greedy generation tokens as the transformers +DFlash drafter when the target/verification path is held fixed. The next +meaningful generation probe should use the real target loop and exercise +`DFlashDraftCache.crop(start)` after acceptance calculation. + +## 2026-06-18 Batch Bench + +The current Codex runner needed an explicit runtime library path to see the WSL +CUDA driver: + +```bash +CUDA_VISIBLE_DEVICES=0 \ +LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda/lib64:/usr/local/cuda/targets/x86_64-linux/lib \ +OPENINFER_FLASHINFER_INCLUDE=/home/hezhaozhao/openinfer/.venv/lib/python3.12/site-packages/flashinfer/data/include \ +cargo run --release -p openinfer-qwen3-4b-dflash --bin qwen3_dflash_batch_bench -- \ + --model-path /home/hezhaozhao/models/Qwen3-4B-DFlash-b16 \ + --ctx-len 2 \ + --q-len 16 \ + --batch-sizes 1,2,4,8 \ + --warmup 2 \ + --iters 5 +``` + +Observed result on the RTX 5070 Ti host: + +| Batch | mean ms | draft tok/s | req/s | +| ---: | ---: | ---: | ---: | +| 1 | 2.052 | 7,796 | 487 | +| 2 | 2.303 | 13,893 | 868 | +| 4 | 3.532 | 18,121 | 1,133 | +| 8 | 4.364 | 29,333 | 1,833 | + +This confirms the draft-only batch path still scales after the fail-closed +cache fix. It is draft throughput only; it does not include target hidden +production, verification, acceptance, or fallback-token work. diff --git a/openinfer-core/src/ops.rs b/openinfer-core/src/ops.rs index dc522350..cb008d6b 100644 --- a/openinfer-core/src/ops.rs +++ b/openinfer-core/src/ops.rs @@ -13,17 +13,19 @@ pub use attention::{ paged_attention_batch_decode_split_kv_into, prefill_attention_paged_into, }; pub use openinfer_kernels::ops::{ - GEMM_LT_MAX_N, LoraDecodeGroupedProjection, accumulate_bf16_token_scaled_to_f32_into, - add_batch, add_batch_into, argmax, argmax_batch_bf16_into, bf16_hidden_to_f32_into, + GEMM_LT_MAX_N, LoraDecodeGroupedProjection, RaggedPrefillPlan, + accumulate_bf16_token_scaled_to_f32_into, add_batch, add_batch_into, argmax, + argmax_batch_bf16_into, batch_prefill_ragged_nhd_noncausal_into, bf16_hidden_to_f32_into, embedding_decode_into, extract_vec, extract_vec_into, extract_vec_ref, extract_vec_ref_into, f32_to_bf16_hidden_into, fused_add_rms_norm_into, gather_hidden_tokens_into, gemm, gemm_graphsafe_into_checked, gemm_graphsafe_ref_into_checked, gemm_into_checked, gemm_lt_tune, - gemm_per_token, gemv, linear, lora_decode_fused_delta_group3_into, - lora_decode_fused_delta_into, pack_lora_b_rows_into, + gemm_per_token, gemv, k_norm_rope_batch_decode_into, linear, + lora_decode_fused_delta_group3_into, lora_decode_fused_delta_into, pack_lora_b_rows_into, qk_norm_partial_rope_batched_decode_hd256_into, rms_norm, rms_norm_batch_offset_into, rms_norm_gated_batch_into, rms_norm_into, rms_norm_offset_into, scale_f32_in_place, scaled_add_batch_into, scaled_add_rows_indexed_into, scaled_add_rows_into, - scaled_add_rows_token_range_into, silu_mul_batch, silu_mul_batch_into, write_vec_into, + scaled_add_rows_token_range_into, silu_mul_batch, silu_mul_batch_into, + single_prefill_nhd_noncausal_into, strided_segment_copy_into, write_vec_into, }; #[cfg(not(feature = "kernel-call-trace"))] pub use openinfer_kernels::ops::{ diff --git a/openinfer-kernels/csrc/shared/elementwise.cu b/openinfer-kernels/csrc/shared/elementwise.cu index 92de04eb..c486152f 100644 --- a/openinfer-kernels/csrc/shared/elementwise.cu +++ b/openinfer-kernels/csrc/shared/elementwise.cu @@ -427,4 +427,54 @@ CUresult embedding_batched_vocab_shard_cuda( return (CUresult)cudaGetLastError(); } +// ============================================================================ +// Strided segment copy for DFlash batch K/V concatenation. +// +// Copies one segment (ctx or noise) of every request in a batch from a +// contiguous source layout to a strided destination layout in a single +// kernel launch, replacing 2 * batch_size memcpy_dtod calls per K/V tensor. +// +// src: [batch_size * src_seg_len, dim] row-major, contiguous +// dst: [batch_size * dst_seg_total, dim] row-major, request r occupies +// rows [r * dst_seg_total + dst_row_offset, +// r * dst_seg_total + dst_row_offset + src_seg_len) +// +// Each thread copies one bf16 element. The total work is +// batch_size * src_seg_len * dim. +// ============================================================================ + +__global__ void strided_segment_copy_kernel( + const __nv_bfloat16 *__restrict__ src, + __nv_bfloat16 *__restrict__ dst, + int dim, int src_seg_len, int dst_seg_total, int dst_row_offset, + int batch_size) { + int total = batch_size * src_seg_len * dim; + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total; + idx += gridDim.x * blockDim.x) { + int element = idx % dim; + int row_in_seg = (idx / dim) % src_seg_len; + int req = idx / (dim * src_seg_len); + int src_row = req * src_seg_len + row_in_seg; + int dst_row = req * dst_seg_total + dst_row_offset + row_in_seg; + dst[dst_row * dim + element] = src[src_row * dim + element]; + } +} + +CUresult strided_segment_copy_cuda( + const __nv_bfloat16 *src, __nv_bfloat16 *dst, + int dim, int src_seg_len, int dst_seg_total, int dst_row_offset, + int batch_size, cudaStream_t stream) { + int total = batch_size * src_seg_len * dim; + int block = 256; + // The kernel uses a grid-stride loop, so any grid size >= 1 is correct. + // Size the grid to the work so every element is covered in the first pass + // (no upper cap — a cap would silently drop elements for large copies). + int grid = (total + block - 1) / block; + if (grid < 1) grid = 1; + strided_segment_copy_kernel<<>>( + src, dst, dim, src_seg_len, dst_seg_total, dst_row_offset, batch_size); + return (CUresult)cudaGetLastError(); +} + } // extern "C" diff --git a/openinfer-kernels/csrc/shared/paged_attention.cu b/openinfer-kernels/csrc/shared/paged_attention.cu index 4506a60f..f21181ed 100644 --- a/openinfer-kernels/csrc/shared/paged_attention.cu +++ b/openinfer-kernels/csrc/shared/paged_attention.cu @@ -22,6 +22,7 @@ using namespace flashinfer; using DType = __nv_bfloat16; using IdType = int32_t; using ParamsT = BatchDecodeParams; +using BatchPrefillRaggedParamsT = BatchPrefillRaggedParams; using Variant = DefaultAttention(stream))); } +// --------------------------------------------------------------------------- +// Single-request non-causal prefill over contiguous NHD K/V. +// +// DFlash draft attention materializes K/V as token-major HiddenStates: +// q: [q_len, num_qo_heads, head_dim] +// k/v: [kv_len, num_kv_heads, head_dim] +// This wrapper mirrors vLLM's non-causal FlashAttention/FlashInfer semantics: +// no causal mask, no sliding window, and GQA handled by FlashInfer. +// --------------------------------------------------------------------------- +int single_prefill_nhd_noncausal_cuda( + void* q, + void* output, + void* k, + void* v, + int32_t num_qo_heads, + int32_t num_kv_heads, + int32_t head_dim, + int32_t q_len, + int32_t kv_len, + float sm_scale, + void* stream) +{ + uint32_t q_stride_n = num_qo_heads * head_dim; + uint32_t q_stride_h = head_dim; + uint32_t kv_stride_n = num_kv_heads * head_dim; + uint32_t kv_stride_h = head_dim; + + PrefillParamsT params( + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + /*maybe_custom_mask=*/nullptr, + reinterpret_cast(output), + /*lse=*/nullptr, + /*maybe_alibi_slopes=*/nullptr, + num_qo_heads, + num_kv_heads, + static_cast(q_len), + static_cast(kv_len), + q_stride_n, + q_stride_h, + kv_stride_n, + kv_stride_h, + static_cast(head_dim), + /*window_left=*/-1, + /*logits_soft_cap=*/0.0f, + sm_scale, + /*rope_scale=*/1.0f, + /*rope_theta=*/1e6f); + + return static_cast( + SinglePrefillWithKVCacheDispatched< + /*HEAD_DIM_QK=*/128, + /*HEAD_DIM_VO=*/128, + PosEncodingMode::kNone, + /*USE_FP16_QK_REDUCTION=*/false, + MaskMode::kNone, + Variant, + PrefillParamsT>( + params, + /*tmp=*/nullptr, + reinterpret_cast(stream))); +} + +// --------------------------------------------------------------------------- +// Batched non-causal prefill over compact ragged NHD K/V. +// +// DFlash groups exact-shape draft requests into compact token-major tensors: +// q: [sum(q_len), num_qo_heads, head_dim] +// k/v: [sum(kv_len), num_kv_heads, head_dim] +// with q_indptr/kv_indptr separating requests. This maps directly to +// FlashInfer BatchPrefillWithRaggedKVCache with MaskMode::kNone. +// --------------------------------------------------------------------------- +int batch_prefill_ragged_nhd_noncausal_cuda( + void* q, + void* output, + void* k, + void* v, + int32_t* q_indptr, + int32_t* kv_indptr, + int32_t* request_indices, + int32_t* qo_tile_indices, + int32_t* kv_tile_indices, + int32_t* kv_chunk_size_ptr, + uint32_t* total_num_rows, + int32_t num_qo_heads, + int32_t num_kv_heads, + int32_t head_dim, + int32_t total_q_len, + int32_t batch_size, + int32_t padded_batch_size, + float sm_scale, + void* stream) +{ + uint32_t q_stride_n = num_qo_heads * head_dim; + uint32_t q_stride_h = head_dim; + uint32_t kv_stride_n = num_kv_heads * head_dim; + uint32_t kv_stride_h = head_dim; + + BatchPrefillRaggedParamsT params( + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + /*maybe_custom_mask=*/nullptr, + q_indptr, + kv_indptr, + /*maybe_mask_indptr=*/nullptr, + /*maybe_q_rope_offset=*/nullptr, + /*maybe_k_rope_offset=*/nullptr, + reinterpret_cast(output), + /*lse=*/nullptr, + /*maybe_alibi_slopes=*/nullptr, + num_qo_heads, + num_kv_heads, + q_stride_n, + q_stride_h, + kv_stride_n, + kv_stride_h, + /*window_left=*/-1, + /*logits_soft_cap=*/0.0f, + sm_scale, + /*rope_scale=*/1.0f, + /*rope_theta=*/1e6f); + + params.request_indices = request_indices; + params.qo_tile_indices = qo_tile_indices; + params.kv_tile_indices = kv_tile_indices; + params.o_indptr = q_indptr; + params.kv_chunk_size_ptr = kv_chunk_size_ptr; + params.total_num_rows = total_num_rows; + params.max_total_num_rows = static_cast(total_q_len); + params.padded_batch_size = static_cast(padded_batch_size); + params.partition_kv = false; + + return static_cast( + BatchPrefillWithRaggedKVCacheDispatched< + /*CTA_TILE_Q=*/16, + /*HEAD_DIM_QK=*/128, + /*HEAD_DIM_VO=*/128, + PosEncodingMode::kNone, + /*USE_FP16_QK_REDUCTION=*/false, + MaskMode::kNone, + Variant, + BatchPrefillRaggedParamsT>( + params, + /*tmp_v=*/nullptr, + /*tmp_s=*/nullptr, + /*enable_pdl=*/false, + reinterpret_cast(stream))); +} + // --------------------------------------------------------------------------- // Single-request prefill for HEAD_DIM=256 — wraps FlashInfer SinglePrefillWithKVCache. // diff --git a/openinfer-kernels/csrc/shared/prefill_attention.cu b/openinfer-kernels/csrc/shared/prefill_attention.cu index a7b24b66..086883d8 100644 --- a/openinfer-kernels/csrc/shared/prefill_attention.cu +++ b/openinfer-kernels/csrc/shared/prefill_attention.cu @@ -136,4 +136,112 @@ void qk_norm_rope_batched_decode_cuda( ); } +// ============================================================================ +// K-only norm + RoPE variant for the DFlash batch path. +// +// The context-hidden K projection needs RMSNorm + RoPE, but there is no +// corresponding Q (the draft Q comes only from the noise tokens). Calling the +// joint QK kernel on the context K would waste num_q_heads / (num_q_heads + +// num_kv_heads) of the GPU work — 80% for Qwen3-4B's 16:4 GQA ratio — on a Q +// buffer whose result is immediately discarded. This variant launches only +// num_kv_heads blocks per token. +// +// It reuses the same in-place per-head RMSNorm + RoPE logic as the joint +// kernel, restricted to the K tensor. +// ============================================================================ + +__global__ void k_norm_rope_kernel( + __nv_bfloat16* __restrict__ k, // [kv_dim, seq_len] modified in-place + const __nv_bfloat16* __restrict__ k_norm_weight, // [head_dim] + const __nv_bfloat16* __restrict__ cos_cache, // [max_pos * head_dim] + const __nv_bfloat16* __restrict__ sin_cache, + int num_kv_heads, int head_dim, + int seq_len, int kv_dim, + const int* start_pos_d, // if non-null, overrides start_pos per token + float eps, + int cos_max_pos +) { + int head_local = blockIdx.x; + int token = blockIdx.y; + int d = threadIdx.x; + + int offset = head_local * head_dim + d + token * kv_dim; + float val = __bfloat162float(k[offset]); + + // RMSNorm: sum of squares via warp reduction + float sq = val * val; + sq = warp_reduce_sum(sq); + + int warp_id = d / WARP_SIZE; + int lane_id = d % WARP_SIZE; + __shared__ float warp_sums[4]; // head_dim/32 = 4 warps + if (lane_id == 0) warp_sums[warp_id] = sq; + __syncthreads(); + + __shared__ float s_inv_rms; + { + float v = (lane_id < 4) ? warp_sums[lane_id] : 0.0f; + float total = warp_reduce_sum(v); + if (lane_id == 0) s_inv_rms = rsqrtf(total / head_dim + eps); + } + __syncthreads(); + + __nv_bfloat16 normed = __float2bfloat16(val * s_inv_rms); + float normed_f = __bfloat162float(normed) * __bfloat162float(k_norm_weight[d]); + + __shared__ __nv_bfloat16 smem[HEAD_DIM]; + smem[d] = __float2bfloat16(normed_f); + __syncthreads(); + + int half = head_dim / 2; + int pos = start_pos_d ? __ldg(start_pos_d + token) : token; + if (pos < 0 || pos >= cos_max_pos) __trap(); + + __nv_bfloat16 result; + if (d < half) { + float lo = __bfloat162float(smem[d]); + float hi = __bfloat162float(smem[d + half]); + float c = __bfloat162float(cos_cache[pos * head_dim + d]); + float s = __bfloat162float(sin_cache[pos * head_dim + d]); + float lo_cos = __bfloat162float(__float2bfloat16(lo * c)); + float hi_sin = __bfloat162float(__float2bfloat16(hi * s)); + result = __float2bfloat16(lo_cos - hi_sin); + } else { + int pair_d = d - half; + float lo = __bfloat162float(smem[pair_d]); + float hi = __bfloat162float(smem[d]); + float c = __bfloat162float(cos_cache[pos * head_dim + pair_d]); + float s = __bfloat162float(sin_cache[pos * head_dim + pair_d]); + float lo_sin = __bfloat162float(__float2bfloat16(lo * s)); + float hi_cos = __bfloat162float(__float2bfloat16(hi * c)); + result = __float2bfloat16(lo_sin + hi_cos); + } + + k[offset] = result; +} + +void k_norm_rope_batched_decode_cuda( + __nv_bfloat16* k, // [kv_dim * batch_size] in-place + const __nv_bfloat16* k_norm_weight, + const __nv_bfloat16* cos_cache, + const __nv_bfloat16* sin_cache, + const int* positions, // [batch_size] per-request positions on GPU + int num_kv_heads, + int head_dim, + int batch_size, + float rms_eps, + int cos_max_pos, + cudaStream_t stream +) { + int kv_dim = num_kv_heads * head_dim; + dim3 grid(num_kv_heads, batch_size); + k_norm_rope_kernel<<>>( + k, k_norm_weight, cos_cache, sin_cache, + num_kv_heads, head_dim, + /*seq_len=*/batch_size, kv_dim, + /*start_pos_d=*/positions, + rms_eps, cos_max_pos + ); +} + } // extern "C" diff --git a/openinfer-kernels/src/ffi/shared.rs b/openinfer-kernels/src/ffi/shared.rs index aff4e5e2..25a37011 100644 --- a/openinfer-kernels/src/ffi/shared.rs +++ b/openinfer-kernels/src/ffi/shared.rs @@ -185,6 +185,21 @@ unsafe extern "C" { stream: CUstream, ) -> i32; + /// Strided segment copy for DFlash batch K/V concatenation. Copies one + /// segment (ctx or noise) of every request from a contiguous source to a + /// strided destination in a single launch. See `strided_segment_copy_cuda` + /// in `csrc/shared/elementwise.cu`. + pub fn strided_segment_copy_cuda( + src: *const Half, + dst: *mut Half, + dim: i32, + src_seg_len: i32, + dst_seg_total: i32, + dst_row_offset: i32, + batch_size: i32, + stream: CUstream, + ) -> CUresult; + pub fn cublas_init(); pub fn cublas_activate_device_handles() -> i32; pub fn cublas_destroy(); @@ -249,6 +264,25 @@ unsafe extern "C" { stream: CUstream, ); + /// K-only norm + RoPE for the DFlash batch context-K path. Same per-head + /// RMSNorm + RoPE as `qk_norm_rope_batched_decode_cuda` but launches only + /// `num_kv_heads` blocks per token — the draft path has no context Q, so + /// the joint kernel wastes the Q work. See `k_norm_rope_batched_decode_cuda` + /// in `csrc/shared/prefill_attention.cu`. + pub fn k_norm_rope_batched_decode_cuda( + k: *mut Half, + k_norm_weight: *const Half, + cos_cache: *const Half, + sin_cache: *const Half, + positions: *const i32, + num_kv_heads: i32, + head_dim: i32, + batch_size: i32, + rms_eps: f32, + cos_max_pos: i32, + stream: CUstream, + ); + // Scatter contiguous KV → paged layout (one layer, FlashInfer prefill append). pub fn paged_kv_scatter_cuda( kv_data: *const Half, @@ -496,6 +530,42 @@ unsafe extern "C" { stream: CUstream, ) -> i32; + pub fn single_prefill_nhd_noncausal_cuda( + q: *const Half, + output: *mut Half, + k: *const Half, + v: *const Half, + num_qo_heads: i32, + num_kv_heads: i32, + head_dim: i32, + q_len: i32, + kv_len: i32, + sm_scale: f32, + stream: CUstream, + ) -> i32; + + pub fn batch_prefill_ragged_nhd_noncausal_cuda( + q: *const Half, + output: *mut Half, + k: *const Half, + v: *const Half, + q_indptr: *const i32, + kv_indptr: *const i32, + request_indices: *const i32, + qo_tile_indices: *const i32, + kv_tile_indices: *const i32, + kv_chunk_size_ptr: *const i32, + total_num_rows: *const u32, + num_qo_heads: i32, + num_kv_heads: i32, + head_dim: i32, + total_q_len: i32, + batch_size: i32, + padded_batch_size: i32, + sm_scale: f32, + stream: CUstream, + ) -> i32; + pub fn repeat_f32_for_reduce_scatter_cuda( local: *const f32, repeated: *mut f32, diff --git a/openinfer-kernels/src/ops.rs b/openinfer-kernels/src/ops.rs index fa8c5362..0d8abc59 100644 --- a/openinfer-kernels/src/ops.rs +++ b/openinfer-kernels/src/ops.rs @@ -5,6 +5,7 @@ mod attention; mod deepep; #[cfg(feature = "deepseek-v2-lite")] mod deepseek_v2_lite; +mod dense_attention; mod elementwise; mod embedding; #[cfg(feature = "kimi-k2")] @@ -15,9 +16,10 @@ mod norm; mod sampling; pub use attention::{ - PrefillPagedPlan, paged_attention_batch_decode_hd256_into, paged_attention_batch_decode_into, - paged_attention_batch_decode_split_kv_into, prefill_attention_paged_into, - qk_norm_partial_rope_batched_decode_hd256_into, qk_norm_rope_batch_decode_into, + PrefillPagedPlan, k_norm_rope_batch_decode_into, paged_attention_batch_decode_hd256_into, + paged_attention_batch_decode_into, paged_attention_batch_decode_split_kv_into, + prefill_attention_paged_into, qk_norm_partial_rope_batched_decode_hd256_into, + qk_norm_rope_batch_decode_into, }; #[cfg(feature = "kimi-k2")] pub use deepep::{ @@ -25,13 +27,16 @@ pub use deepep::{ }; #[cfg(feature = "deepseek-v2-lite")] pub use deepseek_v2_lite::*; +pub use dense_attention::{ + RaggedPrefillPlan, batch_prefill_ragged_nhd_noncausal_into, single_prefill_nhd_noncausal_into, +}; pub use elementwise::{ accumulate_bf16_token_scaled_to_f32_into, add_batch, add_batch_into, bf16_hidden_to_f32_into, extract_vec, extract_vec_into, extract_vec_ref, extract_vec_ref_into, f32_to_bf16_hidden_into, gather_hidden_tokens_into, repeat_f32_for_reduce_scatter_into, scale_f32_in_place, scaled_add_batch_into, scaled_add_rows_indexed_into, scaled_add_rows_into, scaled_add_rows_token_range_into, silu_mul_batch, silu_mul_batch_into, - silu_mul_fused_batch_into, write_vec_into, + silu_mul_fused_batch_into, strided_segment_copy_into, write_vec_into, }; pub use embedding::{embedding_batch, embedding_batch_vocab_shard, embedding_decode_into}; #[cfg(feature = "kimi-k2")] diff --git a/openinfer-kernels/src/ops/attention.rs b/openinfer-kernels/src/ops/attention.rs index 99516a13..122ed295 100644 --- a/openinfer-kernels/src/ops/attention.rs +++ b/openinfer-kernels/src/ops/attention.rs @@ -497,6 +497,50 @@ pub fn qk_norm_rope_batch_decode_into( } } +/// K-only norm + RoPE for the DFlash batch context-K path. +/// +/// Applies in-place RMSNorm + RoPE to `k` only — the draft path's context K +/// projection has no corresponding Q, so the joint `qk_norm_rope` kernel would +/// waste `num_q_heads / (num_q_heads + num_kv_heads)` of its work on a Q buffer +/// whose result is discarded (80% for Qwen3-4B's 16:4 GQA). This variant +/// launches only `num_kv_heads` blocks per token. +#[allow(clippy::too_many_arguments)] +pub fn k_norm_rope_batch_decode_into( + ctx: &DeviceContext, + k: &mut HiddenStates, + k_norm_weight: &DeviceVec, + cos_cache: &DeviceVec, + sin_cache: &DeviceVec, + positions_d: &CudaSlice, + num_kv_heads: usize, + head_dim: usize, + rms_eps: f32, +) { + let batch_size = k.seq_len; + + let (k_ptr, _gk) = k.data.device_ptr_mut(&ctx.stream); + let (kn_ptr, _gkn) = k_norm_weight.data.device_ptr(&ctx.stream); + let (cos_ptr, _gc) = cos_cache.data.device_ptr(&ctx.stream); + let (sin_ptr, _gs) = sin_cache.data.device_ptr(&ctx.stream); + let (pos_ptr, _gp) = positions_d.device_ptr(&ctx.stream); + + unsafe { + ffi::k_norm_rope_batched_decode_cuda( + k_ptr as *mut ffi::Half, + kn_ptr as *const ffi::Half, + cos_ptr as *const ffi::Half, + sin_ptr as *const ffi::Half, + pos_ptr as *const i32, + num_kv_heads as i32, + head_dim as i32, + batch_size as i32, + rms_eps, + (cos_cache.data.len() / head_dim) as i32, + ctx.stream.cu_stream(), + ); + } +} + /// Batched QK RMSNorm + partial RoPE for Qwen3.5 HD256 decode. /// /// Reads Q from interleaved `q_full` ([q, gate] per head), writes prepared Q into `q`, diff --git a/openinfer-kernels/src/ops/dense_attention.rs b/openinfer-kernels/src/ops/dense_attention.rs new file mode 100644 index 00000000..bf438707 --- /dev/null +++ b/openinfer-kernels/src/ops/dense_attention.rs @@ -0,0 +1,204 @@ +use anyhow::Result; +use cudarc::driver::{CudaSlice, DevicePtr, DevicePtrMut}; + +use crate::ffi; +use crate::tensor::{DeviceContext, HiddenStates}; + +#[allow(clippy::too_many_arguments)] +pub fn single_prefill_nhd_noncausal_into( + ctx: &DeviceContext, + q: &HiddenStates, + k: &HiddenStates, + v: &HiddenStates, + out: &mut HiddenStates, + num_qo_heads: usize, + num_kv_heads: usize, + head_dim: usize, +) -> Result<()> { + let q_dim = num_qo_heads * head_dim; + let kv_dim = num_kv_heads * head_dim; + assert_eq!(q.hidden_dim, q_dim); + assert_eq!(k.hidden_dim, kv_dim); + assert_eq!(v.hidden_dim, kv_dim); + assert_eq!(v.seq_len, k.seq_len); + assert_eq!(out.hidden_dim, q_dim); + assert_eq!(out.seq_len, q.seq_len); + assert_eq!( + head_dim, 128, + "FlashInfer wrapper is instantiated for head_dim=128" + ); + + let (q_ptr, _gq) = q.data.device_ptr(&ctx.stream); + let (k_ptr, _gk) = k.data.device_ptr(&ctx.stream); + let (v_ptr, _gv) = v.data.device_ptr(&ctx.stream); + let (out_ptr, _go) = out.data.device_ptr_mut(&ctx.stream); + let sm_scale = 1.0f32 / (head_dim as f32).sqrt(); + let status = unsafe { + ffi::single_prefill_nhd_noncausal_cuda( + q_ptr as *const ffi::Half, + out_ptr as *mut ffi::Half, + k_ptr as *const ffi::Half, + v_ptr as *const ffi::Half, + num_qo_heads as i32, + num_kv_heads as i32, + head_dim as i32, + q.seq_len as i32, + k.seq_len as i32, + sm_scale, + ctx.stream.cu_stream(), + ) + }; + if status != 0 { + anyhow::bail!( + "single_prefill_nhd_noncausal_cuda failed: status={}, q_len={}, kv_len={}, q_heads={}, kv_heads={}, head_dim={}", + status, + q.seq_len, + k.seq_len, + num_qo_heads, + num_kv_heads, + head_dim + ); + } + Ok(()) +} + +pub struct RaggedPrefillPlan { + q_indptr: CudaSlice, + kv_indptr: CudaSlice, + request_indices: CudaSlice, + qo_tile_indices: CudaSlice, + kv_tile_indices: CudaSlice, + kv_chunk_size: CudaSlice, + total_num_rows: CudaSlice, + batch_size: usize, + total_q_len: usize, +} + +impl RaggedPrefillPlan { + pub fn new( + ctx: &DeviceContext, + q_lens: &[usize], + kv_lens: &[usize], + group_size: usize, + ) -> Result { + anyhow::ensure!(!q_lens.is_empty(), "ragged prefill batch is empty"); + anyhow::ensure!( + q_lens.len() == kv_lens.len(), + "q_lens len {} != kv_lens len {}", + q_lens.len(), + kv_lens.len() + ); + anyhow::ensure!(group_size > 0, "group_size must be positive"); + let mut q_indptr = Vec::with_capacity(q_lens.len() + 1); + let mut kv_indptr = Vec::with_capacity(kv_lens.len() + 1); + q_indptr.push(0i32); + kv_indptr.push(0i32); + for (&q_len, &kv_len) in q_lens.iter().zip(kv_lens.iter()) { + anyhow::ensure!(q_len > 0, "ragged prefill q_len must be positive"); + anyhow::ensure!(kv_len > 0, "ragged prefill kv_len must be positive"); + q_indptr.push(q_indptr.last().copied().unwrap() + q_len as i32); + kv_indptr.push(kv_indptr.last().copied().unwrap() + kv_len as i32); + } + let total_q_len = *q_indptr.last().unwrap() as usize; + let mut request_indices = Vec::new(); + let mut qo_tile_indices = Vec::new(); + let mut kv_tile_indices = Vec::new(); + const CTA_TILE_Q: usize = 16; + for (req_idx, &q_len) in q_lens.iter().enumerate() { + let packed_q_len = q_len * group_size; + let tiles = packed_q_len.div_ceil(CTA_TILE_Q); + for tile in 0..tiles { + request_indices.push(req_idx as i32); + qo_tile_indices.push(tile as i32); + kv_tile_indices.push(0i32); + } + } + let kv_chunk_size: Vec = kv_lens.iter().map(|&len| len as i32).collect(); + Ok(Self { + q_indptr: ctx.stream.clone_htod(&q_indptr)?, + kv_indptr: ctx.stream.clone_htod(&kv_indptr)?, + request_indices: ctx.stream.clone_htod(&request_indices)?, + qo_tile_indices: ctx.stream.clone_htod(&qo_tile_indices)?, + kv_tile_indices: ctx.stream.clone_htod(&kv_tile_indices)?, + kv_chunk_size: ctx.stream.clone_htod(&kv_chunk_size)?, + total_num_rows: ctx.stream.clone_htod(&[total_q_len as u32])?, + batch_size: q_lens.len(), + total_q_len, + }) + } +} + +#[allow(clippy::too_many_arguments)] +pub fn batch_prefill_ragged_nhd_noncausal_into( + ctx: &DeviceContext, + q: &HiddenStates, + k: &HiddenStates, + v: &HiddenStates, + out: &mut HiddenStates, + plan: &RaggedPrefillPlan, + num_qo_heads: usize, + num_kv_heads: usize, + head_dim: usize, +) -> Result<()> { + let q_dim = num_qo_heads * head_dim; + let kv_dim = num_kv_heads * head_dim; + assert_eq!(q.hidden_dim, q_dim); + assert_eq!(k.hidden_dim, kv_dim); + assert_eq!(v.hidden_dim, kv_dim); + assert_eq!(v.seq_len, k.seq_len); + assert_eq!(out.hidden_dim, q_dim); + assert_eq!(out.seq_len, q.seq_len); + assert_eq!(q.seq_len, plan.total_q_len); + assert_eq!( + head_dim, 128, + "FlashInfer ragged wrapper is instantiated for head_dim=128" + ); + + let (q_ptr, _gq) = q.data.device_ptr(&ctx.stream); + let (k_ptr, _gk) = k.data.device_ptr(&ctx.stream); + let (v_ptr, _gv) = v.data.device_ptr(&ctx.stream); + let (out_ptr, _go) = out.data.device_ptr_mut(&ctx.stream); + let (q_indptr, _) = plan.q_indptr.device_ptr(&ctx.stream); + let (kv_indptr, _) = plan.kv_indptr.device_ptr(&ctx.stream); + let (request_indices, _) = plan.request_indices.device_ptr(&ctx.stream); + let (qo_tile_indices, _) = plan.qo_tile_indices.device_ptr(&ctx.stream); + let (kv_tile_indices, _) = plan.kv_tile_indices.device_ptr(&ctx.stream); + let (kv_chunk_size, _) = plan.kv_chunk_size.device_ptr(&ctx.stream); + let (total_num_rows, _) = plan.total_num_rows.device_ptr(&ctx.stream); + let sm_scale = 1.0f32 / (head_dim as f32).sqrt(); + let status = unsafe { + ffi::batch_prefill_ragged_nhd_noncausal_cuda( + q_ptr as *const ffi::Half, + out_ptr as *mut ffi::Half, + k_ptr as *const ffi::Half, + v_ptr as *const ffi::Half, + q_indptr as *const i32, + kv_indptr as *const i32, + request_indices as *const i32, + qo_tile_indices as *const i32, + kv_tile_indices as *const i32, + kv_chunk_size as *const i32, + total_num_rows as *const u32, + num_qo_heads as i32, + num_kv_heads as i32, + head_dim as i32, + q.seq_len as i32, + plan.batch_size as i32, + plan.request_indices.len() as i32, + sm_scale, + ctx.stream.cu_stream(), + ) + }; + if status != 0 { + anyhow::bail!( + "batch_prefill_ragged_nhd_noncausal_cuda failed: status={}, total_q_len={}, batch_size={}, q_heads={}, kv_heads={}, head_dim={}", + status, + q.seq_len, + plan.batch_size, + num_qo_heads, + num_kv_heads, + head_dim + ); + } + Ok(()) +} diff --git a/openinfer-kernels/src/ops/elementwise.rs b/openinfer-kernels/src/ops/elementwise.rs index 6956ab47..581ec975 100644 --- a/openinfer-kernels/src/ops/elementwise.rs +++ b/openinfer-kernels/src/ops/elementwise.rs @@ -481,6 +481,49 @@ pub fn silu_mul_fused_batch_into( Ok(()) } +/// Strided segment copy for DFlash batch K/V concatenation. +/// +/// Copies `src_seg_len` rows from every request in the batch from a contiguous +/// source (`[batch_size * src_seg_len, dim]`) into a strided destination +/// (`[batch_size * dst_seg_total, dim]`), placing each request's segment at +/// `dst_row_offset` within its per-request block. One launch copies the entire +/// batch's segment, replacing `batch_size` individual `memcpy_dtod` calls. +/// +/// Used to build the ragged-attention K/V layout `[ctx | noise]` per request +/// from the separately-projected `k_ctx`/`k_noise` buffers. +pub fn strided_segment_copy_into( + ctx: &DeviceContext, + src: &HiddenStates, + dst: &mut HiddenStates, + src_seg_len: usize, + dst_seg_total: usize, + dst_row_offset: usize, + batch_size: usize, +) -> Result<()> { + let dim = src.hidden_dim; + assert_eq!(dst.hidden_dim, dim); + assert_eq!(src.seq_len, batch_size * src_seg_len); + assert!(dst_row_offset + src_seg_len <= dst_seg_total); + assert!(batch_size * dst_seg_total <= dst.seq_len); + + let (src_ptr, _g0) = src.data.device_ptr(&ctx.stream); + let (dst_ptr, _g1) = dst.data.device_ptr_mut(&ctx.stream); + let result = unsafe { + ffi::strided_segment_copy_cuda( + src_ptr as *const ffi::Half, + dst_ptr as *mut ffi::Half, + dim as i32, + src_seg_len as i32, + dst_seg_total as i32, + dst_row_offset as i32, + batch_size as i32, + ctx.stream.cu_stream(), + ) + }; + result.result()?; + Ok(()) +} + /// Extract a single token's vector from a HiddenStates batch (GPU copy) pub fn extract_vec( ctx: &DeviceContext, diff --git a/openinfer-qwen3-4b-dflash/Cargo.toml b/openinfer-qwen3-4b-dflash/Cargo.toml new file mode 100644 index 00000000..47b11a25 --- /dev/null +++ b/openinfer-qwen3-4b-dflash/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "openinfer-qwen3-4b-dflash" +license = "Apache-2.0" +version = "0.1.0" +edition = "2024" + +[dependencies] +anyhow = { workspace = true } +crossbeam-channel = { workspace = true } +cudarc = { workspace = true } +half = { workspace = true } +log = { workspace = true } +memmap2 = { workspace = true } +openinfer-core = { workspace = true } +openinfer-kernels = { workspace = true } +safetensors = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } + +[[bin]] +name = "qwen3_dflash_forward_fixture" +path = "src/bin/qwen3_dflash_forward_fixture.rs" + +[[bin]] +name = "qwen3_dflash_forward_bench" +path = "src/bin/qwen3_dflash_forward_bench.rs" + +[[bin]] +name = "qwen3_dflash_batch_bench" +path = "src/bin/qwen3_dflash_batch_bench.rs" + +[lints] +workspace = true diff --git a/openinfer-qwen3-4b-dflash/src/batch_buffers.rs b/openinfer-qwen3-4b-dflash/src/batch_buffers.rs new file mode 100644 index 00000000..8d7cef74 --- /dev/null +++ b/openinfer-qwen3-4b-dflash/src/batch_buffers.rs @@ -0,0 +1,177 @@ +use anyhow::Result; +use cudarc::driver::CudaSlice; +use openinfer_core::ops::RaggedPrefillPlan; +use openinfer_core::tensor::HiddenStates; + +use crate::weights::DFlashDraftModel; + +pub struct DFlashBatchBuffers { + pub(crate) max_batch_size: usize, + pub(crate) max_q_len: usize, + pub(crate) max_ctx_len: usize, + /// Active shape for the current batch — set by `set_active_shape` before + /// each forward. `q_len`/`ctx_len` may shrink below `max_*`; the physical + /// buffers are sized for the max, so the active values only narrow the view. + pub(crate) q_len: usize, + pub(crate) ctx_len: usize, + pub(crate) total_q_len: usize, + pub(crate) total_ctx_len: usize, + pub(crate) total_kv_len: usize, + pub(crate) noise: HiddenStates, + pub(crate) target_hidden: HiddenStates, + pub(crate) target_projected: HiddenStates, + pub(crate) target_normed: HiddenStates, + pub(crate) hidden: HiddenStates, + pub(crate) hidden_out: HiddenStates, + pub(crate) normed: HiddenStates, + pub(crate) q: HiddenStates, + pub(crate) q_ctx_scratch: HiddenStates, + pub(crate) k_ctx: HiddenStates, + pub(crate) k_noise: HiddenStates, + pub(crate) v_ctx: HiddenStates, + pub(crate) v_noise: HiddenStates, + pub(crate) k_all: HiddenStates, + pub(crate) v_all: HiddenStates, + pub(crate) attn_out: HiddenStates, + pub(crate) o_buf: HiddenStates, + pub(crate) gate_up: HiddenStates, + pub(crate) act_out: HiddenStates, + pub(crate) positions_q: CudaSlice, + pub(crate) positions_ctx: CudaSlice, + pub(crate) ragged_plan: Option, +} + +pub(crate) struct CachedRaggedPlan { + pub(crate) batch_size: usize, + pub(crate) q_len: usize, + pub(crate) ctx_len: usize, + pub(crate) plan: RaggedPrefillPlan, +} + +impl DFlashBatchBuffers { + /// Allocate a single-instance buffer sized for the worst case + /// (`max_batch_size × max_q_len` / `× max_ctx_len`). Each forward narrows + /// the active shape via `set_active_shape`, mirroring Qwen3's + /// `BatchDecodeBuffers` (one allocation, dynamic `set_batch_size`). + pub(crate) fn new( + model: &DFlashDraftModel, + max_batch_size: usize, + max_q_len: usize, + max_ctx_len: usize, + ) -> Result { + anyhow::ensure!(max_batch_size > 0, "max_batch_size must be positive"); + anyhow::ensure!(max_q_len > 0, "max_q_len must be positive"); + anyhow::ensure!(max_ctx_len > 0, "max_ctx_len must be positive"); + let config = model.config(); + let ctx = model.device_context(); + let hidden = config.hidden_size; + let target_hidden_dim = config.hidden_size * config.target_layer_count(); + let q_dim = config.q_dim(); + let kv_dim = config.kv_dim(); + let total_q_len = max_batch_size * max_q_len; + let total_ctx_len = max_batch_size * max_ctx_len; + let total_kv_len = max_batch_size * (max_ctx_len + max_q_len); + Ok(Self { + max_batch_size, + max_q_len, + max_ctx_len, + q_len: max_q_len, + ctx_len: max_ctx_len, + total_q_len, + total_ctx_len, + total_kv_len, + noise: HiddenStates::zeros(ctx, hidden, total_q_len)?, + target_hidden: HiddenStates::zeros(ctx, target_hidden_dim, total_ctx_len)?, + target_projected: HiddenStates::zeros(ctx, hidden, total_ctx_len)?, + target_normed: HiddenStates::zeros(ctx, hidden, total_ctx_len)?, + hidden: HiddenStates::zeros(ctx, hidden, total_q_len)?, + hidden_out: HiddenStates::zeros(ctx, hidden, total_q_len)?, + normed: HiddenStates::zeros(ctx, hidden, total_q_len)?, + q: HiddenStates::zeros(ctx, q_dim, total_q_len)?, + q_ctx_scratch: HiddenStates::zeros(ctx, q_dim, total_ctx_len)?, + k_ctx: HiddenStates::zeros(ctx, kv_dim, total_ctx_len)?, + k_noise: HiddenStates::zeros(ctx, kv_dim, total_q_len)?, + v_ctx: HiddenStates::zeros(ctx, kv_dim, total_ctx_len)?, + v_noise: HiddenStates::zeros(ctx, kv_dim, total_q_len)?, + k_all: HiddenStates::zeros(ctx, kv_dim, total_kv_len)?, + v_all: HiddenStates::zeros(ctx, kv_dim, total_kv_len)?, + attn_out: HiddenStates::zeros(ctx, q_dim, total_q_len)?, + o_buf: HiddenStates::zeros(ctx, hidden, total_q_len)?, + gate_up: HiddenStates::zeros(ctx, 2 * config.intermediate_size, total_q_len)?, + act_out: HiddenStates::zeros(ctx, config.intermediate_size, total_q_len)?, + positions_q: ctx.stream.alloc_zeros(total_q_len)?, + positions_ctx: ctx.stream.alloc_zeros(total_ctx_len)?, + ragged_plan: None, + }) + } + + /// Narrow the active shape for this forward: sets `q_len`/`ctx_len` and + /// recomputes every buffer's `seq_len` to `batch_size × (q|ctx)`. Buffers + /// stay sized for the max, so callers can freely vary batch/q/ctx below it. + pub(crate) fn set_active_shape(&mut self, batch_size: usize, q_len: usize, ctx_len: usize) { + debug_assert!(batch_size <= self.max_batch_size); + debug_assert!(q_len <= self.max_q_len); + debug_assert!(ctx_len <= self.max_ctx_len); + self.q_len = q_len; + self.ctx_len = ctx_len; + self.total_q_len = batch_size * q_len; + self.total_ctx_len = batch_size * ctx_len; + self.total_kv_len = batch_size * (ctx_len + q_len); + self.noise.seq_len = self.total_q_len; + self.target_hidden.seq_len = self.total_ctx_len; + self.target_projected.seq_len = self.total_ctx_len; + self.target_normed.seq_len = self.total_ctx_len; + self.hidden.seq_len = self.total_q_len; + self.hidden_out.seq_len = self.total_q_len; + self.normed.seq_len = self.total_q_len; + self.q.seq_len = self.total_q_len; + self.q_ctx_scratch.seq_len = self.total_ctx_len; + self.k_ctx.seq_len = self.total_ctx_len; + self.k_noise.seq_len = self.total_q_len; + self.v_ctx.seq_len = self.total_ctx_len; + self.v_noise.seq_len = self.total_q_len; + self.k_all.seq_len = self.total_kv_len; + self.v_all.seq_len = self.total_kv_len; + self.attn_out.seq_len = self.total_q_len; + self.o_buf.seq_len = self.total_q_len; + self.gate_up.seq_len = self.total_q_len; + self.act_out.seq_len = self.total_q_len; + } + + pub(crate) fn prepare_ragged_plan( + &mut self, + model: &DFlashDraftModel, + batch_size: usize, + ) -> Result<()> { + // The plan depends on (batch_size, q_len, ctx_len); with a single + // instance buffer any of them can change between forwards, so all three + // must be part of the cache key. + let needs_rebuild = self + .ragged_plan + .as_ref() + .map(|cached| { + cached.batch_size != batch_size + || cached.q_len != self.q_len + || cached.ctx_len != self.ctx_len + }) + .unwrap_or(true); + if needs_rebuild { + let config = model.config(); + let q_lens = vec![self.q_len; batch_size]; + let kv_lens = vec![self.ctx_len + self.q_len; batch_size]; + let plan = RaggedPrefillPlan::new( + model.device_context(), + &q_lens, + &kv_lens, + config.num_attention_heads / config.num_key_value_heads, + )?; + self.ragged_plan = Some(CachedRaggedPlan { + batch_size, + q_len: self.q_len, + ctx_len: self.ctx_len, + plan, + }); + } + Ok(()) + } +} diff --git a/openinfer-qwen3-4b-dflash/src/batch_forward.rs b/openinfer-qwen3-4b-dflash/src/batch_forward.rs new file mode 100644 index 00000000..70e742f4 --- /dev/null +++ b/openinfer-qwen3-4b-dflash/src/batch_forward.rs @@ -0,0 +1,446 @@ +use anyhow::Result; +use half::bf16; +use openinfer_core::ops; +use openinfer_core::tensor::{DeviceContext, HiddenStates}; + +use crate::batch_buffers::DFlashBatchBuffers; +use crate::forward::DFlashTargetHidden; +use crate::weights::{DFlashDraftModel, DFlashLayer}; + +pub struct DFlashBatchInput<'a> { + pub noise_embedding: &'a HiddenStates, + pub target_hidden: DFlashTargetHidden<'a>, + pub position_ids: &'a [i32], +} + +pub struct DFlashHostBatchInput<'a> { + pub noise_embedding: &'a [bf16], + pub target_hidden: &'a [bf16], + pub position_ids: &'a [i32], +} + +impl DFlashDraftModel { + pub fn create_batch_buffers( + &self, + max_batch_size: usize, + max_q_len: usize, + max_ctx_len: usize, + ) -> Result { + DFlashBatchBuffers::new(self, max_batch_size, max_q_len, max_ctx_len) + } + + pub fn forward_batch<'a>( + &self, + requests: &[DFlashBatchInput<'_>], + bufs: &'a mut DFlashBatchBuffers, + ) -> Result<&'a HiddenStates> { + anyhow::ensure!(!requests.is_empty(), "DFlash batch is empty"); + anyhow::ensure!( + requests.len() <= bufs.max_batch_size, + "DFlash batch size {} exceeds buffer capacity {}", + requests.len(), + bufs.max_batch_size + ); + // All requests in an exact-shape batch share one (q_len, ctx_len); read + // it from the first, then narrow the buffer's active shape to match. + let (q_len, ctx_len) = self.validate_forward_inputs( + requests[0].noise_embedding, + &requests[0].target_hidden, + requests[0].position_ids, + )?; + anyhow::ensure!( + q_len <= bufs.max_q_len && ctx_len <= bufs.max_ctx_len, + "DFlash batch shape q_len={}, ctx_len={} exceeds buffer capacity q_len={}, ctx_len={}", + q_len, + ctx_len, + bufs.max_q_len, + bufs.max_ctx_len, + ); + // Exact-shape batch: the first request is fully validated above, so the + // rest only need to match the three lengths that fix (q_len, ctx_len) + // — re-running the full validator per request just repeats the same + // hidden_dim / positivity checks against the same config. + for req in &requests[1..] { + anyhow::ensure!( + req.noise_embedding.seq_len == q_len + && req.noise_embedding.hidden_dim == requests[0].noise_embedding.hidden_dim, + "DFlash exact-shape batch noise_embedding shape mismatch" + ); + anyhow::ensure!( + req.target_hidden.concatenated.seq_len == ctx_len, + "DFlash exact-shape batch target_hidden seq_len mismatch" + ); + anyhow::ensure!( + req.position_ids.len() == ctx_len + q_len, + "DFlash exact-shape batch position_ids len mismatch" + ); + } + bufs.set_active_shape(requests.len(), q_len, ctx_len); + compact_inputs(self.device_context(), requests, bufs)?; + self.forward_compact_batch(requests.len(), bufs)?; + Ok(&bufs.normed) + } + + pub fn forward_host_batch<'a>( + &self, + requests: &[DFlashHostBatchInput<'_>], + bufs: &'a mut DFlashBatchBuffers, + ) -> Result<&'a HiddenStates> { + anyhow::ensure!(!requests.is_empty(), "DFlash host batch is empty"); + anyhow::ensure!( + requests.len() <= bufs.max_batch_size, + "DFlash host batch size {} exceeds buffer capacity {}", + requests.len(), + bufs.max_batch_size + ); + let config = self.config(); + let hidden = config.hidden_size; + let target_hidden_dim = config.hidden_size * config.target_layer_count(); + // Derive the shared (q_len, ctx_len) from the first request, the same + // way forward_batch derives it from device tensors. + let first = &requests[0]; + anyhow::ensure!( + first.noise_embedding.len() % hidden == 0, + "noise_embedding len {} is not a multiple of hidden_size {}", + first.noise_embedding.len(), + hidden, + ); + let q_len = first.noise_embedding.len() / hidden; + anyhow::ensure!( + first.target_hidden.len() % target_hidden_dim == 0, + "target_hidden len {} is not a multiple of target_hidden_dim {}", + first.target_hidden.len(), + target_hidden_dim, + ); + let ctx_len = first.target_hidden.len() / target_hidden_dim; + anyhow::ensure!(q_len > 0, "DFlash host batch q_len must be positive"); + anyhow::ensure!(ctx_len > 0, "DFlash host batch ctx_len must be positive"); + anyhow::ensure!( + q_len <= bufs.max_q_len && ctx_len <= bufs.max_ctx_len, + "DFlash host batch shape q_len={}, ctx_len={} exceeds buffer capacity q_len={}, ctx_len={}", + q_len, + ctx_len, + bufs.max_q_len, + bufs.max_ctx_len, + ); + let noise_len = q_len * hidden; + let target_len = ctx_len * target_hidden_dim; + let position_len = ctx_len + q_len; + for req in &requests[1..] { + anyhow::ensure!( + req.noise_embedding.len() == noise_len, + "noise_embedding len {} != {}", + req.noise_embedding.len(), + noise_len + ); + anyhow::ensure!( + req.target_hidden.len() == target_len, + "target_hidden len {} != {}", + req.target_hidden.len(), + target_len + ); + anyhow::ensure!( + req.position_ids.len() == position_len, + "position_ids len {} != {}", + req.position_ids.len(), + position_len + ); + } + bufs.set_active_shape(requests.len(), q_len, ctx_len); + compact_host_inputs(self.device_context(), requests, bufs)?; + self.forward_compact_batch(requests.len(), bufs)?; + Ok(&bufs.normed) + } + + fn forward_compact_batch( + &self, + batch_size: usize, + bufs: &mut DFlashBatchBuffers, + ) -> Result<()> { + let config = self.config(); + ops::gemm_into_checked( + self.device_context(), + &self.fc, + &bufs.target_hidden, + &mut bufs.target_projected, + )?; + ops::rms_norm_batch_into( + self.device_context(), + &bufs.target_projected, + &self.hidden_norm, + config.rms_norm_eps, + &mut bufs.target_normed, + ); + copy_hidden( + self.device_context(), + &bufs.noise, + 0, + &mut bufs.hidden, + 0, + config.hidden_size, + bufs.total_q_len, + )?; + for layer in &self.layers { + self.forward_compact_batch_layer(layer, batch_size, bufs)?; + } + ops::rms_norm_batch_into( + self.device_context(), + &bufs.hidden, + &self.norm, + config.rms_norm_eps, + &mut bufs.normed, + ); + Ok(()) + } + + fn forward_compact_batch_layer( + &self, + layer: &DFlashLayer, + batch_size: usize, + bufs: &mut DFlashBatchBuffers, + ) -> Result<()> { + let config = self.config(); + let ctx = self.device_context(); + ops::rms_norm_batch_into( + ctx, + &bufs.hidden, + &layer.input_layernorm, + config.rms_norm_eps, + &mut bufs.normed, + ); + ops::gemm_into_checked(ctx, &layer.attention.q_proj, &bufs.normed, &mut bufs.q)?; + ops::gemm_into_checked( + ctx, + &layer.attention.k_proj, + &bufs.normed, + &mut bufs.k_noise, + )?; + ops::gemm_into_checked( + ctx, + &layer.attention.v_proj, + &bufs.normed, + &mut bufs.v_noise, + )?; + ops::qk_norm_rope_batch_decode_into( + ctx, + &mut bufs.q, + &mut bufs.k_noise, + &layer.attention.q_norm, + &layer.attention.k_norm, + &self.cos_cache, + &self.sin_cache, + &bufs.positions_q, + config.num_attention_heads, + config.num_key_value_heads, + config.head_dim, + config.rms_norm_eps, + ); + + ops::gemm_into_checked( + ctx, + &layer.attention.k_proj, + &bufs.target_normed, + &mut bufs.k_ctx, + )?; + ops::gemm_into_checked( + ctx, + &layer.attention.v_proj, + &bufs.target_normed, + &mut bufs.v_ctx, + )?; + // Context-K needs norm + RoPE but has no corresponding Q. The K-only + // kernel launches num_kv_heads blocks per token instead of + // num_q_heads + num_kv_heads, dropping 80% of the joint kernel's work + // (the dead Q branch) for Qwen3-4B's 16:4 GQA ratio. + ops::k_norm_rope_batch_decode_into( + ctx, + &mut bufs.k_ctx, + &layer.attention.k_norm, + &self.cos_cache, + &self.sin_cache, + &bufs.positions_ctx, + config.num_key_value_heads, + config.head_dim, + config.rms_norm_eps, + ); + + // Concatenate per-request [ctx | noise] K/V into the contiguous layout + // the ragged attention kernel expects. Two strided segment copies per + // tensor (ctx segment at offset 0, noise segment at offset ctx_len) + // replace the old 2 * batch_size memcpy_dtod loop (`compact_kv`): + // bs=32 dropped from 128 launches/layer to 4. + let kv_seg_total = bufs.ctx_len + bufs.q_len; + ops::strided_segment_copy_into( + ctx, + &bufs.k_ctx, + &mut bufs.k_all, + bufs.ctx_len, + kv_seg_total, + 0, + batch_size, + )?; + ops::strided_segment_copy_into( + ctx, + &bufs.k_noise, + &mut bufs.k_all, + bufs.q_len, + kv_seg_total, + bufs.ctx_len, + batch_size, + )?; + ops::strided_segment_copy_into( + ctx, + &bufs.v_ctx, + &mut bufs.v_all, + bufs.ctx_len, + kv_seg_total, + 0, + batch_size, + )?; + ops::strided_segment_copy_into( + ctx, + &bufs.v_noise, + &mut bufs.v_all, + bufs.q_len, + kv_seg_total, + bufs.ctx_len, + batch_size, + )?; + bufs.prepare_ragged_plan(self, batch_size)?; + let cached_plan = bufs.ragged_plan.take().expect("ragged plan exists"); + let attention_result = ops::batch_prefill_ragged_nhd_noncausal_into( + ctx, + &bufs.q, + &bufs.k_all, + &bufs.v_all, + &mut bufs.attn_out, + &cached_plan.plan, + config.num_attention_heads, + config.num_key_value_heads, + config.head_dim, + ); + bufs.ragged_plan = Some(cached_plan); + attention_result?; + ops::gemm_into_checked( + ctx, + &layer.attention.o_proj, + &bufs.attn_out, + &mut bufs.o_buf, + )?; + openinfer_kernels::ops::fused_add_rms_norm_round_batch_into( + ctx, + &mut bufs.hidden, + &bufs.o_buf, + &layer.post_attention_layernorm, + config.rms_norm_eps, + &mut bufs.normed, + )?; + ops::gemm_into_checked( + ctx, + &layer.mlp.gate_up_proj, + &bufs.normed, + &mut bufs.gate_up, + )?; + ops::silu_mul_fused_batch_into(ctx, &bufs.gate_up, &mut bufs.act_out)?; + ops::gemm_into_checked(ctx, &layer.mlp.down_proj, &bufs.act_out, &mut bufs.o_buf)?; + ops::add_batch_into(ctx, &bufs.hidden, &bufs.o_buf, &mut bufs.hidden_out)?; + std::mem::swap(&mut bufs.hidden, &mut bufs.hidden_out); + Ok(()) + } +} + +fn compact_inputs( + ctx: &DeviceContext, + requests: &[DFlashBatchInput<'_>], + bufs: &mut DFlashBatchBuffers, +) -> Result<()> { + let hidden = bufs.noise.hidden_dim; + let target_hidden = bufs.target_hidden.hidden_dim; + let mut pos_q = Vec::with_capacity(bufs.total_q_len); + let mut pos_ctx = Vec::with_capacity(bufs.total_ctx_len); + for (i, req) in requests.iter().enumerate() { + copy_hidden( + ctx, + req.noise_embedding, + 0, + &mut bufs.noise, + i * bufs.q_len, + hidden, + bufs.q_len, + )?; + copy_hidden( + ctx, + req.target_hidden.concatenated, + 0, + &mut bufs.target_hidden, + i * bufs.ctx_len, + target_hidden, + bufs.ctx_len, + )?; + pos_ctx.extend_from_slice(&req.position_ids[..bufs.ctx_len]); + pos_q.extend_from_slice(&req.position_ids[bufs.ctx_len..]); + } + let mut dst_q = bufs.positions_q.slice_mut(..pos_q.len()); + ctx.stream.memcpy_htod(&pos_q, &mut dst_q)?; + let mut dst_ctx = bufs.positions_ctx.slice_mut(..pos_ctx.len()); + ctx.stream.memcpy_htod(&pos_ctx, &mut dst_ctx)?; + Ok(()) +} + +fn compact_host_inputs( + ctx: &DeviceContext, + requests: &[DFlashHostBatchInput<'_>], + bufs: &mut DFlashBatchBuffers, +) -> Result<()> { + let hidden = bufs.noise.hidden_dim; + let target_hidden = bufs.target_hidden.hidden_dim; + let q_len = bufs.q_len; + let ctx_len = bufs.ctx_len; + let batch_size = requests.len(); + + // Stitch all requests into contiguous host slices, then upload each tensor + // in a single H2D copy — matches Qwen3's batch metadata upload pattern and + // avoids one launch per request per tensor. + let mut noise_flat = Vec::with_capacity(batch_size * q_len * hidden); + let mut target_flat = Vec::with_capacity(batch_size * ctx_len * target_hidden); + let mut pos_q = Vec::with_capacity(batch_size * q_len); + let mut pos_ctx = Vec::with_capacity(batch_size * ctx_len); + for req in requests { + noise_flat.extend_from_slice(req.noise_embedding); + target_flat.extend_from_slice(req.target_hidden); + pos_ctx.extend_from_slice(&req.position_ids[..ctx_len]); + pos_q.extend_from_slice(&req.position_ids[ctx_len..]); + } + + let mut noise_dst = bufs.noise.data.slice_mut(..noise_flat.len()); + ctx.stream.memcpy_htod(&noise_flat, &mut noise_dst)?; + let mut target_dst = bufs.target_hidden.data.slice_mut(..target_flat.len()); + ctx.stream.memcpy_htod(&target_flat, &mut target_dst)?; + let mut dst_q = bufs.positions_q.slice_mut(..pos_q.len()); + ctx.stream.memcpy_htod(&pos_q, &mut dst_q)?; + let mut dst_ctx = bufs.positions_ctx.slice_mut(..pos_ctx.len()); + ctx.stream.memcpy_htod(&pos_ctx, &mut dst_ctx)?; + Ok(()) +} + +pub(crate) fn copy_hidden( + ctx: &DeviceContext, + src: &HiddenStates, + src_token_offset: usize, + dst: &mut HiddenStates, + dst_token_offset: usize, + hidden_dim: usize, + token_count: usize, +) -> Result<()> { + debug_assert_eq!(src.hidden_dim, hidden_dim); + debug_assert_eq!(dst.hidden_dim, hidden_dim); + debug_assert!(src_token_offset + token_count <= src.seq_len); + debug_assert!(dst_token_offset + token_count <= dst.seq_len); + let len = hidden_dim * token_count; + let src_offset = hidden_dim * src_token_offset; + let dst_offset = hidden_dim * dst_token_offset; + let src_view = src.data.slice(src_offset..src_offset + len); + let mut dst_view = dst.data.slice_mut(dst_offset..dst_offset + len); + ctx.stream.memcpy_dtod(&src_view, &mut dst_view)?; + Ok(()) +} diff --git a/openinfer-qwen3-4b-dflash/src/bin/qwen3_dflash_batch_bench.rs b/openinfer-qwen3-4b-dflash/src/bin/qwen3_dflash_batch_bench.rs new file mode 100644 index 00000000..9ba748ae --- /dev/null +++ b/openinfer-qwen3-4b-dflash/src/bin/qwen3_dflash_batch_bench.rs @@ -0,0 +1,227 @@ +use std::path::PathBuf; +use std::time::Instant; + +use anyhow::{Context, Result, bail}; +use half::bf16; +use openinfer_core::tensor::HiddenStates; +use openinfer_qwen3_4b_dflash::{DFlashBatchInput, DFlashDraftModel, DFlashTargetHidden}; +use serde::Serialize; + +fn main() -> Result<()> { + let args = Args::parse()?; + let model = DFlashDraftModel::load(&args.model_path, args.device)?; + let config = model.config(); + let ctx = model.device_context(); + let mut reports = Vec::new(); + + for &batch_size in &args.batch_sizes { + let mut noises = Vec::with_capacity(batch_size); + let mut targets = Vec::with_capacity(batch_size); + let mut positions = Vec::with_capacity(batch_size); + for i in 0..batch_size { + let noise = deterministic_bf16(args.q_len * config.hidden_size, 0xD4A5_0000 + i as u64); + let target = deterministic_bf16( + args.ctx_len * config.hidden_size * config.target_layer_count(), + 0xC0DE_0000 + i as u64, + ); + noises.push(HiddenStates { + data: ctx.stream.clone_htod(&noise).context("noise h2d")?, + hidden_dim: config.hidden_size, + seq_len: args.q_len, + }); + targets.push(HiddenStates { + data: ctx.stream.clone_htod(&target).context("target h2d")?, + hidden_dim: config.hidden_size * config.target_layer_count(), + seq_len: args.ctx_len, + }); + positions.push( + (0..(args.ctx_len + args.q_len)) + .map(|pos| pos as i32) + .collect::>(), + ); + } + let mut bufs = model.create_batch_buffers(batch_size, args.q_len, args.ctx_len)?; + let inputs = build_inputs(&noises, &targets, &positions); + for _ in 0..args.warmup { + let _ = model.forward_batch(&inputs, &mut bufs)?; + ctx.sync()?; + } + let mut latencies_ms = Vec::with_capacity(args.iters); + for _ in 0..args.iters { + ctx.sync()?; + let started = Instant::now(); + let _ = model.forward_batch(&inputs, &mut bufs)?; + ctx.sync()?; + latencies_ms.push(started.elapsed().as_secs_f64() * 1000.0); + } + let stats = Stats::from(&latencies_ms); + let mean_s = stats.mean / 1000.0; + reports.push(BatchReport { + batch_size, + ctx_len: args.ctx_len, + q_len: args.q_len, + warmup: args.warmup, + iters: args.iters, + draft_tokens_per_s: (batch_size * args.q_len) as f64 / mean_s, + requests_per_s: batch_size as f64 / mean_s, + latency_ms: stats, + }); + } + + let report = Report { + schema: 1, + engine: "openinfer-qwen3-4b-dflash-batch", + model_path: args.model_path.to_string_lossy().to_string(), + device: args.device, + hidden_size: config.hidden_size, + target_layer_count: config.target_layer_count(), + reports, + }; + println!("{}", serde_json::to_string_pretty(&report)?); + Ok(()) +} + +fn build_inputs<'a>( + noises: &'a [HiddenStates], + targets: &'a [HiddenStates], + positions: &'a [Vec], +) -> Vec> { + noises + .iter() + .zip(targets.iter()) + .zip(positions.iter()) + .map(|((noise, target), position_ids)| DFlashBatchInput { + noise_embedding: noise, + target_hidden: DFlashTargetHidden { + concatenated: target, + }, + position_ids, + }) + .collect() +} + +#[derive(Clone)] +struct Args { + model_path: PathBuf, + device: usize, + ctx_len: usize, + q_len: usize, + warmup: usize, + iters: usize, + batch_sizes: Vec, +} + +impl Args { + fn parse() -> Result { + let mut model_path = PathBuf::from("/home/hezhaozhao/models/Qwen3-4B-DFlash-b16"); + let mut device = 0usize; + let mut ctx_len = 2usize; + let mut q_len = 16usize; + let mut warmup = 5usize; + let mut iters = 30usize; + let mut batch_sizes = vec![1, 2, 4, 8, 16, 32]; + let mut args = std::env::args().skip(1); + while let Some(arg) = args.next() { + match arg.as_str() { + "--model-path" => model_path = PathBuf::from(next_value(&mut args, &arg)?), + "--device" => device = next_value(&mut args, &arg)?.parse()?, + "--ctx-len" => ctx_len = next_value(&mut args, &arg)?.parse()?, + "--q-len" => q_len = next_value(&mut args, &arg)?.parse()?, + "--warmup" => warmup = next_value(&mut args, &arg)?.parse()?, + "--iters" => iters = next_value(&mut args, &arg)?.parse()?, + "--batch-sizes" => { + batch_sizes = next_value(&mut args, &arg)? + .split(',') + .map(str::parse) + .collect::, _>>()?; + } + _ => bail!("unknown argument {arg}"), + } + } + if ctx_len == 0 || q_len == 0 || iters == 0 { + bail!("--ctx-len, --q-len, and --iters must be greater than zero"); + } + if batch_sizes.is_empty() || batch_sizes.contains(&0) { + bail!("--batch-sizes must contain positive batch sizes"); + } + Ok(Self { + model_path, + device, + ctx_len, + q_len, + warmup, + iters, + batch_sizes, + }) + } +} + +fn next_value(args: &mut impl Iterator, flag: &str) -> Result { + args.next() + .with_context(|| format!("{flag} requires a value")) +} + +fn deterministic_bf16(len: usize, seed: u64) -> Vec { + let mut state = seed; + let mut out = Vec::with_capacity(len); + for _ in 0..len { + state = state.wrapping_mul(6364136223846793005).wrapping_add(1); + let bits = ((state >> 32) as u32) as f32 / (u32::MAX as f32); + out.push(bf16::from_f32((bits * 2.0 - 1.0) * 0.125)); + } + out +} + +#[derive(Serialize)] +struct Report { + schema: u32, + engine: &'static str, + model_path: String, + device: usize, + hidden_size: usize, + target_layer_count: usize, + reports: Vec, +} + +#[derive(Serialize)] +struct BatchReport { + batch_size: usize, + ctx_len: usize, + q_len: usize, + warmup: usize, + iters: usize, + draft_tokens_per_s: f64, + requests_per_s: f64, + latency_ms: Stats, +} + +#[derive(Serialize)] +struct Stats { + mean: f64, + p50: f64, + p90: f64, + p99: f64, + min: f64, + max: f64, +} + +impl Stats { + fn from(values: &[f64]) -> Self { + let mut sorted = values.to_vec(); + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let mean = sorted.iter().sum::() / sorted.len() as f64; + Self { + mean, + p50: percentile(&sorted, 0.50), + p90: percentile(&sorted, 0.90), + p99: percentile(&sorted, 0.99), + min: sorted[0], + max: sorted[sorted.len() - 1], + } + } +} + +fn percentile(sorted: &[f64], q: f64) -> f64 { + let idx = ((sorted.len() - 1) as f64 * q).round() as usize; + sorted[idx.min(sorted.len() - 1)] +} diff --git a/openinfer-qwen3-4b-dflash/src/bin/qwen3_dflash_forward_bench.rs b/openinfer-qwen3-4b-dflash/src/bin/qwen3_dflash_forward_bench.rs new file mode 100644 index 00000000..dac9f374 --- /dev/null +++ b/openinfer-qwen3-4b-dflash/src/bin/qwen3_dflash_forward_bench.rs @@ -0,0 +1,308 @@ +use std::path::PathBuf; +use std::time::Instant; + +use anyhow::{Context, Result, bail}; +use half::bf16; +use openinfer_core::tensor::HiddenStates; +use openinfer_qwen3_4b_dflash::{DFlashDraftModel, DFlashTargetHidden}; +use safetensors::{Dtype, SafeTensors}; +use serde::Serialize; + +fn main() -> Result<()> { + let args = Args::parse()?; + let model = DFlashDraftModel::load(&args.model_path, args.device)?; + let config = model.config(); + let ctx = model.device_context(); + + let (noise, target_hidden, positions, ctx_len, q_len) = if let Some(fixture) = &args.fixture { + let bytes = std::fs::read(fixture) + .with_context(|| format!("failed to read fixture {}", fixture.display()))?; + let st = SafeTensors::deserialize(&bytes).context("parse fixture")?; + // Derive ctx_len/q_len from the fixture's actual tensor shapes so the + // bench works for any --ctx-len/--q-len the Python side used, rather + // than requiring the caller to repeat them on both sides. + let noise_view = st + .tensor("noise_embedding") + .with_context(|| "missing tensor noise_embedding")?; + let q_len = noise_view.shape()[1]; + let target_view = st + .tensor("target_hidden") + .with_context(|| "missing tensor target_hidden")?; + let ctx_len = target_view.shape()[1]; + let noise = read_bf16(&st, "noise_embedding", &[1, q_len, config.hidden_size])?; + let target_hidden = read_bf16( + &st, + "target_hidden", + &[1, ctx_len, config.hidden_size * config.target_layer_count()], + )?; + let positions = read_i32(&st, "position_ids", &[1, ctx_len + q_len])?; + (noise, target_hidden, positions, ctx_len, q_len) + } else { + let noise = deterministic_bf16(args.q_len * config.hidden_size, 0xD4A5_4B16); + let target_hidden = deterministic_bf16( + args.ctx_len * config.hidden_size * config.target_layer_count(), + 0xD4A5_C0DE, + ); + let positions = (0..(args.ctx_len + args.q_len)) + .map(|pos| pos as i32) + .collect::>(); + (noise, target_hidden, positions, args.ctx_len, args.q_len) + }; + + let noise = HiddenStates { + data: ctx.stream.clone_htod(&noise).context("noise h2d")?, + hidden_dim: config.hidden_size, + seq_len: q_len, + }; + let target_hidden = HiddenStates { + data: ctx + .stream + .clone_htod(&target_hidden) + .context("target hidden h2d")?, + hidden_dim: config.hidden_size * config.target_layer_count(), + seq_len: ctx_len, + }; + ctx.sync()?; + + let mut cache = model.create_draft_cache(q_len, ctx_len, ctx_len + q_len)?; + if args.draft_cache { + model.prepare_step_context( + DFlashTargetHidden { + concatenated: &target_hidden, + }, + &positions, + &mut cache, + )?; + ctx.sync()?; + } + for _ in 0..args.warmup { + if args.draft_cache { + cache.reset(); + model.prepare_step_context( + DFlashTargetHidden { + concatenated: &target_hidden, + }, + &positions, + &mut cache, + )?; + let _out = model.forward_with_draft_cache(&noise, &positions, &mut cache)?; + } else { + let _out = model.forward_with_cache( + &noise, + DFlashTargetHidden { + concatenated: &target_hidden, + }, + &positions, + &mut cache, + )?; + } + ctx.sync()?; + } + + let mut latencies_ms = Vec::with_capacity(args.iters); + for _ in 0..args.iters { + ctx.sync()?; + let started = Instant::now(); + if args.draft_cache { + cache.reset(); + model.prepare_step_context( + DFlashTargetHidden { + concatenated: &target_hidden, + }, + &positions, + &mut cache, + )?; + let _out = model.forward_with_draft_cache(&noise, &positions, &mut cache)?; + } else { + let _out = model.forward_with_cache( + &noise, + DFlashTargetHidden { + concatenated: &target_hidden, + }, + &positions, + &mut cache, + )?; + } + ctx.sync()?; + latencies_ms.push(started.elapsed().as_secs_f64() * 1000.0); + } + + let report = Report { + schema: 1, + engine: "openinfer-qwen3-4b-dflash", + model_path: args.model_path.to_string_lossy().to_string(), + device: args.device, + ctx_len: args.ctx_len, + q_len: args.q_len, + hidden_size: config.hidden_size, + target_layer_count: config.target_layer_count(), + draft_cache: args.draft_cache, + warmup: args.warmup, + iters: args.iters, + latency_ms: Stats::from(&latencies_ms), + }; + println!("{}", serde_json::to_string_pretty(&report)?); + Ok(()) +} + +#[derive(Clone)] +struct Args { + model_path: PathBuf, + fixture: Option, + device: usize, + ctx_len: usize, + q_len: usize, + warmup: usize, + iters: usize, + draft_cache: bool, +} + +impl Args { + fn parse() -> Result { + let mut model_path = PathBuf::from("/home/hezhaozhao/models/Qwen3-4B-DFlash-b16"); + let mut fixture = None; + let mut device = 0usize; + let mut ctx_len = 2usize; + let mut q_len = 16usize; + let mut warmup = 5usize; + let mut iters = 30usize; + let mut draft_cache = false; + let mut args = std::env::args().skip(1); + while let Some(arg) = args.next() { + match arg.as_str() { + "--model-path" => model_path = PathBuf::from(next_value(&mut args, &arg)?), + "--fixture" => fixture = Some(PathBuf::from(next_value(&mut args, &arg)?)), + "--device" => device = next_value(&mut args, &arg)?.parse()?, + "--ctx-len" => ctx_len = next_value(&mut args, &arg)?.parse()?, + "--q-len" => q_len = next_value(&mut args, &arg)?.parse()?, + "--warmup" => warmup = next_value(&mut args, &arg)?.parse()?, + "--iters" => iters = next_value(&mut args, &arg)?.parse()?, + "--draft-cache" | "--context-cache" => draft_cache = true, + _ => bail!("unknown argument {arg}"), + } + } + if ctx_len == 0 { + bail!("--ctx-len must be greater than zero"); + } + if q_len == 0 { + bail!("--q-len must be greater than zero"); + } + if iters == 0 { + bail!("--iters must be greater than zero"); + } + Ok(Self { + model_path, + fixture, + device, + ctx_len, + q_len, + warmup, + iters, + draft_cache, + }) + } +} + +fn next_value(args: &mut impl Iterator, flag: &str) -> Result { + args.next() + .with_context(|| format!("{flag} requires a value")) +} + +fn deterministic_bf16(len: usize, seed: u64) -> Vec { + let mut state = seed; + let mut out = Vec::with_capacity(len); + for _ in 0..len { + state = state.wrapping_mul(6364136223846793005).wrapping_add(1); + let bits = ((state >> 32) as u32) as f32 / (u32::MAX as f32); + let value = (bits * 2.0 - 1.0) * 0.125; + out.push(bf16::from_f32(value)); + } + out +} + +#[derive(Serialize)] +struct Report { + schema: u32, + engine: &'static str, + model_path: String, + device: usize, + ctx_len: usize, + q_len: usize, + hidden_size: usize, + target_layer_count: usize, + draft_cache: bool, + warmup: usize, + iters: usize, + latency_ms: Stats, +} + +#[derive(Serialize)] +struct Stats { + mean: f64, + p50: f64, + p90: f64, + p99: f64, + min: f64, + max: f64, +} + +impl Stats { + fn from(values: &[f64]) -> Self { + let mut sorted = values.to_vec(); + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let mean = sorted.iter().sum::() / sorted.len() as f64; + Self { + mean, + p50: percentile(&sorted, 0.50), + p90: percentile(&sorted, 0.90), + p99: percentile(&sorted, 0.99), + min: sorted[0], + max: sorted[sorted.len() - 1], + } + } +} + +fn read_bf16(st: &SafeTensors<'_>, name: &str, shape: &[usize]) -> Result> { + let view = st + .tensor(name) + .with_context(|| format!("missing tensor {name}"))?; + if view.dtype() != Dtype::BF16 { + bail!("{name} must be BF16, got {:?}", view.dtype()); + } + if view.shape() != shape { + bail!( + "{name} shape mismatch: expected {shape:?}, got {:?}", + view.shape() + ); + } + Ok(view + .data() + .chunks_exact(2) + .map(|chunk| bf16::from_bits(u16::from_le_bytes([chunk[0], chunk[1]]))) + .collect()) +} + +fn read_i32(st: &SafeTensors<'_>, name: &str, shape: &[usize]) -> Result> { + let view = st + .tensor(name) + .with_context(|| format!("missing tensor {name}"))?; + if view.dtype() != Dtype::I32 { + bail!("{name} must be I32, got {:?}", view.dtype()); + } + if view.shape() != shape { + bail!( + "{name} shape mismatch: expected {shape:?}, got {:?}", + view.shape() + ); + } + Ok(view + .data() + .chunks_exact(4) + .map(|chunk| i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect()) +} + +fn percentile(sorted: &[f64], q: f64) -> f64 { + let idx = ((sorted.len() - 1) as f64 * q).round() as usize; + sorted[idx.min(sorted.len() - 1)] +} diff --git a/openinfer-qwen3-4b-dflash/src/bin/qwen3_dflash_forward_fixture.rs b/openinfer-qwen3-4b-dflash/src/bin/qwen3_dflash_forward_fixture.rs new file mode 100644 index 00000000..08ff66a3 --- /dev/null +++ b/openinfer-qwen3-4b-dflash/src/bin/qwen3_dflash_forward_fixture.rs @@ -0,0 +1,155 @@ +use std::collections::HashMap; +use std::path::PathBuf; + +use anyhow::{Context, Result, bail}; +use half::bf16; +use openinfer_core::tensor::HiddenStates; +use openinfer_qwen3_4b_dflash::{DFlashDraftModel, DFlashTargetHidden}; +use safetensors::{Dtype, SafeTensors, tensor::TensorView}; + +fn main() -> Result<()> { + let args = Args::parse()?; + let fixture_bytes = std::fs::read(&args.fixture).with_context(|| { + format!( + "failed to read input fixture {}", + args.fixture.to_string_lossy() + ) + })?; + let st = SafeTensors::deserialize(&fixture_bytes).context("parse input fixture")?; + let model = DFlashDraftModel::load(&args.model_path, args.device)?; + let config = model.config(); + let ctx = model.device_context(); + + let noise = bf16_tensor(&st, "noise_embedding")?; + let target_hidden = bf16_tensor(&st, "target_hidden")?; + let positions = i32_tensor(&st, "position_ids")?; + + if noise.1.len() != 3 || noise.1[0] != 1 || noise.1[2] != config.hidden_size { + bail!( + "noise_embedding shape mismatch: expected [1, q_len, {}], got {:?}", + config.hidden_size, + noise.1 + ); + } + if target_hidden.1.len() != 3 + || target_hidden.1[0] != 1 + || target_hidden.1[2] != config.hidden_size * config.target_layer_count() + { + bail!( + "target_hidden shape mismatch: expected [1, ctx_len, {}], got {:?}", + config.hidden_size * config.target_layer_count(), + target_hidden.1 + ); + } + let q_len = noise.1[1]; + let ctx_len = target_hidden.1[1]; + ensure_shape("position_ids", &positions.1, &[1, ctx_len + q_len])?; + + let noise_embedding = HiddenStates { + data: ctx.stream.clone_htod(&noise.0)?, + hidden_dim: config.hidden_size, + seq_len: q_len, + }; + let target_hidden = HiddenStates { + data: ctx.stream.clone_htod(&target_hidden.0)?, + hidden_dim: config.hidden_size * config.target_layer_count(), + seq_len: ctx_len, + }; + let out = model.forward( + &noise_embedding, + DFlashTargetHidden { + concatenated: &target_hidden, + }, + &positions.0, + )?; + ctx.sync()?; + let out = ctx.stream.clone_dtoh(&out.data)?; + ctx.sync()?; + + let out_bytes = bf16_bytes(&out); + let tensors = HashMap::from([( + "openinfer_output".to_string(), + TensorView::new(Dtype::BF16, vec![1, q_len, config.hidden_size], &out_bytes)?, + )]); + safetensors::serialize_to_file(tensors, None, &args.out)?; + Ok(()) +} + +struct Args { + model_path: PathBuf, + fixture: PathBuf, + out: PathBuf, + device: usize, +} + +impl Args { + fn parse() -> Result { + let mut model_path = None; + let mut fixture = None; + let mut out = None; + let mut device = 0usize; + let mut args = std::env::args().skip(1); + while let Some(arg) = args.next() { + match arg.as_str() { + "--model-path" => model_path = Some(PathBuf::from(next_value(&mut args, &arg)?)), + "--fixture" => fixture = Some(PathBuf::from(next_value(&mut args, &arg)?)), + "--out" => out = Some(PathBuf::from(next_value(&mut args, &arg)?)), + "--device" => device = next_value(&mut args, &arg)?.parse()?, + _ => bail!("unknown argument {arg}"), + } + } + Ok(Self { + model_path: model_path + .unwrap_or_else(|| PathBuf::from("/home/hezhaozhao/models/Qwen3-4B-DFlash-b16")), + fixture: fixture.context("--fixture is required")?, + out: out.context("--out is required")?, + device, + }) + } +} + +fn next_value(args: &mut impl Iterator, flag: &str) -> Result { + args.next() + .with_context(|| format!("{flag} requires a value")) +} + +fn ensure_shape(name: &str, got: &[usize], expected: &[usize]) -> Result<()> { + if got != expected { + bail!("{name} shape mismatch: expected {expected:?}, got {got:?}"); + } + Ok(()) +} + +fn bf16_tensor(st: &SafeTensors<'_>, name: &str) -> Result<(Vec, Vec)> { + let view = st.tensor(name)?; + if view.dtype() != Dtype::BF16 { + bail!("{name} must be BF16, got {:?}", view.dtype()); + } + let values = view + .data() + .chunks_exact(2) + .map(|chunk| bf16::from_bits(u16::from_le_bytes([chunk[0], chunk[1]]))) + .collect(); + Ok((values, view.shape().to_vec())) +} + +fn i32_tensor(st: &SafeTensors<'_>, name: &str) -> Result<(Vec, Vec)> { + let view = st.tensor(name)?; + if view.dtype() != Dtype::I32 { + bail!("{name} must be I32, got {:?}", view.dtype()); + } + let values = view + .data() + .chunks_exact(4) + .map(|chunk| i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect(); + Ok((values, view.shape().to_vec())) +} + +fn bf16_bytes(values: &[bf16]) -> Vec { + let mut out = Vec::with_capacity(values.len() * 2); + for value in values { + out.extend(value.to_bits().to_le_bytes()); + } + out +} diff --git a/openinfer-qwen3-4b-dflash/src/config.rs b/openinfer-qwen3-4b-dflash/src/config.rs new file mode 100644 index 00000000..bebd1657 --- /dev/null +++ b/openinfer-qwen3-4b-dflash/src/config.rs @@ -0,0 +1,143 @@ +use anyhow::{Result, bail}; +use serde::Deserialize; +use std::fs; +use std::path::Path; + +#[derive(Clone, Debug, Deserialize)] +pub struct DFlashInnerConfig { + pub mask_token_id: u32, + pub target_layer_ids: Vec, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct DFlashConfig { + pub architectures: Vec, + pub attention_bias: bool, + pub attention_dropout: f32, + pub block_size: usize, + pub dflash_config: DFlashInnerConfig, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_attention_heads: usize, + pub num_hidden_layers: usize, + pub num_key_value_heads: usize, + pub num_target_layers: usize, + pub head_dim: usize, + pub max_position_embeddings: usize, + pub rms_norm_eps: f32, + pub rope_theta: f32, + pub tie_word_embeddings: bool, + pub vocab_size: usize, +} + +impl DFlashConfig { + pub fn from_model_dir(model_path: &Path) -> Result { + let content = fs::read_to_string(model_path.join("config.json"))?; + let config: Self = serde_json::from_str(&content)?; + config.validate()?; + Ok(config) + } + + pub fn validate(&self) -> Result<()> { + if self + .architectures + .iter() + .all(|name| name != "DFlashDraftModel") + { + bail!("DFlash config architectures must include DFlashDraftModel"); + } + if self.attention_bias { + bail!("DFlash v1 expects bias-free Qwen3 projections"); + } + if self.attention_dropout != 0.0 { + bail!("DFlash inference expects attention_dropout=0"); + } + if self.num_hidden_layers == 0 { + bail!("DFlash draft must have at least one layer"); + } + if self.num_hidden_layers != 5 { + bail!( + "openinfer-qwen3-4b-dflash supports only Qwen3-4B-DFlash-b16 with 5 draft layers, got {}", + self.num_hidden_layers + ); + } + if self.block_size != 16 { + bail!( + "openinfer-qwen3-4b-dflash supports only Qwen3-4B-DFlash-b16 block_size=16, got {}", + self.block_size + ); + } + if self.dflash_config.mask_token_id != 151669 { + bail!( + "openinfer-qwen3-4b-dflash supports only Qwen3-4B-DFlash-b16 mask_token_id=151669, got {}", + self.dflash_config.mask_token_id + ); + } + if self.hidden_size == 0 || self.head_dim == 0 { + bail!("DFlash hidden_size/head_dim must be positive"); + } + if self.num_attention_heads == 0 || self.num_key_value_heads == 0 { + bail!("DFlash attention/KV head counts must be positive"); + } + if self.num_attention_heads % self.num_key_value_heads != 0 { + bail!("DFlash GQA requires attention heads divisible by KV heads"); + } + if self.dflash_config.target_layer_ids.len() != self.num_hidden_layers { + bail!( + "DFlash target_layer_ids len {} must match draft layers {}", + self.dflash_config.target_layer_ids.len(), + self.num_hidden_layers + ); + } + if self + .dflash_config + .target_layer_ids + .iter() + .any(|&layer| layer >= self.num_target_layers) + { + bail!("DFlash target_layer_ids must be within num_target_layers"); + } + if self.dflash_config.target_layer_ids.as_slice() != [1, 9, 17, 25, 33] { + bail!( + "openinfer-qwen3-4b-dflash supports only Qwen3-4B-DFlash-b16 target_layer_ids=[1, 9, 17, 25, 33], got {:?}", + self.dflash_config.target_layer_ids + ); + } + Ok(()) + } + + pub fn target_layer_count(&self) -> usize { + self.dflash_config.target_layer_ids.len() + } + + pub fn q_dim(&self) -> usize { + self.num_attention_heads * self.head_dim + } + + pub fn kv_dim(&self) -> usize { + self.num_key_value_heads * self.head_dim + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const LOCAL_DFLASH: &str = "/home/hezhaozhao/models/Qwen3-4B-DFlash-b16"; + + #[test] + fn parses_local_dflash_config() { + let path = Path::new(LOCAL_DFLASH); + if !path.exists() { + eprintln!("skipping: {LOCAL_DFLASH} does not exist"); + return; + } + let config = DFlashConfig::from_model_dir(path).expect("config"); + assert_eq!(config.num_hidden_layers, 5); + assert_eq!(config.block_size, 16); + assert_eq!(config.dflash_config.mask_token_id, 151669); + assert_eq!(config.dflash_config.target_layer_ids, [1, 9, 17, 25, 33]); + assert_eq!(config.hidden_size, 2560); + assert_eq!(config.intermediate_size, 9728); + } +} diff --git a/openinfer-qwen3-4b-dflash/src/executor.rs b/openinfer-qwen3-4b-dflash/src/executor.rs new file mode 100644 index 00000000..de23e08b --- /dev/null +++ b/openinfer-qwen3-4b-dflash/src/executor.rs @@ -0,0 +1,681 @@ +use std::collections::HashMap; +use std::path::Path; +use std::time::{Duration, Instant}; + +use anyhow::Result; +use half::bf16; +use openinfer_core::tensor::{DeviceContext, HiddenStates}; + +use crate::batch_buffers::DFlashBatchBuffers; +use crate::batch_forward::{DFlashBatchInput, DFlashHostBatchInput, copy_hidden}; +use crate::forward::{DFlashDraftCache, DFlashTargetHidden}; +use crate::weights::DFlashDraftModel; + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)] +pub struct DFlashRequestId(pub u64); + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +pub enum DFlashCacheMode { + NoCache, + DraftCache, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +pub struct DFlashBatchKey { + pub q_len: usize, + pub ctx_len: usize, + pub past_len: usize, + pub cache_mode: DFlashCacheMode, +} + +pub struct DFlashDraftRequest { + pub request_id: DFlashRequestId, + pub noise_embedding: HiddenStates, + pub target_hidden: HiddenStates, + pub position_ids: Vec, + pub cache_mode: DFlashCacheMode, +} + +pub struct DFlashDraftHostRequest { + pub request_id: DFlashRequestId, + pub noise_embedding: Vec, + pub target_hidden: Vec, + pub position_ids: Vec, + pub q_len: usize, + pub ctx_len: usize, + pub cache_mode: DFlashCacheMode, +} + +pub struct DFlashDraftResponse { + pub request_id: DFlashRequestId, + pub output: HiddenStates, + pub cache_seq_len: usize, + pub batch_size: usize, + pub elapsed: Duration, +} + +pub struct DFlashDraftHostResponse { + pub request_id: DFlashRequestId, + pub output: Vec, + pub hidden_dim: usize, + pub seq_len: usize, + pub cache_seq_len: usize, + pub batch_size: usize, + pub elapsed: Duration, +} + +pub struct DFlashDraftBatchResponse { + pub request_ids: Vec, + pub output: HiddenStates, + pub cache_seq_lens: Vec, + pub batch_size: usize, + pub q_len: usize, + pub elapsed: Duration, +} + +pub struct DFlashDraftBatchView<'a> { + pub request_ids: Vec, + pub output: &'a HiddenStates, + pub cache_seq_lens: Vec, + pub batch_size: usize, + pub q_len: usize, + pub elapsed: Duration, +} + +pub struct DFlashExecutorOptions { + pub max_batch_size: usize, + pub max_step_context_len: usize, + /// Largest draft length (`q_len`) the executor must serve. Batch buffers + /// are sized once for `max_batch_size × max_q_len`, so every shape at or + /// below it reuses the same allocation (mirrors Qwen3's `BatchDecodeBuffers`). + pub max_q_len: usize, + pub max_seq_len: usize, + /// Upper bound on resident draft caches. Each `DraftCache` request creates + /// a per-request `DFlashDraftCache` (full `ForwardBuffers` + per-layer past + /// K/V); without a cap they accumulate forever and leak GPU memory. + /// Admission fails closed when this is exceeded — callers must `drop_cache` + /// a retired request before submitting a new one. Mirrors Qwen3's per- + /// request block accounting under the fixed `KvCacheManager` pool. + pub max_caches: usize, +} + +impl Default for DFlashExecutorOptions { + fn default() -> Self { + Self { + max_batch_size: 32, + max_step_context_len: 16, + max_q_len: 16, + max_seq_len: 4096, + max_caches: 64, + } + } +} + +pub struct DFlashExecutor { + model: DFlashDraftModel, + options: DFlashExecutorOptions, + /// Single-instance batch buffer, sized for the worst case + /// (`max_batch_size × max_q_len × max_step_context_len`). Each forward + /// narrows the active shape via `set_active_shape` instead of reallocating. + buffers: DFlashBatchBuffers, + caches: HashMap, +} + +impl DFlashExecutor { + pub fn load( + model_path: &Path, + device_ordinal: usize, + options: DFlashExecutorOptions, + ) -> Result { + let model = DFlashDraftModel::load(model_path, device_ordinal)?; + let buffers = model.create_batch_buffers( + options.max_batch_size, + options.max_q_len, + options.max_step_context_len, + )?; + Ok(Self { + model, + options, + buffers, + caches: HashMap::new(), + }) + } + + pub fn model(&self) -> &DFlashDraftModel { + &self.model + } + + pub fn max_batch_size(&self) -> usize { + self.options.max_batch_size + } + + pub fn batch_key(&self, req: &DFlashDraftRequest) -> Result { + let target = DFlashTargetHidden { + concatenated: &req.target_hidden, + }; + let (q_len, ctx_len) = + self.model + .validate_forward_inputs(&req.noise_embedding, &target, &req.position_ids)?; + let past_len = self + .caches + .get(&req.request_id) + .map(DFlashDraftCache::seq_len) + .unwrap_or(0); + Ok(DFlashBatchKey { + q_len, + ctx_len, + past_len, + cache_mode: req.cache_mode, + }) + } + + pub fn host_batch_key(&self, req: &DFlashDraftHostRequest) -> Result { + let config = self.model.config(); + anyhow::ensure!( + req.noise_embedding.len() == req.q_len * config.hidden_size, + "noise_embedding len {} != q_len * hidden_size {}", + req.noise_embedding.len(), + req.q_len * config.hidden_size + ); + anyhow::ensure!( + req.target_hidden.len() + == req.ctx_len * config.hidden_size * config.target_layer_count(), + "target_hidden len {} != ctx_len * target_layer_count * hidden_size {}", + req.target_hidden.len(), + req.ctx_len * config.hidden_size * config.target_layer_count() + ); + anyhow::ensure!( + req.position_ids.len() == req.ctx_len + req.q_len, + "position_ids len {} != ctx_len + q_len {}", + req.position_ids.len(), + req.ctx_len + req.q_len + ); + let past_len = self + .caches + .get(&req.request_id) + .map(DFlashDraftCache::seq_len) + .unwrap_or(0); + Ok(DFlashBatchKey { + q_len: req.q_len, + ctx_len: req.ctx_len, + past_len, + cache_mode: req.cache_mode, + }) + } + + pub fn execute_batch( + &mut self, + requests: Vec, + ) -> Result> { + let batch = self.execute_batch_compact(requests)?; + self.split_compact_response(batch) + } + + pub fn execute_host_batch_compact( + &mut self, + requests: Vec, + ) -> Result { + anyhow::ensure!(!requests.is_empty(), "DFlash host executor batch is empty"); + anyhow::ensure!( + requests.len() <= self.options.max_batch_size, + "DFlash host executor batch size {} exceeds max_batch_size {}", + requests.len(), + self.options.max_batch_size + ); + let key = self.host_batch_key(&requests[0])?; + for req in &requests[1..] { + let req_key = self.host_batch_key(req)?; + anyhow::ensure!( + req_key == key, + "DFlash host executor requires exact-shape batch: first={key:?}, got={req_key:?}" + ); + } + if key.cache_mode == DFlashCacheMode::DraftCache { + return self.execute_cached_host_requests_serial_compact(requests, key); + } + anyhow::ensure!( + key.q_len <= self.options.max_q_len, + "DFlash host q_len {} exceeds executor max_q_len {}", + key.q_len, + self.options.max_q_len + ); + anyhow::ensure!( + key.ctx_len <= self.options.max_step_context_len, + "DFlash host ctx_len {} exceeds executor max_step_context_len {}", + key.ctx_len, + self.options.max_step_context_len + ); + let started = Instant::now(); + let batch_size = requests.len(); + let request_ids = requests + .iter() + .map(|request| request.request_id) + .collect::>(); + let inputs = requests + .iter() + .map(|req| DFlashHostBatchInput { + noise_embedding: &req.noise_embedding, + target_hidden: &req.target_hidden, + position_ids: &req.position_ids, + }) + .collect::>(); + let batch_output = self.model.forward_host_batch(&inputs, &mut self.buffers)?; + self.model.device_context().sync()?; + let elapsed = started.elapsed(); + // forward returns a borrow into self.buffers; materialize an owned copy + // so the next batch can reuse the buffer without aliasing the response. + let output = clone_batch_output(self.model.device_context(), batch_output)?; + Ok(DFlashDraftBatchResponse { + request_ids, + output, + cache_seq_lens: vec![0; batch_size], + batch_size, + q_len: key.q_len, + elapsed, + }) + } + + pub fn execute_host_batch( + &mut self, + requests: Vec, + ) -> Result> { + let batch = self.execute_host_batch_compact(requests)?; + self.split_compact_response(batch) + } + + pub fn execute_host_batch_host( + &mut self, + requests: Vec, + ) -> Result> { + let batch = self.execute_host_batch_compact(requests)?; + self.split_compact_host_response(batch) + } + + pub fn execute_host_batch_view( + &mut self, + requests: Vec, + ) -> Result> { + anyhow::ensure!(!requests.is_empty(), "DFlash host executor batch is empty"); + anyhow::ensure!( + requests.len() <= self.options.max_batch_size, + "DFlash host executor batch size {} exceeds max_batch_size {}", + requests.len(), + self.options.max_batch_size + ); + let key = self.host_batch_key(&requests[0])?; + for req in &requests[1..] { + let req_key = self.host_batch_key(req)?; + anyhow::ensure!( + req_key == key, + "DFlash host executor requires exact-shape batch: first={key:?}, got={req_key:?}" + ); + } + anyhow::ensure!( + key.cache_mode == DFlashCacheMode::NoCache, + "borrowed host batch view currently supports only NoCache mode" + ); + anyhow::ensure!( + key.q_len <= self.options.max_q_len, + "DFlash host q_len {} exceeds executor max_q_len {}", + key.q_len, + self.options.max_q_len + ); + anyhow::ensure!( + key.ctx_len <= self.options.max_step_context_len, + "DFlash host ctx_len {} exceeds executor max_step_context_len {}", + key.ctx_len, + self.options.max_step_context_len + ); + let started = Instant::now(); + let batch_size = requests.len(); + let request_ids = requests + .iter() + .map(|request| request.request_id) + .collect::>(); + let inputs = requests + .iter() + .map(|req| DFlashHostBatchInput { + noise_embedding: &req.noise_embedding, + target_hidden: &req.target_hidden, + position_ids: &req.position_ids, + }) + .collect::>(); + let output = self.model.forward_host_batch(&inputs, &mut self.buffers)?; + self.model.device_context().sync()?; + Ok(DFlashDraftBatchView { + request_ids, + output, + cache_seq_lens: vec![0; batch_size], + batch_size, + q_len: key.q_len, + elapsed: started.elapsed(), + }) + } + + pub fn execute_batch_compact( + &mut self, + requests: Vec, + ) -> Result { + anyhow::ensure!(!requests.is_empty(), "DFlash executor batch is empty"); + anyhow::ensure!( + requests.len() <= self.options.max_batch_size, + "DFlash executor batch size {} exceeds max_batch_size {}", + requests.len(), + self.options.max_batch_size + ); + let key = self.batch_key(&requests[0])?; + for req in &requests[1..] { + let req_key = self.batch_key(req)?; + anyhow::ensure!( + req_key == key, + "DFlash executor requires exact-shape batch: first={key:?}, got={req_key:?}" + ); + } + match key.cache_mode { + DFlashCacheMode::NoCache => self.execute_uncached_batch_compact(requests, key), + DFlashCacheMode::DraftCache => { + self.execute_cached_requests_serial_compact(requests, key) + } + } + } + + pub fn reset_cache(&mut self, request_id: DFlashRequestId) -> Result<()> { + let Some(cache) = self.caches.get_mut(&request_id) else { + anyhow::bail!("unknown DFlash cache request_id {:?}", request_id); + }; + cache.reset(); + Ok(()) + } + + pub fn crop_cache(&mut self, request_id: DFlashRequestId, seq_len: usize) -> Result<()> { + let Some(cache) = self.caches.get_mut(&request_id) else { + anyhow::bail!("unknown DFlash cache request_id {:?}", request_id); + }; + cache.crop(seq_len)?; + Ok(()) + } + + pub fn cache_seq_len(&self, request_id: DFlashRequestId) -> Result { + self.caches + .get(&request_id) + .map(DFlashDraftCache::seq_len) + .ok_or_else(|| anyhow::anyhow!("unknown DFlash cache request_id {:?}", request_id)) + } + + /// Release a request's draft cache. Mirrors Qwen3's `drop_request` + /// (`openinfer-qwen3-4b/src/executor.rs`): remove the entry and let RAII + /// drop the GPU buffers. Idempotent — a missing cache is not an error, so + /// callers can retire a request from any lifecycle state. + pub fn drop_cache(&mut self, request_id: DFlashRequestId) -> Result<()> { + self.caches.remove(&request_id); + Ok(()) + } + + /// Resident cache count, for admission diagnostics. + pub fn cache_count(&self) -> usize { + self.caches.len() + } + + /// Ensure a draft cache exists for `request_id`, enforcing the + /// `max_caches` cap. Existing caches are reused (a re-submitted request + /// keeps its past state). Over-cap admission fails closed. Returns without + /// borrowing the cache so callers can then use disjoint `&self.model` and + /// `&mut self.caches` borrows in the same scope (NLL split borrow). + fn ensure_cache_entry( + &mut self, + request_id: DFlashRequestId, + key: &DFlashBatchKey, + ) -> Result<()> { + if !self.caches.contains_key(&request_id) { + anyhow::ensure!( + self.caches.len() < self.options.max_caches, + "DFlash cache pool full: {} resident caches, max_caches={}; drop_cache a retired request before submitting a new one", + self.caches.len(), + self.options.max_caches, + ); + let cache = self.model.create_draft_cache( + key.q_len, + self.options.max_step_context_len, + self.options.max_seq_len, + )?; + self.caches.insert(request_id, cache); + } + Ok(()) + } + + fn execute_uncached_batch_compact( + &mut self, + requests: Vec, + key: DFlashBatchKey, + ) -> Result { + anyhow::ensure!( + key.q_len <= self.options.max_q_len, + "DFlash q_len {} exceeds executor max_q_len {}", + key.q_len, + self.options.max_q_len + ); + anyhow::ensure!( + key.ctx_len <= self.options.max_step_context_len, + "DFlash ctx_len {} exceeds executor max_step_context_len {}", + key.ctx_len, + self.options.max_step_context_len + ); + let started = Instant::now(); + let batch_size = requests.len(); + let request_ids = requests + .iter() + .map(|request| request.request_id) + .collect::>(); + let inputs = requests + .iter() + .map(|req| DFlashBatchInput { + noise_embedding: &req.noise_embedding, + target_hidden: DFlashTargetHidden { + concatenated: &req.target_hidden, + }, + position_ids: &req.position_ids, + }) + .collect::>(); + let batch_output = self.model.forward_batch(&inputs, &mut self.buffers)?; + self.model.device_context().sync()?; + let elapsed = started.elapsed(); + let output = clone_batch_output(self.model.device_context(), batch_output)?; + Ok(DFlashDraftBatchResponse { + request_ids, + output, + cache_seq_lens: vec![0; batch_size], + batch_size, + q_len: key.q_len, + elapsed, + }) + } + + fn execute_cached_requests_serial_compact( + &mut self, + requests: Vec, + key: DFlashBatchKey, + ) -> Result { + let started = Instant::now(); + let batch_size = requests.len(); + let mut request_ids = Vec::with_capacity(batch_size); + let mut cache_seq_lens = Vec::with_capacity(batch_size); + let mut output = HiddenStates::zeros( + self.model.device_context(), + self.model.config().hidden_size, + batch_size * key.q_len, + )?; + for (i, req) in requests.into_iter().enumerate() { + self.ensure_cache_entry(req.request_id, &key)?; + let cache = self.caches.get_mut(&req.request_id).expect("cache exists"); + self.model.prepare_step_context( + DFlashTargetHidden { + concatenated: &req.target_hidden, + }, + &req.position_ids, + cache, + )?; + let out = self.model.forward_with_draft_cache( + &req.noise_embedding, + &req.position_ids, + cache, + )?; + self.model.device_context().sync()?; + copy_hidden( + self.model.device_context(), + out, + 0, + &mut output, + i * key.q_len, + self.model.config().hidden_size, + key.q_len, + )?; + request_ids.push(req.request_id); + cache_seq_lens.push(cache.seq_len()); + } + Ok(DFlashDraftBatchResponse { + request_ids, + output, + cache_seq_lens, + batch_size, + q_len: key.q_len, + elapsed: started.elapsed(), + }) + } + + fn execute_cached_host_requests_serial_compact( + &mut self, + requests: Vec, + key: DFlashBatchKey, + ) -> Result { + let started = Instant::now(); + let batch_size = requests.len(); + let config = self.model.config(); + let hidden = config.hidden_size; + let target_hidden_dim = config.hidden_size * config.target_layer_count(); + let mut request_ids = Vec::with_capacity(batch_size); + let mut cache_seq_lens = Vec::with_capacity(batch_size); + let mut output = + HiddenStates::zeros(self.model.device_context(), hidden, batch_size * key.q_len)?; + for (i, req) in requests.into_iter().enumerate() { + let noise_embedding = HiddenStates { + data: self + .model + .device_context() + .stream + .clone_htod(&req.noise_embedding)?, + hidden_dim: hidden, + seq_len: key.q_len, + }; + let target_hidden = HiddenStates { + data: self + .model + .device_context() + .stream + .clone_htod(&req.target_hidden)?, + hidden_dim: target_hidden_dim, + seq_len: key.ctx_len, + }; + self.ensure_cache_entry(req.request_id, &key)?; + let cache = self.caches.get_mut(&req.request_id).expect("cache exists"); + self.model.prepare_step_context( + DFlashTargetHidden { + concatenated: &target_hidden, + }, + &req.position_ids, + cache, + )?; + let out = + self.model + .forward_with_draft_cache(&noise_embedding, &req.position_ids, cache)?; + self.model.device_context().sync()?; + copy_hidden( + self.model.device_context(), + out, + 0, + &mut output, + i * key.q_len, + hidden, + key.q_len, + )?; + request_ids.push(req.request_id); + cache_seq_lens.push(cache.seq_len()); + } + Ok(DFlashDraftBatchResponse { + request_ids, + output, + cache_seq_lens, + batch_size, + q_len: key.q_len, + elapsed: started.elapsed(), + }) + } + + fn split_compact_response( + &self, + batch: DFlashDraftBatchResponse, + ) -> Result> { + let mut responses = Vec::with_capacity(batch.batch_size); + for i in 0..batch.batch_size { + let mut output = HiddenStates::zeros( + self.model.device_context(), + self.model.config().hidden_size, + batch.q_len, + )?; + copy_hidden( + self.model.device_context(), + &batch.output, + i * batch.q_len, + &mut output, + 0, + self.model.config().hidden_size, + batch.q_len, + )?; + responses.push(DFlashDraftResponse { + request_id: batch.request_ids[i], + output, + cache_seq_len: batch.cache_seq_lens[i], + batch_size: batch.batch_size, + elapsed: batch.elapsed, + }); + } + Ok(responses) + } + + fn split_compact_host_response( + &self, + batch: DFlashDraftBatchResponse, + ) -> Result> { + let host = self + .model + .device_context() + .stream + .clone_dtoh(&batch.output.data)?; + self.model.device_context().sync()?; + let row_len = batch.output.hidden_dim * batch.q_len; + let mut responses = Vec::with_capacity(batch.batch_size); + for i in 0..batch.batch_size { + responses.push(DFlashDraftHostResponse { + request_id: batch.request_ids[i], + output: host[i * row_len..(i + 1) * row_len].to_vec(), + hidden_dim: batch.output.hidden_dim, + seq_len: batch.q_len, + cache_seq_len: batch.cache_seq_lens[i], + batch_size: batch.batch_size, + elapsed: batch.elapsed, + }); + } + Ok(responses) + } +} + +/// Materialize an owned snapshot of a batch forward's output (a borrow into +/// the single-instance buffer). One allocation + one device-to-device copy of +/// the active region; the next batch may overwrite the buffer immediately. +fn clone_batch_output(ctx: &DeviceContext, src: &HiddenStates) -> Result { + let mut dst = HiddenStates::zeros(ctx, src.hidden_dim, src.seq_len)?; + let len = src.hidden_dim * src.seq_len; + let src_view = src.data.slice(..len); + let mut dst_view = dst.data.slice_mut(..len); + ctx.stream.memcpy_dtod(&src_view, &mut dst_view)?; + Ok(dst) +} diff --git a/openinfer-qwen3-4b-dflash/src/forward.rs b/openinfer-qwen3-4b-dflash/src/forward.rs new file mode 100644 index 00000000..0ac4b07e --- /dev/null +++ b/openinfer-qwen3-4b-dflash/src/forward.rs @@ -0,0 +1,886 @@ +use anyhow::Result; +use cudarc::driver::CudaSlice; +use openinfer_core::ops; +use openinfer_core::tensor::HiddenStates; + +use crate::weights::{DFlashDraftModel, DFlashLayer}; + +pub struct DFlashTargetHidden<'a> { + /// HF reference layout: `[seq_len, target_layer_count * hidden_size]`. + pub concatenated: &'a HiddenStates, +} + +pub struct DFlashDraftCache { + pub(crate) q_len: usize, + pub(crate) state: DFlashDraftState, + pub(crate) step: DFlashStepContext, + pub(crate) scratch: ForwardBuffers, +} + +pub(crate) struct DFlashDraftState { + pub(crate) max_seq_len: usize, + pub(crate) seq_len: usize, + pub(crate) layers: Vec, +} + +pub(crate) struct DFlashStepContext { + pub(crate) max_len: usize, + pub(crate) len: usize, + pub(crate) valid: bool, + pub(crate) layers: Vec, +} + +pub(crate) struct DFlashLayerStepContext { + pub(crate) k_ctx: HiddenStates, + pub(crate) v_ctx: HiddenStates, +} + +pub(crate) struct DFlashLayerPastKv { + pub(crate) k_past: HiddenStates, + pub(crate) v_past: HiddenStates, +} + +pub(crate) struct ForwardBuffers { + pub(crate) hidden_out: HiddenStates, + pub(crate) target_projected: HiddenStates, + pub(crate) target_normed: HiddenStates, + pub(crate) normed: HiddenStates, + pub(crate) q: HiddenStates, + pub(crate) q_ctx_scratch: HiddenStates, + pub(crate) k_ctx: HiddenStates, + pub(crate) k_noise: HiddenStates, + pub(crate) v_ctx: HiddenStates, + pub(crate) v_noise: HiddenStates, + pub(crate) k_all: HiddenStates, + pub(crate) v_all: HiddenStates, + pub(crate) attn_out: HiddenStates, + pub(crate) o_buf: HiddenStates, + pub(crate) gate_up: HiddenStates, + pub(crate) act_out: HiddenStates, + pub(crate) positions_q: CudaSlice, + pub(crate) positions_ctx: CudaSlice, +} + +impl DFlashDraftModel { + pub fn create_draft_cache( + &self, + q_len: usize, + max_step_context_len: usize, + max_seq_len: usize, + ) -> Result { + anyhow::ensure!(q_len > 0, "DFlash scratch requires q_len greater than zero"); + anyhow::ensure!( + max_step_context_len > 0, + "DFlash cache requires max_step_context_len greater than zero" + ); + anyhow::ensure!( + max_seq_len >= max_step_context_len + q_len, + "DFlash cache max_seq_len {} must fit at least one step: context {} + q_len {}", + max_seq_len, + max_step_context_len, + q_len + ); + Ok(DFlashDraftCache { + q_len, + state: DFlashDraftState::new(self, max_seq_len)?, + step: DFlashStepContext::new(self, max_step_context_len)?, + scratch: ForwardBuffers::new(self, q_len, max_step_context_len)?, + }) + } + + pub fn forward( + &self, + noise_embedding: &HiddenStates, + target_hidden: DFlashTargetHidden<'_>, + position_ids: &[i32], + ) -> Result { + let (q_len, ctx_len) = + self.validate_forward_inputs(noise_embedding, &target_hidden, position_ids)?; + let mut bufs = ForwardBuffers::new(self, q_len, ctx_len)?; + self.project_target_hidden(target_hidden, &mut bufs)?; + self.run_forward(noise_embedding, ctx_len, position_ids, &mut bufs)?; + Ok(bufs.normed) + } + + pub fn forward_with_cache<'a>( + &self, + noise_embedding: &HiddenStates, + target_hidden: DFlashTargetHidden<'_>, + position_ids: &[i32], + cache: &'a mut DFlashDraftCache, + ) -> Result<&'a HiddenStates> { + let (q_len, ctx_len) = + self.validate_forward_inputs(noise_embedding, &target_hidden, position_ids)?; + anyhow::ensure!( + cache.q_len == q_len && cache.step.max_len >= ctx_len, + "DFlash cache shape mismatch: cache q_len={}, max_step_context_len={} but input q_len={}, ctx_len={}", + cache.q_len, + cache.step.max_len, + q_len, + ctx_len + ); + cache.reset(); + self.prepare_step_context(target_hidden, position_ids, cache)?; + self.run_forward(noise_embedding, ctx_len, position_ids, &mut cache.scratch)?; + cache.step.valid = false; + Ok(&cache.scratch.normed) + } + + pub fn prepare_step_context( + &self, + target_hidden: DFlashTargetHidden<'_>, + position_ids: &[i32], + cache: &mut DFlashDraftCache, + ) -> Result<()> { + let config = &self.config; + let ctx_len = target_hidden.concatenated.seq_len; + anyhow::ensure!( + ctx_len <= cache.step.max_len, + "DFlash step context length {} exceeds cache capacity {}", + ctx_len, + cache.step.max_len + ); + anyhow::ensure!( + cache.state.seq_len + ctx_len + cache.q_len <= cache.state.max_seq_len, + "DFlash draft cache would exceed capacity: past {} + ctx {} + q {} > {}", + cache.state.seq_len, + ctx_len, + cache.q_len, + cache.state.max_seq_len + ); + anyhow::ensure!( + ctx_len > 0, + "DFlash step context must contain at least one token" + ); + anyhow::ensure!( + position_ids.len() >= ctx_len, + "position_ids len {} < ctx_len {}", + position_ids.len(), + ctx_len + ); + anyhow::ensure!( + target_hidden.concatenated.hidden_dim + == config.target_layer_count() * config.hidden_size, + "target_hidden hidden_dim {} != {}", + target_hidden.concatenated.hidden_dim, + config.target_layer_count() * config.hidden_size + ); + set_step_context_len(&mut cache.scratch, &mut cache.step.layers, ctx_len); + let mut positions_ctx = cache.scratch.positions_ctx.slice_mut(..ctx_len); + self.ctx + .stream + .memcpy_htod(&position_ids[..ctx_len], &mut positions_ctx)?; + + ops::gemm_into_checked( + &self.ctx, + &self.fc, + target_hidden.concatenated, + &mut cache.scratch.target_projected, + )?; + ops::rms_norm_batch_into( + &self.ctx, + &cache.scratch.target_projected, + &self.hidden_norm, + config.rms_norm_eps, + &mut cache.scratch.target_normed, + ); + for (layer, cached) in self.layers.iter().zip(cache.step.layers.iter_mut()) { + ops::gemm_into_checked( + &self.ctx, + &layer.attention.k_proj, + &cache.scratch.target_normed, + &mut cached.k_ctx, + )?; + ops::gemm_into_checked( + &self.ctx, + &layer.attention.v_proj, + &cache.scratch.target_normed, + &mut cached.v_ctx, + )?; + ops::qk_norm_rope_batch_decode_into( + &self.ctx, + &mut cache.scratch.q_ctx_scratch, + &mut cached.k_ctx, + &layer.attention.q_norm, + &layer.attention.k_norm, + &self.cos_cache, + &self.sin_cache, + &cache.scratch.positions_ctx, + config.num_attention_heads, + config.num_key_value_heads, + config.head_dim, + config.rms_norm_eps, + ); + } + cache.step.len = ctx_len; + cache.step.valid = true; + Ok(()) + } + + pub fn forward_with_draft_cache<'a>( + &self, + noise_embedding: &HiddenStates, + position_ids: &[i32], + cache: &'a mut DFlashDraftCache, + ) -> Result<&'a HiddenStates> { + anyhow::ensure!(cache.step.valid, "DFlash step context is not prepared"); + anyhow::ensure!( + noise_embedding.hidden_dim == self.config.hidden_size, + "noise_embedding hidden_dim {} != {}", + noise_embedding.hidden_dim, + self.config.hidden_size + ); + anyhow::ensure!( + noise_embedding.seq_len == cache.q_len, + "noise_embedding q_len {} != scratch q_len {}", + noise_embedding.seq_len, + cache.q_len + ); + anyhow::ensure!( + position_ids.len() == cache.step.len + cache.q_len, + "position_ids len {} != step_context_len + q_len {}", + position_ids.len(), + cache.step.len + cache.q_len + ); + anyhow::ensure!( + cache.state.seq_len + cache.step.len + cache.q_len <= cache.state.max_seq_len, + "DFlash draft cache would exceed capacity: past {} + ctx {} + q {} > {}", + cache.state.seq_len, + cache.step.len, + cache.q_len, + cache.state.max_seq_len + ); + let past_len = cache.state.seq_len; + self.run_forward_with_draft_cache(noise_embedding, past_len, position_ids, cache)?; + cache.step.valid = false; + Ok(&cache.scratch.normed) + } + + pub(crate) fn validate_forward_inputs( + &self, + noise_embedding: &HiddenStates, + target_hidden: &DFlashTargetHidden<'_>, + position_ids: &[i32], + ) -> Result<(usize, usize)> { + let config = &self.config; + anyhow::ensure!( + noise_embedding.hidden_dim == config.hidden_size, + "noise_embedding hidden_dim {} != {}", + noise_embedding.hidden_dim, + config.hidden_size + ); + let ctx_len = target_hidden.concatenated.seq_len; + let q_len = noise_embedding.seq_len; + anyhow::ensure!( + ctx_len > 0, + "DFlash forward requires at least one target-hidden token" + ); + anyhow::ensure!( + q_len > 0, + "DFlash forward requires at least one noise token" + ); + anyhow::ensure!( + target_hidden.concatenated.hidden_dim + == config.target_layer_count() * config.hidden_size, + "target_hidden hidden_dim {} != {}", + target_hidden.concatenated.hidden_dim, + config.target_layer_count() * config.hidden_size + ); + anyhow::ensure!( + position_ids.len() == ctx_len + q_len, + "position_ids len {} != ctx_len + q_len {}", + position_ids.len(), + ctx_len + q_len + ); + Ok((q_len, ctx_len)) + } + + fn project_target_hidden( + &self, + target_hidden: DFlashTargetHidden<'_>, + bufs: &mut ForwardBuffers, + ) -> Result<()> { + let config = &self.config; + ops::gemm_into_checked( + &self.ctx, + &self.fc, + target_hidden.concatenated, + &mut bufs.target_projected, + )?; + ops::rms_norm_batch_into( + &self.ctx, + &bufs.target_projected, + &self.hidden_norm, + config.rms_norm_eps, + &mut bufs.target_normed, + ); + Ok(()) + } + + pub(crate) fn run_forward( + &self, + noise_embedding: &HiddenStates, + ctx_len: usize, + position_ids: &[i32], + bufs: &mut ForwardBuffers, + ) -> Result<()> { + let q_len = noise_embedding.seq_len; + let mut positions_q = bufs.positions_q.slice_mut(..q_len); + self.ctx + .stream + .memcpy_htod(&position_ids[ctx_len..], &mut positions_q)?; + let mut positions_ctx = bufs.positions_ctx.slice_mut(..ctx_len); + self.ctx + .stream + .memcpy_htod(&position_ids[..ctx_len], &mut positions_ctx)?; + + let mut hidden = clone_hidden(&self.ctx, noise_embedding)?; + for layer in &self.layers { + self.forward_layer(layer, &mut hidden, bufs)?; + } + ops::rms_norm_batch_into( + &self.ctx, + &hidden, + &self.norm, + self.config.rms_norm_eps, + &mut bufs.normed, + ); + Ok(()) + } + + fn run_forward_with_draft_cache( + &self, + noise_embedding: &HiddenStates, + past_len: usize, + position_ids: &[i32], + cache: &mut DFlashDraftCache, + ) -> Result<()> { + let ctx_len = cache.step.len; + let q_len = noise_embedding.seq_len; + let total_len = past_len + ctx_len + q_len; + let mut positions_q = cache.scratch.positions_q.slice_mut(..q_len); + self.ctx + .stream + .memcpy_htod(&position_ids[ctx_len..], &mut positions_q)?; + + let mut hidden = clone_hidden(&self.ctx, noise_embedding)?; + for layer_idx in 0..self.layers.len() { + let layer = &self.layers[layer_idx]; + self.forward_layer_with_draft_cache( + layer, + past_len, + total_len, + &cache.step.layers[layer_idx], + &mut cache.state.layers[layer_idx], + &mut hidden, + &mut cache.scratch, + )?; + } + ops::rms_norm_batch_into( + &self.ctx, + &hidden, + &self.norm, + self.config.rms_norm_eps, + &mut cache.scratch.normed, + ); + cache.state.seq_len = total_len; + set_past_seq_len(&mut cache.state.layers, total_len); + Ok(()) + } + + pub(crate) fn forward_layer( + &self, + layer: &DFlashLayer, + hidden: &mut HiddenStates, + bufs: &mut ForwardBuffers, + ) -> Result<()> { + let config = &self.config; + let q_len = hidden.seq_len; + let ctx_len = bufs.target_normed.seq_len; + + ops::rms_norm_batch_into( + &self.ctx, + hidden, + &layer.input_layernorm, + config.rms_norm_eps, + &mut bufs.normed, + ); + + ops::gemm_into_checked( + &self.ctx, + &layer.attention.q_proj, + &bufs.normed, + &mut bufs.q, + )?; + ops::gemm_into_checked( + &self.ctx, + &layer.attention.k_proj, + &bufs.normed, + &mut bufs.k_noise, + )?; + ops::gemm_into_checked( + &self.ctx, + &layer.attention.v_proj, + &bufs.normed, + &mut bufs.v_noise, + )?; + + ops::qk_norm_rope_batch_decode_into( + &self.ctx, + &mut bufs.q, + &mut bufs.k_noise, + &layer.attention.q_norm, + &layer.attention.k_norm, + &self.cos_cache, + &self.sin_cache, + &bufs.positions_q, + config.num_attention_heads, + config.num_key_value_heads, + config.head_dim, + config.rms_norm_eps, + ); + ops::gemm_into_checked( + &self.ctx, + &layer.attention.k_proj, + &bufs.target_normed, + &mut bufs.k_ctx, + )?; + ops::gemm_into_checked( + &self.ctx, + &layer.attention.v_proj, + &bufs.target_normed, + &mut bufs.v_ctx, + )?; + // Normalize and rotate context K with its own positions. Q has already + // been prepared above; q_ctx_scratch only reuses the shared Q/K kernel. + ops::qk_norm_rope_batch_decode_into( + &self.ctx, + &mut bufs.q_ctx_scratch, + &mut bufs.k_ctx, + &layer.attention.q_norm, + &layer.attention.k_norm, + &self.cos_cache, + &self.sin_cache, + &bufs.positions_ctx, + config.num_attention_heads, + config.num_key_value_heads, + config.head_dim, + config.rms_norm_eps, + ); + concat_kv( + &self.ctx, + &bufs.k_ctx, + &bufs.k_noise, + ctx_len, + q_len, + &mut bufs.k_all, + )?; + concat_kv( + &self.ctx, + &bufs.v_ctx, + &bufs.v_noise, + ctx_len, + q_len, + &mut bufs.v_all, + )?; + + ops::single_prefill_nhd_noncausal_into( + &self.ctx, + &bufs.q, + &bufs.k_all, + &bufs.v_all, + &mut bufs.attn_out, + config.num_attention_heads, + config.num_key_value_heads, + config.head_dim, + )?; + ops::gemm_into_checked( + &self.ctx, + &layer.attention.o_proj, + &bufs.attn_out, + &mut bufs.o_buf, + )?; + openinfer_kernels::ops::fused_add_rms_norm_round_batch_into( + &self.ctx, + hidden, + &bufs.o_buf, + &layer.post_attention_layernorm, + config.rms_norm_eps, + &mut bufs.normed, + )?; + + ops::gemm_into_checked( + &self.ctx, + &layer.mlp.gate_up_proj, + &bufs.normed, + &mut bufs.gate_up, + )?; + ops::silu_mul_fused_batch_into(&self.ctx, &bufs.gate_up, &mut bufs.act_out)?; + ops::gemm_into_checked( + &self.ctx, + &layer.mlp.down_proj, + &bufs.act_out, + &mut bufs.o_buf, + )?; + ops::add_batch_into(&self.ctx, hidden, &bufs.o_buf, &mut bufs.hidden_out)?; + std::mem::swap(hidden, &mut bufs.hidden_out); + Ok(()) + } + + fn forward_layer_with_draft_cache( + &self, + layer: &DFlashLayer, + past_len: usize, + total_len: usize, + step_context: &DFlashLayerStepContext, + past: &mut DFlashLayerPastKv, + hidden: &mut HiddenStates, + bufs: &mut ForwardBuffers, + ) -> Result<()> { + let config = &self.config; + let q_len = hidden.seq_len; + let ctx_len = bufs.target_normed.seq_len; + + ops::rms_norm_batch_into( + &self.ctx, + hidden, + &layer.input_layernorm, + config.rms_norm_eps, + &mut bufs.normed, + ); + + ops::gemm_into_checked( + &self.ctx, + &layer.attention.q_proj, + &bufs.normed, + &mut bufs.q, + )?; + ops::gemm_into_checked( + &self.ctx, + &layer.attention.k_proj, + &bufs.normed, + &mut bufs.k_noise, + )?; + ops::gemm_into_checked( + &self.ctx, + &layer.attention.v_proj, + &bufs.normed, + &mut bufs.v_noise, + )?; + + ops::qk_norm_rope_batch_decode_into( + &self.ctx, + &mut bufs.q, + &mut bufs.k_noise, + &layer.attention.q_norm, + &layer.attention.k_norm, + &self.cos_cache, + &self.sin_cache, + &bufs.positions_q, + config.num_attention_heads, + config.num_key_value_heads, + config.head_dim, + config.rms_norm_eps, + ); + + append_kv( + &self.ctx, + &step_context.k_ctx, + &bufs.k_noise, + past_len, + ctx_len, + q_len, + &mut past.k_past, + )?; + append_kv( + &self.ctx, + &step_context.v_ctx, + &bufs.v_noise, + past_len, + ctx_len, + q_len, + &mut past.v_past, + )?; + past.k_past.seq_len = total_len; + past.v_past.seq_len = total_len; + + ops::single_prefill_nhd_noncausal_into( + &self.ctx, + &bufs.q, + &past.k_past, + &past.v_past, + &mut bufs.attn_out, + config.num_attention_heads, + config.num_key_value_heads, + config.head_dim, + )?; + ops::gemm_into_checked( + &self.ctx, + &layer.attention.o_proj, + &bufs.attn_out, + &mut bufs.o_buf, + )?; + openinfer_kernels::ops::fused_add_rms_norm_round_batch_into( + &self.ctx, + hidden, + &bufs.o_buf, + &layer.post_attention_layernorm, + config.rms_norm_eps, + &mut bufs.normed, + )?; + + ops::gemm_into_checked( + &self.ctx, + &layer.mlp.gate_up_proj, + &bufs.normed, + &mut bufs.gate_up, + )?; + ops::silu_mul_fused_batch_into(&self.ctx, &bufs.gate_up, &mut bufs.act_out)?; + ops::gemm_into_checked( + &self.ctx, + &layer.mlp.down_proj, + &bufs.act_out, + &mut bufs.o_buf, + )?; + ops::add_batch_into(&self.ctx, hidden, &bufs.o_buf, &mut bufs.hidden_out)?; + std::mem::swap(hidden, &mut bufs.hidden_out); + Ok(()) + } +} + +impl DFlashDraftCache { + pub fn seq_len(&self) -> usize { + self.state.seq_len + } + + pub fn reset(&mut self) { + self.state.seq_len = 0; + self.step.len = 0; + self.step.valid = false; + set_past_seq_len(&mut self.state.layers, 0); + } + + pub fn crop(&mut self, seq_len: usize) -> Result<()> { + anyhow::ensure!( + seq_len <= self.state.seq_len, + "cannot crop DFlash draft cache from {} to larger length {}", + self.state.seq_len, + seq_len + ); + self.state.seq_len = seq_len; + self.step.valid = false; + self.step.len = 0; + set_past_seq_len(&mut self.state.layers, seq_len); + Ok(()) + } +} + +impl DFlashDraftState { + fn new(model: &DFlashDraftModel, max_seq_len: usize) -> Result { + let config = &model.config; + let kv_dim = config.kv_dim(); + let mut layers = Vec::with_capacity(config.num_hidden_layers); + for _ in 0..config.num_hidden_layers { + layers.push(DFlashLayerPastKv { + k_past: HiddenStates::zeros(&model.ctx, kv_dim, max_seq_len)?, + v_past: HiddenStates::zeros(&model.ctx, kv_dim, max_seq_len)?, + }); + } + Ok(Self { + max_seq_len, + seq_len: 0, + layers, + }) + } +} + +impl DFlashStepContext { + fn new(model: &DFlashDraftModel, max_len: usize) -> Result { + let config = &model.config; + let kv_dim = config.kv_dim(); + let mut layers = Vec::with_capacity(config.num_hidden_layers); + for _ in 0..config.num_hidden_layers { + layers.push(DFlashLayerStepContext { + k_ctx: HiddenStates::zeros(&model.ctx, kv_dim, max_len)?, + v_ctx: HiddenStates::zeros(&model.ctx, kv_dim, max_len)?, + }); + } + Ok(Self { + max_len, + len: 0, + valid: false, + layers, + }) + } +} + +impl ForwardBuffers { + pub(crate) fn new(model: &DFlashDraftModel, q_len: usize, ctx_len: usize) -> Result { + let config = &model.config; + let ctx = &model.ctx; + let hidden = config.hidden_size; + let q_dim = config.q_dim(); + let kv_dim = config.kv_dim(); + Ok(Self { + hidden_out: HiddenStates::zeros(ctx, hidden, q_len)?, + target_projected: HiddenStates::zeros(ctx, hidden, ctx_len)?, + target_normed: HiddenStates::zeros(ctx, hidden, ctx_len)?, + normed: HiddenStates::zeros(ctx, hidden, q_len)?, + q: HiddenStates::zeros(ctx, q_dim, q_len)?, + q_ctx_scratch: HiddenStates::zeros(ctx, q_dim, ctx_len)?, + k_ctx: HiddenStates::zeros(ctx, kv_dim, ctx_len)?, + k_noise: HiddenStates::zeros(ctx, kv_dim, q_len)?, + v_ctx: HiddenStates::zeros(ctx, kv_dim, ctx_len)?, + v_noise: HiddenStates::zeros(ctx, kv_dim, q_len)?, + k_all: HiddenStates::zeros(ctx, kv_dim, ctx_len + q_len)?, + v_all: HiddenStates::zeros(ctx, kv_dim, ctx_len + q_len)?, + attn_out: HiddenStates::zeros(ctx, q_dim, q_len)?, + o_buf: HiddenStates::zeros(ctx, hidden, q_len)?, + gate_up: HiddenStates::zeros(ctx, 2 * config.intermediate_size, q_len)?, + act_out: HiddenStates::zeros(ctx, config.intermediate_size, q_len)?, + positions_q: ctx.stream.alloc_zeros(q_len)?, + positions_ctx: ctx.stream.alloc_zeros(ctx_len)?, + }) + } +} + +pub(crate) fn clone_hidden( + ctx: &openinfer_core::tensor::DeviceContext, + input: &HiddenStates, +) -> Result { + let mut out = HiddenStates::zeros(ctx, input.hidden_dim, input.seq_len)?; + let src = input.data.slice(..input.hidden_dim * input.seq_len); + let mut dst = out.data.slice_mut(..input.hidden_dim * input.seq_len); + ctx.stream.memcpy_dtod(&src, &mut dst)?; + Ok(out) +} + +pub(crate) fn concat_kv( + ctx: &openinfer_core::tensor::DeviceContext, + ctx_part: &HiddenStates, + noise_part: &HiddenStates, + ctx_len: usize, + q_len: usize, + out: &mut HiddenStates, +) -> Result<()> { + debug_assert_eq!(ctx_part.seq_len, ctx_len); + debug_assert_eq!(noise_part.seq_len, q_len); + debug_assert_eq!(ctx_part.hidden_dim, noise_part.hidden_dim); + debug_assert_eq!(out.hidden_dim, ctx_part.hidden_dim); + debug_assert_eq!(out.seq_len, ctx_len + q_len); + let ctx_src = ctx_part.data.slice(..ctx_part.hidden_dim * ctx_len); + let mut ctx_dst = out.data.slice_mut(..ctx_part.hidden_dim * ctx_len); + ctx.stream.memcpy_dtod(&ctx_src, &mut ctx_dst)?; + let noise_src = noise_part.data.slice(..noise_part.hidden_dim * q_len); + let offset = ctx_part.hidden_dim * ctx_len; + let mut noise_dst = out + .data + .slice_mut(offset..offset + noise_part.hidden_dim * q_len); + ctx.stream.memcpy_dtod(&noise_src, &mut noise_dst)?; + Ok(()) +} + +pub(crate) fn append_kv( + ctx: &openinfer_core::tensor::DeviceContext, + ctx_part: &HiddenStates, + noise_part: &HiddenStates, + past_len: usize, + ctx_len: usize, + q_len: usize, + out: &mut HiddenStates, +) -> Result<()> { + debug_assert_eq!(ctx_part.seq_len, ctx_len); + debug_assert_eq!(noise_part.seq_len, q_len); + debug_assert_eq!(ctx_part.hidden_dim, noise_part.hidden_dim); + debug_assert_eq!(out.hidden_dim, ctx_part.hidden_dim); + debug_assert!(past_len + ctx_len + q_len <= out.data.len()); + let ctx_src = ctx_part.data.slice(..ctx_part.hidden_dim * ctx_len); + let ctx_offset = ctx_part.hidden_dim * past_len; + let mut ctx_dst = out + .data + .slice_mut(ctx_offset..ctx_offset + ctx_part.hidden_dim * ctx_len); + ctx.stream.memcpy_dtod(&ctx_src, &mut ctx_dst)?; + let noise_src = noise_part.data.slice(..noise_part.hidden_dim * q_len); + let noise_offset = ctx_part.hidden_dim * (past_len + ctx_len); + let mut noise_dst = out + .data + .slice_mut(noise_offset..noise_offset + noise_part.hidden_dim * q_len); + ctx.stream.memcpy_dtod(&noise_src, &mut noise_dst)?; + Ok(()) +} + +pub(crate) fn set_step_context_len( + bufs: &mut ForwardBuffers, + layers: &mut [DFlashLayerStepContext], + ctx_len: usize, +) { + bufs.target_projected.seq_len = ctx_len; + bufs.target_normed.seq_len = ctx_len; + bufs.q_ctx_scratch.seq_len = ctx_len; + bufs.k_ctx.seq_len = ctx_len; + bufs.v_ctx.seq_len = ctx_len; + for layer in layers { + layer.k_ctx.seq_len = ctx_len; + layer.v_ctx.seq_len = ctx_len; + } +} + +pub(crate) fn set_past_seq_len(layers: &mut [DFlashLayerPastKv], seq_len: usize) { + for layer in layers { + layer.k_past.seq_len = seq_len; + layer.v_past.seq_len = seq_len; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use half::bf16; + use std::path::Path; + + const LOCAL_DFLASH: &str = "/home/hezhaozhao/models/Qwen3-4B-DFlash-b16"; + + #[test] + fn draft_forward_smoke_local_model() { + let path = Path::new(LOCAL_DFLASH); + if !path.exists() { + eprintln!("skipping: {LOCAL_DFLASH} does not exist"); + return; + } + + let model = DFlashDraftModel::load(path, 0).expect("load model"); + let config = model.config(); + let ctx_len = 1; + let q_len = 1; + let noise_host = vec![bf16::ZERO; config.hidden_size * q_len]; + let target_host = + vec![bf16::ZERO; config.hidden_size * config.target_layer_count() * ctx_len]; + let noise_embedding = HiddenStates { + data: model.ctx.stream.clone_htod(&noise_host).expect("noise h2d"), + hidden_dim: config.hidden_size, + seq_len: q_len, + }; + let target_hidden = HiddenStates { + data: model + .ctx + .stream + .clone_htod(&target_host) + .expect("target h2d"), + hidden_dim: config.hidden_size * config.target_layer_count(), + seq_len: ctx_len, + }; + + let out = model + .forward( + &noise_embedding, + DFlashTargetHidden { + concatenated: &target_hidden, + }, + &[0, 1], + ) + .expect("forward"); + model.ctx.sync().expect("sync"); + assert_eq!(out.hidden_dim, config.hidden_size); + assert_eq!(out.seq_len, q_len); + } +} diff --git a/openinfer-qwen3-4b-dflash/src/lib.rs b/openinfer-qwen3-4b-dflash/src/lib.rs new file mode 100644 index 00000000..4b64dcfc --- /dev/null +++ b/openinfer-qwen3-4b-dflash/src/lib.rs @@ -0,0 +1,19 @@ +mod batch_buffers; +mod batch_forward; +mod config; +mod executor; +mod forward; +mod scheduler; +mod weights; + +pub use batch_buffers::DFlashBatchBuffers; +pub use batch_forward::DFlashBatchInput; +pub use config::{DFlashConfig, DFlashInnerConfig}; +pub use executor::{ + DFlashBatchKey, DFlashCacheMode, DFlashDraftBatchResponse, DFlashDraftHostRequest, + DFlashDraftHostResponse, DFlashDraftRequest, DFlashDraftResponse, DFlashExecutor, + DFlashExecutorOptions, DFlashRequestId, +}; +pub use forward::{DFlashDraftCache, DFlashTargetHidden}; +pub use scheduler::{DFlashSchedulerHandle, DFlashSchedulerOptions}; +pub use weights::DFlashDraftModel; diff --git a/openinfer-qwen3-4b-dflash/src/scheduler.rs b/openinfer-qwen3-4b-dflash/src/scheduler.rs new file mode 100644 index 00000000..3eafc84c --- /dev/null +++ b/openinfer-qwen3-4b-dflash/src/scheduler.rs @@ -0,0 +1,485 @@ +use std::collections::VecDeque; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::thread::{self, JoinHandle}; +use std::time::{Duration, Instant}; + +use anyhow::Result; +use crossbeam_channel as channel; + +use crate::executor::{ + DFlashBatchKey, DFlashDraftHostRequest, DFlashDraftHostResponse, DFlashExecutor, + DFlashExecutorOptions, DFlashRequestId, +}; + +pub struct DFlashSchedulerOptions { + pub executor: DFlashExecutorOptions, + pub max_wait: Duration, + pub max_total_tokens: usize, +} + +impl Default for DFlashSchedulerOptions { + fn default() -> Self { + Self { + executor: DFlashExecutorOptions::default(), + max_wait: Duration::from_micros(200), + max_total_tokens: 512, + } + } +} + +/// Handle to the DFlash draft scheduler thread. Mirrors the `EngineHandle` +/// pattern (`openinfer-engine::engine::EngineHandle`): the handle is cheaply +/// cloneable (shared sender), and the last clone's `Drop` closes the channel +/// and joins the scheduler thread, replying "stopped" to any in-flight +/// requests. This prevents leaking the GPU-owner thread when a caller drops +/// the handle without an explicit shutdown. +#[derive(Clone)] +pub struct DFlashSchedulerHandle { + inner: Arc, +} + +struct DFlashSchedulerInner { + submit_tx: Option>, + join_handle: Option>, +} + +impl Drop for DFlashSchedulerInner { + fn drop(&mut self) { + // Drop our sender first; when the last sender goes, the scheduler + // loop's `recv` returns `Err` and the thread flushes pending requests + // via `send_stopped` before exiting (mirrors EngineHandle::Drop in + // openinfer-engine/src/engine.rs). + self.submit_tx.take(); + if let Some(join_handle) = self.join_handle.take() { + // Never join from inside the scheduler thread itself. + if join_handle.thread().id() != thread::current().id() { + let _ = join_handle.join(); + } + } + } +} + +enum SchedulerMessage { + Submit { + request: DFlashDraftHostRequest, + response_tx: channel::Sender>, + }, + ResetCache { + request_id: DFlashRequestId, + response_tx: channel::Sender>, + }, + DropCache { + request_id: DFlashRequestId, + response_tx: channel::Sender>, + }, + CropCache { + request_id: DFlashRequestId, + seq_len: usize, + response_tx: channel::Sender>, + }, + CacheSeqLen { + request_id: DFlashRequestId, + response_tx: channel::Sender>, + }, +} + +struct PendingRequest { + request: DFlashDraftHostRequest, + response_tx: channel::Sender>, + queued_at: Instant, +} + +enum PendingItem { + Submit(PendingRequest), + Control(SchedulerControl), +} + +enum SchedulerControl { + ResetCache { + request_id: DFlashRequestId, + response_tx: channel::Sender>, + }, + DropCache { + request_id: DFlashRequestId, + response_tx: channel::Sender>, + }, + CropCache { + request_id: DFlashRequestId, + seq_len: usize, + response_tx: channel::Sender>, + }, + CacheSeqLen { + request_id: DFlashRequestId, + response_tx: channel::Sender>, + }, +} + +impl DFlashSchedulerHandle { + pub fn start( + model_path: &Path, + device_ordinal: usize, + options: DFlashSchedulerOptions, + ) -> Result { + let (submit_tx, submit_rx) = channel::unbounded(); + let (init_tx, init_rx) = channel::bounded(1); + let model_path = PathBuf::from(model_path); + let max_wait = options.max_wait; + let max_total_tokens = options.max_total_tokens; + let join_handle = thread::Builder::new() + .name("qwen3-dflash-scheduler".into()) + .spawn(move || { + let mut executor = + match DFlashExecutor::load(&model_path, device_ordinal, options.executor) { + Ok(executor) => executor, + Err(err) => { + let _ = init_tx.send(Err(err)); + return; + } + }; + let _ = init_tx.send(Ok(())); + scheduler_loop(&mut executor, submit_rx, max_wait, max_total_tokens); + }) + .expect("failed to spawn DFlash scheduler thread"); + init_rx + .recv() + .map_err(|_| anyhow::anyhow!("DFlash scheduler initialization channel closed"))??; + Ok(Self { + inner: Arc::new(DFlashSchedulerInner { + submit_tx: Some(submit_tx), + join_handle: Some(join_handle), + }), + }) + } + + fn submit_tx(&self) -> Result<&channel::Sender> { + self.inner + .submit_tx + .as_ref() + .ok_or_else(|| anyhow::anyhow!("DFlash scheduler is closed")) + } + + pub fn submit(&self, request: DFlashDraftHostRequest) -> Result { + let (response_tx, response_rx) = channel::bounded(1); + self.submit_tx()? + .send(SchedulerMessage::Submit { + request, + response_tx, + }) + .map_err(|_| anyhow::anyhow!("DFlash scheduler is closed"))?; + response_rx + .recv() + .map_err(|_| anyhow::anyhow!("DFlash scheduler response channel closed"))? + } + + pub fn reset_cache(&self, request_id: DFlashRequestId) -> Result<()> { + let (response_tx, response_rx) = channel::bounded(1); + self.submit_tx()? + .send(SchedulerMessage::ResetCache { + request_id, + response_tx, + }) + .map_err(|_| anyhow::anyhow!("DFlash scheduler is closed"))?; + response_rx + .recv() + .map_err(|_| anyhow::anyhow!("DFlash scheduler response channel closed"))? + } + + /// Release a request's draft cache and reclaim its GPU buffers. Mirrors + /// Qwen3's `drop_request`: the executor removes the cache entry and RAII + /// frees the per-layer past K/V + scratch. Idempotent — retiring a + /// request that never created a cache is not an error. Callers should + /// invoke this once a draft request is verified or abandoned so the + /// `max_caches` pool does not fill with dead entries. + pub fn drop_cache(&self, request_id: DFlashRequestId) -> Result<()> { + let (response_tx, response_rx) = channel::bounded(1); + self.submit_tx()? + .send(SchedulerMessage::DropCache { + request_id, + response_tx, + }) + .map_err(|_| anyhow::anyhow!("DFlash scheduler is closed"))?; + response_rx + .recv() + .map_err(|_| anyhow::anyhow!("DFlash scheduler response channel closed"))? + } + + pub fn crop_cache(&self, request_id: DFlashRequestId, seq_len: usize) -> Result<()> { + let (response_tx, response_rx) = channel::bounded(1); + self.submit_tx()? + .send(SchedulerMessage::CropCache { + request_id, + seq_len, + response_tx, + }) + .map_err(|_| anyhow::anyhow!("DFlash scheduler is closed"))?; + response_rx + .recv() + .map_err(|_| anyhow::anyhow!("DFlash scheduler response channel closed"))? + } + + pub fn cache_seq_len(&self, request_id: DFlashRequestId) -> Result { + let (response_tx, response_rx) = channel::bounded(1); + self.submit_tx()? + .send(SchedulerMessage::CacheSeqLen { + request_id, + response_tx, + }) + .map_err(|_| anyhow::anyhow!("DFlash scheduler is closed"))?; + response_rx + .recv() + .map_err(|_| anyhow::anyhow!("DFlash scheduler response channel closed"))? + } +} + +fn scheduler_loop( + executor: &mut DFlashExecutor, + submit_rx: channel::Receiver, + max_wait: Duration, + max_total_tokens: usize, +) { + let mut pending: VecDeque = VecDeque::new(); + loop { + if pending.is_empty() { + match submit_rx.recv() { + Ok(msg) => handle_message_or_enqueue(msg, &mut pending), + Err(_) => break, + } + } + while let Ok(msg) = submit_rx.try_recv() { + handle_message_or_enqueue(msg, &mut pending); + } + if pending.is_empty() { + continue; + } + let head_wait = pending + .front() + .and_then(PendingItem::queued_elapsed) + .unwrap_or(max_wait); + if pending.len() == 1 && head_wait < max_wait { + let timeout = max_wait - head_wait; + if let Ok(msg) = submit_rx.recv_timeout(timeout) { + handle_message_or_enqueue(msg, &mut pending); + while let Ok(msg) = submit_rx.try_recv() { + handle_message_or_enqueue(msg, &mut pending); + } + } + } + drain_one_batch(executor, &mut pending, max_total_tokens); + } + for pending in pending { + pending.send_stopped(); + } +} + +fn handle_message_or_enqueue(msg: SchedulerMessage, pending: &mut VecDeque) { + match msg { + SchedulerMessage::Submit { + request, + response_tx, + } => pending.push_back(PendingItem::Submit(PendingRequest { + request, + response_tx, + queued_at: Instant::now(), + })), + SchedulerMessage::ResetCache { + request_id, + response_tx, + } => pending.push_back(PendingItem::Control(SchedulerControl::ResetCache { + request_id, + response_tx, + })), + SchedulerMessage::DropCache { + request_id, + response_tx, + } => pending.push_back(PendingItem::Control(SchedulerControl::DropCache { + request_id, + response_tx, + })), + SchedulerMessage::CropCache { + request_id, + seq_len, + response_tx, + } => pending.push_back(PendingItem::Control(SchedulerControl::CropCache { + request_id, + seq_len, + response_tx, + })), + SchedulerMessage::CacheSeqLen { + request_id, + response_tx, + } => pending.push_back(PendingItem::Control(SchedulerControl::CacheSeqLen { + request_id, + response_tx, + })), + } +} + +fn drain_one_batch( + executor: &mut DFlashExecutor, + pending: &mut VecDeque, + max_total_tokens: usize, +) { + let Some(first) = pending.pop_front() else { + return; + }; + let PendingItem::Submit(first) = first else { + if let PendingItem::Control(control) = first { + control.execute(executor); + } + return; + }; + let key = match executor.host_batch_key(&first.request) { + Ok(key) => key, + Err(err) => { + let _ = first.response_tx.send(Err(err)); + return; + } + }; + let max_batch_size = executor_max_batch_size(executor); + let mut batch = vec![first]; + let mut total_tokens = key.q_len + key.ctx_len + key.past_len; + if total_tokens > max_total_tokens { + let err = anyhow::anyhow!( + "DFlash scheduler request total tokens {} exceeds max_total_tokens {}", + total_tokens, + max_total_tokens + ); + let first = batch.pop().expect("first request exists"); + let _ = first.response_tx.send(Err(err)); + return; + } + let mut i = 0; + while i < pending.len() && batch.len() < max_batch_size { + if !matches!(pending.get(i), Some(PendingItem::Submit(_))) { + break; + } + let matches = pending + .get(i) + .map(|candidate| { + let PendingItem::Submit(candidate) = candidate else { + return false; + }; + request_matches_key( + executor, + &candidate.request, + key, + total_tokens, + max_total_tokens, + ) + }) + .unwrap_or(false); + if matches { + total_tokens += key.q_len + key.ctx_len + key.past_len; + match pending.remove(i).expect("pending index exists") { + PendingItem::Submit(request) => batch.push(request), + PendingItem::Control(_) => unreachable!("control items are batch barriers"), + } + } else { + i += 1; + } + } + let response_txs = batch + .iter() + .map(|req| req.response_tx.clone()) + .collect::>(); + let requests = batch.into_iter().map(|pending| pending.request).collect(); + match executor.execute_host_batch_host(requests) { + Ok(responses) => { + for (response_tx, response) in response_txs.into_iter().zip(responses.into_iter()) { + let _ = response_tx.send(Ok(response)); + } + } + Err(err) => { + let message = err.to_string(); + for response_tx in response_txs { + let _ = response_tx.send(Err(anyhow::anyhow!(message.clone()))); + } + } + } +} + +fn request_matches_key( + executor: &DFlashExecutor, + request: &DFlashDraftHostRequest, + key: DFlashBatchKey, + current_total_tokens: usize, + max_total_tokens: usize, +) -> bool { + executor + .host_batch_key(request) + .map(|candidate| { + let candidate_tokens = candidate.q_len + candidate.ctx_len + candidate.past_len; + candidate == key && current_total_tokens + candidate_tokens <= max_total_tokens + }) + .unwrap_or(false) +} + +fn executor_max_batch_size(executor: &DFlashExecutor) -> usize { + executor.max_batch_size() +} + +impl PendingItem { + fn queued_elapsed(&self) -> Option { + match self { + PendingItem::Submit(request) => Some(request.queued_at.elapsed()), + PendingItem::Control(_) => None, + } + } + + fn send_stopped(self) { + match self { + PendingItem::Submit(request) => { + let _ = request + .response_tx + .send(Err(anyhow::anyhow!("DFlash scheduler stopped"))); + } + PendingItem::Control(control) => control.send_stopped(), + } + } +} + +impl SchedulerControl { + fn execute(self, executor: &mut DFlashExecutor) { + match self { + SchedulerControl::ResetCache { + request_id, + response_tx, + } => { + let _ = response_tx.send(executor.reset_cache(request_id)); + } + SchedulerControl::DropCache { + request_id, + response_tx, + } => { + let _ = response_tx.send(executor.drop_cache(request_id)); + } + SchedulerControl::CropCache { + request_id, + seq_len, + response_tx, + } => { + let _ = response_tx.send(executor.crop_cache(request_id, seq_len)); + } + SchedulerControl::CacheSeqLen { + request_id, + response_tx, + } => { + let _ = response_tx.send(executor.cache_seq_len(request_id)); + } + } + } + + fn send_stopped(self) { + match self { + SchedulerControl::ResetCache { response_tx, .. } + | SchedulerControl::DropCache { response_tx, .. } + | SchedulerControl::CropCache { response_tx, .. } => { + let _ = response_tx.send(Err(anyhow::anyhow!("DFlash scheduler stopped"))); + } + SchedulerControl::CacheSeqLen { response_tx, .. } => { + let _ = response_tx.send(Err(anyhow::anyhow!("DFlash scheduler stopped"))); + } + } + } +} diff --git a/openinfer-qwen3-4b-dflash/src/weights.rs b/openinfer-qwen3-4b-dflash/src/weights.rs new file mode 100644 index 00000000..c0f4f7aa --- /dev/null +++ b/openinfer-qwen3-4b-dflash/src/weights.rs @@ -0,0 +1,274 @@ +use anyhow::{Context, Result, bail}; +use log::info; +use openinfer_core::tensor::{DeviceContext, DeviceMatrix, DeviceVec}; +use openinfer_core::weight_loader::{ + deserialize_shards, load_shard_info, load_tensor_1d, load_tensor_2d, mmap_shards, + precompute_rope, +}; +use std::collections::HashMap; +use std::path::Path; + +use crate::config::DFlashConfig; + +pub(crate) struct DFlashAttention { + pub(crate) q_proj: DeviceMatrix, + pub(crate) k_proj: DeviceMatrix, + pub(crate) v_proj: DeviceMatrix, + pub(crate) o_proj: DeviceMatrix, + pub(crate) q_norm: DeviceVec, + pub(crate) k_norm: DeviceVec, +} + +pub(crate) struct DFlashMlp { + pub(crate) gate_up_proj: DeviceMatrix, + pub(crate) down_proj: DeviceMatrix, +} + +pub(crate) struct DFlashLayer { + pub(crate) input_layernorm: DeviceVec, + pub(crate) attention: DFlashAttention, + pub(crate) post_attention_layernorm: DeviceVec, + pub(crate) mlp: DFlashMlp, +} + +pub struct DFlashDraftModel { + pub(crate) ctx: DeviceContext, + pub(crate) config: DFlashConfig, + pub(crate) layers: Vec, + pub(crate) fc: DeviceMatrix, + pub(crate) hidden_norm: DeviceVec, + pub(crate) norm: DeviceVec, + pub(crate) cos_cache: DeviceVec, + pub(crate) sin_cache: DeviceVec, +} + +// SAFETY: The model owns one CUDA context/stream and is intended to run on one +// worker thread at a time, matching other OpenInfer model structs. +unsafe impl Send for DFlashDraftModel {} +unsafe impl Sync for DFlashDraftModel {} + +impl DFlashDraftModel { + pub fn load(model_path: &Path, device_ordinal: usize) -> Result { + info!( + "Loading Qwen3-4B DFlash draft model from {}", + model_path.display() + ); + let ctx = DeviceContext::new_with_device(device_ordinal)?; + let config = DFlashConfig::from_model_dir(model_path)?; + let model_path_str = model_path + .to_str() + .ok_or_else(|| anyhow::anyhow!("DFlash model path must be valid UTF-8"))?; + let (shard_paths, weight_map) = load_shard_info(model_path_str)?; + let mmaps = mmap_shards(&shard_paths)?; + let shards = deserialize_shards(&mmaps)?; + + let fc = load_tensor_2d(&ctx, &shards, &weight_map, "fc.weight") + .context("load DFlash fc.weight")?; + ensure_matrix_shape( + "fc.weight", + &fc, + config.hidden_size, + config.hidden_size * config.target_layer_count(), + )?; + let hidden_norm = load_tensor_1d(&ctx, &shards, &weight_map, "hidden_norm.weight")?; + let norm = load_tensor_1d(&ctx, &shards, &weight_map, "norm.weight")?; + ensure_vec_len("hidden_norm.weight", &hidden_norm, config.hidden_size)?; + ensure_vec_len("norm.weight", &norm, config.hidden_size)?; + + let mut layers = Vec::with_capacity(config.num_hidden_layers); + for layer_idx in 0..config.num_hidden_layers { + layers.push(load_layer(&ctx, &shards, &weight_map, &config, layer_idx)?); + } + let (cos_cache, sin_cache) = precompute_rope( + &ctx, + config.head_dim, + config.max_position_embeddings, + config.rope_theta, + )?; + + Ok(Self { + ctx, + config, + layers, + fc, + hidden_norm, + norm, + cos_cache, + sin_cache, + }) + } + + pub fn config(&self) -> &DFlashConfig { + &self.config + } + + pub fn target_layer_ids(&self) -> &[usize] { + &self.config.dflash_config.target_layer_ids + } + + pub fn mask_token_id(&self) -> u32 { + self.config.dflash_config.mask_token_id + } + + pub fn device_context(&self) -> &DeviceContext { + &self.ctx + } +} + +fn load_layer( + ctx: &DeviceContext, + shards: &[safetensors::SafeTensors<'_>], + weight_map: &HashMap, + config: &DFlashConfig, + layer_idx: usize, +) -> Result { + let prefix = format!("layers.{layer_idx}"); + let q_proj = load_tensor_2d( + ctx, + shards, + weight_map, + &format!("{prefix}.self_attn.q_proj.weight"), + )?; + let k_proj = load_tensor_2d( + ctx, + shards, + weight_map, + &format!("{prefix}.self_attn.k_proj.weight"), + )?; + let v_proj = load_tensor_2d( + ctx, + shards, + weight_map, + &format!("{prefix}.self_attn.v_proj.weight"), + )?; + let o_proj = load_tensor_2d( + ctx, + shards, + weight_map, + &format!("{prefix}.self_attn.o_proj.weight"), + )?; + ensure_matrix_shape("q_proj", &q_proj, config.q_dim(), config.hidden_size)?; + ensure_matrix_shape("k_proj", &k_proj, config.kv_dim(), config.hidden_size)?; + ensure_matrix_shape("v_proj", &v_proj, config.kv_dim(), config.hidden_size)?; + ensure_matrix_shape("o_proj", &o_proj, config.hidden_size, config.q_dim())?; + + let gate_proj = load_tensor_2d( + ctx, + shards, + weight_map, + &format!("{prefix}.mlp.gate_proj.weight"), + )?; + let up_proj = load_tensor_2d( + ctx, + shards, + weight_map, + &format!("{prefix}.mlp.up_proj.weight"), + )?; + let gate_up_proj = DeviceMatrix::vstack(ctx, &[&gate_proj, &up_proj])?; + let down_proj = load_tensor_2d( + ctx, + shards, + weight_map, + &format!("{prefix}.mlp.down_proj.weight"), + )?; + ensure_matrix_shape( + "gate_up_proj", + &gate_up_proj, + 2 * config.intermediate_size, + config.hidden_size, + )?; + ensure_matrix_shape( + "down_proj", + &down_proj, + config.hidden_size, + config.intermediate_size, + )?; + + let input_layernorm = load_tensor_1d( + ctx, + shards, + weight_map, + &format!("{prefix}.input_layernorm.weight"), + )?; + let post_attention_layernorm = load_tensor_1d( + ctx, + shards, + weight_map, + &format!("{prefix}.post_attention_layernorm.weight"), + )?; + let q_norm = load_tensor_1d( + ctx, + shards, + weight_map, + &format!("{prefix}.self_attn.q_norm.weight"), + )?; + let k_norm = load_tensor_1d( + ctx, + shards, + weight_map, + &format!("{prefix}.self_attn.k_norm.weight"), + )?; + ensure_vec_len("input_layernorm", &input_layernorm, config.hidden_size)?; + ensure_vec_len( + "post_attention_layernorm", + &post_attention_layernorm, + config.hidden_size, + )?; + ensure_vec_len("q_norm", &q_norm, config.head_dim)?; + ensure_vec_len("k_norm", &k_norm, config.head_dim)?; + + Ok(DFlashLayer { + input_layernorm, + attention: DFlashAttention { + q_proj, + k_proj, + v_proj, + o_proj, + q_norm, + k_norm, + }, + post_attention_layernorm, + mlp: DFlashMlp { + gate_up_proj, + down_proj, + }, + }) +} + +fn ensure_matrix_shape(name: &str, matrix: &DeviceMatrix, rows: usize, cols: usize) -> Result<()> { + if matrix.rows != rows || matrix.cols != cols { + bail!( + "{name} shape mismatch: expected [{rows}, {cols}], got [{}, {}]", + matrix.rows, + matrix.cols + ); + } + Ok(()) +} + +fn ensure_vec_len(name: &str, vector: &DeviceVec, len: usize) -> Result<()> { + if vector.len != len { + bail!("{name} length mismatch: expected {len}, got {}", vector.len); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + const LOCAL_DFLASH: &str = "/home/hezhaozhao/models/Qwen3-4B-DFlash-b16"; + + #[test] + fn loads_local_dflash_weights() { + let path = Path::new(LOCAL_DFLASH); + if !path.exists() { + eprintln!("skipping: {LOCAL_DFLASH} does not exist"); + return; + } + let model = DFlashDraftModel::load(path, 0).expect("load model"); + assert_eq!(model.layers.len(), 5); + assert_eq!(model.fc.rows, 2560); + assert_eq!(model.fc.cols, 12800); + } +} diff --git a/openinfer-qwen3-4b-dflash/tests/hf_golden_gate.rs b/openinfer-qwen3-4b-dflash/tests/hf_golden_gate.rs new file mode 100644 index 00000000..54ee0903 --- /dev/null +++ b/openinfer-qwen3-4b-dflash/tests/hf_golden_gate.rs @@ -0,0 +1,810 @@ +//! HuggingFace remote-code golden gate for the standalone Qwen3-4B-DFlash draft. +//! +//! The fixture is generated by: +//! +//! ```ignore +//! .venv/bin/python tools/accuracy/dump_qwen3_4b_dflash_hf_golden.py \ +//! --model-path /home/hezhaozhao/models/Qwen3-4B-DFlash-b16 \ +//! --out test_data/qwen3-4b-dflash-hf-golden.safetensors +//! ``` + +use std::path::{Path, PathBuf}; +use std::sync::{Arc, Barrier}; + +use half::bf16; +use openinfer_core::tensor::HiddenStates; +use openinfer_qwen3_4b_dflash::{ + DFlashBatchInput, DFlashCacheMode, DFlashDraftHostRequest, DFlashDraftModel, + DFlashDraftRequest, DFlashExecutor, DFlashExecutorOptions, DFlashRequestId, + DFlashSchedulerHandle, DFlashSchedulerOptions, DFlashTargetHidden, +}; +use safetensors::{Dtype, SafeTensors}; + +const LOCAL_DFLASH: &str = "/home/hezhaozhao/models/Qwen3-4B-DFlash-b16"; +const GOLDEN: &str = concat!( + env!("CARGO_MANIFEST_DIR"), + "/../test_data/qwen3-4b-dflash-hf-golden.safetensors" +); + +const MEAN_TOL: f32 = 0.12; +const P99_TOL: f32 = 0.35; + +#[test] +fn dflash_forward_matches_hf_remote_code() { + let Some(model_path) = model_path_or_skip("dflash golden gate") else { + return; + }; + let golden_path = Path::new(GOLDEN); + if !golden_path.exists() { + eprintln!("skipping dflash golden gate: {GOLDEN} does not exist"); + return; + } + + let bytes = std::fs::read(golden_path).expect("read golden"); + let st = SafeTensors::deserialize(&bytes).expect("parse golden"); + let model = DFlashDraftModel::load(&model_path, 0).expect("load dflash"); + let config = model.config(); + let ctx = model.device_context(); + + let noise = bf16_tensor(&st, "noise_embedding", &[1, 3, config.hidden_size]); + let target_hidden = bf16_tensor( + &st, + "target_hidden", + &[1, 2, config.hidden_size * config.target_layer_count()], + ); + let expected = bf16_tensor(&st, "output", &[1, 3, config.hidden_size]); + let positions = i32_tensor(&st, "position_ids", &[1, 5]); + + let noise_embedding = HiddenStates { + data: ctx.stream.clone_htod(&noise).expect("noise h2d"), + hidden_dim: config.hidden_size, + seq_len: 3, + }; + let target_hidden = HiddenStates { + data: ctx.stream.clone_htod(&target_hidden).expect("target h2d"), + hidden_dim: config.hidden_size * config.target_layer_count(), + seq_len: 2, + }; + let uncached = model + .forward( + &noise_embedding, + DFlashTargetHidden { + concatenated: &target_hidden, + }, + &positions, + ) + .expect("forward"); + ctx.sync().expect("sync"); + let uncached = ctx.stream.clone_dtoh(&uncached.data).expect("output d2h"); + ctx.sync().expect("sync"); + assert_deltas("dflash HF golden deltas", &uncached, &expected); + + let mut cache = model + .create_draft_cache(3, 2, 8) + .expect("create draft cache"); + let cached_one_shot = model + .forward_with_cache( + &noise_embedding, + DFlashTargetHidden { + concatenated: &target_hidden, + }, + &positions, + &mut cache, + ) + .expect("cached one-shot forward"); + ctx.sync().expect("sync"); + let cached_one_shot = ctx + .stream + .clone_dtoh(&cached_one_shot.data) + .expect("output d2h"); + ctx.sync().expect("sync"); + assert_deltas( + "dflash unified-cache one-shot HF golden deltas", + &cached_one_shot, + &expected, + ); + + cache.reset(); + model + .prepare_step_context( + DFlashTargetHidden { + concatenated: &target_hidden, + }, + &positions, + &mut cache, + ) + .expect("prepare step context"); + let cached = model + .forward_with_draft_cache(&noise_embedding, &positions, &mut cache) + .expect("cached forward"); + ctx.sync().expect("sync"); + let cached = ctx.stream.clone_dtoh(&cached.data).expect("output d2h"); + ctx.sync().expect("sync"); + assert_deltas("dflash draft-cache HF golden deltas", &cached, &expected); + assert_eq!(cache.seq_len(), 5); + cache.crop(2).expect("crop draft cache"); + assert_eq!(cache.seq_len(), 2); +} + +#[test] +fn dflash_batched_forward_matches_single_forward() { + let Some(model_path) = model_path_or_skip("dflash batch gate") else { + return; + }; + let golden_path = Path::new(GOLDEN); + if !golden_path.exists() { + eprintln!("skipping dflash batch gate: {GOLDEN} does not exist"); + return; + } + + let bytes = std::fs::read(golden_path).expect("read golden"); + let st = SafeTensors::deserialize(&bytes).expect("parse golden"); + let model = DFlashDraftModel::load(&model_path, 0).expect("load dflash"); + let config = model.config(); + let ctx = model.device_context(); + + let noise0 = bf16_tensor(&st, "noise_embedding", &[1, 3, config.hidden_size]); + let target0 = bf16_tensor( + &st, + "target_hidden", + &[1, 2, config.hidden_size * config.target_layer_count()], + ); + let positions0 = i32_tensor(&st, "position_ids", &[1, 5]); + let mut noise1 = noise0.clone(); + for (i, value) in noise1.iter_mut().enumerate() { + if i % 13 == 0 { + *value = bf16::from_f32(value.to_f32() + 0.015625); + } + } + let mut target1 = target0.clone(); + for (i, value) in target1.iter_mut().enumerate() { + if i % 31 == 0 { + *value = bf16::from_f32(value.to_f32() - 0.03125); + } + } + let mut positions1 = positions0.clone(); + for value in &mut positions1 { + *value += 2; + } + let noise_a = HiddenStates { + data: ctx.stream.clone_htod(&noise0).expect("noise h2d"), + hidden_dim: config.hidden_size, + seq_len: 3, + }; + let target_a = HiddenStates { + data: ctx.stream.clone_htod(&target0).expect("target h2d"), + hidden_dim: config.hidden_size * config.target_layer_count(), + seq_len: 2, + }; + let noise_b = HiddenStates { + data: ctx.stream.clone_htod(&noise1).expect("noise h2d"), + hidden_dim: config.hidden_size, + seq_len: 3, + }; + let target_b = HiddenStates { + data: ctx.stream.clone_htod(&target1).expect("target h2d"), + hidden_dim: config.hidden_size * config.target_layer_count(), + seq_len: 2, + }; + + let single = model + .forward( + &noise_a, + DFlashTargetHidden { + concatenated: &target_a, + }, + &positions0, + ) + .expect("single forward"); + ctx.sync().expect("sync"); + let single = ctx.stream.clone_dtoh(&single.data).expect("single d2h"); + let single_row1 = model + .forward( + &noise_b, + DFlashTargetHidden { + concatenated: &target_b, + }, + &positions1, + ) + .expect("single row1 forward"); + ctx.sync().expect("sync"); + let single_row1 = ctx + .stream + .clone_dtoh(&single_row1.data) + .expect("single row1 d2h"); + + let mut bufs = model.create_batch_buffers(2, 3, 2).expect("batch buffers"); + let batch = model + .forward_batch( + &[ + DFlashBatchInput { + noise_embedding: &noise_a, + target_hidden: DFlashTargetHidden { + concatenated: &target_a, + }, + position_ids: &positions0, + }, + DFlashBatchInput { + noise_embedding: &noise_b, + target_hidden: DFlashTargetHidden { + concatenated: &target_b, + }, + position_ids: &positions1, + }, + ], + &mut bufs, + ) + .expect("batch forward"); + ctx.sync().expect("sync"); + let batch = ctx.stream.clone_dtoh(&batch.data).expect("batch d2h"); + let row_len = config.hidden_size * 3; + assert_deltas("dflash batch row0 vs single", &batch[..row_len], &single); + assert_deltas( + "dflash batch row1 vs single", + &batch[row_len..2 * row_len], + &single_row1, + ); +} + +#[test] +fn dflash_executor_returns_request_tagged_batch_outputs() { + let Some(model_path) = model_path_or_skip("dflash executor gate") else { + return; + }; + let golden_path = Path::new(GOLDEN); + if !golden_path.exists() { + eprintln!("skipping dflash executor gate: {GOLDEN} does not exist"); + return; + } + + let bytes = std::fs::read(golden_path).expect("read golden"); + let st = SafeTensors::deserialize(&bytes).expect("parse golden"); + let mut executor = DFlashExecutor::load( + &model_path, + 0, + DFlashExecutorOptions { + max_batch_size: 2, + max_step_context_len: 2, + max_q_len: 3, + max_seq_len: 8, + max_caches: 8, + }, + ) + .expect("load executor"); + let hidden_size = executor.model().config().hidden_size; + let target_layer_count = executor.model().config().target_layer_count(); + let ctx = executor.model().device_context(); + let noise = bf16_tensor(&st, "noise_embedding", &[1, 3, hidden_size]); + let target = bf16_tensor( + &st, + "target_hidden", + &[1, 2, hidden_size * target_layer_count], + ); + let positions = i32_tensor(&st, "position_ids", &[1, 5]); + let mk_req = |request_id| DFlashDraftRequest { + request_id: DFlashRequestId(request_id), + noise_embedding: HiddenStates { + data: ctx.stream.clone_htod(&noise).expect("noise h2d"), + hidden_dim: hidden_size, + seq_len: 3, + }, + target_hidden: HiddenStates { + data: ctx.stream.clone_htod(&target).expect("target h2d"), + hidden_dim: hidden_size * target_layer_count, + seq_len: 2, + }, + position_ids: positions.clone(), + cache_mode: DFlashCacheMode::NoCache, + }; + let responses = executor + .execute_batch(vec![mk_req(7), mk_req(8)]) + .expect("execute batch"); + assert_eq!(responses.len(), 2); + assert_eq!(responses[0].request_id, DFlashRequestId(7)); + assert_eq!(responses[1].request_id, DFlashRequestId(8)); + assert_eq!(responses[0].output.hidden_dim, hidden_size); + assert_eq!(responses[0].output.seq_len, 3); + assert_eq!(responses[0].batch_size, 2); +} + +#[test] +fn dflash_scheduler_accepts_host_requests() { + let Some(model_path) = model_path_or_skip("dflash scheduler gate") else { + return; + }; + let golden_path = Path::new(GOLDEN); + if !golden_path.exists() { + eprintln!("skipping dflash scheduler gate: {GOLDEN} does not exist"); + return; + } + + let bytes = std::fs::read(golden_path).expect("read golden"); + let st = SafeTensors::deserialize(&bytes).expect("parse golden"); + let config = + openinfer_qwen3_4b_dflash::DFlashConfig::from_model_dir(&model_path).expect("load config"); + let noise0 = bf16_tensor(&st, "noise_embedding", &[1, 3, config.hidden_size]); + let target0 = bf16_tensor( + &st, + "target_hidden", + &[1, 2, config.hidden_size * config.target_layer_count()], + ); + let positions0 = i32_tensor(&st, "position_ids", &[1, 5]); + let mut noise1 = noise0.clone(); + for (i, value) in noise1.iter_mut().enumerate() { + if i % 13 == 0 { + *value = bf16::from_f32(value.to_f32() + 0.015625); + } + } + let mut target1 = target0.clone(); + for (i, value) in target1.iter_mut().enumerate() { + if i % 31 == 0 { + *value = bf16::from_f32(value.to_f32() - 0.03125); + } + } + let mut positions1 = positions0.clone(); + for value in &mut positions1 { + *value += 2; + } + let scheduler = DFlashSchedulerHandle::start( + &model_path, + 0, + DFlashSchedulerOptions { + executor: DFlashExecutorOptions { + max_batch_size: 2, + max_step_context_len: 2, + max_q_len: 3, + max_seq_len: 8, + max_caches: 8, + }, + max_wait: std::time::Duration::from_millis(50), + max_total_tokens: 16, + }, + ) + .expect("start scheduler"); + let barrier = Arc::new(Barrier::new(3)); + let scheduler0 = scheduler.clone(); + let barrier0 = Arc::clone(&barrier); + let t0 = std::thread::spawn(move || { + barrier0.wait(); + scheduler0.submit(DFlashDraftHostRequest { + request_id: DFlashRequestId(42), + noise_embedding: noise0, + target_hidden: target0, + position_ids: positions0, + q_len: 3, + ctx_len: 2, + cache_mode: DFlashCacheMode::NoCache, + }) + }); + let barrier1 = Arc::clone(&barrier); + let t1 = std::thread::spawn(move || { + barrier1.wait(); + scheduler.submit(DFlashDraftHostRequest { + request_id: DFlashRequestId(43), + noise_embedding: noise1, + target_hidden: target1, + position_ids: positions1, + q_len: 3, + ctx_len: 2, + cache_mode: DFlashCacheMode::NoCache, + }) + }); + barrier.wait(); + let response0 = t0 + .join() + .expect("join scheduler request 0") + .expect("submit 0"); + let response1 = t1 + .join() + .expect("join scheduler request 1") + .expect("submit 1"); + assert_eq!(response0.request_id, DFlashRequestId(42)); + assert_eq!(response1.request_id, DFlashRequestId(43)); + assert_eq!(response0.hidden_dim, config.hidden_size); + assert_eq!(response1.hidden_dim, config.hidden_size); + assert_eq!(response0.seq_len, 3); + assert_eq!(response1.seq_len, 3); + assert_eq!(response0.output.len(), config.hidden_size * 3); + assert_eq!(response1.output.len(), config.hidden_size * 3); + assert_eq!(response0.batch_size, 2); + assert_eq!(response1.batch_size, 2); + assert_eq!(response0.cache_seq_len, 0); + assert_eq!(response1.cache_seq_len, 0); +} + +#[test] +fn dflash_scheduler_manages_draft_cache() { + let Some(model_path) = model_path_or_skip("dflash scheduler cache gate") else { + return; + }; + let golden_path = Path::new(GOLDEN); + if !golden_path.exists() { + eprintln!("skipping dflash scheduler cache gate: {GOLDEN} does not exist"); + return; + } + + let bytes = std::fs::read(golden_path).expect("read golden"); + let st = SafeTensors::deserialize(&bytes).expect("parse golden"); + let config = + openinfer_qwen3_4b_dflash::DFlashConfig::from_model_dir(&model_path).expect("load config"); + let noise = bf16_tensor(&st, "noise_embedding", &[1, 3, config.hidden_size]); + let target = bf16_tensor( + &st, + "target_hidden", + &[1, 2, config.hidden_size * config.target_layer_count()], + ); + let positions = i32_tensor(&st, "position_ids", &[1, 5]); + let scheduler = DFlashSchedulerHandle::start( + &model_path, + 0, + DFlashSchedulerOptions { + executor: DFlashExecutorOptions { + max_batch_size: 2, + max_step_context_len: 2, + max_q_len: 3, + max_seq_len: 8, + max_caches: 8, + }, + max_wait: std::time::Duration::from_millis(10), + max_total_tokens: 16, + }, + ) + .expect("start scheduler"); + let request_id = DFlashRequestId(99); + let response = scheduler + .submit(DFlashDraftHostRequest { + request_id, + noise_embedding: noise, + target_hidden: target, + position_ids: positions, + q_len: 3, + ctx_len: 2, + cache_mode: DFlashCacheMode::DraftCache, + }) + .expect("submit cached request"); + assert_eq!(response.request_id, request_id); + assert_eq!(response.cache_seq_len, 5); + assert_eq!( + scheduler.cache_seq_len(request_id).expect("cache seq len"), + 5 + ); + scheduler.crop_cache(request_id, 2).expect("crop cache"); + assert_eq!( + scheduler.cache_seq_len(request_id).expect("cache seq len"), + 2 + ); + scheduler.reset_cache(request_id).expect("reset cache"); + assert_eq!( + scheduler.cache_seq_len(request_id).expect("cache seq len"), + 0 + ); +} + +#[test] +fn dflash_scheduler_control_messages_are_fifo() { + let Some(model_path) = model_path_or_skip("dflash scheduler fifo gate") else { + return; + }; + let golden_path = Path::new(GOLDEN); + if !golden_path.exists() { + eprintln!("skipping dflash scheduler fifo gate: {GOLDEN} does not exist"); + return; + } + + let bytes = std::fs::read(golden_path).expect("read golden"); + let st = SafeTensors::deserialize(&bytes).expect("parse golden"); + let config = + openinfer_qwen3_4b_dflash::DFlashConfig::from_model_dir(&model_path).expect("load config"); + let noise = bf16_tensor(&st, "noise_embedding", &[1, 3, config.hidden_size]); + let target = bf16_tensor( + &st, + "target_hidden", + &[1, 2, config.hidden_size * config.target_layer_count()], + ); + let positions = i32_tensor(&st, "position_ids", &[1, 5]); + let scheduler = DFlashSchedulerHandle::start( + &model_path, + 0, + DFlashSchedulerOptions { + executor: DFlashExecutorOptions { + max_batch_size: 2, + max_step_context_len: 2, + max_q_len: 3, + max_seq_len: 8, + max_caches: 8, + }, + max_wait: std::time::Duration::from_millis(100), + max_total_tokens: 16, + }, + ) + .expect("start scheduler"); + let request_id = DFlashRequestId(123); + // The scheduler uses one unbounded channel for both submit and control + // messages, so FIFO ordering is guaranteed by construction: each call + // blocks until the scheduler thread has processed it. Submit the cached + // request first; when it returns the cache must exist, then the following + // control calls run strictly after it. + let response = scheduler + .submit(DFlashDraftHostRequest { + request_id, + noise_embedding: noise, + target_hidden: target, + position_ids: positions, + q_len: 3, + ctx_len: 2, + cache_mode: DFlashCacheMode::DraftCache, + }) + .expect("cached submit"); + assert_eq!(response.cache_seq_len, 5); + assert_eq!( + scheduler + .cache_seq_len(request_id) + .expect("cache seq len after submit"), + 5 + ); + scheduler.reset_cache(request_id).expect("reset cache"); + assert_eq!( + scheduler.cache_seq_len(request_id).expect("cache seq len"), + 0 + ); +} + +#[test] +fn dflash_cache_control_rejects_unknown_request_ids() { + let Some(model_path) = model_path_or_skip("dflash cache rejection gate") else { + return; + }; + let mut executor = DFlashExecutor::load( + &model_path, + 0, + DFlashExecutorOptions { + max_batch_size: 2, + max_step_context_len: 2, + max_q_len: 3, + max_seq_len: 8, + max_caches: 8, + }, + ) + .expect("load executor"); + let unknown = DFlashRequestId(777); + let reset_err = executor.reset_cache(unknown).expect_err("reset must fail"); + assert!( + reset_err + .to_string() + .contains("unknown DFlash cache request_id"), + "unexpected reset error: {reset_err}" + ); + let crop_err = executor.crop_cache(unknown, 1).expect_err("crop must fail"); + assert!( + crop_err + .to_string() + .contains("unknown DFlash cache request_id"), + "unexpected crop error: {crop_err}" + ); + let seq_err = executor + .cache_seq_len(unknown) + .expect_err("cache seq len must fail"); + assert!( + seq_err + .to_string() + .contains("unknown DFlash cache request_id"), + "unexpected seq len error: {seq_err}" + ); + + let scheduler = DFlashSchedulerHandle::start(&model_path, 0, DFlashSchedulerOptions::default()) + .expect("start scheduler"); + let reset_err = scheduler + .reset_cache(unknown) + .expect_err("scheduler reset must fail"); + assert!( + reset_err + .to_string() + .contains("unknown DFlash cache request_id"), + "unexpected scheduler reset error: {reset_err}" + ); + let crop_err = scheduler + .crop_cache(unknown, 1) + .expect_err("scheduler crop must fail"); + assert!( + crop_err + .to_string() + .contains("unknown DFlash cache request_id"), + "unexpected scheduler crop error: {crop_err}" + ); + let seq_err = scheduler + .cache_seq_len(unknown) + .expect_err("scheduler cache seq len must fail"); + assert!( + seq_err + .to_string() + .contains("unknown DFlash cache request_id"), + "unexpected scheduler seq len error: {seq_err}" + ); +} + +#[test] +fn dflash_cache_drop_releases_and_capacity_fails_closed() { + let Some(model_path) = model_path_or_skip("dflash cache drop gate") else { + return; + }; + let golden_path = Path::new(GOLDEN); + if !golden_path.exists() { + eprintln!("skipping dflash cache drop gate: {GOLDEN} does not exist"); + return; + } + + let bytes = std::fs::read(golden_path).expect("read golden"); + let st = SafeTensors::deserialize(&bytes).expect("parse golden"); + let config = + openinfer_qwen3_4b_dflash::DFlashConfig::from_model_dir(&model_path).expect("load config"); + let noise = bf16_tensor(&st, "noise_embedding", &[1, 3, config.hidden_size]); + let target = bf16_tensor( + &st, + "target_hidden", + &[1, 2, config.hidden_size * config.target_layer_count()], + ); + let positions = i32_tensor(&st, "position_ids", &[1, 5]); + + // Cap the pool at one cache so a second concurrent request must fail closed + // until the first is retired via drop_cache. + let scheduler = DFlashSchedulerHandle::start( + &model_path, + 0, + DFlashSchedulerOptions { + executor: DFlashExecutorOptions { + max_batch_size: 2, + max_step_context_len: 2, + max_q_len: 3, + max_seq_len: 8, + max_caches: 1, + }, + max_wait: std::time::Duration::from_millis(10), + max_total_tokens: 16, + }, + ) + .expect("start scheduler"); + + let first = DFlashRequestId(1); + let second = DFlashRequestId(2); + let submit = |id: DFlashRequestId| { + scheduler.submit(DFlashDraftHostRequest { + request_id: id, + noise_embedding: noise.clone(), + target_hidden: target.clone(), + position_ids: positions.clone(), + q_len: 3, + ctx_len: 2, + cache_mode: DFlashCacheMode::DraftCache, + }) + }; + + submit(first).expect("first cached submit creates a cache"); + assert_eq!( + scheduler.cache_seq_len(first).expect("first cache exists"), + 5 + ); + + // Pool is full (max_caches=1): a second distinct request must fail closed. + let overflow_err = match submit(second) { + Ok(_) => panic!("overflow submit must fail closed, but succeeded"), + Err(err) => err, + }; + assert!( + overflow_err.to_string().contains("DFlash cache pool full"), + "unexpected overflow error: {overflow_err}" + ); + + // drop_cache is idempotent and releases the slot for reuse. + scheduler.drop_cache(first).expect("drop first cache"); + // Idempotent: dropping an already-removed (or never-seen) id is not an error. + scheduler + .drop_cache(first) + .expect("drop_cache is idempotent"); + scheduler + .drop_cache(DFlashRequestId(999)) + .expect("drop_cache unknown id is idempotent"); + // The retired id's cache is gone, so reads fail closed. + let gone_err = scheduler + .cache_seq_len(first) + .expect_err("retired cache must be gone"); + assert!( + gone_err + .to_string() + .contains("unknown DFlash cache request_id"), + "unexpected retired-cache error: {gone_err}" + ); + + // Slot is reclaimed: the second request now succeeds. + submit(second).expect("second submit after drop succeeds"); + assert_eq!( + scheduler + .cache_seq_len(second) + .expect("second cache exists"), + 5 + ); +} + +fn assert_deltas(label: &str, actual: &[bf16], expected: &[bf16]) { + assert_eq!(actual.len(), expected.len()); + let mut deltas = actual + .iter() + .zip(expected.iter()) + .map(|(got, want)| (got.to_f32() - want.to_f32()).abs()) + .collect::>(); + deltas.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let mean = deltas.iter().sum::() / deltas.len() as f32; + let p99 = deltas[((deltas.len() as f32 * 0.99).floor() as usize).min(deltas.len() - 1)]; + let max = deltas[deltas.len() - 1]; + eprintln!( + "{label}: mean={mean:.6}, p99={p99:.6}, max={max:.6}, n={}", + deltas.len() + ); + assert!(mean <= MEAN_TOL, "mean delta {mean} > {MEAN_TOL}"); + assert!(p99 <= P99_TOL, "p99 delta {p99} > {P99_TOL}; max={max}"); +} + +fn tensor<'a>(st: &'a SafeTensors<'_>, name: &str, dtype: Dtype, shape: &[usize]) -> &'a [u8] { + let view = st + .tensor(name) + .unwrap_or_else(|err| panic!("golden missing {name}: {err}")); + assert_eq!(view.dtype(), dtype, "{name} dtype mismatch"); + assert_eq!(view.shape(), shape, "{name} shape mismatch"); + view.data() +} + +fn bf16_tensor(st: &SafeTensors<'_>, name: &str, shape: &[usize]) -> Vec { + tensor(st, name, Dtype::BF16, shape) + .chunks_exact(2) + .map(|chunk| bf16::from_bits(u16::from_le_bytes([chunk[0], chunk[1]]))) + .collect() +} + +fn i32_tensor(st: &SafeTensors<'_>, name: &str, shape: &[usize]) -> Vec { + tensor(st, name, Dtype::I32, shape) + .chunks_exact(4) + .map(|chunk| i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect() +} + +fn model_path_or_skip(label: &str) -> Option { + let path = std::env::var("OPENINFER_DFLASH_TEST_MODEL_PATH") + .map(PathBuf::from) + .unwrap_or_else(|_| PathBuf::from(LOCAL_DFLASH)); + let config_path = path.join("config.json"); + if !config_path.exists() { + eprintln!( + "skipping {label}: {}/config.json does not exist; set OPENINFER_DFLASH_TEST_MODEL_PATH to run it", + path.display() + ); + return None; + } + let config_text = std::fs::read_to_string(&config_path).unwrap_or_else(|err| { + panic!( + "failed to read DFlash config {}: {err}", + config_path.display() + ) + }); + let config: serde_json::Value = serde_json::from_str(&config_text).unwrap_or_else(|err| { + panic!( + "failed to parse DFlash config {}: {err}", + config_path.display() + ) + }); + let is_dflash = config + .get("architectures") + .and_then(serde_json::Value::as_array) + .map(|items| { + items + .iter() + .any(|item| item.as_str() == Some("DFlashDraftModel")) + }) + .unwrap_or(false); + if !is_dflash { + eprintln!( + "skipping {label}: {} is not a DFlashDraftModel checkpoint; set OPENINFER_DFLASH_TEST_MODEL_PATH", + path.display() + ); + return None; + } + Some(path) +} diff --git a/test_data/qwen3-4b-dflash-hf-golden.safetensors b/test_data/qwen3-4b-dflash-hf-golden.safetensors new file mode 100644 index 00000000..6007c5c6 Binary files /dev/null and b/test_data/qwen3-4b-dflash-hf-golden.safetensors differ diff --git a/tools/accuracy/bench_qwen3_4b_dflash_forward.py b/tools/accuracy/bench_qwen3_4b_dflash_forward.py new file mode 100644 index 00000000..34fe05b2 --- /dev/null +++ b/tools/accuracy/bench_qwen3_4b_dflash_forward.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 +"""Benchmark Qwen3-4B-DFlash forward in Hugging Face and OpenInfer. + +The benchmark uses the same synthetic fixed inputs for both engines, so the +result isolates the standalone drafter forward cost. It does not measure the +full speculative decoding loop because the OpenInfer target/controller path is +not implemented yet. + +Example: + + .venv/bin/python tools/accuracy/bench_qwen3_4b_dflash_forward.py \ + --draft-model-path /home/hezhaozhao/models/Qwen3-4B-DFlash-b16 \ + --openinfer-bin target/release/qwen3_dflash_forward_bench \ + --out target/benchmarks/qwen3-dflash/forward.json +""" + +from __future__ import annotations + +import argparse +import json +import subprocess +import tempfile +import time +from pathlib import Path + +import torch +from safetensors.torch import load_file, save_file +from transformers import AutoModel + +SEED = 0xD4A5_4B16 + + +def stats(values: list[float]) -> dict[str, float]: + sorted_values = sorted(values) + if not sorted_values: + return {"mean": 0.0, "p50": 0.0, "p90": 0.0, "p99": 0.0, "min": 0.0, "max": 0.0} + def pct(q: float) -> float: + idx = round((len(sorted_values) - 1) * q) + return float(sorted_values[min(idx, len(sorted_values) - 1)]) + return { + "mean": float(sum(sorted_values) / len(sorted_values)), + "p50": pct(0.50), + "p90": pct(0.90), + "p99": pct(0.99), + "min": float(sorted_values[0]), + "max": float(sorted_values[-1]), + } + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--draft-model-path", default="/home/hezhaozhao/models/Qwen3-4B-DFlash-b16") + parser.add_argument("--fixture-out", default="target/benchmarks/qwen3-dflash/forward-input.safetensors") + parser.add_argument("--openinfer-bin", type=Path, help="Path to qwen3_dflash_forward_bench") + parser.add_argument("--openinfer-draft-cache", action="store_true") + parser.add_argument("--openinfer-context-cache", action="store_true", help=argparse.SUPPRESS) + parser.add_argument("--out", default="target/benchmarks/qwen3-dflash/forward.json") + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--ctx-len", type=int, default=2) + parser.add_argument("--q-len", type=int, default=16) + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--iters", type=int, default=30) + parser.add_argument("--target-model-path", default="/home/hezhaozhao/models/Qwen3-4B") + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for the DFlash forward benchmark") + + draft = AutoModel.from_pretrained( + args.draft_model_path, + dtype=torch.bfloat16, + device_map={"": f"cuda:{args.device}"}, + trust_remote_code=True, + ).eval() + device = next(draft.parameters()).device + + gen = torch.Generator(device=device).manual_seed(SEED) + hidden = draft.config.hidden_size + target_layer_count = len(draft.target_layer_ids) + noise_embedding = torch.randn((1, args.q_len, hidden), generator=gen, device=device, dtype=torch.bfloat16) + target_hidden = torch.randn( + (1, args.ctx_len, hidden * target_layer_count), + generator=gen, + device=device, + dtype=torch.bfloat16, + ) + position_ids = torch.arange(args.ctx_len + args.q_len, device=device, dtype=torch.int32).unsqueeze(0) + fixture_path = Path(args.fixture_out) + fixture_path.parent.mkdir(parents=True, exist_ok=True) + save_file( + { + "noise_embedding": noise_embedding.detach().to("cpu", dtype=torch.bfloat16).contiguous(), + "target_hidden": target_hidden.detach().to("cpu", dtype=torch.bfloat16).contiguous(), + "position_ids": position_ids.detach().to("cpu", dtype=torch.int32).contiguous(), + }, + str(fixture_path), + ) + + hf_latencies = [] + with torch.inference_mode(): + for _ in range(args.warmup): + _ = draft( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=position_ids, + use_cache=False, + is_causal=False, + ) + torch.cuda.synchronize(device) + for _ in range(args.iters): + start = time.perf_counter() + _ = draft( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=position_ids, + use_cache=False, + is_causal=False, + ) + torch.cuda.synchronize(device) + hf_latencies.append((time.perf_counter() - start) * 1000.0) + + openinfer_latencies = None + if args.openinfer_bin is not None: + cmd = [ + str(args.openinfer_bin), + "--model-path", + args.draft_model_path, + "--fixture", + str(fixture_path), + "--device", + str(args.device), + "--ctx-len", + str(args.ctx_len), + "--q-len", + str(args.q_len), + "--warmup", + str(args.warmup), + "--iters", + str(args.iters), + ] + openinfer_draft_cache = args.openinfer_draft_cache or args.openinfer_context_cache + if openinfer_draft_cache: + cmd.append("--draft-cache") + raw = subprocess.run(cmd, check=True, capture_output=True, text=True).stdout + payload = json.loads(raw) + openinfer_latencies = payload["latency_ms"] + + report = { + "schema": 1, + "draft_model_path": args.draft_model_path, + "target_model_path": args.target_model_path, + "device": args.device, + "ctx_len": args.ctx_len, + "q_len": args.q_len, + "warmup": args.warmup, + "iters": args.iters, + "openinfer_draft_cache": args.openinfer_draft_cache or args.openinfer_context_cache, + "fixture_out": str(fixture_path), + "hf_remote_code": { + "engine": "transformers", + "latency_ms": stats(hf_latencies), + }, + "openinfer": openinfer_latencies, + } + out = Path(args.out) + out.parent.mkdir(parents=True, exist_ok=True) + out.write_text(json.dumps(report, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") + print(f"wrote {out}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/accuracy/compare_qwen3_4b_dflash_drafter_generation.py b/tools/accuracy/compare_qwen3_4b_dflash_drafter_generation.py new file mode 100644 index 00000000..3ea144d8 --- /dev/null +++ b/tools/accuracy/compare_qwen3_4b_dflash_drafter_generation.py @@ -0,0 +1,466 @@ +#!/usr/bin/env python3 +"""Compare Qwen3-4B-DFlash HF drafter vs OpenInfer drafter in one target loop. + +This is an end-to-end drafter-substitution probe for the current +`openinfer-qwen3-4b-dflash` boundary. The target model, tokenizer, target KV +cache, target verification, target `lm_head`, and greedy sampler all come from +Transformers. The only variable is the drafter: + + * HF remote-code `DFlashDraftModel.forward` + * OpenInfer `qwen3_dflash_forward_fixture` + +The script intentionally uses a no-draft-cache loop on both sides because the +current OpenInfer crate implements standalone draft forward only, not DFlash's +Python `DynamicCache` path or an OpenInfer target/controller. + +Example: + + .venv/bin/python tools/accuracy/compare_qwen3_4b_dflash_drafter_generation.py \ + --target-model-path /home/hezhaozhao/models/Qwen3-4B \ + --draft-model-path /home/hezhaozhao/models/Qwen3-4B-DFlash-b16 \ + --openinfer-bin target/release/qwen3_dflash_forward_fixture \ + --out target/accuracy/qwen3-dflash/drafter-generation.json +""" + +from __future__ import annotations + +import argparse +import hashlib +import json +import subprocess +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +from safetensors.torch import load_file, save_file +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, DynamicCache + +DEFAULT_PROMPTS = [ + "Hello, my name is", + "The capital of France is", + "Qwen is a language model that", + "1, 1, 2, 3, 5,", +] + + +def sha256_u32_le(values: list[int]) -> str: + digest = hashlib.sha256() + for value in values: + digest.update(int(value).to_bytes(4, byteorder="little", signed=False)) + return digest.hexdigest() + + +def sha256_text(text: str) -> str: + return hashlib.sha256(text.encode("utf-8")).hexdigest() + + +def first_diff(left: list[int], right: list[int]) -> dict[str, Any] | None: + limit = min(len(left), len(right)) + for index in range(limit): + if left[index] != right[index]: + return { + "index": index, + "hf_drafter": left[index], + "openinfer_drafter": right[index], + "reason": "token_mismatch", + } + if len(left) != len(right): + return { + "index": limit, + "hf_drafter": left[limit] if len(left) > limit else None, + "openinfer_drafter": right[limit] if len(right) > limit else None, + "reason": "length_mismatch", + } + return None + + +def input_device(model: torch.nn.Module) -> torch.device: + return next(model.parameters()).device + + +def extract_context_feature(hidden_states: tuple[torch.Tensor, ...], layer_ids: list[int]) -> torch.Tensor: + # HF hidden_states includes the embedding output at index 0. + return torch.cat([hidden_states[layer_id + 1] for layer_id in layer_ids], dim=-1) + + +def greedy(logits: torch.Tensor) -> torch.Tensor: + return torch.argmax(logits, dim=-1) + + +def tensor_deltas(got: torch.Tensor, want: torch.Tensor) -> dict[str, float]: + deltas = (got.float() - want.float()).abs().flatten().detach().cpu() + if deltas.numel() == 0: + return {"mean": 0.0, "p99": 0.0, "max": 0.0, "n": 0} + sorted_deltas = torch.sort(deltas).values + p99_index = min(int(deltas.numel() * 0.99), deltas.numel() - 1) + return { + "mean": float(deltas.mean().item()), + "p99": float(sorted_deltas[p99_index].item()), + "max": float(sorted_deltas[-1].item()), + "n": int(deltas.numel()), + } + + +def merge_delta_stats(items: list[dict[str, float]]) -> dict[str, float] | None: + if not items: + return None + total_n = sum(int(item["n"]) for item in items) + if total_n == 0: + return {"mean": 0.0, "p99": 0.0, "max": 0.0, "n": 0} + # The exact aggregate p99 needs raw samples. For this report the per-block + # worst p99 is the conservative summary, and max is exact. + return { + "mean": sum(item["mean"] * item["n"] for item in items) / total_n, + "p99": max(item["p99"] for item in items), + "max": max(item["max"] for item in items), + "n": total_n, + } + + +@dataclass +class Runtime: + target: torch.nn.Module + draft: torch.nn.Module + tokenizer: Any + target_layer_ids: list[int] + block_size: int + mask_token_id: int + stop_token_ids: list[int] + openinfer_bin: Path | None + draft_model_path: Path + repo_root: Path + device_ordinal: int + collect_hidden_delta: bool + + +def run_openinfer_draft( + runtime: Runtime, + *, + noise_embedding: torch.Tensor, + target_hidden: torch.Tensor, + position_ids: torch.Tensor, + temp_dir: Path, + step_index: int, +) -> torch.Tensor: + fixture = temp_dir / f"dflash-input-{step_index:03d}.safetensors" + out = temp_dir / f"dflash-output-{step_index:03d}.safetensors" + save_file( + { + "noise_embedding": noise_embedding.detach().to("cpu", dtype=torch.bfloat16).contiguous(), + "target_hidden": target_hidden.detach().to("cpu", dtype=torch.bfloat16).contiguous(), + "position_ids": position_ids.detach().to("cpu", dtype=torch.int32).contiguous(), + }, + str(fixture), + ) + if runtime.openinfer_bin is not None: + cmd = [ + str(runtime.openinfer_bin), + "--model-path", + str(runtime.draft_model_path), + "--fixture", + str(fixture), + "--out", + str(out), + "--device", + str(runtime.device_ordinal), + ] + else: + cmd = [ + "cargo", + "run", + "--release", + "-p", + "openinfer-qwen3-4b-dflash", + "--bin", + "qwen3_dflash_forward_fixture", + "--", + "--model-path", + str(runtime.draft_model_path), + "--fixture", + str(fixture), + "--out", + str(out), + "--device", + str(runtime.device_ordinal), + ] + subprocess.run(cmd, cwd=runtime.repo_root, check=True) + tensors = load_file(str(out)) + return tensors["openinfer_output"].to(input_device(runtime.target), dtype=torch.bfloat16) + + +def draft_hidden( + runtime: Runtime, + *, + kind: str, + noise_embedding: torch.Tensor, + target_hidden: torch.Tensor, + position_ids: torch.Tensor, + temp_dir: Path, + step_index: int, +) -> tuple[torch.Tensor, dict[str, float] | None]: + with torch.inference_mode(): + hf_hidden = runtime.draft( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=position_ids, + use_cache=False, + is_causal=False, + ) + if kind == "hf": + return hf_hidden, None + oi_hidden = run_openinfer_draft( + runtime, + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=position_ids, + temp_dir=temp_dir, + step_index=step_index, + ) + delta = tensor_deltas(oi_hidden, hf_hidden) if runtime.collect_hidden_delta else None + return oi_hidden, delta + + +def generate_with_drafter( + runtime: Runtime, + *, + prompt: str, + max_new_tokens: int, + kind: str, + temp_dir: Path, +) -> dict[str, Any]: + dev = input_device(runtime.target) + encoded = runtime.tokenizer(prompt, return_tensors="pt") + input_ids = encoded.input_ids.to(dev) + num_input_tokens = input_ids.shape[1] + max_length = num_input_tokens + max_new_tokens + output_ids = torch.full( + (1, max_length + runtime.block_size), + runtime.mask_token_id, + dtype=torch.long, + device=dev, + ) + all_position_ids = torch.arange(output_ids.shape[1], device=dev).unsqueeze(0) + + target_cache = DynamicCache() + with torch.inference_mode(): + output = runtime.target( + input_ids, + position_ids=all_position_ids[:, :num_input_tokens], + past_key_values=target_cache, + use_cache=True, + logits_to_keep=1, + output_hidden_states=True, + ) + output_ids[:, :num_input_tokens] = input_ids + output_ids[:, num_input_tokens : num_input_tokens + 1] = greedy(output.logits) + target_hidden = extract_context_feature(output.hidden_states, runtime.target_layer_ids) + + start = num_input_tokens + accepted_plus_fallback_lengths: list[int] = [] + hidden_deltas: list[dict[str, float]] = [] + step_index = 0 + while start < max_length: + q_len = runtime.block_size + block_output_ids = output_ids[:, start : start + q_len].clone() + block_position_ids = all_position_ids[:, start : start + q_len] + noise_embedding = runtime.target.model.embed_tokens(block_output_ids) + + ctx_len = target_hidden.shape[1] + draft_position_ids = all_position_ids[:, start - ctx_len : start + q_len] + hidden, delta = draft_hidden( + runtime, + kind=kind, + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=draft_position_ids, + temp_dir=temp_dir, + step_index=step_index, + ) + if delta is not None: + hidden_deltas.append(delta) + draft_logits = runtime.target.lm_head(hidden[:, -runtime.block_size + 1 :, :]) + block_output_ids[:, 1:] = greedy(draft_logits) + + with torch.inference_mode(): + output = runtime.target( + block_output_ids, + position_ids=block_position_ids, + past_key_values=target_cache, + use_cache=True, + output_hidden_states=True, + ) + posterior = greedy(output.logits) + matches = block_output_ids[:, 1:] == posterior[:, :-1] + acceptance_length = int(matches.cumprod(dim=1).sum(dim=1)[0].item()) + advanced = acceptance_length + 1 + output_ids[:, start : start + advanced] = block_output_ids[:, :advanced] + output_ids[:, start + advanced] = posterior[:, acceptance_length] + start += advanced + target_cache.crop(start) + target_hidden = extract_context_feature(output.hidden_states, runtime.target_layer_ids)[:, :advanced, :] + accepted_plus_fallback_lengths.append(advanced) + step_index += 1 + + generated_so_far = output_ids[0, num_input_tokens : min(start + 1, max_length)] + if runtime.stop_token_ids and torch.isin( + generated_so_far, + torch.tensor(runtime.stop_token_ids, device=generated_so_far.device), + ).any(): + break + + full_ids = output_ids[0, :max_length] + full_ids = full_ids[full_ids != runtime.mask_token_id] + if runtime.stop_token_ids: + generated = full_ids[num_input_tokens:] + stop_tensor = torch.tensor(runtime.stop_token_ids, device=generated.device) + stop_positions = torch.isin(generated, stop_tensor).nonzero(as_tuple=True)[0] + if stop_positions.numel() > 0: + full_ids = full_ids[: num_input_tokens + int(stop_positions[0].item()) + 1] + + full_token_ids = [int(token) for token in full_ids.detach().cpu().tolist()] + generated_token_ids = full_token_ids[num_input_tokens:] + full_text = runtime.tokenizer.decode(full_token_ids, skip_special_tokens=False) + generated_text = runtime.tokenizer.decode(generated_token_ids, skip_special_tokens=False) + return { + "prompt_token_ids": [int(token) for token in input_ids[0].detach().cpu().tolist()], + "full_token_ids": full_token_ids, + "generated_token_ids": generated_token_ids, + "full_text": full_text, + "generated_text": generated_text, + "token_sha256": sha256_u32_le(generated_token_ids), + "text_sha256": sha256_text(generated_text), + "accepted_plus_fallback_lengths": accepted_plus_fallback_lengths, + "hidden_delta_vs_hf": merge_delta_stats(hidden_deltas), + } + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--target-model-path", required=True) + parser.add_argument("--draft-model-path", default="/home/hezhaozhao/models/Qwen3-4B-DFlash-b16") + parser.add_argument("--out", default="target/accuracy/qwen3-dflash/drafter-generation.json") + parser.add_argument("--prompt", action="append", help="Prompt to test; can be repeated.") + parser.add_argument("--max-new-tokens", type=int, default=12) + parser.add_argument("--openinfer-bin", type=Path, help="Path to a built qwen3_dflash_forward_fixture binary.") + parser.add_argument("--repo-root", type=Path, default=Path(__file__).resolve().parents[2]) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--skip-hidden-delta", action="store_true") + parser.add_argument("--stop-token-id", type=int, action="append", default=[]) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for the DFlash drafter generation comparison") + + target = AutoModelForCausalLM.from_pretrained( + args.target_model_path, + dtype=torch.bfloat16, + device_map={"": f"cuda:{args.device}"}, + trust_remote_code=True, + ).eval() + draft = AutoModel.from_pretrained( + args.draft_model_path, + dtype=torch.bfloat16, + device_map={"": f"cuda:{args.device}"}, + trust_remote_code=True, + ).eval() + tokenizer = AutoTokenizer.from_pretrained(args.target_model_path, trust_remote_code=True) + + stop_token_ids = list(args.stop_token_id) + eos = getattr(target.config, "eos_token_id", None) + if isinstance(eos, int): + stop_token_ids.append(eos) + elif isinstance(eos, list): + stop_token_ids.extend(int(token) for token in eos) + stop_token_ids = sorted(set(stop_token_ids)) + + runtime = Runtime( + target=target, + draft=draft, + tokenizer=tokenizer, + target_layer_ids=[int(layer) for layer in draft.target_layer_ids], + block_size=int(draft.block_size), + mask_token_id=int(getattr(draft, "mask_token_id", None) or draft.config.dflash_config["mask_token_id"]), + stop_token_ids=stop_token_ids, + openinfer_bin=args.openinfer_bin, + draft_model_path=Path(args.draft_model_path), + repo_root=args.repo_root, + device_ordinal=args.device, + collect_hidden_delta=not args.skip_hidden_delta, + ) + + prompts = args.prompt or DEFAULT_PROMPTS + cases = [] + with tempfile.TemporaryDirectory(prefix="qwen3-dflash-parity-") as tmp: + temp_dir = Path(tmp) + for index, prompt in enumerate(prompts): + hf = generate_with_drafter( + runtime, + prompt=prompt, + max_new_tokens=args.max_new_tokens, + kind="hf", + temp_dir=temp_dir, + ) + openinfer = generate_with_drafter( + runtime, + prompt=prompt, + max_new_tokens=args.max_new_tokens, + kind="openinfer", + temp_dir=temp_dir, + ) + token_diff = first_diff(hf["generated_token_ids"], openinfer["generated_token_ids"]) + text_match = hf["generated_text"] == openinfer["generated_text"] + token_match = token_diff is None + classification = "all_token_text_exact" if token_match and text_match else "drafter_generation_mismatch" + cases.append( + { + "id": f"prompt_{index:03d}", + "prompt": prompt, + "max_new_tokens": args.max_new_tokens, + "prompt_token_ids": hf["prompt_token_ids"], + "hf_drafter": hf, + "openinfer_drafter": openinfer, + "token_match": token_match, + "text_match": text_match, + "classification": classification, + "first_diff": token_diff, + } + ) + print( + f"{classification}: {prompt!r}; " + f"hf_accept={hf['accepted_plus_fallback_lengths']} " + f"openinfer_accept={openinfer['accepted_plus_fallback_lengths']}" + ) + + result = { + "schema": 1, + "comparison": "qwen3_4b_dflash_drafter_generation", + "mode": "greedy_bs1_no_draft_cache_drafter_substitution", + "target_model_path": args.target_model_path, + "draft_model_path": args.draft_model_path, + "openinfer_bin": str(args.openinfer_bin) if args.openinfer_bin else None, + "block_size": runtime.block_size, + "target_layer_ids": runtime.target_layer_ids, + "mask_token_id": runtime.mask_token_id, + "stop_token_ids": runtime.stop_token_ids, + "torch_version": torch.__version__, + "transformers_version": __import__("transformers").__version__, + "case_count": len(cases), + "all_token_text_exact": all(case["classification"] == "all_token_text_exact" for case in cases), + "cases": cases, + } + out = Path(args.out) + out.parent.mkdir(parents=True, exist_ok=True) + out.write_text(json.dumps(result, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") + print(f"wrote {out}") + if not result["all_token_text_exact"]: + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/accuracy/dump_qwen3_4b_dflash_hf_golden.py b/tools/accuracy/dump_qwen3_4b_dflash_hf_golden.py new file mode 100644 index 00000000..5efc34b0 --- /dev/null +++ b/tools/accuracy/dump_qwen3_4b_dflash_hf_golden.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +"""Generate a tiny HuggingFace remote-code golden for Qwen3-4B-DFlash-b16. + +The DFlash crate compares its standalone draft forward against this fixture +without importing Python at Rust test time. The input tensors are synthetic but +seed-pinned, so the fixture exercises the exact `DFlashDraftModel.forward` +contract: selected target hidden states, noise embeddings, and absolute +position ids. + + .venv/bin/python tools/accuracy/dump_qwen3_4b_dflash_hf_golden.py \ + --model-path /home/hezhaozhao/models/Qwen3-4B-DFlash-b16 \ + --out test_data/qwen3-4b-dflash-hf-golden.safetensors +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import torch +from safetensors.torch import save_file +from transformers import AutoModel + +SEED = 0xD4A5_4B16 +CTX_LEN = 2 +Q_LEN = 3 + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model-path", default="/home/hezhaozhao/models/Qwen3-4B-DFlash-b16") + parser.add_argument("--out", default="test_data/qwen3-4b-dflash-hf-golden.safetensors") + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required to generate the DFlash bf16 golden") + + model = AutoModel.from_pretrained( + args.model_path, + dtype=torch.bfloat16, + device_map="cuda", + trust_remote_code=True, + ).eval() + + gen = torch.Generator(device="cuda").manual_seed(SEED) + hidden = model.config.hidden_size + target_layers = len(model.target_layer_ids) + noise_embedding = torch.randn( + (1, Q_LEN, hidden), + generator=gen, + device="cuda", + dtype=torch.bfloat16, + ) + target_hidden = torch.randn( + (1, CTX_LEN, hidden * target_layers), + generator=gen, + device="cuda", + dtype=torch.bfloat16, + ) + position_ids = torch.arange(CTX_LEN + Q_LEN, device="cuda", dtype=torch.int64).unsqueeze(0) + + with torch.inference_mode(): + output = model( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=position_ids, + use_cache=False, + is_causal=False, + ) + torch.cuda.synchronize() + + tensors = { + "noise_embedding": noise_embedding.cpu(), + "target_hidden": target_hidden.cpu(), + "position_ids": position_ids.to(torch.int32).cpu(), + "output": output.cpu(), + } + meta = { + "model_path": args.model_path, + "seed": str(SEED), + "ctx_len": str(CTX_LEN), + "q_len": str(Q_LEN), + "hidden_size": str(hidden), + "target_layer_ids": ",".join(str(layer) for layer in model.target_layer_ids), + "block_size": str(model.block_size), + "mask_token_id": str(model.mask_token_id), + "torch_version": torch.__version__, + "transformers_version": __import__("transformers").__version__, + } + out = Path(args.out) + out.parent.mkdir(parents=True, exist_ok=True) + save_file(tensors, str(out), metadata=meta) + print(f"wrote {out}: ctx_len={CTX_LEN}, q_len={Q_LEN}, hidden={hidden}, seed={SEED}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())