Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion server/src/common/dflash_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,21 @@ struct DFlashTarget {
// During forward, the target MUST capture intermediate activations at
// the layers specified by capture_layer_ids() and store them in the
// draft's feature ring (how this happens is implementation-defined).
//
// `pad_to`, when greater than tokens.size(), asks the implementation to
// build the forward graph at a fixed width of `pad_to` tokens (padding rows
// are masked out and never read) so the graph's node dimensions stay
// constant across decode steps and the CUDA-graph cache can replay instead
// of recapturing. The consumed result for the first tokens.size() positions
// is identical to an unpadded call. Implementations may ignore it (the
// default 0 preserves today's variable-width behavior); callers must only
// pad within a window whose KV slots are already resident.
virtual bool verify_batch(const std::vector<int32_t> & tokens,
int base_pos,
int & last_tok,
std::vector<int32_t> * all_argmax = nullptr,
bool capture_ssm_intermediates = false) = 0;
bool capture_ssm_intermediates = false,
int pad_to = 0) = 0;

// Read the full [n_tokens x vocab] f32 logits produced by the most
// recent verify_batch call. Used by sampled-verify (spec decode with
Expand Down
4 changes: 3 additions & 1 deletion server/src/gemma4/gemma4_dflash_target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ bool Gemma4DFlashTarget::verify_batch(
int base_pos,
int & last_tok,
std::vector<int32_t> * all_argmax,
bool capture_ssm_intermediates) {
bool capture_ssm_intermediates,
int pad_to) {
(void)capture_ssm_intermediates; // Gemma4 is pure-attention, no SSM state
(void)pad_to; // fixed-width verify not implemented here
const int n_tokens = (int)tokens.size();
if (n_tokens <= 0) return false;

Expand Down
3 changes: 2 additions & 1 deletion server/src/gemma4/gemma4_dflash_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class Gemma4DFlashTarget : public DFlashTarget {
int base_pos,
int & last_tok,
std::vector<int32_t> * all_argmax = nullptr,
bool capture_ssm_intermediates = false) override;
bool capture_ssm_intermediates = false,
int pad_to = 0) override;

// kvflash: route verify writes through the pool (slots allocated here,
// slot-space mask inside gemma4_verify_batch). Non-owning.
Expand Down
11 changes: 10 additions & 1 deletion server/src/qwen35/qwen35_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1846,6 +1846,14 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
const int max_verify_tokens = cfg_.ddtree_mode
? std::max<int>(dw_.block_size, cfg_.ddtree_budget + 1)
: dw_.block_size;
// Opt-in: pad the post-accept replay verify to the same width as the main
// verify (q_len) so both share one ggml-cuda captured graph instead of
// forcing a recapture each low-acceptance step. Off by default — it only
// helps when the verify graph is CUDA-graph-eligible (q_len within the MoE
// mul_mat_id batch limit), and otherwise just adds padded compute.
static const bool g_fixed_verify_width =
(std::getenv("DFLASH_QWEN35_FIXED_VERIFY") != nullptr);
const int replay_pad_to = g_fixed_verify_width ? q_len : 0;
if ((cfg_.fast_rollback || cfg_.ddtree_mode) && !cache_.rollback_ctx) {
if (!migrate_prefill_cache(w_, cfg_.device.max_ctx,
max_verify_tokens,
Expand Down Expand Up @@ -2415,7 +2423,8 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
for (int i = 0; i < commit_n; i++) {
replay_batch[i] = (i < accept_n) ? draft_tok[i] : bonus_tok;
}
if (!target->verify_batch(replay_batch, committed, replay_last_tok, nullptr)) {
if (!target->verify_batch(replay_batch, committed, replay_last_tok, nullptr,
/*capture_ssm_intermediates=*/false, replay_pad_to)) {
std::fprintf(stderr, "spec-decode: replay failed\n");
step_graph_destroy(draft_sg);
return false;
Expand Down
35 changes: 24 additions & 11 deletions server/src/qwen35/qwen35_dflash_target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,21 @@ bool Qwen35DFlashTarget::verify_batch(
int base_pos,
int & last_tok,
std::vector<int32_t> * all_argmax,
bool capture_ssm_intermediates) {
const int n_tokens = (int)tokens.size();
if (n_tokens <= 0) return false;
bool capture_ssm_intermediates,
int pad_to) {
const int n_real = (int)tokens.size();
if (n_real <= 0) return false;

// Fixed-width verify: pad the graph to `pad_to` tokens so its node
// dimensions stay constant across decode steps and the ggml-cuda CUDA-graph
// cache can replay instead of recapturing every step. Padding rows are never
// consumed: real rows (0..n_real-1) attend only to committed positions and
// their own causal slots, so causality/masking excludes every padded column
// and the argmax for rows 0..n_real-1 is bit-identical to an unpadded call.
// The caller must only pad within a window whose slots are already resident
// (e.g. the same round's main verify), so this allocates no new pool slots
// and triggers no eviction.
const int n_tokens = (pad_to > n_real) ? pad_to : n_real;

const int hidden = w_.n_embd;
const bool pool = pager_ != nullptr;
Expand Down Expand Up @@ -102,9 +114,10 @@ bool Qwen35DFlashTarget::verify_batch(
}

// Embed input tokens and fill positions.
std::vector<float> embed((size_t)n_tokens * hidden);
if (!w_.embedder.embed(tokens.data(), n_tokens, embed.data())) {
std::fprintf(stderr, "verify_batch: embed failed (n=%d)\n", n_tokens);
// Padding rows keep a zero embedding; they are masked out and never read.
std::vector<float> embed((size_t)n_tokens * hidden, 0.0f);
if (!w_.embedder.embed(tokens.data(), n_real, embed.data())) {
std::fprintf(stderr, "verify_batch: embed failed (n=%d)\n", n_real);
return false;
}
ggml_backend_tensor_set(sg_.inp_embed, embed.data(), 0,
Expand Down Expand Up @@ -167,17 +180,17 @@ bool Qwen35DFlashTarget::verify_batch(
return false;
}

// Read argmax results from GPU.
std::vector<int32_t> argmax_buf(n_tokens);
// Read argmax results from GPU — only the real (unpadded) positions.
std::vector<int32_t> argmax_buf(n_real);
ggml_backend_tensor_get(sg_.argmax_tokens, argmax_buf.data(), 0,
sizeof(int32_t) * n_tokens);
last_tok = argmax_buf[n_tokens - 1];
sizeof(int32_t) * n_real);
last_tok = argmax_buf[n_real - 1];

if (all_argmax) {
*all_argmax = std::move(argmax_buf);
}

cache_.cur_pos = base_pos + n_tokens;
cache_.cur_pos = base_pos + n_real;
return true;
}

Expand Down
3 changes: 2 additions & 1 deletion server/src/qwen35/qwen35_dflash_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class Qwen35DFlashTarget : public DFlashTarget {
int base_pos,
int & last_tok,
std::vector<int32_t> * all_argmax = nullptr,
bool capture_ssm_intermediates = false) override;
bool capture_ssm_intermediates = false,
int pad_to = 0) override;

bool read_verify_logits(int n_tokens, std::vector<float> & out) override;

Expand Down
4 changes: 3 additions & 1 deletion server/src/qwen35/qwen35_layer_split_dflash_target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ bool Qwen35LayerSplitDFlashTarget::verify_batch(
int base_pos,
int & last_tok,
std::vector<int32_t> * all_argmax,
bool capture_ssm_intermediates) {
bool capture_ssm_intermediates,
int pad_to) {
(void)pad_to; // fixed-width verify not implemented for the layer-split path
if (shards_.empty()) return false;
if (remote_target_shard_ && remote_target_shard_->active()) {
return run_qwen35_mixed_layer_split_forward(
Expand Down
3 changes: 2 additions & 1 deletion server/src/qwen35/qwen35_layer_split_dflash_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class Qwen35LayerSplitDFlashTarget : public DFlashTarget {
int base_pos,
int & last_tok,
std::vector<int32_t> * all_argmax = nullptr,
bool capture_ssm_intermediates = false) override;
bool capture_ssm_intermediates = false,
int pad_to = 0) override;

bool snapshot_kv() override;
bool restore_kv() override;
Expand Down