diff --git a/server/src/common/model_backend.h b/server/src/common/model_backend.h index 182b5003..6182bf47 100644 --- a/server/src/common/model_backend.h +++ b/server/src/common/model_backend.h @@ -98,6 +98,12 @@ struct GenerateRequest { // When non-null, the spec decode loop uses these as draft overrides, // bypassing draft model computation for covered positions. const std::vector * hint_tokens = nullptr; + // Optional env-gated dflash stall recovery: when spec decode is about to + // emit early EOS after an action preamble, inject a bare tool-call XML + // prefix and continue in AR with KV state intact. + const std::vector * stall_tool_prefix_tokens = nullptr; + const std::vector * stall_action_suffix_tokens = nullptr; + const std::vector * stall_skip_tokens = nullptr; // Optional thinking-budget hook — see BudgetHook docs above. BudgetHook budget_hook; }; diff --git a/server/src/qwen35/qwen35_backend.cpp b/server/src/qwen35/qwen35_backend.cpp index 234ad374..eaa7f49a 100644 --- a/server/src/qwen35/qwen35_backend.cpp +++ b/server/src/qwen35/qwen35_backend.cpp @@ -34,6 +34,42 @@ static float bf16_bits_to_f32(uint16_t bits) { v.u = (uint32_t)bits << 16; return v.f; } + +static bool tokens_contain_recent_sequence(const std::vector & tokens, + const std::vector & needle, + size_t max_trailing) { + if (needle.empty() || tokens.size() < needle.size()) return false; + const size_t last_start = tokens.size() - needle.size(); + const size_t first_start = + last_start > max_trailing ? last_start - max_trailing : 0; + for (size_t start = first_start; start <= last_start; ++start) { + if (std::equal(needle.begin(), needle.end(), tokens.begin() + start)) { + return true; + } + } + return false; +} + +static bool tokens_have_recent_any(const std::vector & tokens, + const std::vector & candidates, + size_t max_trailing) { + if (tokens.empty() || candidates.empty()) return false; + for (size_t trailing = 0; trailing <= max_trailing; ++trailing) { + if (tokens.size() <= trailing) break; + const int32_t tok = tokens[tokens.size() - 1 - trailing]; + if (std::find(candidates.begin(), candidates.end(), tok) != candidates.end()) { + return true; + } + } + return false; +} + +static int env_int_or_default(const char * name, int fallback) { + if (const char * raw = std::getenv(name)) { + if (*raw) return std::atoi(raw); + } + return fallback; +} } // namespace #define IS_EOS_TOK(tok, w) \ @@ -577,14 +613,27 @@ GenerateResult Qwen35Backend::generate(const GenerateRequest & req, // without sacrificing spec-decode throughput for the bulk of // generation. Most requests never hit the tail because the // model closes naturally well before the budget edge. - if (!do_spec_decode(committed, req.n_gen, result.tokens, out_io, + { + bool _sd_ok = do_spec_decode(committed, req.n_gen, result.tokens, out_io, result.accept_rate, result.spec_decode_ran, - req.hint_tokens, &req.budget_hook, + req.hint_tokens, + req.stall_tool_prefix_tokens, + req.stall_action_suffix_tokens, + req.stall_skip_tokens, + &req.budget_hook, &result.budget_forced_close, - &result.degenerate_decode_close)) { + &result.degenerate_decode_close); + if (_sd_ok && result.tokens.empty()) { + // FIX: spec-decode degenerate empty (EOS as first token) on certain + // agentic turns -> fall back to AR decode, which is verified to produce + // correct non-empty output for exactly these contexts (temp-0 parity). + _sd_ok = do_ar_decode(committed, req.n_gen, result.tokens, out_io, req.budget_hook, &result.budget_forced_close, &result.degenerate_decode_close); + } + if (!_sd_ok) { result.error = "decode"; return result; } + } result.decode_s = std::chrono::duration( std::chrono::steady_clock::now() - t_decode_start).count(); } @@ -668,14 +717,27 @@ GenerateResult Qwen35Backend::restore_and_generate(int slot, // without sacrificing spec-decode throughput for the bulk of // generation. Most requests never hit the tail because the // model closes naturally well before the budget edge. - if (!do_spec_decode(committed, req.n_gen, result.tokens, out_io, + { + bool _sd_ok = do_spec_decode(committed, req.n_gen, result.tokens, out_io, result.accept_rate, result.spec_decode_ran, - req.hint_tokens, &req.budget_hook, + req.hint_tokens, + req.stall_tool_prefix_tokens, + req.stall_action_suffix_tokens, + req.stall_skip_tokens, + &req.budget_hook, &result.budget_forced_close, - &result.degenerate_decode_close)) { + &result.degenerate_decode_close); + if (_sd_ok && result.tokens.empty()) { + // FIX: spec-decode degenerate empty (EOS as first token) on certain + // agentic turns -> fall back to AR decode, which is verified to produce + // correct non-empty output for exactly these contexts (temp-0 parity). + _sd_ok = do_ar_decode(committed, req.n_gen, result.tokens, out_io, req.budget_hook, &result.budget_forced_close, &result.degenerate_decode_close); + } + if (!_sd_ok) { result.error = "decode"; return result; } + } result.decode_s = std::chrono::duration( std::chrono::steady_clock::now() - t_decode_start).count(); } @@ -922,6 +984,13 @@ bool Qwen35Backend::do_ar_decode(int committed, int n_gen, auto t_dec0_ar = std::chrono::steady_clock::now(); const size_t out_tokens_at_entry = out_tokens.size(); + static const int _min_floor = env_int_or_default("DFLASH_MIN_TOKENS", 0); + static const int _repeat_guard = []{ + const int explicit_guard = + env_int_or_default("DFLASH_DEGENERATE_RUN_TOKENS", -1); + if (explicit_guard >= 0) return explicit_guard; + return env_int_or_default("DFLASH_MIN_TOKENS", 0) > 0 ? 32 : 0; + }(); const int hidden = w_.n_embd; const int vocab = w_.n_vocab; @@ -1000,6 +1069,26 @@ bool Qwen35Backend::do_ar_decode(int committed, int n_gen, } } + // MIN_TOKENS_BEFORE_EOS (env DFLASH_MIN_TOKENS, default 0=off): if the + // model tries to stop before producing N tokens in this decode call, + // suppress EOS and take the best NON-eos token instead. Targets the Q4 + // 'preamble then stop, no tool_call' agentic stall. Env-gated so the + // default production lane is byte-for-byte unchanged. + { + if (_min_floor > 0 && (int)out_tokens.size() < _min_floor && IS_EOS_TOK(next_tok, w_)) { + int alt = -1; float altbest = -1e30f; + for (int v = 0; v < vocab; v++) { + if (IS_EOS_TOK(v, w_)) continue; + if (logits_buf[v] > altbest) { altbest = logits_buf[v]; alt = v; } + } + if (alt >= 0) { + FILE* _d = std::fopen("/tmp/dflash_floor.log", "a"); + if (_d) { std::fprintf(_d, "[floor] eos@%d -> alt=%d\n", (int)out_tokens.size(), alt); std::fclose(_d); } + next_tok = alt; + } + } + } + maybe_force_close(next_tok, committed); out_tokens.push_back(next_tok); @@ -1010,6 +1099,22 @@ bool Qwen35Backend::do_ar_decode(int committed, int n_gen, if (IS_EOS_TOK(next_tok, w_)) break; + if (_repeat_guard > 0 && (int)out_tokens.size() >= _repeat_guard) { + int run = 1; + for (int j = (int)out_tokens.size() - 2; j >= 0; --j) { + if (out_tokens[j] != next_tok) break; + run++; + } + if (run >= _repeat_guard) { + std::fprintf(stderr, + "[degenerate-decode] token %d repeated %d times - " + "breaking AR loop at committed=%d\n", + next_tok, run, committed); + if (degenerate_close_out) *degenerate_close_out = true; + break; + } + } + // Degenerate-decode watchdog. Once we're past the budget-hook's // close sequence (model in post-`` content phase), watch // for repetition loops. The aime2025-02 case at think_max=4k @@ -1100,6 +1205,9 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, float & out_accept_rate, bool & out_spec_ran, const std::vector * hint_tokens, + const std::vector * stall_tool_prefix_tokens, + const std::vector * stall_action_suffix_tokens, + const std::vector * stall_skip_tokens, const BudgetHook * budget_hook, bool * forced_close_out, bool * degenerate_close_out) { @@ -1138,6 +1246,10 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, } out_spec_ran = true; + static const int _min_floor = []{ + const char* e = std::getenv("DFLASH_MIN_TOKENS"); + return e ? std::atoi(e) : 0; + }(); // ── DFlash spec-decode: draft → verify → accept → replay ────────── @@ -1208,6 +1320,26 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, } } + if (last_tok < 0 && !out_tokens.empty()) { + std::fprintf(stderr, + "[spec-decode] invalid draft seed %d after %d emitted tokens; " + "switching to AR\n", + last_tok, (int)out_tokens.size()); + step_graph_destroy(draft_sg); + cache_.last_tok = out_tokens.back(); + const int ar_n_gen = n_gen - n_generated; + if (ar_n_gen <= 0) { + io.emit(-1); + return true; + } + BudgetHook tail_hook = budget_hook ? *budget_hook : BudgetHook{}; + bool ok = do_ar_decode(committed, ar_n_gen, out_tokens, io, + tail_hook, forced_close_out, + degenerate_close_out); + io.emit(-1); + return ok; + } + // 1. Build noise input for draft noise_ids[0] = last_tok; for (int i = 1; i < q_len; i++) noise_ids[i] = target->mask_token_id(); @@ -1342,7 +1474,6 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, step_graph_destroy(draft_sg); return false; } - last_tok = replay_last_tok; // 7. Sync features for replayed range to mirror (needed for next draft step) if (use_remote_draft && cache_.target_feat) { @@ -1357,20 +1488,110 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, // 8. Emit committed tokens (stop at EOS) bool hit_eos = false; + bool floor_to_ar = false; + bool inject_tool_prefix = false; + constexpr size_t kActionSuffixLookback = 16; + constexpr size_t kSkipSequenceLookback = 64; int emitted = 0; for (int i = 0; i < commit_n; i++) { + if (_min_floor > 0 && (int)out_tokens.size() < _min_floor && + IS_EOS_TOK(replay_tok[i], w_)) { + // Action preambles often end as "I'll check:\n\n" before EOS. + // Tokenization makes the colon several tokens back, so keep a + // modest trailing window while still requiring a recent action + // suffix and no nearby completion phrase. + const bool can_inject_tool = + stall_tool_prefix_tokens && !stall_tool_prefix_tokens->empty() && + stall_action_suffix_tokens && !stall_action_suffix_tokens->empty() && + tokens_have_recent_any(out_tokens, *stall_action_suffix_tokens, + kActionSuffixLookback) && + !(stall_skip_tokens && + tokens_contain_recent_sequence(out_tokens, + *stall_skip_tokens, + kSkipSequenceLookback)); + if (can_inject_tool) { + FILE* _d = std::fopen("/tmp/dflash_floor.log", "a"); + if (_d) { + std::fprintf(_d, + "[spec-tool-floor] eos@%d committed=%d emitted=%d prefix=%zu -> ar\n", + (int)out_tokens.size(), committed, emitted, + stall_tool_prefix_tokens->size()); + std::fclose(_d); + } + floor_to_ar = true; + inject_tool_prefix = true; + break; + } + } out_tokens.push_back(replay_tok[i]); io.emit(replay_tok[i]); emitted++; if (io.cancelled) break; if (IS_EOS_TOK(replay_tok[i], w_)) { hit_eos = true; break; } } - committed += emitted; + int injected = 0; + if (floor_to_ar) { + if (!target->restore_kv()) { + step_graph_destroy(draft_sg); + return false; + } + cache_.cur_pos = committed; + if (emitted > 0) { + std::vector replay_prefix(replay_tok.begin(), + replay_tok.begin() + emitted); + int prefix_last_tok = -1; + if (!target->verify_batch(replay_prefix, committed, + prefix_last_tok, nullptr)) { + std::fprintf(stderr, "spec-decode: floor prefix replay failed\n"); + step_graph_destroy(draft_sg); + return false; + } + } + committed += emitted; + cache_.cur_pos = committed; + if (inject_tool_prefix) { + int tool_prefix_last_tok = -1; + if (!target->verify_batch(*stall_tool_prefix_tokens, committed, + tool_prefix_last_tok, nullptr)) { + std::fprintf(stderr, "spec-decode: tool prefix replay failed\n"); + step_graph_destroy(draft_sg); + return false; + } + for (int32_t tok : *stall_tool_prefix_tokens) { + out_tokens.push_back(tok); + io.emit(tok); + } + injected = (int)stall_tool_prefix_tokens->size(); + committed += injected; + cache_.cur_pos = committed; + } + } else { + last_tok = replay_last_tok; + committed += emitted; + } cache_.cur_pos = committed; - n_generated += emitted; + n_generated += emitted + injected; n_accept_sum += std::min(accept_n, emitted); n_draft_steps++; if (io.cancelled) break; + if (floor_to_ar) { + step_graph_destroy(draft_sg); + cache_.last_tok = out_tokens.empty() ? last_tok : out_tokens.back(); + const int total_draft_pos = std::max(1, n_draft_steps * q_len); + out_accept_rate = + (float)((double)n_accept_sum / (double)total_draft_pos); + const int ar_n_gen = n_gen - n_generated; + if (ar_n_gen <= 0) { + io.emit(-1); + return true; + } + BudgetHook tail_hook = budget_hook ? *budget_hook : BudgetHook{}; + bool ok = do_ar_decode(committed, ar_n_gen, out_tokens, io, + tail_hook, forced_close_out, + degenerate_close_out); + io.emit(-1); + return ok; + } if (hit_eos) break; } diff --git a/server/src/qwen35/qwen35_backend.h b/server/src/qwen35/qwen35_backend.h index fb9b8f60..f0884c38 100644 --- a/server/src/qwen35/qwen35_backend.h +++ b/server/src/qwen35/qwen35_backend.h @@ -227,6 +227,9 @@ class Qwen35Backend : public ModelBackend { float & out_accept_rate, bool & out_spec_ran, const std::vector * hint_tokens = nullptr, + const std::vector * stall_tool_prefix_tokens = nullptr, + const std::vector * stall_action_suffix_tokens = nullptr, + const std::vector * stall_skip_tokens = nullptr, const BudgetHook * budget_hook = nullptr, bool * forced_close_out = nullptr, bool * degenerate_close_out = nullptr); diff --git a/server/src/server/http_server.cpp b/server/src/server/http_server.cpp index a89309dd..85e22ac8 100644 --- a/server/src/server/http_server.cpp +++ b/server/src/server/http_server.cpp @@ -10,7 +10,9 @@ #include #include #include +#include #include +#include #include #include @@ -77,6 +79,90 @@ static size_t json_array_size(const json & value) { return value.is_array() ? value.size() : 0; } +static bool env_flag_enabled(const char * name) { + const char * raw = std::getenv(name); + if (!raw || !*raw) return false; + std::string value(raw); + std::transform(value.begin(), value.end(), value.begin(), + [](unsigned char c) { return (char)std::tolower(c); }); + return value != "0" && value != "false" && value != "no" && + value != "off"; +} + +static const json * find_tool_function(const json & tools, + const std::string & name) { + if (!tools.is_array() || name.empty()) return nullptr; + for (const auto & tool : tools) { + if (!tool.contains("function") || !tool["function"].is_object()) { + continue; + } + const json & fn = tool["function"]; + if (fn.value("name", "") == name) return &fn; + } + return nullptr; +} + +static std::string first_tool_parameter_name(const json & function_def) { + const auto & params = function_def.value("parameters", json::object()); + if (params.contains("required") && params["required"].is_array()) { + for (const auto & name : params["required"]) { + if (name.is_string()) return name.get(); + } + } + if (params.contains("properties") && params["properties"].is_object()) { + for (const auto & item : params["properties"].items()) { + return item.key(); + } + } + return ""; +} + +static const json * select_stall_recovery_function(const json & tools, + const json & tool_choice) { + if (!tools.is_array() || tools.empty()) return nullptr; + + if (tool_choice.is_object() && tool_choice.contains("function") && + tool_choice["function"].is_object()) { + const std::string forced_name = + tool_choice["function"].value("name", ""); + // If the request forced a concrete function, recovery must honor it; + // falling back to terminal here would synthesize invalid tool XML. + return find_tool_function(tools, forced_name); + } + + if (tool_choice.is_string() && tool_choice.get() == "required" && + tools.size() == 1 && tools[0].contains("function") && + tools[0]["function"].is_object()) { + return &tools[0]["function"]; + } + + if (const json * terminal = find_tool_function(tools, "terminal")) { + return terminal; + } + if (tools.size() == 1 && tools[0].contains("function") && + tools[0]["function"].is_object()) { + return &tools[0]["function"]; + } + return nullptr; +} + +static std::string build_stall_tool_prefix(const json & tools, + const json & tool_choice) { + const json * function_def = + select_stall_recovery_function(tools, tool_choice); + if (!function_def) return "\nvalue("name", ""); + if (name.empty()) return "\n\n"; + std::string param = first_tool_parameter_name(*function_def); + if (!param.empty()) { + prefix += "\n"; + } + return prefix; +} + // Build the /props response body. // // Non-static so unit tests can call it directly (declared in http_server.h). @@ -1332,6 +1418,32 @@ void HttpServer::worker_loop() { gen_req.hint_tokens = &hint_tokens_storage; } } + std::vector stall_tool_prefix_tokens_storage; + std::vector stall_action_suffix_tokens_storage; + std::vector stall_skip_tokens_storage; + if (!req.tools.empty() && env_flag_enabled("DFLASH_STALL_TOOL_PREFIX")) { + stall_tool_prefix_tokens_storage = + tokenizer_.encode(build_stall_tool_prefix(req.tools, + req.tool_choice)); + stall_action_suffix_tokens_storage = tokenizer_.encode(":"); + auto add_suffix_terminal = [&](const std::string & text) { + auto ids = tokenizer_.encode(text); + if (ids.empty()) return; + int32_t tok = ids.back(); + if (std::find(stall_action_suffix_tokens_storage.begin(), + stall_action_suffix_tokens_storage.end(), tok) == + stall_action_suffix_tokens_storage.end()) { + stall_action_suffix_tokens_storage.push_back(tok); + } + }; + add_suffix_terminal("`:"); + add_suffix_terminal("):"); + add_suffix_terminal("\":"); + stall_skip_tokens_storage = tokenizer_.encode(" done"); + gen_req.stall_tool_prefix_tokens = &stall_tool_prefix_tokens_storage; + gen_req.stall_action_suffix_tokens = &stall_action_suffix_tokens_storage; + gen_req.stall_skip_tokens = &stall_skip_tokens_storage; + } // Prefix cache: check for cached KV state. auto [cache_slot, prefix_len] = prefix_cache_.lookup(effective_prompt); @@ -1705,6 +1817,9 @@ void HttpServer::worker_loop() { effective_finish_reason = "length"; } } + if (result.degenerate_decode_close) { + effective_finish_reason = "length"; + } json choice = { {"index", 0}, {"message", msg}, {"finish_reason", effective_finish_reason}