diff --git a/server/src/common/daemon_loop.cpp b/server/src/common/daemon_loop.cpp index 30b5638c..7f70185c 100644 --- a/server/src/common/daemon_loop.cpp +++ b/server/src/common/daemon_loop.cpp @@ -29,11 +29,22 @@ namespace dflash::common { // ── DaemonIO ──────────────────────────────────────────────────────────── +bool DaemonIO::should_cancel() const { + if (cancelled.load(std::memory_order_relaxed)) return true; + if (is_cancelled && is_cancelled()) { + cancelled.store(true, std::memory_order_relaxed); + return true; + } + return false; +} + void DaemonIO::emit(int32_t v) const { + if (should_cancel()) return; + // Call the token callback for non-sentinel tokens. if (on_token && v >= 0) { if (!on_token(v)) { - cancelled = true; + cancelled.store(true, std::memory_order_relaxed); return; } } diff --git a/server/src/common/dflash_spec_decode.cpp b/server/src/common/dflash_spec_decode.cpp index 141e45e9..36bcade6 100644 --- a/server/src/common/dflash_spec_decode.cpp +++ b/server/src/common/dflash_spec_decode.cpp @@ -86,6 +86,7 @@ bool run_dflash_spec_decode( auto t_dec0 = std::chrono::steady_clock::now(); while (n_generated < n_gen) { + if (io.should_cancel()) break; const int need_commit_budget = n_gen - n_generated; // ── Build noise input for draft ──────────────────────────────────── @@ -109,10 +110,12 @@ bool run_dflash_spec_decode( // ── Draft compute (local or remote) ─────────────────────────────── const float * draft_hidden_host = nullptr; if (use_remote_draft) { + if (io.should_cancel()) break; if (!remote_draft->propose(committed, draft_ctx, noise_embed, remote_hidden)) { std::fprintf(stderr, "dflash-spec remote draft propose failed\n"); return false; } + if (io.should_cancel()) break; draft_hidden_host = remote_hidden.data(); } else { if (!build_draft_step(draft_sg, draft_weights, /*lm_head=*/nullptr, draft_backend, @@ -138,10 +141,12 @@ bool run_dflash_spec_decode( ggml_backend_tensor_set(draft_sg.positions_k, pos_k.data(), 0, sizeof(int32_t) * pos_k.size()); auto st = ggml_backend_graph_compute(draft_backend, draft_sg.gf); - if (st != GGML_STATUS_SUCCESS) { + const auto compute_result = classify_daemon_compute_result(st, io); + if (compute_result == DaemonComputeResult::Failed) { std::fprintf(stderr, "dflash-spec draft compute %d\n", (int)st); return false; } + if (compute_result == DaemonComputeResult::Cancelled) break; // Read draft hidden states out to host so the target adapter can // project them through its own LM head (target-internal layout). local_hidden.resize((size_t)hidden * q_len); @@ -174,6 +179,7 @@ bool run_dflash_spec_decode( std::fprintf(stderr, "dflash-spec snapshot_kv failed\n"); return false; } + if (io.should_cancel()) break; int verify_last_tok = -1; if (!target.verify_batch(draft_tok, committed, verify_last_tok, &target_tok)) { @@ -209,6 +215,7 @@ bool run_dflash_spec_decode( std::fprintf(stderr, "dflash-spec restore_kv failed\n"); return false; } + if (io.should_cancel()) break; std::vector replay_tok((size_t)commit_n); for (int i = 0; i < commit_n; i++) { @@ -226,7 +233,7 @@ bool run_dflash_spec_decode( for (int i = 0; i < commit_n; i++) { out_all.push_back(replay_tok[i]); io.emit(replay_tok[i]); - if (io.cancelled) break; + if (io.should_cancel()) break; ++emitted; if (target.is_eos(replay_tok[i])) hit_eos = true; } @@ -234,7 +241,7 @@ bool run_dflash_spec_decode( n_generated += emitted; n_accept_sum += std::min(accept_n, emitted); n_draft_steps++; - if (io.cancelled) break; + if (io.should_cancel()) break; if (hit_eos) break; } if (!use_remote_draft && draft_backend) ggml_backend_synchronize(draft_backend); diff --git a/server/src/common/model_backend.h b/server/src/common/model_backend.h index de439092..523cd3dc 100644 --- a/server/src/common/model_backend.h +++ b/server/src/common/model_backend.h @@ -10,6 +10,7 @@ #pragma once +#include #include #include #include @@ -25,16 +26,39 @@ namespace dflash::common { // Token callback for streaming generation. Called once per committed token. // Return true to continue generation, false to abort. using TokenCallback = std::function; +using CancelCallback = std::function; // ─── I/O handle passed to backend methods that need protocol output ───── struct DaemonIO { int stream_fd = -1; + DaemonIO() = default; + DaemonIO(const DaemonIO & other) + : stream_fd(other.stream_fd) + , on_token(other.on_token) + , is_cancelled(other.is_cancelled) + , cancelled(other.cancelled.load(std::memory_order_relaxed)) {} + + DaemonIO & operator=(const DaemonIO & other) { + if (this == &other) return *this; + stream_fd = other.stream_fd; + on_token = other.on_token; + is_cancelled = other.is_cancelled; + cancelled.store(other.cancelled.load(std::memory_order_relaxed), + std::memory_order_relaxed); + return *this; + } + // Optional token callback. When set, emit() calls this for each token // (excluding the -1 sentinel). If it returns false, the `cancelled` // flag is set and the caller should abort generation. TokenCallback on_token; - mutable bool cancelled = false; + CancelCallback is_cancelled; + mutable std::atomic cancelled{false}; + + // Observe cooperative cancellation even before a token callback fires + // (for example while a long prefill/spec-decode step is still working). + bool should_cancel() const; // Write a single int32 to the stream fd (token or -1 sentinel). // Also invokes on_token if set. Sets cancelled=true if on_token @@ -45,6 +69,22 @@ struct DaemonIO { DaemonIO with_token_callback(const TokenCallback & cb) const; }; +enum class DaemonComputeResult { + Success, + Cancelled, + Failed, +}; + +inline DaemonComputeResult classify_daemon_compute_result( + ggml_status status, + const DaemonIO & io) { + // A completed graph failure is a backend error even if the client also + // disconnected while the graph was running. + if (status != GGML_STATUS_SUCCESS) return DaemonComputeResult::Failed; + if (io.should_cancel()) return DaemonComputeResult::Cancelled; + return DaemonComputeResult::Success; +} + // ─── Generate request/result ──────────────────────────────────────────── // Thinking-budget force-close hook. Mirrors antirez/ds4 ds4_eval.c's @@ -131,6 +171,11 @@ struct GenerateResult { float accept_rate = 0.0f; // True when spec decode actually ran (accept_rate==0 still needs a bandit update). bool spec_decode_ran = false; + // True when decode emitted only tokens that the API layer suppresses + // (for example an immediate EOS/EOT). This is semantically equivalent + // to zero output for clients and should take the same AR retry path as + // an empty token vector. + bool empty_visible_output = false; }; // ─── Backend interface ────────────────────────────────────────────────── @@ -160,7 +205,7 @@ struct ModelBackend { if (!should_retry_empty_spec_decode(req, result)) return result; std::fprintf(stderr, - "[backend] spec-decode produced zero tokens; retrying with AR decode\n"); + "[backend] spec-decode produced no visible output; retrying with AR decode\n"); GenerateRequest retry = req; retry.force_ar_decode = true; return merge_empty_spec_retry_result(result, generate(retry, io)); @@ -188,7 +233,7 @@ struct ModelBackend { if (!should_retry_empty_spec_decode(req, result)) return result; std::fprintf(stderr, - "[backend] restored spec-decode produced zero tokens; retrying with AR decode\n"); + "[backend] restored spec-decode produced no visible output; retrying with AR decode\n"); GenerateRequest retry = req; retry.force_ar_decode = true; return merge_empty_spec_retry_result(result, @@ -201,7 +246,7 @@ struct ModelBackend { && !req.force_ar_decode && result.ok && result.spec_decode_ran - && result.tokens.empty(); + && (result.tokens.empty() || result.empty_visible_output); } static GenerateResult merge_empty_spec_retry_result( diff --git a/server/src/gemma4/gemma4_backend.cpp b/server/src/gemma4/gemma4_backend.cpp index e09ce575..0525d7cb 100644 --- a/server/src/gemma4/gemma4_backend.cpp +++ b/server/src/gemma4/gemma4_backend.cpp @@ -202,7 +202,6 @@ bool Gemma4Backend::unpark(const std::string & what) { int Gemma4Backend::do_prefill(const std::vector & tokens, const DaemonIO & io, int kv_offset) { - (void)io; const int n = (int)tokens.size(); const int hidden = w_.n_embd; const int chunk = cfg_.chunk; @@ -212,6 +211,7 @@ int Gemma4Backend::do_prefill(const std::vector & tokens, int pos = 0; while (pos < n) { + if (io.should_cancel()) return -1; int len = std::min(chunk, n - pos); // Limit chunk to avoid ring-buffer wrap for SWA layers @@ -233,6 +233,7 @@ int Gemma4Backend::do_prefill(const std::vector & tokens, std::fprintf(stderr, "[gemma4] prefill step failed at pos=%d\n", kv_pos); return -1; } + if (io.should_cancel()) return -1; pos += len; cache_.cur_pos = kv_offset + pos; @@ -368,7 +369,7 @@ bool Gemma4Backend::do_decode(int committed, int n_gen, io.emit(next); committed++; cache_.cur_pos = committed; - if (io.cancelled) break; + if (io.should_cancel()) break; // Check EOS if (next == w_.eos_id || next == w_.eos_chat_id) break; @@ -586,7 +587,7 @@ bool Gemma4Backend::do_spec_decode(int committed, int n_gen, out_tokens.push_back(tok); io.emit(tok); emitted++; - if (io.cancelled) break; + if (io.should_cancel()) break; if (tok == w_.eos_id || tok == w_.eos_chat_id) { hit_eos = true; break; } @@ -596,7 +597,7 @@ bool Gemma4Backend::do_spec_decode(int committed, int n_gen, n_generated += emitted; n_accept_sum += std::min(accept_n, emitted); n_draft_steps++; - if (io.cancelled) break; + if (io.should_cancel()) break; if (hit_eos) break; } @@ -640,9 +641,17 @@ GenerateResult Gemma4Backend::generate(const GenerateRequest & req, result.prefill_s = std::chrono::duration( std::chrono::steady_clock::now() - t_prefill_start).count(); if (committed < 0) { + if (out_io.should_cancel()) { + result.ok = true; + return result; + } result.error = "prefill"; return result; } + if (out_io.should_cancel()) { + result.ok = true; + return result; + } // Inline snapshot at snap_pos for prefix cache if (req.snap_slot >= 0 && req.snap_pos > 0 && req.snap_pos <= committed) { @@ -703,7 +712,7 @@ GenerateResult Gemma4Backend::generate(const GenerateRequest & req, } result.tokens.push_back(first); out_io.emit(first); - if (out_io.cancelled) { + if (out_io.should_cancel()) { out_io.emit(-1); result.ok = true; return result; @@ -804,6 +813,10 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot, std::vector delta(req.prompt.begin() + snap_pos, req.prompt.end()); committed = do_prefill(delta, out_io, /*kv_offset=*/snap_pos); if (committed < 0) { + if (out_io.should_cancel()) { + result.ok = true; + return result; + } result.error = "prefill"; return result; } @@ -815,6 +828,10 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot, // else: prompt_len == snap_pos → no delta, committed stays at snap_pos result.prefill_s = std::chrono::duration( std::chrono::steady_clock::now() - t_prefill_start).count(); + if (out_io.should_cancel()) { + result.ok = true; + return result; + } // Inline snapshot at snap_pos for prefix cache (new snap from this request) if (req.snap_slot >= 0 && req.snap_pos > 0 && req.snap_pos <= committed) { @@ -883,7 +900,7 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot, } result.tokens.push_back(first); out_io.emit(first); - if (out_io.cancelled) { + if (out_io.should_cancel()) { out_io.emit(-1); result.ok = true; return result; diff --git a/server/src/gemma4/gemma4_layer_split_adapter.cpp b/server/src/gemma4/gemma4_layer_split_adapter.cpp index 4e7c6a87..95b3925a 100644 --- a/server/src/gemma4/gemma4_layer_split_adapter.cpp +++ b/server/src/gemma4/gemma4_layer_split_adapter.cpp @@ -374,7 +374,7 @@ bool Gemma4LayerSplitAdapter::decode_ar( const auto & w = shards_.front().weights; out_tokens.push_back(last_tok); io.emit(last_tok); - if (io.cancelled) { + if (io.should_cancel()) { io.emit(-1); return true; } @@ -392,7 +392,7 @@ bool Gemma4LayerSplitAdapter::decode_ar( out_tokens.push_back(last_tok); io.emit(last_tok); ++committed; - if (io.cancelled) break; + if (io.should_cancel()) break; if (last_tok == w.eos_id || last_tok == w.eos_chat_id) break; } io.emit(-1); diff --git a/server/src/laguna/laguna_backend.cpp b/server/src/laguna/laguna_backend.cpp index d6108e4e..ad8dbf27 100644 --- a/server/src/laguna/laguna_backend.cpp +++ b/server/src/laguna/laguna_backend.cpp @@ -171,11 +171,19 @@ GenerateResult LagunaBackend::generate(const GenerateRequest & req, bool ok = true; const int n_chunks = (N + args_.chunk - 1) / args_.chunk; for (int c = 0; c < n_chunks && ok; ++c) { + if (out_io.should_cancel()) { + result.ok = true; + return result; + } const int kv_start = c * args_.chunk; const int n_tok = std::min(args_.chunk, N - c * args_.chunk); ok = laguna_step(backend_, w_, cache_, embed_pf.data() + (size_t)kv_start * w_.n_embd, n_tok, kv_start, no_mask, last_logits); + if (out_io.should_cancel()) { + result.ok = true; + return result; + } } if (!ok) { result.error = "prefill"; return result; } auto t_pf1 = std::chrono::steady_clock::now(); @@ -267,7 +275,7 @@ GenerateResult LagunaBackend::generate(const GenerateRequest & req, history.push_back(next_tok); if (should_emit) { out_io.emit(next_tok); - if (out_io.cancelled) break; + if (out_io.should_cancel()) break; } if (!w_.embedder.embed(&next_tok, 1, embed_step.data())) { ok = false; break; } std::vector step_logits; @@ -328,12 +336,20 @@ GenerateResult LagunaBackend::restore_and_generate(int slot, bool ok = true; const int n_chunks = (diff_n + args_.chunk - 1) / args_.chunk; for (int c = 0; c < n_chunks && ok; ++c) { + if (out_io.should_cancel()) { + result.ok = true; + return result; + } const int off = c * args_.chunk; const int n_tok = std::min(args_.chunk, diff_n - off); const int starts = kv_start + off; ok = laguna_step(backend_, w_, cache_, embed_diff.data() + (size_t)off * w_.n_embd, n_tok, starts, no_mask, last_logits); + if (out_io.should_cancel()) { + result.ok = true; + return result; + } } if (!ok) { result.error = "prefill"; return result; } @@ -402,7 +418,7 @@ GenerateResult LagunaBackend::restore_and_generate(int slot, history.push_back(next_tok); result.tokens.push_back(next_tok); out_io.emit(next_tok); - if (out_io.cancelled) break; + if (out_io.should_cancel()) break; if (!w_.embedder.embed(&next_tok, 1, embed_step.data())) { ok = false; break; } std::vector step_logits; if (!laguna_step(backend_, w_, cache_, embed_step.data(), 1, diff --git a/server/src/qwen3/qwen3_backend.cpp b/server/src/qwen3/qwen3_backend.cpp index e2adc7f6..9fcd96eb 100644 --- a/server/src/qwen3/qwen3_backend.cpp +++ b/server/src/qwen3/qwen3_backend.cpp @@ -384,6 +384,7 @@ int Qwen3Backend::do_prefill(const std::vector & tokens, std::vector embed_buf((size_t)chunk * hidden); for (int start = 0; start < total; start += chunk) { + if (io.should_cancel()) return -1; const int n = std::min(chunk, total - start); // CPU embedding: read rows from tok_embd (which is on GPU) @@ -422,6 +423,7 @@ int Qwen3Backend::do_prefill(const std::vector & tokens, if (!do_step(embed_buf.data(), n, kv_offset + start, logits)) { return -1; } + if (io.should_cancel()) return -1; committed = kv_offset + start + n; cache_.cur_pos = committed; last_logits_ = std::move(logits); @@ -464,7 +466,7 @@ bool Qwen3Backend::do_decode(int committed, int n_gen, io.emit(next); committed++; cache_.cur_pos = committed; - if (io.cancelled) break; + if (io.should_cancel()) break; // Check EOS if (next == 151643 || next == 151645) break; @@ -525,9 +527,17 @@ GenerateResult Qwen3Backend::generate(const GenerateRequest & req, // Prefill const int committed = do_prefill(req.prompt, out_io); if (committed < 0) { + if (out_io.should_cancel()) { + result.ok = true; + return result; + } result.error = "prefill"; return result; } + if (out_io.should_cancel()) { + result.ok = true; + return result; + } // Inline snapshot at snap_pos for prefix cache if (req.snap_slot >= 0 && req.snap_pos > 0 && req.snap_pos <= committed) { @@ -596,7 +606,7 @@ GenerateResult Qwen3Backend::generate(const GenerateRequest & req, } result.tokens.push_back(first); out_io.emit(first); - if (out_io.cancelled) { + if (out_io.should_cancel()) { out_io.emit(-1); result.ok = true; return result; @@ -691,10 +701,18 @@ GenerateResult Qwen3Backend::restore_and_generate(int slot, req.prompt.end()); const int committed = do_prefill(remaining, out_io, prefix_len); if (committed < 0) { + if (out_io.should_cancel()) { + result.ok = true; + return result; + } result.error = "prefill after restore"; return result; } } + if (out_io.should_cancel()) { + result.ok = true; + return result; + } // Now generate (decode) from here const int total_committed = (int)req.prompt.size(); @@ -763,7 +781,7 @@ GenerateResult Qwen3Backend::restore_and_generate(int slot, } result.tokens.push_back(first); out_io.emit(first); - if (out_io.cancelled) { + if (out_io.should_cancel()) { out_io.emit(-1); result.ok = true; return result; diff --git a/server/src/qwen35/qwen35_backend.cpp b/server/src/qwen35/qwen35_backend.cpp index e3b161d8..ba8f74f9 100644 --- a/server/src/qwen35/qwen35_backend.cpp +++ b/server/src/qwen35/qwen35_backend.cpp @@ -40,6 +40,15 @@ static float bf16_bits_to_f32(uint16_t bits) { ( ((w).eos_chat_id >= 0 && (tok) == (w).eos_chat_id) \ || ((w).eos_id >= 0 && (tok) == (w).eos_id ) ) +static bool qwen35_empty_visible_output(const std::vector & tokens, + const TargetWeights & w) { + if (tokens.empty()) return true; + for (int32_t tok : tokens) { + if (!IS_EOS_TOK(tok, w)) return false; + } + return true; +} + // ── Construction / destruction ────────────────────────────────────────── Qwen35Backend::Qwen35Backend(const Qwen35Config & cfg) : cfg_(cfg) {} @@ -562,9 +571,17 @@ GenerateResult Qwen35Backend::generate(const GenerateRequest & req, auto t_prefill_start = std::chrono::steady_clock::now(); const int committed = do_prefill(req.prompt, out_io, req.snap_pos, req.snap_slot); if (committed < 0) { + if (out_io.should_cancel()) { + result.ok = true; + return result; + } result.error = "prefill"; return result; } + if (out_io.should_cancel()) { + result.ok = true; + return result; + } auto t_prefill_end = std::chrono::steady_clock::now(); result.prefill_s = std::chrono::duration(t_prefill_end - t_prefill_start).count(); @@ -590,6 +607,10 @@ GenerateResult Qwen35Backend::generate(const GenerateRequest & req, req.hint_tokens, &req.budget_hook, &result.budget_forced_close, &result.degenerate_decode_close); + if (decode_ok) { + result.empty_visible_output = + qwen35_empty_visible_output(result.tokens, w_); + } } if (!decode_ok) { result.error = "decode"; @@ -657,6 +678,10 @@ GenerateResult Qwen35Backend::restore_and_generate(int slot, std::vector delta = restore_prompt_delta(req.prompt, snap_pos); committed = do_prefill(delta, out_io, req.snap_pos, req.snap_slot, /*kv_offset=*/snap_pos); if (committed < 0) { + if (out_io.should_cancel()) { + result.ok = true; + return result; + } result.error = "prefill"; return result; } @@ -668,6 +693,10 @@ GenerateResult Qwen35Backend::restore_and_generate(int slot, out_io.emit(-1); return result; } + if (out_io.should_cancel()) { + result.ok = true; + return result; + } // Decode if (req.n_gen > 0) { @@ -691,6 +720,10 @@ GenerateResult Qwen35Backend::restore_and_generate(int slot, req.hint_tokens, &req.budget_hook, &result.budget_forced_close, &result.degenerate_decode_close); + if (decode_ok) { + result.empty_visible_output = + qwen35_empty_visible_output(result.tokens, w_); + } } if (!decode_ok) { result.error = "decode"; @@ -714,8 +747,6 @@ int Qwen35Backend::do_prefill(const std::vector & tokens, const DaemonIO & io, int snap_pos, int snap_slot, int kv_offset) { - (void)io; - const int hidden = w_.n_embd; const int vocab = w_.n_vocab; int prefill_ubatch = 512; @@ -740,6 +771,8 @@ int Qwen35Backend::do_prefill(const std::vector & tokens, std::vector embed_buf((size_t)hidden * prefill_ubatch); int committed = kv_offset; for (int start = 0; start < prompt_len;) { + if (io.should_cancel()) return -1; + const int kv_pos = kv_offset + start; int n_tokens = std::min(prefill_ubatch, prompt_len - start); @@ -817,10 +850,12 @@ int Qwen35Backend::do_prefill(const std::vector & tokens, // Compute auto st = ggml_backend_graph_compute(target_backend_, sg_.gf); - if (st != GGML_STATUS_SUCCESS) { + const auto compute_result = classify_daemon_compute_result(st, io); + if (compute_result == DaemonComputeResult::Failed) { std::fprintf(stderr, "prefill compute @%d failed\n", kv_pos); return -1; } + if (compute_result == DaemonComputeResult::Cancelled) return -1; after_target_compute(sg_, kv_pos, n_tokens); int32_t last_tok = -1; @@ -1026,7 +1061,7 @@ bool Qwen35Backend::do_ar_decode(int committed, int n_gen, io.emit(next_tok); committed++; cache_.cur_pos = committed; - if (io.cancelled) break; + if (io.should_cancel()) break; if (IS_EOS_TOK(next_tok, w_)) break; @@ -1184,6 +1219,7 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, auto t_dec0 = std::chrono::steady_clock::now(); while (n_generated < n_gen) { + if (io.should_cancel()) break; const int need_commit_budget = n_gen - n_generated; // Budget tail-off: when remaining budget is within the spec-decode @@ -1250,12 +1286,14 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, draft_feature_mirror_can_view(feature_mirror_, committed, draft_ctx, mirror_slot0); if (use_remote_draft) { + if (io.should_cancel()) break; local_hidden.clear(); if (!remote_draft_.propose(committed, draft_ctx, noise_embed, local_hidden)) { std::fprintf(stderr, "spec-decode: remote draft propose failed\n"); step_graph_destroy(draft_sg); return false; } + if (io.should_cancel()) break; } else { if (!build_draft_step(draft_sg, dw_, /*lm_head=*/nullptr, draft_backend_, draft_ctx, use_mirror_view ? &feature_mirror_ : nullptr, @@ -1283,11 +1321,13 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, sizeof(int32_t) * pos_k.size()); auto st = ggml_backend_graph_compute(draft_backend_, draft_sg.gf); - if (st != GGML_STATUS_SUCCESS) { + const auto compute_result = classify_daemon_compute_result(st, io); + if (compute_result == DaemonComputeResult::Failed) { std::fprintf(stderr, "spec-decode: draft compute failed\n"); step_graph_destroy(draft_sg); return false; } + if (compute_result == DaemonComputeResult::Cancelled) break; // Read draft hidden states to host for LM-head projection. local_hidden.resize((size_t)hidden * q_len); @@ -1319,6 +1359,7 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, step_graph_destroy(draft_sg); return false; } + if (io.should_cancel()) break; int verify_last_tok = -1; if (!target->verify_batch(draft_tok, committed, verify_last_tok, &target_tok)) { @@ -1351,6 +1392,7 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, step_graph_destroy(draft_sg); return false; } + if (io.should_cancel()) break; std::vector replay_tok((size_t)commit_n); for (int i = 0; i < commit_n; i++) { @@ -1382,7 +1424,7 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, out_tokens.push_back(replay_tok[i]); io.emit(replay_tok[i]); emitted++; - if (io.cancelled) break; + if (io.should_cancel()) break; if (IS_EOS_TOK(replay_tok[i], w_)) { hit_eos = true; break; } } committed += emitted; @@ -1390,7 +1432,7 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, n_generated += emitted; n_accept_sum += std::min(accept_n, emitted); n_draft_steps++; - if (io.cancelled) break; + if (io.should_cancel()) break; if (hit_eos) break; } diff --git a/server/src/qwen35/qwen35_layer_split_adapter.cpp b/server/src/qwen35/qwen35_layer_split_adapter.cpp index a9ece2c3..4598cdfd 100644 --- a/server/src/qwen35/qwen35_layer_split_adapter.cpp +++ b/server/src/qwen35/qwen35_layer_split_adapter.cpp @@ -380,7 +380,7 @@ bool Qwen35LayerSplitAdapter::decode_ar( out_tokens.push_back(last_tok); io.emit(last_tok); - if (io.cancelled) { + if (io.should_cancel()) { io.emit(-1); return true; } @@ -401,7 +401,7 @@ bool Qwen35LayerSplitAdapter::decode_ar( } out_tokens.push_back(next_tok); io.emit(next_tok); - if (io.cancelled) break; + if (io.should_cancel()) break; if (is_eos_tok(next_tok, shards_.front().weights)) break; last_tok = next_tok; ++committed; diff --git a/server/src/qwen35moe/qwen35moe_backend.cpp b/server/src/qwen35moe/qwen35moe_backend.cpp index 7eae3855..be89b7ce 100644 --- a/server/src/qwen35moe/qwen35moe_backend.cpp +++ b/server/src/qwen35moe/qwen35moe_backend.cpp @@ -449,7 +449,7 @@ bool Qwen35MoeBackend::run_ar_decode_path(int committed, int n_gen, io.emit(next_tok); committed++; target_cache().cur_pos = committed; - if (io.cancelled) break; + if (io.should_cancel()) break; if (is_eos_tok(next_tok, target_weights())) break; } step_graph_destroy(layer_sg); @@ -524,6 +524,7 @@ GenerateResult Qwen35MoeBackend::generate(const GenerateRequest & req, // Helper: process one token through all layers (host-based with cached graphs) auto process_one_token = [&](int kv_pos) -> bool { for (int il = 0; il < n_layer; ++il) { + if (out_io.should_cancel()) return false; const bool is_attn = (((il + 1) % target_weights().full_attention_interval) == 0); const auto t0 = HybridClock::now(); @@ -551,7 +552,9 @@ GenerateResult Qwen35MoeBackend::generate(const GenerateRequest & req, build_us_total += elapsed_us(t0, t1); auto st = ggml_backend_graph_compute(target_backend(), sg_ptr->gf); - if (st != GGML_STATUS_SUCCESS) return false; + const auto compute_result = classify_daemon_compute_result(st, out_io); + if (compute_result == DaemonComputeResult::Failed) return false; + if (compute_result == DaemonComputeResult::Cancelled) return false; const auto t2 = HybridClock::now(); compute_us_total += elapsed_us(t1, t2); @@ -645,6 +648,11 @@ GenerateResult Qwen35MoeBackend::generate(const GenerateRequest & req, const int n_expert_used = target_weights().n_expert_used; std::vector embed_all((size_t)prompt_len * (size_t)hidden); for (int i = 0; i < prompt_len; ++i) { + if (out_io.should_cancel()) { + result.ok = true; + cleanup_graphs(); + return result; + } int32_t tok = req.prompt[(size_t)i]; if (!target_weights().embedder.embed(&tok, 1, embed_all.data() + (size_t)i * (size_t)hidden)) { result.error = "prefill_embed"; @@ -659,6 +667,11 @@ GenerateResult Qwen35MoeBackend::generate(const GenerateRequest & req, const auto & L = target_weights().layers[(size_t)il]; for (int chunk_start = 0; chunk_start < prompt_len; chunk_start += prefill_chunk) { + if (out_io.should_cancel()) { + result.ok = true; + cleanup_graphs(); + return result; + } const int chunk_len = std::min(prefill_chunk, prompt_len - chunk_start); const auto t0 = HybridClock::now(); @@ -708,12 +721,19 @@ GenerateResult Qwen35MoeBackend::generate(const GenerateRequest & req, // Compute batched pre-FFN auto st = ggml_backend_graph_compute(target_backend(), prefill_sg.gf); - if (st != GGML_STATUS_SUCCESS) { + const auto compute_result = classify_daemon_compute_result(st, out_io); + if (compute_result == DaemonComputeResult::Failed) { result.error = "prefill_compute"; step_graph_destroy(prefill_sg); cleanup_graphs(); return result; } + if (compute_result == DaemonComputeResult::Cancelled) { + result.ok = true; + step_graph_destroy(prefill_sg); + cleanup_graphs(); + return result; + } const auto t2 = HybridClock::now(); compute_us_total += elapsed_us(t1, t2); @@ -816,6 +836,11 @@ GenerateResult Qwen35MoeBackend::generate(const GenerateRequest & req, target_cache().cur_pos = committed; auto t_prefill_end = std::chrono::steady_clock::now(); result.prefill_s = std::chrono::duration(t_prefill_end - t_prefill_start).count(); + if (out_io.should_cancel()) { + result.ok = true; + cleanup_graphs(); + return result; + } // ── Hybrid Decode ── if (req.n_gen > 0) { @@ -879,6 +904,11 @@ GenerateResult Qwen35MoeBackend::generate(const GenerateRequest & req, } result.tokens.push_back(first_tok); out_io.emit(first_tok); + if (out_io.should_cancel()) { + result.ok = true; + cleanup_graphs(); + return result; + } if (!is_eos_tok(first_tok, target_weights())) { committed++; target_cache().cur_pos = committed; @@ -894,6 +924,11 @@ GenerateResult Qwen35MoeBackend::generate(const GenerateRequest & req, } embed_us_total += elapsed_us(t_emb0, HybridClock::now()); if (!process_one_token(committed)) { + if (out_io.should_cancel()) { + result.ok = true; + cleanup_graphs(); + return result; + } result.error = "decode"; cleanup_graphs(); return result; @@ -920,7 +955,7 @@ GenerateResult Qwen35MoeBackend::generate(const GenerateRequest & req, out_io.emit(next_tok); committed++; target_cache().cur_pos = committed; - if (out_io.cancelled) break; + if (out_io.should_cancel()) break; if (is_eos_tok(next_tok, target_weights())) break; } } @@ -1318,7 +1353,7 @@ bool Qwen35MoeBackend::do_hybrid_spec_decode(int committed, int n_gen, out_tokens.push_back(replay_tok[i]); io.emit(replay_tok[i]); emitted++; - if (io.cancelled) break; + if (io.should_cancel()) break; if (is_eos_tok(replay_tok[i], target_weights())) { hit_eos = true; break; } } committed += emitted; @@ -1326,7 +1361,7 @@ bool Qwen35MoeBackend::do_hybrid_spec_decode(int committed, int n_gen, n_generated += emitted; n_accept_sum += std::min(accept_n, emitted); n_draft_steps++; - if (io.cancelled) break; + if (io.should_cancel()) break; if (hit_eos) break; } diff --git a/server/src/server/http_server.cpp b/server/src/server/http_server.cpp index 362c2f4d..ce55f3ea 100644 --- a/server/src/server/http_server.cpp +++ b/server/src/server/http_server.cpp @@ -12,6 +12,8 @@ #include #include #include +#include +#include #include #include @@ -24,6 +26,21 @@ namespace dflash::common { +// Linux reports peer half-close promptly via POLLRDHUP. macOS/BSD do not +// expose that flag, so those builds rely on HUP/ERR plus readable EOF detected +// by the MSG_PEEK fallback in client_socket_disconnected(). +#ifdef POLLRDHUP +static constexpr short kDisconnectPollEvents = + static_cast(POLLIN | POLLHUP | POLLERR | POLLNVAL | POLLRDHUP); +static constexpr short kDisconnectCloseEvents = + static_cast(POLLHUP | POLLERR | POLLNVAL | POLLRDHUP); +#else +static constexpr short kDisconnectPollEvents = + static_cast(POLLIN | POLLHUP | POLLERR | POLLNVAL); +static constexpr short kDisconnectCloseEvents = + static_cast(POLLHUP | POLLERR | POLLNVAL); +#endif + // ─── /props constants ─────────────────────────────────────────────────── // // SERVER_NAME / SERVER_VERSION mirror the Python server's identity strings @@ -77,6 +94,132 @@ static size_t json_array_size(const json & value) { return value.is_array() ? value.size() : 0; } +static bool client_socket_disconnected(int fd) { + struct pollfd pfd{fd, kDisconnectPollEvents, 0}; + int pr; + do { + pr = poll(&pfd, 1, 0); + } while (pr < 0 && errno == EINTR); + if (pr <= 0) return false; + if (pfd.revents & kDisconnectCloseEvents) return true; + if (!(pfd.revents & POLLIN)) return false; + + char byte; + ssize_t n = recv(fd, &byte, 1, MSG_PEEK); + if (n == 0) return true; + if (n < 0 && (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)) { + return false; + } + return n < 0; +} + +class DisconnectPoller { +public: + static DisconnectPoller & instance() { + static DisconnectPoller poller; + return poller; + } + + uint64_t watch(int fd, std::shared_ptr> cancelled) { + std::lock_guard lk(mu_); + const uint64_t id = next_id_++; + watches_.push_back({id, fd, std::move(cancelled)}); + cv_.notify_one(); + return id; + } + + void unwatch(uint64_t id) { + std::lock_guard lk(mu_); + watches_.erase( + std::remove_if(watches_.begin(), watches_.end(), + [&](const Watch & watch) { return watch.id == id; }), + watches_.end()); + } + +private: + struct Watch { + uint64_t id; + int fd; + std::weak_ptr> cancelled; + }; + + DisconnectPoller() { + thread_ = std::thread([this]() { + run(); + }); + } + + ~DisconnectPoller() { + stopping_.store(true, std::memory_order_relaxed); + cv_.notify_one(); + if (thread_.joinable()) { + thread_.join(); + } + } + + void run() { + while (!stopping_.load(std::memory_order_relaxed)) { + std::vector>>> active; + { + std::unique_lock lk(mu_); + cv_.wait_for(lk, std::chrono::milliseconds(100), [&]() { + return stopping_.load(std::memory_order_relaxed) || + !watches_.empty(); + }); + if (stopping_.load(std::memory_order_relaxed)) return; + + for (auto it = watches_.begin(); it != watches_.end(); ) { + auto cancelled = it->cancelled.lock(); + if (!cancelled) { + it = watches_.erase(it); + continue; + } + active.push_back({it->fd, std::move(cancelled)}); + ++it; + } + } + + for (const auto & [fd, cancelled] : active) { + if (!cancelled->load(std::memory_order_relaxed) && + client_socket_disconnected(fd)) { + cancelled->store(true, std::memory_order_relaxed); + } + } + } + } + + std::mutex mu_; + std::condition_variable cv_; + std::vector watches_; + std::thread thread_; + std::atomic stopping_{false}; + uint64_t next_id_ = 1; +}; + +class RequestDisconnectWatcher { +public: + explicit RequestDisconnectWatcher(int fd) + : fd_(fd) + , cancelled_(std::make_shared>(false)) + , watch_id_(DisconnectPoller::instance().watch(fd_, cancelled_)) { + } + + ~RequestDisconnectWatcher() { + if (watch_id_ != 0) { + DisconnectPoller::instance().unwatch(watch_id_); + } + } + + bool cancelled() const { + return cancelled_->load(std::memory_order_relaxed); + } + +private: + int fd_; + std::shared_ptr> cancelled_; + uint64_t watch_id_ = 0; +}; + // Build the /props response body. // // Non-static so unit tests can call it directly (declared in http_server.h). @@ -751,14 +894,44 @@ void HttpServer::handle_client(int fd) { bool HttpServer::route_request(int fd, const HttpRequest & hr) { if (hr.method != "POST") return false; + const bool generation_route = + hr.path == "/v1/chat/completions" || + hr.path == "/v1/messages" || + hr.path == "/v1/responses"; + + // Watch generation routes from the start of request handling. Large agent + // prompts can spend meaningful time in render/tokenize before any SSE + // write occurs, so cancellation cannot wait until generation starts. + // count_tokens is bounded to parse/tokenize/response and avoids watcher + // registration on its short request path. + std::unique_ptr disconnect_watcher; + if (generation_route) { + int flags = fcntl(fd, F_GETFL, 0); + if (flags >= 0) fcntl(fd, F_SETFL, flags | O_NONBLOCK); + disconnect_watcher = std::make_unique(fd); + } + std::fprintf(stderr, "[server] request path=%s body_bytes=%zu\n", hr.path.c_str(), hr.body.size()); ParsedRequest req; std::string err; + auto request_cancelled = [&]() { + return disconnect_watcher && disconnect_watcher->cancelled(); + }; + auto drop_cancelled_request = [&](const char * phase) { + std::fprintf(stderr, + "[server] client disconnected during %s before enqueue; " + "dropping request path=%s id=%s\n", + phase, + hr.path.c_str(), + req.response_id.c_str()); + return true; + }; try { json body = json::parse(hr.body); + if (request_cancelled()) return drop_cancelled_request("json_parse"); // Common fields. req.stream = body.value("stream", false); @@ -880,10 +1053,12 @@ bool HttpServer::route_request(int fd, const HttpRequest & hr) { } else { return false; } + if (request_cancelled()) return drop_cancelled_request("message_parse"); // Render messages to text and tokenize. std::vector chat_msgs = normalize_chat_messages(req.messages, req.format, tool_memory_); + if (request_cancelled()) return drop_cancelled_request("message_normalize"); // Determine thinking mode BEFORE rendering so the template can inject // the \n\n\n\n block when thinking is disabled. @@ -1044,22 +1219,29 @@ bool HttpServer::route_request(int fd, const HttpRequest & hr) { true, enable_thinking, tools_json); } - req.prompt_tokens = tokenizer_.encode(rendered); + if (request_cancelled()) return drop_cancelled_request("template_render"); + req.prompt_tokens = tokenizer_.encode(rendered, request_cancelled); + if (request_cancelled()) return drop_cancelled_request("tokenize"); // count_tokens: short-circuit after tokenization. Skip generation // entirely — Anthropic's contract is just `{"input_tokens": N}`. if (count_tokens_only) { + if (request_cancelled()) return drop_cancelled_request("count_tokens"); json resp = {{"input_tokens", (int)req.prompt_tokens.size()}}; send_response(fd, 200, "application/json", resp.dump() + "\n"); return true; } + } catch (const TokenizationCancelled &) { + return drop_cancelled_request("tokenize"); } catch (const std::exception & e) { + if (request_cancelled()) return drop_cancelled_request("parse_error"); send_error(fd, 400, std::string("JSON parse error: ") + e.what()); return true; // handled (with error) } // Check context length. + if (request_cancelled()) return drop_cancelled_request("context_check"); if ((int)req.prompt_tokens.size() + req.max_output > config_.max_ctx) { send_error(fd, 400, "prompt + max_tokens exceeds context window"); return true; @@ -1080,21 +1262,37 @@ bool HttpServer::route_request(int fd, const HttpRequest & hr) { req.stop_sequences.size(), req.model.c_str()); - // Set socket non-blocking for send() stall detection during streaming. - int flags = fcntl(fd, F_GETFL, 0); - if (flags >= 0) fcntl(fd, F_SETFL, flags | O_NONBLOCK); - // Enqueue job and wait for worker. ServerJob job; job.fd = fd; job.req = std::move(req); + job.cancelled.store(request_cancelled(), std::memory_order_relaxed); enqueue(&job); - // Wait for the worker to signal completion. + // Wait for the worker to signal completion. While the worker is busy it + // cannot observe a pre-token disconnect by sending SSE, so the client + // thread keeps watching the socket and flips the job's cooperative + // cancellation flag if the peer goes away. { std::unique_lock lk(job.mu); - job.cv.wait(lk, [&]() { return job.done; }); + bool disconnect_logged = false; + while (!job.done) { + if (job.cv.wait_for(lk, std::chrono::milliseconds(100), + [&]() { return job.done; })) { + break; + } + lk.unlock(); + if (!disconnect_logged && request_cancelled()) { + job.cancelled.store(true, std::memory_order_relaxed); + disconnect_logged = true; + std::fprintf(stderr, + "[server] client disconnected before worker completed; " + "cancelling chat %s\n", + job.req.response_id.c_str()); + } + lk.lock(); + } } return true; @@ -1110,6 +1308,9 @@ void HttpServer::worker_loop() { int fd = job->fd; const auto & req = job->req; auto started_at = std::chrono::steady_clock::now(); + auto job_cancelled = [job]() { + return job->cancelled.load(std::memory_order_relaxed); + }; auto finish_job = [&]() { std::lock_guard lk(job->mu); @@ -1140,10 +1341,19 @@ void HttpServer::worker_loop() { req.max_output, json_array_size(req.tools)); + if (job_cancelled()) { + std::fprintf(stderr, + "[server] chat CANCELLED %s before generation started\n", + req.response_id.c_str()); + finish_job(); + continue; + } + // Send SSE headers. if (req.stream) { if (!send_sse_headers(fd)) { // Client already disconnected before we started. + job->cancelled.store(true, std::memory_order_relaxed); finish_job(); continue; } @@ -1165,6 +1375,7 @@ void HttpServer::worker_loop() { } } if (!start_ok) { + job->cancelled.store(true, std::memory_order_relaxed); finish_job(); continue; } @@ -1385,7 +1596,13 @@ void HttpServer::worker_loop() { cold_req.snap_pos = cold_boundary; // save at end of prefix DaemonIO cold_io; cold_io.stream_fd = -1; + cold_io.is_cancelled = job_cancelled; auto cold_result = backend_.generate_with_empty_spec_fallback(cold_req, cold_io); + if (cold_io.should_cancel()) { + job->cancelled.store(true, std::memory_order_relaxed); + finish_job(); + continue; + } if (cold_result.ok && backend_.snapshot_used(DISK_STAGING_SLOT)) { disk_cache_.learn_layout(DISK_STAGING_SLOT); std::vector prefix_tokens(effective_prompt.begin(), @@ -1428,12 +1645,17 @@ void HttpServer::worker_loop() { // Set up DaemonIO with on_token callback for streaming + disconnect. DaemonIO io; io.stream_fd = -1; // no pipe — we write SSE directly + io.is_cancelled = job_cancelled; int completion_tokens = 0; + bool visible_output_seen = false; bool client_disconnected = false; io.on_token = [&](int32_t token) -> bool { - if (client_disconnected) return false; + if (client_disconnected || job_cancelled()) { + client_disconnected = true; + return false; + } completion_tokens++; // Skip EOS/EOT/special tokens — don't forward to SSE. @@ -1443,8 +1665,13 @@ void HttpServer::worker_loop() { const std::string & raw = tokenizer_.raw_token(token); + // Reasoning delimiters are intentionally counted as visible stream + // output. The cache gate below is meant to reject zero-output + // disconnect/empty-spec cases, not streams where the client saw a + // reasoning-only response. // Gemma4 thinking channel: map <|channel> → , \n if (raw == "<|channel>") { + visible_output_seen = true; if (req.stream) { auto chunks = emitter.emit_token(""); for (const auto & chunk : chunks) @@ -1453,6 +1680,7 @@ void HttpServer::worker_loop() { return true; } if (raw == "") { + visible_output_seen = true; if (req.stream) { auto chunks = emitter.emit_token("\n"); for (const auto & chunk : chunks) @@ -1469,6 +1697,7 @@ void HttpServer::worker_loop() { // reasoning_content with empty visible content. Forward the text // form into the emitter so parse_reasoning() can split correctly. if (raw == "" || raw == "") { + visible_output_seen = true; if (req.stream) { auto chunks = emitter.emit_token( raw == "" ? "\n" : ""); @@ -1487,6 +1716,10 @@ void HttpServer::worker_loop() { std::string text = tokenizer_.token_text(token); + if (!text.empty()) { + visible_output_seen = true; + } + if (req.stream && !text.empty()) { auto chunks = emitter.emit_token(text); for (const auto & chunk : chunks) { @@ -1514,6 +1747,10 @@ void HttpServer::worker_loop() { } else { result = backend_.generate_with_empty_spec_fallback(gen_req, io); } + if (io.should_cancel()) { + client_disconnected = true; + job->cancelled.store(true, std::memory_order_relaxed); + } // Lazy-draft: park decode draft after generate to free VRAM. if (config_.lazy_draft) { @@ -1541,7 +1778,7 @@ void HttpServer::worker_loop() { // Confirm or abort the inline snapshot. if (snap_prepared) { - if (completion_tokens > 0 && !client_disconnected && + if (completion_tokens > 0 && visible_output_seen && !client_disconnected && backend_.snapshot_used(snap_slot)) { prefix_cache_.confirm_inline_snap(snap_slot, snap_cut, effective_prompt); // Track for shutdown save. @@ -1564,7 +1801,8 @@ void HttpServer::worker_loop() { // Continued checkpoint: save if total tokens crossed an interval boundary. // This captures prompt + all generated tokens for long conversation reuse. - if (!disk_cache_.disabled() && result.ok && completion_tokens > 0 && !client_disconnected) { + if (!disk_cache_.disabled() && result.ok && completion_tokens > 0 && + visible_output_seen && !client_disconnected) { int final_pos = (int)effective_prompt.size() + (int)result.tokens.size(); if (final_pos >= disk_cache_.continued_interval()) { // Build all_tokens = effective_prompt + result.tokens @@ -1580,7 +1818,8 @@ void HttpServer::worker_loop() { } // Full-compress cache: reserve + confirm after successful generation. - if (pflash_compressed && completion_tokens > 0 && !client_disconnected) { + if (pflash_compressed && completion_tokens > 0 && + visible_output_seen && !client_disconnected) { int full_slot = prefix_cache_.prepare_full_snap(req.prompt_tokens); if (full_slot >= 0) { prefix_cache_.confirm_full_snap(full_slot, req.prompt_tokens, diff --git a/server/src/server/http_server.h b/server/src/server/http_server.h index 999eb5d9..f373d3b8 100644 --- a/server/src/server/http_server.h +++ b/server/src/server/http_server.h @@ -6,8 +6,9 @@ // - Per-client thread: parse HTTP request, enqueue job, wait for completion // - Single worker thread: dequeue jobs, call ModelBackend::generate() // -// Client disconnect detection: the worker writes SSE chunks via send(). -// If send() fails (EPIPE/ECONNRESET), generation aborts immediately. +// Client disconnect detection: the client thread watches the socket while the +// worker owns generation. The worker also treats failed SSE writes as +// cancellation, so pre-token and mid-stream disconnects both abort generation. #pragma once @@ -311,6 +312,7 @@ class HttpServer { struct ServerJob { int fd = -1; ParsedRequest req; + std::atomic cancelled{false}; bool done = false; std::mutex mu; std::condition_variable cv; diff --git a/server/src/server/tokenizer.cpp b/server/src/server/tokenizer.cpp index 5ff4b1a7..489dbe31 100644 --- a/server/src/server/tokenizer.cpp +++ b/server/src/server/tokenizer.cpp @@ -18,6 +18,13 @@ namespace dflash::common { +static void check_tokenizer_cancelled( + const TokenizerCancelCallback & should_cancel) { + if (should_cancel && should_cancel()) { + throw TokenizationCancelled(); + } +} + // ─── Unicode helpers ──────────────────────────────────────────────────── static int utf8_len(uint8_t c) { @@ -147,7 +154,9 @@ static bool is_newline(uint32_t cp) { // \s+(?!\S) | // \s+ -std::vector Tokenizer::pre_tokenize(const std::string & text) const { +std::vector Tokenizer::pre_tokenize( + const std::string & text, + const TokenizerCancelCallback & should_cancel) const { std::vector pieces; const char * s = text.c_str(); const size_t len = text.size(); @@ -159,6 +168,8 @@ std::vector Tokenizer::pre_tokenize(const std::string & text) const }; while (pos < len) { + check_tokenizer_cancelled(should_cancel); + size_t start = pos; int cplen = 0; uint32_t cp = peek_cp(pos, &cplen); @@ -204,6 +215,7 @@ std::vector Tokenizer::pre_tokenize(const std::string & text) const // One or more letter/mark chars if (cl > 0 && (is_letter(c) || is_mark(c))) { while (cl > 0 && (is_letter(c) || is_mark(c))) { + check_tokenizer_cancelled(should_cancel); p += cl; c = peek_cp(p, &cl); } @@ -234,12 +246,14 @@ std::vector Tokenizer::pre_tokenize(const std::string & text) const size_t punc_start = p; while (cl > 0 && !is_whitespace(c) && !is_letter(c) && !is_mark(c) && !is_digit(c)) { + check_tokenizer_cancelled(should_cancel); p += cl; c = peek_cp(p, &cl); } if (p > punc_start) { // Trailing newlines while (cl > 0 && is_newline(c)) { + check_tokenizer_cancelled(should_cancel); p += cl; c = peek_cp(p, &cl); } @@ -256,11 +270,13 @@ std::vector Tokenizer::pre_tokenize(const std::string & text) const uint32_t c = peek_cp(p, &cl); // Consume leading whitespace while (cl > 0 && is_whitespace(c) && !is_newline(c)) { + check_tokenizer_cancelled(should_cancel); p += cl; c = peek_cp(p, &cl); } if (cl > 0 && is_newline(c)) { while (cl > 0 && is_newline(c)) { + check_tokenizer_cancelled(should_cancel); p += cl; c = peek_cp(p, &cl); } @@ -277,6 +293,7 @@ std::vector Tokenizer::pre_tokenize(const std::string & text) const c = peek_cp(p, &cl); size_t prev_p = pos; // position before last whitespace char while (cl > 0 && is_whitespace(c)) { + check_tokenizer_cancelled(should_cancel); prev_p = p; p += cl; c = peek_cp(p, &cl); @@ -361,7 +378,10 @@ static std::string encode_gpt2_bpe(const std::string & text) { } // Encode a single pre-tokenized piece using BPE merges. -std::vector Tokenizer::bpe_encode_piece(const std::string & piece) const { +std::vector Tokenizer::bpe_encode_piece( + const std::string & piece, + const TokenizerCancelCallback & should_cancel) const { + check_tokenizer_cancelled(should_cancel); if (piece.empty()) return {}; std::vector symbols; @@ -380,6 +400,7 @@ std::vector Tokenizer::bpe_encode_piece(const std::string & piece) cons std::string encoded; encoded.reserve(sp_piece.size()); for (char c : sp_piece) { + check_tokenizer_cancelled(should_cancel); if (c == ' ') { encoded += "\xe2\x96\x81"; } else { @@ -397,6 +418,7 @@ std::vector Tokenizer::bpe_encode_piece(const std::string & piece) cons const char * p = encoded.c_str(); const char * end = p + encoded.size(); while (p < end) { + check_tokenizer_cancelled(should_cancel); int cplen; utf8_decode(p, (size_t)(end - p), &cplen); if (cplen <= 0) cplen = 1; @@ -424,6 +446,7 @@ std::vector Tokenizer::bpe_encode_piece(const std::string & piece) cons // Split into individual GPT-2-encoded bytes as initial BPE symbols. for (size_t i = 0; i < piece.size(); i++) { + check_tokenizer_cancelled(should_cancel); std::string sym = byte_to_gpt2_unicode((uint8_t)piece[i]); auto sit = token_to_id_.find(sym); if (sit != token_to_id_.end()) { @@ -446,10 +469,12 @@ std::vector Tokenizer::bpe_encode_piece(const std::string & piece) cons // Iteratively merge the highest-priority pair until no more merges apply. while (symbols.size() > 1) { + check_tokenizer_cancelled(should_cancel); int best_rank = std::numeric_limits::max(); size_t best_pos = SIZE_MAX; for (size_t i = 0; i + 1 < symbols.size(); i++) { + check_tokenizer_cancelled(should_cancel); std::string pair = symbols[i] + " " + symbols[i + 1]; auto mit = merge_rank_.find(pair); if (mit != merge_rank_.end() && mit->second < best_rank) { @@ -469,6 +494,7 @@ std::vector Tokenizer::bpe_encode_piece(const std::string & piece) cons std::vector ids; ids.reserve(symbols.size()); for (const auto & sym : symbols) { + check_tokenizer_cancelled(should_cancel); auto sit = token_to_id_.find(sym); if (sit != token_to_id_.end()) { ids.push_back(sit->second); @@ -493,6 +519,7 @@ std::vector Tokenizer::bpe_encode_piece(const std::string & piece) cons const char * p = sym.c_str(); const char * end = p + sym.size(); while (p < end) { + check_tokenizer_cancelled(should_cancel); int cplen; uint32_t cp = utf8_decode(p, (size_t)(end - p), &cplen); uint8_t orig_byte; @@ -634,12 +661,21 @@ bool Tokenizer::load_from_gguf(const char * model_path) { } std::vector Tokenizer::encode(const std::string & text) const { + return encode(text, TokenizerCancelCallback{}); +} + +std::vector Tokenizer::encode( + const std::string & text, + const TokenizerCancelCallback & should_cancel) const { + check_tokenizer_cancelled(should_cancel); + // If no added tokens, fast path: pre-tokenize → BPE entire text. if (added_tokens_.empty()) { - std::vector pieces = pre_tokenize(text); + std::vector pieces = pre_tokenize(text, should_cancel); std::vector ids; for (const auto & piece : pieces) { - auto piece_ids = bpe_encode_piece(piece); + check_tokenizer_cancelled(should_cancel); + auto piece_ids = bpe_encode_piece(piece, should_cancel); ids.insert(ids.end(), piece_ids.begin(), piece_ids.end()); } return ids; @@ -650,6 +686,8 @@ std::vector Tokenizer::encode(const std::string & text) const { std::vector ids; size_t pos = 0; while (pos < text.size()) { + check_tokenizer_cancelled(should_cancel); + // Try to match any added token at current position. bool matched = false; for (const auto & [tok_str, tok_id] : added_tokens_) { @@ -666,6 +704,7 @@ std::vector Tokenizer::encode(const std::string & text) const { // Find the next special token (or end of string). size_t next_special = text.size(); for (const auto & [tok_str, tok_id] : added_tokens_) { + check_tokenizer_cancelled(should_cancel); size_t found = text.find(tok_str, pos); if (found != std::string::npos && found < next_special) { next_special = found; @@ -674,9 +713,10 @@ std::vector Tokenizer::encode(const std::string & text) const { // Pre-tokenize + BPE the normal segment. std::string segment = text.substr(pos, next_special - pos); - std::vector pieces = pre_tokenize(segment); + std::vector pieces = pre_tokenize(segment, should_cancel); for (const auto & piece : pieces) { - auto piece_ids = bpe_encode_piece(piece); + check_tokenizer_cancelled(should_cancel); + auto piece_ids = bpe_encode_piece(piece, should_cancel); ids.insert(ids.end(), piece_ids.begin(), piece_ids.end()); } pos = next_special; diff --git a/server/src/server/tokenizer.h b/server/src/server/tokenizer.h index 5484fa47..a9eb3286 100644 --- a/server/src/server/tokenizer.h +++ b/server/src/server/tokenizer.h @@ -10,12 +10,23 @@ #pragma once #include +#include +#include #include #include #include namespace dflash::common { +class TokenizationCancelled final : public std::exception { +public: + const char * what() const noexcept override { + return "tokenization cancelled"; + } +}; + +using TokenizerCancelCallback = std::function; + class Tokenizer { public: Tokenizer() = default; @@ -31,6 +42,9 @@ class Tokenizer { // ─── Encode ────────────────────────────────────────────────────── // Tokenize a UTF-8 string into token IDs. std::vector encode(const std::string & text) const; + std::vector encode( + const std::string & text, + const TokenizerCancelCallback & should_cancel) const; // ─── Decode ────────────────────────────────────────────────────── // Convert a single token ID to its text representation. @@ -55,10 +69,14 @@ class Tokenizer { private: // Pre-tokenize text into pieces using Qwen3/3.5 regex pattern. - std::vector pre_tokenize(const std::string & text) const; + std::vector pre_tokenize( + const std::string & text, + const TokenizerCancelCallback & should_cancel) const; // Apply BPE merges to a single pre-tokenized piece. - std::vector bpe_encode_piece(const std::string & piece) const; + std::vector bpe_encode_piece( + const std::string & piece, + const TokenizerCancelCallback & should_cancel) const; // Vocabulary: id → token string std::vector id_to_token_; diff --git a/server/test/test_server_unit.cpp b/server/test/test_server_unit.cpp index 275ec935..abf4d4ad 100644 --- a/server/test/test_server_unit.cpp +++ b/server/test/test_server_unit.cpp @@ -16,6 +16,7 @@ #include "server/api_types.h" #include "server/http_server.h" #include "server/chat_template.h" +#include "common/model_backend.h" #include "common/sampler.h" #include "common/backend_ipc.h" #include "placement/pflash_placement.h" @@ -1340,7 +1341,7 @@ static void test_layer_split_backend_inline_snapshot_and_restore_delta() { req.snap_slot = 2; req.snap_pos = 3; DaemonIO io; - GenerateResult result = backend.generate(req, io); + GenerateResult result = backend.generate_with_empty_spec_fallback(req, io); TEST_ASSERT(result.ok); TEST_ASSERT(raw->reset_called); @@ -2413,6 +2414,99 @@ static void test_usage_timings_omitted_when_null() { TEST_ASSERT(finish_str.find("[DONE]") != std::string::npos); } +// DaemonIO cancellation tests +// ═══════════════════════════════════════════════════════════════════════ + +static void test_daemon_io_cancel_callback_stops_emit_before_token_callback() { + bool token_callback_called = false; + DaemonIO io; + io.is_cancelled = []() { return true; }; + io.on_token = [&](int32_t) -> bool { + token_callback_called = true; + return true; + }; + + io.emit(123); + + TEST_ASSERT(io.cancelled.load()); + TEST_ASSERT(!token_callback_called); +} + +static void test_daemon_io_with_token_callback_preserves_cancel_callback() { + bool cancelled = false; + int callback_score = 0; + + DaemonIO io; + io.is_cancelled = [&]() { return cancelled; }; + io.on_token = [&](int32_t) -> bool { + callback_score += 1; + return true; + }; + + DaemonIO out = io.with_token_callback([&](int32_t) -> bool { + callback_score += 10; + return true; + }); + + out.emit(7); + TEST_ASSERT(callback_score == 11); + + cancelled = true; + out.emit(8); + TEST_ASSERT(out.cancelled.load()); + TEST_ASSERT(callback_score == 11); +} + +static void test_daemon_compute_result_prefers_failure_over_cancel() { + int cancel_polls = 0; + DaemonIO io; + io.is_cancelled = [&]() { + cancel_polls++; + return true; + }; + + const auto result = + classify_daemon_compute_result(GGML_STATUS_FAILED, io); + + TEST_ASSERT(result == DaemonComputeResult::Failed); + TEST_ASSERT(cancel_polls == 0); + TEST_ASSERT(!io.cancelled.load()); +} + +static void test_daemon_compute_result_reports_cancel_after_success() { + int cancel_polls = 0; + DaemonIO io; + io.is_cancelled = [&]() { + cancel_polls++; + return true; + }; + + const auto result = + classify_daemon_compute_result(GGML_STATUS_SUCCESS, io); + + TEST_ASSERT(result == DaemonComputeResult::Cancelled); + TEST_ASSERT(cancel_polls == 1); + TEST_ASSERT(io.cancelled.load()); +} + +static void test_tokenizer_encode_honors_cancel_callback() { + Tokenizer tok; + int cancel_polls = 0; + bool cancelled = false; + + try { + tok.encode("large request body", [&]() { + cancel_polls++; + return cancel_polls >= 2; + }); + } catch (const TokenizationCancelled &) { + cancelled = true; + } + + TEST_ASSERT(cancel_polls >= 2); + TEST_ASSERT(cancelled); +} + // ModelBackend common empty-spec retry tests // ═══════════════════════════════════════════════════════════════════════ @@ -2421,6 +2515,8 @@ struct EmptySpecRetryBackend : MockBackend { int restore_calls = 0; bool generate_saw_force_ar = false; bool restore_saw_force_ar = false; + bool generate_first_empty_visible = false; + bool restore_first_empty_visible = false; GenerateResult generate(const GenerateRequest & req, const DaemonIO &) override { @@ -2432,6 +2528,10 @@ struct EmptySpecRetryBackend : MockBackend { result.tokens = {42}; } else { result.spec_decode_ran = true; + if (generate_first_empty_visible) { + result.tokens = {2}; + result.empty_visible_output = true; + } } return result; } @@ -2446,6 +2546,10 @@ struct EmptySpecRetryBackend : MockBackend { result.tokens = {84}; } else { result.spec_decode_ran = true; + if (restore_first_empty_visible) { + result.tokens = {2}; + result.empty_visible_output = true; + } } return result; } @@ -2486,6 +2590,45 @@ static void test_model_backend_retries_empty_spec_restore_once_with_ar() { TEST_ASSERT(backend.restore_saw_force_ar); } +static void test_model_backend_retries_empty_visible_spec_generate_once_with_ar() { + EmptySpecRetryBackend backend; + backend.generate_first_empty_visible = true; + GenerateRequest req; + req.prompt = {1, 2, 3}; + req.n_gen = 4; + DaemonIO io; + + GenerateResult result = backend.generate_with_empty_spec_fallback(req, io); + + TEST_ASSERT(result.ok); + TEST_ASSERT(result.tokens.size() == 1); + TEST_ASSERT(result.tokens[0] == 42); + TEST_ASSERT(!result.empty_visible_output); + TEST_ASSERT(result.spec_decode_ran); + TEST_ASSERT(backend.generate_calls == 2); + TEST_ASSERT(backend.generate_saw_force_ar); +} + +static void test_model_backend_retries_empty_visible_spec_restore_once_with_ar() { + EmptySpecRetryBackend backend; + backend.restore_first_empty_visible = true; + GenerateRequest req; + req.prompt = {1, 2, 3}; + req.n_gen = 4; + DaemonIO io; + + GenerateResult result = + backend.restore_and_generate_with_empty_spec_fallback(7, req, io); + + TEST_ASSERT(result.ok); + TEST_ASSERT(result.tokens.size() == 1); + TEST_ASSERT(result.tokens[0] == 84); + TEST_ASSERT(!result.empty_visible_output); + TEST_ASSERT(result.spec_decode_ran); + TEST_ASSERT(backend.restore_calls == 2); + TEST_ASSERT(backend.restore_saw_force_ar); +} + // GenerateResult.accept_rate plumbing tests (Day 1 of bandit MVP) // ═══════════════════════════════════════════════════════════════════════ @@ -2714,9 +2857,18 @@ int main() { RUN_TEST(test_usage_timings_zero_decode_no_div_by_zero); RUN_TEST(test_usage_timings_omitted_when_null); + std::fprintf(stderr, "\n── DaemonIO cancellation ──\n"); + RUN_TEST(test_daemon_io_cancel_callback_stops_emit_before_token_callback); + RUN_TEST(test_daemon_io_with_token_callback_preserves_cancel_callback); + RUN_TEST(test_daemon_compute_result_prefers_failure_over_cancel); + RUN_TEST(test_daemon_compute_result_reports_cancel_after_success); + RUN_TEST(test_tokenizer_encode_honors_cancel_callback); + std::fprintf(stderr, "\n── ModelBackend empty-spec retry ──\n"); RUN_TEST(test_model_backend_retries_empty_spec_generate_once_with_ar); RUN_TEST(test_model_backend_retries_empty_spec_restore_once_with_ar); + RUN_TEST(test_model_backend_retries_empty_visible_spec_generate_once_with_ar); + RUN_TEST(test_model_backend_retries_empty_visible_spec_restore_once_with_ar); std::fprintf(stderr, "\n── GenerateResult.accept_rate ──\n"); RUN_TEST(test_generate_result_accept_rate_defaults_to_zero);