Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions server/src/common/model_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ struct GenerateRequest {
// When non-null, the spec decode loop uses these as draft overrides,
// bypassing draft model computation for covered positions.
const std::vector<int32_t> * hint_tokens = nullptr;
// Optional env-gated dflash stall recovery: when spec decode is about to
// emit early EOS after an action preamble, inject a bare tool-call XML
// prefix and continue in AR with KV state intact.
const std::vector<int32_t> * stall_tool_prefix_tokens = nullptr;
const std::vector<int32_t> * stall_action_suffix_tokens = nullptr;
const std::vector<int32_t> * stall_skip_tokens = nullptr;
// Optional thinking-budget hook — see BudgetHook docs above.
BudgetHook budget_hook;
};
Expand Down
239 changes: 230 additions & 9 deletions server/src/qwen35/qwen35_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,42 @@ static float bf16_bits_to_f32(uint16_t bits) {
v.u = (uint32_t)bits << 16;
return v.f;
}

static bool tokens_contain_recent_sequence(const std::vector<int32_t> & tokens,
const std::vector<int32_t> & needle,
size_t max_trailing) {
if (needle.empty() || tokens.size() < needle.size()) return false;
const size_t last_start = tokens.size() - needle.size();
const size_t first_start =
last_start > max_trailing ? last_start - max_trailing : 0;
for (size_t start = first_start; start <= last_start; ++start) {
if (std::equal(needle.begin(), needle.end(), tokens.begin() + start)) {
return true;
}
}
return false;
}

static bool tokens_have_recent_any(const std::vector<int32_t> & tokens,
const std::vector<int32_t> & candidates,
size_t max_trailing) {
if (tokens.empty() || candidates.empty()) return false;
for (size_t trailing = 0; trailing <= max_trailing; ++trailing) {
if (tokens.size() <= trailing) break;
const int32_t tok = tokens[tokens.size() - 1 - trailing];
if (std::find(candidates.begin(), candidates.end(), tok) != candidates.end()) {
return true;
}
}
return false;
}

static int env_int_or_default(const char * name, int fallback) {
if (const char * raw = std::getenv(name)) {
if (*raw) return std::atoi(raw);
}
return fallback;
}
} // namespace

#define IS_EOS_TOK(tok, w) \
Expand Down Expand Up @@ -577,14 +613,27 @@ GenerateResult Qwen35Backend::generate(const GenerateRequest & req,
// without sacrificing spec-decode throughput for the bulk of
// generation. Most requests never hit the tail because the
// model closes </think> naturally well before the budget edge.
if (!do_spec_decode(committed, req.n_gen, result.tokens, out_io,
{
bool _sd_ok = do_spec_decode(committed, req.n_gen, result.tokens, out_io,
result.accept_rate, result.spec_decode_ran,
req.hint_tokens, &req.budget_hook,
req.hint_tokens,
req.stall_tool_prefix_tokens,
req.stall_action_suffix_tokens,
req.stall_skip_tokens,
&req.budget_hook,
&result.budget_forced_close,
&result.degenerate_decode_close)) {
&result.degenerate_decode_close);
if (_sd_ok && result.tokens.empty()) {
// FIX: spec-decode degenerate empty (EOS as first token) on certain
// agentic turns -> fall back to AR decode, which is verified to produce
// correct non-empty output for exactly these contexts (temp-0 parity).
_sd_ok = do_ar_decode(committed, req.n_gen, result.tokens, out_io, req.budget_hook, &result.budget_forced_close, &result.degenerate_decode_close);
}
if (!_sd_ok) {
result.error = "decode";
return result;
}
}
result.decode_s = std::chrono::duration<double>(
std::chrono::steady_clock::now() - t_decode_start).count();
}
Expand Down Expand Up @@ -668,14 +717,27 @@ GenerateResult Qwen35Backend::restore_and_generate(int slot,
// without sacrificing spec-decode throughput for the bulk of
// generation. Most requests never hit the tail because the
// model closes </think> naturally well before the budget edge.
if (!do_spec_decode(committed, req.n_gen, result.tokens, out_io,
{
bool _sd_ok = do_spec_decode(committed, req.n_gen, result.tokens, out_io,
result.accept_rate, result.spec_decode_ran,
req.hint_tokens, &req.budget_hook,
req.hint_tokens,
req.stall_tool_prefix_tokens,
req.stall_action_suffix_tokens,
req.stall_skip_tokens,
&req.budget_hook,
&result.budget_forced_close,
&result.degenerate_decode_close)) {
&result.degenerate_decode_close);
if (_sd_ok && result.tokens.empty()) {
// FIX: spec-decode degenerate empty (EOS as first token) on certain
// agentic turns -> fall back to AR decode, which is verified to produce
// correct non-empty output for exactly these contexts (temp-0 parity).
_sd_ok = do_ar_decode(committed, req.n_gen, result.tokens, out_io, req.budget_hook, &result.budget_forced_close, &result.degenerate_decode_close);
}
if (!_sd_ok) {
result.error = "decode";
return result;
}
}
result.decode_s = std::chrono::duration<double>(
std::chrono::steady_clock::now() - t_decode_start).count();
}
Expand Down Expand Up @@ -922,6 +984,13 @@ bool Qwen35Backend::do_ar_decode(int committed, int n_gen,

auto t_dec0_ar = std::chrono::steady_clock::now();
const size_t out_tokens_at_entry = out_tokens.size();
static const int _min_floor = env_int_or_default("DFLASH_MIN_TOKENS", 0);
static const int _repeat_guard = []{
const int explicit_guard =
env_int_or_default("DFLASH_DEGENERATE_RUN_TOKENS", -1);
if (explicit_guard >= 0) return explicit_guard;
return env_int_or_default("DFLASH_MIN_TOKENS", 0) > 0 ? 32 : 0;
}();

const int hidden = w_.n_embd;
const int vocab = w_.n_vocab;
Expand Down Expand Up @@ -1000,6 +1069,26 @@ bool Qwen35Backend::do_ar_decode(int committed, int n_gen,
}
}

// MIN_TOKENS_BEFORE_EOS (env DFLASH_MIN_TOKENS, default 0=off): if the
// model tries to stop before producing N tokens in this decode call,
// suppress EOS and take the best NON-eos token instead. Targets the Q4
// 'preamble then stop, no tool_call' agentic stall. Env-gated so the
// default production lane is byte-for-byte unchanged.
{
if (_min_floor > 0 && (int)out_tokens.size() < _min_floor && IS_EOS_TOK(next_tok, w_)) {
int alt = -1; float altbest = -1e30f;
for (int v = 0; v < vocab; v++) {
if (IS_EOS_TOK(v, w_)) continue;
if (logits_buf[v] > altbest) { altbest = logits_buf[v]; alt = v; }
}
if (alt >= 0) {
FILE* _d = std::fopen("/tmp/dflash_floor.log", "a");
if (_d) { std::fprintf(_d, "[floor] eos@%d -> alt=%d\n", (int)out_tokens.size(), alt); std::fclose(_d); }
next_tok = alt;
}
}
}

maybe_force_close(next_tok, committed);

out_tokens.push_back(next_tok);
Expand All @@ -1010,6 +1099,22 @@ bool Qwen35Backend::do_ar_decode(int committed, int n_gen,

if (IS_EOS_TOK(next_tok, w_)) break;

if (_repeat_guard > 0 && (int)out_tokens.size() >= _repeat_guard) {
int run = 1;
for (int j = (int)out_tokens.size() - 2; j >= 0; --j) {
if (out_tokens[j] != next_tok) break;
run++;
}
if (run >= _repeat_guard) {
std::fprintf(stderr,
"[degenerate-decode] token %d repeated %d times - "
"breaking AR loop at committed=%d\n",
next_tok, run, committed);
if (degenerate_close_out) *degenerate_close_out = true;
break;
}
}

// Degenerate-decode watchdog. Once we're past the budget-hook's
// close sequence (model in post-`</think>` content phase), watch
// for repetition loops. The aime2025-02 case at think_max=4k
Expand Down Expand Up @@ -1100,6 +1205,9 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
float & out_accept_rate,
bool & out_spec_ran,
const std::vector<int32_t> * hint_tokens,
const std::vector<int32_t> * stall_tool_prefix_tokens,
const std::vector<int32_t> * stall_action_suffix_tokens,
const std::vector<int32_t> * stall_skip_tokens,
const BudgetHook * budget_hook,
bool * forced_close_out,
bool * degenerate_close_out) {
Expand Down Expand Up @@ -1138,6 +1246,10 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
}

out_spec_ran = true;
static const int _min_floor = []{
const char* e = std::getenv("DFLASH_MIN_TOKENS");
return e ? std::atoi(e) : 0;
}();

// ── DFlash spec-decode: draft → verify → accept → replay ──────────

Expand Down Expand Up @@ -1208,6 +1320,26 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
}
}

if (last_tok < 0 && !out_tokens.empty()) {
std::fprintf(stderr,
"[spec-decode] invalid draft seed %d after %d emitted tokens; "
"switching to AR\n",
last_tok, (int)out_tokens.size());
step_graph_destroy(draft_sg);
cache_.last_tok = out_tokens.back();
const int ar_n_gen = n_gen - n_generated;
if (ar_n_gen <= 0) {
io.emit(-1);
return true;
}
BudgetHook tail_hook = budget_hook ? *budget_hook : BudgetHook{};
bool ok = do_ar_decode(committed, ar_n_gen, out_tokens, io,
tail_hook, forced_close_out,
degenerate_close_out);
io.emit(-1);
return ok;
}

// 1. Build noise input for draft
noise_ids[0] = last_tok;
for (int i = 1; i < q_len; i++) noise_ids[i] = target->mask_token_id();
Expand Down Expand Up @@ -1342,7 +1474,6 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
step_graph_destroy(draft_sg);
return false;
}
last_tok = replay_last_tok;

// 7. Sync features for replayed range to mirror (needed for next draft step)
if (use_remote_draft && cache_.target_feat) {
Expand All @@ -1357,20 +1488,110 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,

// 8. Emit committed tokens (stop at EOS)
bool hit_eos = false;
bool floor_to_ar = false;
bool inject_tool_prefix = false;
constexpr size_t kActionSuffixLookback = 16;
constexpr size_t kSkipSequenceLookback = 64;
int emitted = 0;
for (int i = 0; i < commit_n; i++) {
if (_min_floor > 0 && (int)out_tokens.size() < _min_floor &&
IS_EOS_TOK(replay_tok[i], w_)) {
// Action preambles often end as "I'll check:\n\n" before EOS.
// Tokenization makes the colon several tokens back, so keep a
// modest trailing window while still requiring a recent action
// suffix and no nearby completion phrase.
const bool can_inject_tool =
stall_tool_prefix_tokens && !stall_tool_prefix_tokens->empty() &&
stall_action_suffix_tokens && !stall_action_suffix_tokens->empty() &&
tokens_have_recent_any(out_tokens, *stall_action_suffix_tokens,
kActionSuffixLookback) &&
!(stall_skip_tokens &&
tokens_contain_recent_sequence(out_tokens,
*stall_skip_tokens,
kSkipSequenceLookback));
if (can_inject_tool) {
FILE* _d = std::fopen("/tmp/dflash_floor.log", "a");
if (_d) {
std::fprintf(_d,
"[spec-tool-floor] eos@%d committed=%d emitted=%d prefix=%zu -> ar\n",
(int)out_tokens.size(), committed, emitted,
stall_tool_prefix_tokens->size());
std::fclose(_d);
}
floor_to_ar = true;
inject_tool_prefix = true;
break;
}
}
out_tokens.push_back(replay_tok[i]);
io.emit(replay_tok[i]);
emitted++;
if (io.cancelled) break;
if (IS_EOS_TOK(replay_tok[i], w_)) { hit_eos = true; break; }
}
committed += emitted;
int injected = 0;
if (floor_to_ar) {
if (!target->restore_kv()) {
step_graph_destroy(draft_sg);
return false;
}
cache_.cur_pos = committed;
if (emitted > 0) {
std::vector<int32_t> replay_prefix(replay_tok.begin(),
replay_tok.begin() + emitted);
int prefix_last_tok = -1;
if (!target->verify_batch(replay_prefix, committed,
prefix_last_tok, nullptr)) {
std::fprintf(stderr, "spec-decode: floor prefix replay failed\n");
step_graph_destroy(draft_sg);
return false;
}
}
committed += emitted;
cache_.cur_pos = committed;
if (inject_tool_prefix) {
int tool_prefix_last_tok = -1;
if (!target->verify_batch(*stall_tool_prefix_tokens, committed,
tool_prefix_last_tok, nullptr)) {
std::fprintf(stderr, "spec-decode: tool prefix replay failed\n");
step_graph_destroy(draft_sg);
return false;
}
for (int32_t tok : *stall_tool_prefix_tokens) {
out_tokens.push_back(tok);
io.emit(tok);
}
injected = (int)stall_tool_prefix_tokens->size();
committed += injected;
cache_.cur_pos = committed;
}
} else {
last_tok = replay_last_tok;
committed += emitted;
}
cache_.cur_pos = committed;
n_generated += emitted;
n_generated += emitted + injected;
n_accept_sum += std::min(accept_n, emitted);
n_draft_steps++;
if (io.cancelled) break;
if (floor_to_ar) {
step_graph_destroy(draft_sg);
cache_.last_tok = out_tokens.empty() ? last_tok : out_tokens.back();
const int total_draft_pos = std::max(1, n_draft_steps * q_len);
out_accept_rate =
(float)((double)n_accept_sum / (double)total_draft_pos);
const int ar_n_gen = n_gen - n_generated;
if (ar_n_gen <= 0) {
io.emit(-1);
return true;
}
BudgetHook tail_hook = budget_hook ? *budget_hook : BudgetHook{};
bool ok = do_ar_decode(committed, ar_n_gen, out_tokens, io,
tail_hook, forced_close_out,
degenerate_close_out);
io.emit(-1);
return ok;
}
if (hit_eos) break;
}

Expand Down
3 changes: 3 additions & 0 deletions server/src/qwen35/qwen35_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ class Qwen35Backend : public ModelBackend {
float & out_accept_rate,
bool & out_spec_ran,
const std::vector<int32_t> * hint_tokens = nullptr,
const std::vector<int32_t> * stall_tool_prefix_tokens = nullptr,
const std::vector<int32_t> * stall_action_suffix_tokens = nullptr,
const std::vector<int32_t> * stall_skip_tokens = nullptr,
const BudgetHook * budget_hook = nullptr,
bool * forced_close_out = nullptr,
bool * degenerate_close_out = nullptr);
Expand Down
Loading
Loading