From eb76e8c7299d25ecb628eda80b5701e9a2538cf2 Mon Sep 17 00:00:00 2001 From: cheese-cakee Date: Fri, 19 Jun 2026 19:29:53 +0530 Subject: [PATCH] perf(qwen35): fixed-width verify graph for CUDA-graph replay MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The kvflash spec-decode verify path already builds a step-invariant ggml graph (set_rows + kv_write_rows, stride-256 FA span, persistent step arena), so the ggml-cuda CUDA-graph cache can replay it across decode steps. But the post-accept replay verify runs at a variable width (commit_n), building a graph whose node dimensions differ from the q_len-wide main verify — every low-acceptance step forces a recapture, and alternating verify/replay never settles on one captured graph. Add an optional fixed-width path. verify_batch(..., pad_to) builds the forward at max(pad_to, tokens.size()) tokens; the padding rows carry a zero embedding and are masked out. Real rows attend only to committed positions and their own causal slots, so causality/masking excludes every padded column — the argmax consumed for the real positions is bit-identical to an unpadded call. The caller pads the replay to q_len so it reuses the main verify's graph; those slots are already resident from the same round's main verify, so this allocates no new pool slot and triggers no eviction. - DFlashTarget::verify_batch grows a trailing `pad_to = 0`; the default preserves the current variable-width behavior. Qwen35DFlashTarget implements it; the gemma4 and layer-split overrides accept and ignore it. - Gated behind DFLASH_QWEN35_FIXED_VERIFY (off by default). The win only lands when the verify graph is CUDA-graph-eligible: for an MoE target ggml-cuda disables graphs when the mul_mat_id token batch exceeds mmvq_mmid_max (~8 on Turing+), so a wider block_size would only pay the padded compute. --- server/src/common/dflash_target.h | 12 ++++++- server/src/gemma4/gemma4_dflash_target.cpp | 4 ++- server/src/gemma4/gemma4_dflash_target.h | 3 +- server/src/qwen35/qwen35_backend.cpp | 11 +++++- server/src/qwen35/qwen35_dflash_target.cpp | 35 +++++++++++++------ server/src/qwen35/qwen35_dflash_target.h | 3 +- .../qwen35_layer_split_dflash_target.cpp | 4 ++- .../qwen35/qwen35_layer_split_dflash_target.h | 3 +- 8 files changed, 57 insertions(+), 18 deletions(-) diff --git a/server/src/common/dflash_target.h b/server/src/common/dflash_target.h index b7244c1b1..da01750f6 100644 --- a/server/src/common/dflash_target.h +++ b/server/src/common/dflash_target.h @@ -30,11 +30,21 @@ struct DFlashTarget { // During forward, the target MUST capture intermediate activations at // the layers specified by capture_layer_ids() and store them in the // draft's feature ring (how this happens is implementation-defined). + // + // `pad_to`, when greater than tokens.size(), asks the implementation to + // build the forward graph at a fixed width of `pad_to` tokens (padding rows + // are masked out and never read) so the graph's node dimensions stay + // constant across decode steps and the CUDA-graph cache can replay instead + // of recapturing. The consumed result for the first tokens.size() positions + // is identical to an unpadded call. Implementations may ignore it (the + // default 0 preserves today's variable-width behavior); callers must only + // pad within a window whose KV slots are already resident. virtual bool verify_batch(const std::vector & tokens, int base_pos, int & last_tok, std::vector * all_argmax = nullptr, - bool capture_ssm_intermediates = false) = 0; + bool capture_ssm_intermediates = false, + int pad_to = 0) = 0; // Read the full [n_tokens x vocab] f32 logits produced by the most // recent verify_batch call. Used by sampled-verify (spec decode with diff --git a/server/src/gemma4/gemma4_dflash_target.cpp b/server/src/gemma4/gemma4_dflash_target.cpp index 92a712c9e..ca2793f9d 100644 --- a/server/src/gemma4/gemma4_dflash_target.cpp +++ b/server/src/gemma4/gemma4_dflash_target.cpp @@ -38,8 +38,10 @@ bool Gemma4DFlashTarget::verify_batch( int base_pos, int & last_tok, std::vector * all_argmax, - bool capture_ssm_intermediates) { + bool capture_ssm_intermediates, + int pad_to) { (void)capture_ssm_intermediates; // Gemma4 is pure-attention, no SSM state + (void)pad_to; // fixed-width verify not implemented here const int n_tokens = (int)tokens.size(); if (n_tokens <= 0) return false; diff --git a/server/src/gemma4/gemma4_dflash_target.h b/server/src/gemma4/gemma4_dflash_target.h index 52ef56aaa..1ccade308 100644 --- a/server/src/gemma4/gemma4_dflash_target.h +++ b/server/src/gemma4/gemma4_dflash_target.h @@ -31,7 +31,8 @@ class Gemma4DFlashTarget : public DFlashTarget { int base_pos, int & last_tok, std::vector * all_argmax = nullptr, - bool capture_ssm_intermediates = false) override; + bool capture_ssm_intermediates = false, + int pad_to = 0) override; // kvflash: route verify writes through the pool (slots allocated here, // slot-space mask inside gemma4_verify_batch). Non-owning. diff --git a/server/src/qwen35/qwen35_backend.cpp b/server/src/qwen35/qwen35_backend.cpp index 4b1978ed8..c5689e773 100644 --- a/server/src/qwen35/qwen35_backend.cpp +++ b/server/src/qwen35/qwen35_backend.cpp @@ -1846,6 +1846,14 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, const int max_verify_tokens = cfg_.ddtree_mode ? std::max(dw_.block_size, cfg_.ddtree_budget + 1) : dw_.block_size; + // Opt-in: pad the post-accept replay verify to the same width as the main + // verify (q_len) so both share one ggml-cuda captured graph instead of + // forcing a recapture each low-acceptance step. Off by default — it only + // helps when the verify graph is CUDA-graph-eligible (q_len within the MoE + // mul_mat_id batch limit), and otherwise just adds padded compute. + static const bool g_fixed_verify_width = + (std::getenv("DFLASH_QWEN35_FIXED_VERIFY") != nullptr); + const int replay_pad_to = g_fixed_verify_width ? q_len : 0; if ((cfg_.fast_rollback || cfg_.ddtree_mode) && !cache_.rollback_ctx) { if (!migrate_prefill_cache(w_, cfg_.device.max_ctx, max_verify_tokens, @@ -2415,7 +2423,8 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, for (int i = 0; i < commit_n; i++) { replay_batch[i] = (i < accept_n) ? draft_tok[i] : bonus_tok; } - if (!target->verify_batch(replay_batch, committed, replay_last_tok, nullptr)) { + if (!target->verify_batch(replay_batch, committed, replay_last_tok, nullptr, + /*capture_ssm_intermediates=*/false, replay_pad_to)) { std::fprintf(stderr, "spec-decode: replay failed\n"); step_graph_destroy(draft_sg); return false; diff --git a/server/src/qwen35/qwen35_dflash_target.cpp b/server/src/qwen35/qwen35_dflash_target.cpp index 88be8c979..28a696143 100644 --- a/server/src/qwen35/qwen35_dflash_target.cpp +++ b/server/src/qwen35/qwen35_dflash_target.cpp @@ -43,9 +43,21 @@ bool Qwen35DFlashTarget::verify_batch( int base_pos, int & last_tok, std::vector * all_argmax, - bool capture_ssm_intermediates) { - const int n_tokens = (int)tokens.size(); - if (n_tokens <= 0) return false; + bool capture_ssm_intermediates, + int pad_to) { + const int n_real = (int)tokens.size(); + if (n_real <= 0) return false; + + // Fixed-width verify: pad the graph to `pad_to` tokens so its node + // dimensions stay constant across decode steps and the ggml-cuda CUDA-graph + // cache can replay instead of recapturing every step. Padding rows are never + // consumed: real rows (0..n_real-1) attend only to committed positions and + // their own causal slots, so causality/masking excludes every padded column + // and the argmax for rows 0..n_real-1 is bit-identical to an unpadded call. + // The caller must only pad within a window whose slots are already resident + // (e.g. the same round's main verify), so this allocates no new pool slots + // and triggers no eviction. + const int n_tokens = (pad_to > n_real) ? pad_to : n_real; const int hidden = w_.n_embd; const bool pool = pager_ != nullptr; @@ -102,9 +114,10 @@ bool Qwen35DFlashTarget::verify_batch( } // Embed input tokens and fill positions. - std::vector embed((size_t)n_tokens * hidden); - if (!w_.embedder.embed(tokens.data(), n_tokens, embed.data())) { - std::fprintf(stderr, "verify_batch: embed failed (n=%d)\n", n_tokens); + // Padding rows keep a zero embedding; they are masked out and never read. + std::vector embed((size_t)n_tokens * hidden, 0.0f); + if (!w_.embedder.embed(tokens.data(), n_real, embed.data())) { + std::fprintf(stderr, "verify_batch: embed failed (n=%d)\n", n_real); return false; } ggml_backend_tensor_set(sg_.inp_embed, embed.data(), 0, @@ -167,17 +180,17 @@ bool Qwen35DFlashTarget::verify_batch( return false; } - // Read argmax results from GPU. - std::vector argmax_buf(n_tokens); + // Read argmax results from GPU — only the real (unpadded) positions. + std::vector argmax_buf(n_real); ggml_backend_tensor_get(sg_.argmax_tokens, argmax_buf.data(), 0, - sizeof(int32_t) * n_tokens); - last_tok = argmax_buf[n_tokens - 1]; + sizeof(int32_t) * n_real); + last_tok = argmax_buf[n_real - 1]; if (all_argmax) { *all_argmax = std::move(argmax_buf); } - cache_.cur_pos = base_pos + n_tokens; + cache_.cur_pos = base_pos + n_real; return true; } diff --git a/server/src/qwen35/qwen35_dflash_target.h b/server/src/qwen35/qwen35_dflash_target.h index 3c8864b6b..3cdad5797 100644 --- a/server/src/qwen35/qwen35_dflash_target.h +++ b/server/src/qwen35/qwen35_dflash_target.h @@ -37,7 +37,8 @@ class Qwen35DFlashTarget : public DFlashTarget { int base_pos, int & last_tok, std::vector * all_argmax = nullptr, - bool capture_ssm_intermediates = false) override; + bool capture_ssm_intermediates = false, + int pad_to = 0) override; bool read_verify_logits(int n_tokens, std::vector & out) override; diff --git a/server/src/qwen35/qwen35_layer_split_dflash_target.cpp b/server/src/qwen35/qwen35_layer_split_dflash_target.cpp index 9228fa99c..88145b6cc 100644 --- a/server/src/qwen35/qwen35_layer_split_dflash_target.cpp +++ b/server/src/qwen35/qwen35_layer_split_dflash_target.cpp @@ -40,7 +40,9 @@ bool Qwen35LayerSplitDFlashTarget::verify_batch( int base_pos, int & last_tok, std::vector * all_argmax, - bool capture_ssm_intermediates) { + bool capture_ssm_intermediates, + int pad_to) { + (void)pad_to; // fixed-width verify not implemented for the layer-split path if (shards_.empty()) return false; if (remote_target_shard_ && remote_target_shard_->active()) { return run_qwen35_mixed_layer_split_forward( diff --git a/server/src/qwen35/qwen35_layer_split_dflash_target.h b/server/src/qwen35/qwen35_layer_split_dflash_target.h index 274a4a2db..c9abad236 100644 --- a/server/src/qwen35/qwen35_layer_split_dflash_target.h +++ b/server/src/qwen35/qwen35_layer_split_dflash_target.h @@ -43,7 +43,8 @@ class Qwen35LayerSplitDFlashTarget : public DFlashTarget { int base_pos, int & last_tok, std::vector * all_argmax = nullptr, - bool capture_ssm_intermediates = false) override; + bool capture_ssm_intermediates = false, + int pad_to = 0) override; bool snapshot_kv() override; bool restore_kv() override;