Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion server/src/common/daemon_loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,22 @@ namespace dflash::common {

// ── DaemonIO ────────────────────────────────────────────────────────────

bool DaemonIO::should_cancel() const {
if (cancelled.load(std::memory_order_relaxed)) return true;
if (is_cancelled && is_cancelled()) {
cancelled.store(true, std::memory_order_relaxed);
return true;
}
return false;
}

void DaemonIO::emit(int32_t v) const {
if (should_cancel()) return;

// Call the token callback for non-sentinel tokens.
if (on_token && v >= 0) {
if (!on_token(v)) {
cancelled = true;
cancelled.store(true, std::memory_order_relaxed);
return;
}
}
Expand Down
13 changes: 10 additions & 3 deletions server/src/common/dflash_spec_decode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ bool run_dflash_spec_decode(

auto t_dec0 = std::chrono::steady_clock::now();
while (n_generated < n_gen) {
if (io.should_cancel()) break;
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
const int need_commit_budget = n_gen - n_generated;

// ── Build noise input for draft ────────────────────────────────────
Expand All @@ -109,10 +110,12 @@ bool run_dflash_spec_decode(
// ── Draft compute (local or remote) ───────────────────────────────
const float * draft_hidden_host = nullptr;
if (use_remote_draft) {
if (io.should_cancel()) break;
if (!remote_draft->propose(committed, draft_ctx, noise_embed, remote_hidden)) {
std::fprintf(stderr, "dflash-spec remote draft propose failed\n");
return false;
}
if (io.should_cancel()) break;
draft_hidden_host = remote_hidden.data();
} else {
if (!build_draft_step(draft_sg, draft_weights, /*lm_head=*/nullptr, draft_backend,
Expand All @@ -138,10 +141,12 @@ bool run_dflash_spec_decode(
ggml_backend_tensor_set(draft_sg.positions_k, pos_k.data(), 0,
sizeof(int32_t) * pos_k.size());
auto st = ggml_backend_graph_compute(draft_backend, draft_sg.gf);
if (st != GGML_STATUS_SUCCESS) {
const auto compute_result = classify_daemon_compute_result(st, io);
if (compute_result == DaemonComputeResult::Failed) {
std::fprintf(stderr, "dflash-spec draft compute %d\n", (int)st);
return false;
}
if (compute_result == DaemonComputeResult::Cancelled) break;
// Read draft hidden states out to host so the target adapter can
// project them through its own LM head (target-internal layout).
local_hidden.resize((size_t)hidden * q_len);
Expand Down Expand Up @@ -174,6 +179,7 @@ bool run_dflash_spec_decode(
std::fprintf(stderr, "dflash-spec snapshot_kv failed\n");
return false;
}
if (io.should_cancel()) break;

int verify_last_tok = -1;
if (!target.verify_batch(draft_tok, committed, verify_last_tok, &target_tok)) {
Expand Down Expand Up @@ -209,6 +215,7 @@ bool run_dflash_spec_decode(
std::fprintf(stderr, "dflash-spec restore_kv failed\n");
return false;
}
if (io.should_cancel()) break;

std::vector<int32_t> replay_tok((size_t)commit_n);
for (int i = 0; i < commit_n; i++) {
Expand All @@ -226,15 +233,15 @@ bool run_dflash_spec_decode(
for (int i = 0; i < commit_n; i++) {
out_all.push_back(replay_tok[i]);
io.emit(replay_tok[i]);
if (io.cancelled) break;
if (io.should_cancel()) break;
++emitted;
if (target.is_eos(replay_tok[i])) hit_eos = true;
}
committed += emitted;
n_generated += emitted;
n_accept_sum += std::min(accept_n, emitted);
n_draft_steps++;
if (io.cancelled) break;
if (io.should_cancel()) break;
if (hit_eos) break;
}
if (!use_remote_draft && draft_backend) ggml_backend_synchronize(draft_backend);
Expand Down
53 changes: 49 additions & 4 deletions server/src/common/model_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#pragma once

#include <atomic>
#include <cstdint>
#include <cstdio>
#include <functional>
Expand All @@ -25,16 +26,39 @@ namespace dflash::common {
// Token callback for streaming generation. Called once per committed token.
// Return true to continue generation, false to abort.
using TokenCallback = std::function<bool(int32_t token)>;
using CancelCallback = std::function<bool()>;

// ─── I/O handle passed to backend methods that need protocol output ─────
struct DaemonIO {
int stream_fd = -1;

DaemonIO() = default;
DaemonIO(const DaemonIO & other)
: stream_fd(other.stream_fd)
, on_token(other.on_token)
, is_cancelled(other.is_cancelled)
, cancelled(other.cancelled.load(std::memory_order_relaxed)) {}

DaemonIO & operator=(const DaemonIO & other) {
if (this == &other) return *this;
stream_fd = other.stream_fd;
on_token = other.on_token;
is_cancelled = other.is_cancelled;
cancelled.store(other.cancelled.load(std::memory_order_relaxed),
std::memory_order_relaxed);
return *this;
}

// Optional token callback. When set, emit() calls this for each token
// (excluding the -1 sentinel). If it returns false, the `cancelled`
// flag is set and the caller should abort generation.
TokenCallback on_token;
mutable bool cancelled = false;
CancelCallback is_cancelled;
mutable std::atomic<bool> cancelled{false};

// Observe cooperative cancellation even before a token callback fires
// (for example while a long prefill/spec-decode step is still working).
bool should_cancel() const;

// Write a single int32 to the stream fd (token or -1 sentinel).
// Also invokes on_token if set. Sets cancelled=true if on_token
Expand All @@ -45,6 +69,22 @@ struct DaemonIO {
DaemonIO with_token_callback(const TokenCallback & cb) const;
};

enum class DaemonComputeResult {
Success,
Cancelled,
Failed,
};

inline DaemonComputeResult classify_daemon_compute_result(
ggml_status status,
const DaemonIO & io) {
// A completed graph failure is a backend error even if the client also
// disconnected while the graph was running.
if (status != GGML_STATUS_SUCCESS) return DaemonComputeResult::Failed;
if (io.should_cancel()) return DaemonComputeResult::Cancelled;
return DaemonComputeResult::Success;
}

// ─── Generate request/result ────────────────────────────────────────────

// Thinking-budget force-close hook. Mirrors antirez/ds4 ds4_eval.c's
Expand Down Expand Up @@ -131,6 +171,11 @@ struct GenerateResult {
float accept_rate = 0.0f;
// True when spec decode actually ran (accept_rate==0 still needs a bandit update).
bool spec_decode_ran = false;
// True when decode emitted only tokens that the API layer suppresses
// (for example an immediate EOS/EOT). This is semantically equivalent
// to zero output for clients and should take the same AR retry path as
// an empty token vector.
bool empty_visible_output = false;
};

// ─── Backend interface ──────────────────────────────────────────────────
Expand Down Expand Up @@ -160,7 +205,7 @@ struct ModelBackend {
if (!should_retry_empty_spec_decode(req, result)) return result;

std::fprintf(stderr,
"[backend] spec-decode produced zero tokens; retrying with AR decode\n");
"[backend] spec-decode produced no visible output; retrying with AR decode\n");
GenerateRequest retry = req;
retry.force_ar_decode = true;
return merge_empty_spec_retry_result(result, generate(retry, io));
Expand Down Expand Up @@ -188,7 +233,7 @@ struct ModelBackend {
if (!should_retry_empty_spec_decode(req, result)) return result;

std::fprintf(stderr,
"[backend] restored spec-decode produced zero tokens; retrying with AR decode\n");
"[backend] restored spec-decode produced no visible output; retrying with AR decode\n");
GenerateRequest retry = req;
retry.force_ar_decode = true;
return merge_empty_spec_retry_result(result,
Expand All @@ -201,7 +246,7 @@ struct ModelBackend {
&& !req.force_ar_decode
&& result.ok
&& result.spec_decode_ran
&& result.tokens.empty();
&& (result.tokens.empty() || result.empty_visible_output);
}

static GenerateResult merge_empty_spec_retry_result(
Expand Down
29 changes: 23 additions & 6 deletions server/src/gemma4/gemma4_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ bool Gemma4Backend::unpark(const std::string & what) {

int Gemma4Backend::do_prefill(const std::vector<int32_t> & tokens,
const DaemonIO & io, int kv_offset) {
(void)io;
const int n = (int)tokens.size();
const int hidden = w_.n_embd;
const int chunk = cfg_.chunk;
Expand All @@ -212,6 +211,7 @@ int Gemma4Backend::do_prefill(const std::vector<int32_t> & tokens,

int pos = 0;
while (pos < n) {
if (io.should_cancel()) return -1;
int len = std::min(chunk, n - pos);

// Limit chunk to avoid ring-buffer wrap for SWA layers
Expand All @@ -233,6 +233,7 @@ int Gemma4Backend::do_prefill(const std::vector<int32_t> & tokens,
std::fprintf(stderr, "[gemma4] prefill step failed at pos=%d\n", kv_pos);
return -1;
}
if (io.should_cancel()) return -1;

pos += len;
cache_.cur_pos = kv_offset + pos;
Expand Down Expand Up @@ -368,7 +369,7 @@ bool Gemma4Backend::do_decode(int committed, int n_gen,
io.emit(next);
committed++;
cache_.cur_pos = committed;
if (io.cancelled) break;
if (io.should_cancel()) break;

// Check EOS
if (next == w_.eos_id || next == w_.eos_chat_id) break;
Expand Down Expand Up @@ -586,7 +587,7 @@ bool Gemma4Backend::do_spec_decode(int committed, int n_gen,
out_tokens.push_back(tok);
io.emit(tok);
emitted++;
if (io.cancelled) break;
if (io.should_cancel()) break;
if (tok == w_.eos_id || tok == w_.eos_chat_id) {
hit_eos = true; break;
}
Expand All @@ -596,7 +597,7 @@ bool Gemma4Backend::do_spec_decode(int committed, int n_gen,
n_generated += emitted;
n_accept_sum += std::min(accept_n, emitted);
n_draft_steps++;
if (io.cancelled) break;
if (io.should_cancel()) break;
if (hit_eos) break;
}

Expand Down Expand Up @@ -640,9 +641,17 @@ GenerateResult Gemma4Backend::generate(const GenerateRequest & req,
result.prefill_s = std::chrono::duration<double>(
std::chrono::steady_clock::now() - t_prefill_start).count();
if (committed < 0) {
if (out_io.should_cancel()) {
result.ok = true;
return result;
}
result.error = "prefill";
return result;
}
if (out_io.should_cancel()) {
result.ok = true;
return result;
}

// Inline snapshot at snap_pos for prefix cache
if (req.snap_slot >= 0 && req.snap_pos > 0 && req.snap_pos <= committed) {
Expand Down Expand Up @@ -703,7 +712,7 @@ GenerateResult Gemma4Backend::generate(const GenerateRequest & req,
}
result.tokens.push_back(first);
out_io.emit(first);
if (out_io.cancelled) {
if (out_io.should_cancel()) {
out_io.emit(-1);
result.ok = true;
return result;
Expand Down Expand Up @@ -804,6 +813,10 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot,
std::vector<int32_t> delta(req.prompt.begin() + snap_pos, req.prompt.end());
committed = do_prefill(delta, out_io, /*kv_offset=*/snap_pos);
if (committed < 0) {
if (out_io.should_cancel()) {
result.ok = true;
return result;
}
result.error = "prefill";
return result;
}
Expand All @@ -815,6 +828,10 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot,
// else: prompt_len == snap_pos → no delta, committed stays at snap_pos
result.prefill_s = std::chrono::duration<double>(
std::chrono::steady_clock::now() - t_prefill_start).count();
if (out_io.should_cancel()) {
result.ok = true;
return result;
}

// Inline snapshot at snap_pos for prefix cache (new snap from this request)
if (req.snap_slot >= 0 && req.snap_pos > 0 && req.snap_pos <= committed) {
Expand Down Expand Up @@ -883,7 +900,7 @@ GenerateResult Gemma4Backend::restore_and_generate(int slot,
}
result.tokens.push_back(first);
out_io.emit(first);
if (out_io.cancelled) {
if (out_io.should_cancel()) {
out_io.emit(-1);
result.ok = true;
return result;
Expand Down
4 changes: 2 additions & 2 deletions server/src/gemma4/gemma4_layer_split_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ bool Gemma4LayerSplitAdapter::decode_ar(
const auto & w = shards_.front().weights;
out_tokens.push_back(last_tok);
io.emit(last_tok);
if (io.cancelled) {
if (io.should_cancel()) {
io.emit(-1);
return true;
}
Expand All @@ -392,7 +392,7 @@ bool Gemma4LayerSplitAdapter::decode_ar(
out_tokens.push_back(last_tok);
io.emit(last_tok);
++committed;
if (io.cancelled) break;
if (io.should_cancel()) break;
if (last_tok == w.eos_id || last_tok == w.eos_chat_id) break;
}
io.emit(-1);
Expand Down
20 changes: 18 additions & 2 deletions server/src/laguna/laguna_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,19 @@ GenerateResult LagunaBackend::generate(const GenerateRequest & req,
bool ok = true;
const int n_chunks = (N + args_.chunk - 1) / args_.chunk;
for (int c = 0; c < n_chunks && ok; ++c) {
if (out_io.should_cancel()) {
result.ok = true;
return result;
}
const int kv_start = c * args_.chunk;
const int n_tok = std::min(args_.chunk, N - c * args_.chunk);
ok = laguna_step(backend_, w_, cache_,
embed_pf.data() + (size_t)kv_start * w_.n_embd,
n_tok, kv_start, no_mask, last_logits);
if (out_io.should_cancel()) {
result.ok = true;
return result;
}
}
if (!ok) { result.error = "prefill"; return result; }
auto t_pf1 = std::chrono::steady_clock::now();
Expand Down Expand Up @@ -267,7 +275,7 @@ GenerateResult LagunaBackend::generate(const GenerateRequest & req,
history.push_back(next_tok);
if (should_emit) {
out_io.emit(next_tok);
if (out_io.cancelled) break;
if (out_io.should_cancel()) break;
}
if (!w_.embedder.embed(&next_tok, 1, embed_step.data())) { ok = false; break; }
std::vector<float> step_logits;
Expand Down Expand Up @@ -328,12 +336,20 @@ GenerateResult LagunaBackend::restore_and_generate(int slot,
bool ok = true;
const int n_chunks = (diff_n + args_.chunk - 1) / args_.chunk;
for (int c = 0; c < n_chunks && ok; ++c) {
if (out_io.should_cancel()) {
result.ok = true;
return result;
}
const int off = c * args_.chunk;
const int n_tok = std::min(args_.chunk, diff_n - off);
const int starts = kv_start + off;
ok = laguna_step(backend_, w_, cache_,
embed_diff.data() + (size_t)off * w_.n_embd,
n_tok, starts, no_mask, last_logits);
if (out_io.should_cancel()) {
result.ok = true;
return result;
}
}
if (!ok) { result.error = "prefill"; return result; }

Expand Down Expand Up @@ -402,7 +418,7 @@ GenerateResult LagunaBackend::restore_and_generate(int slot,
history.push_back(next_tok);
result.tokens.push_back(next_tok);
out_io.emit(next_tok);
if (out_io.cancelled) break;
if (out_io.should_cancel()) break;
if (!w_.embedder.embed(&next_tok, 1, embed_step.data())) { ok = false; break; }
std::vector<float> step_logits;
if (!laguna_step(backend_, w_, cache_, embed_step.data(), 1,
Expand Down
Loading
Loading