From 3cdcf7b775a47632e57eca6f55c189b445ac4f71 Mon Sep 17 00:00:00 2001 From: dusterbloom <32869278+dusterbloom@users.noreply.github.com> Date: Wed, 27 May 2026 09:40:14 +0200 Subject: [PATCH 1/2] feat(pflash): adaptive keep_ratio bandit MVP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per-session ε-greedy bandit that adjusts compression based on observed accept_rate. Opt-in via session_id; clients without it get the existing fixed-keep path, byte-identical to main. Includes: - Bandit state machine (LRU-bounded session map, cap 1024) - HTTP server session_id parsing + bandit hook - accept_rate plumbing from DFlash GenerateResult - CI submodule PAT fix for fork PRs - Harness session_id env-var wiring 5-turn trajectory + NIAH @16K/32K + 3-seed A/B/C evidence (reproducible via the follow-up bench PR; not committed here). Bench scripts + result artifacts split to follow-up PR. Bug #42 tail-capture fix moved to PR #274. --- .github/workflows/ci.yml | 1 + .gitignore | 1 + harness/clients/prompts/logic_check.txt | 5 + harness/clients/prompts/math_check.txt | 5 + harness/clients/run_claude_code.sh | 43 +++- harness/clients/session_inject_proxy.py | 144 ++++++++++++++ server/CMakeLists.txt | 12 ++ server/src/common/model_backend.h | 5 + server/src/qwen35/qwen35_backend.cpp | 9 + server/src/qwen35/qwen35_backend.h | 4 + server/src/server/adaptive_keep_ratio.h | 116 +++++++++++ server/src/server/http_server.cpp | 32 ++- server/src/server/http_server.h | 22 +++ server/test/test_adaptive_keep_ratio.cpp | 239 +++++++++++++++++++++++ server/test/test_bandit_integration.cpp | 200 +++++++++++++++++++ server/test/test_server_unit.cpp | 82 ++++++++ thoughts/2026-05-21_pflash_mvp_plan.md | 129 ++++++++++++ 17 files changed, 1043 insertions(+), 6 deletions(-) create mode 100644 harness/clients/prompts/logic_check.txt create mode 100644 harness/clients/prompts/math_check.txt create mode 100755 harness/clients/session_inject_proxy.py create mode 100644 server/src/server/adaptive_keep_ratio.h create mode 100644 server/test/test_adaptive_keep_ratio.cpp create mode 100644 server/test/test_bandit_integration.cpp create mode 100644 thoughts/2026-05-21_pflash_mvp_plan.md diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 705bbf09b..737d4b820 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,6 +26,7 @@ jobs: - uses: actions/checkout@v4 with: submodules: recursive + token: ${{ secrets.SUBMODULE_PAT || secrets.GITHUB_TOKEN }} - uses: Jimver/cuda-toolkit@v0.2.35 with: diff --git a/.gitignore b/.gitignore index 4b406506d..b400bb6de 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,7 @@ env/ *.qdrep *.sqlite bench-out/ +dflash/bench/results/ profile-out/ # Model weights and caches (pull fresh from HF) diff --git a/harness/clients/prompts/logic_check.txt b/harness/clients/prompts/logic_check.txt new file mode 100644 index 000000000..eb46cbfc5 --- /dev/null +++ b/harness/clients/prompts/logic_check.txt @@ -0,0 +1,5 @@ +Answer these logic puzzles. End your answer with OK_DONE. + +1. If all roses are flowers and some flowers fade quickly, can we conclude that some roses fade quickly? +2. A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost? +3. If you have a 3-litre jug and a 5-litre jug, how can you measure exactly 4 litres of water? diff --git a/harness/clients/prompts/math_check.txt b/harness/clients/prompts/math_check.txt new file mode 100644 index 000000000..c6d8df470 --- /dev/null +++ b/harness/clients/prompts/math_check.txt @@ -0,0 +1,5 @@ +Solve the following math problems. End your answer with OK_DONE. + +1. What is 17 * 23? +2. What is the sum of the first 10 prime numbers? +3. If a rectangle has width 7 and height 11, what is its area? diff --git a/harness/clients/run_claude_code.sh b/harness/clients/run_claude_code.sh index 3b969f04b..f2d120597 100755 --- a/harness/clients/run_claude_code.sh +++ b/harness/clients/run_claude_code.sh @@ -22,11 +22,45 @@ start_lucebox_server trap stop_lucebox_server EXIT wait_lucebox_server +# When PFLASH_SESSION_ID is set, start a thin proxy that injects +# extra_body.session_id into every /v1/messages request. The claude CLI +# cannot inject extra_body natively, so the proxy does it transparently. +PROXY_PID="" +CLIENT_BASE_URL="$BASE_URL" +if [[ -n "${PFLASH_SESSION_ID:-}" ]]; then + PROXY_PORT="${PFLASH_PROXY_PORT:-18082}" + python3 "$SCRIPT_DIR/session_inject_proxy.py" \ + --host "$HOST" \ + --port "$PROXY_PORT" \ + --upstream "$BASE_URL" \ + --session-id "$PFLASH_SESSION_ID" \ + >> "$LOG_DIR/proxy.log" 2>&1 & + PROXY_PID=$! + _proxy_ready=0 + for _i in $(seq 1 10); do + if curl -fsS "http://$HOST:$PROXY_PORT/health" >/dev/null 2>&1; then _proxy_ready=1; break; fi + sleep 1 + if ! kill -0 "$PROXY_PID" 2>/dev/null; then + echo "session-inject proxy exited early; log: $LOG_DIR/proxy.log" >&2 + cat "$LOG_DIR/proxy.log" >&2 || true + exit 1 + fi + done + if [[ "$_proxy_ready" -eq 0 ]]; then + echo "session-inject proxy did not become ready after 10s; log: $LOG_DIR/proxy.log" >&2 + cat "$LOG_DIR/proxy.log" >&2 || true + kill "$PROXY_PID" 2>/dev/null || true + exit 1 + fi + CLIENT_BASE_URL="http://$HOST:$PROXY_PORT" + echo "[run_claude_code] session-inject proxy up on $CLIENT_BASE_URL (session=$PFLASH_SESSION_ID)" +fi + set +e HOME="$HOME_DIR" \ ANTHROPIC_API_KEY="$API_KEY" \ -ANTHROPIC_BASE_URL="$BASE_URL" \ -CLAUDE_CODE_API_BASE_URL="$BASE_URL" \ +ANTHROPIC_BASE_URL="$CLIENT_BASE_URL" \ +CLAUDE_CODE_API_BASE_URL="$CLIENT_BASE_URL" \ CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1 \ CLAUDE_CODE_DISABLE_TELEMETRY=1 \ CLAUDE_CODE_DISABLE_NONSTREAMING_FALLBACK=1 \ @@ -42,5 +76,10 @@ timeout "${CLAUDE_TIMEOUT}s" "$CLAUDE_BIN" \ RC=$? set -e +if [[ -n "$PROXY_PID" ]] && kill -0 "$PROXY_PID" 2>/dev/null; then + kill "$PROXY_PID" 2>/dev/null || true + wait "$PROXY_PID" 2>/dev/null || true +fi + finish_report "$CLIENT_OUT" "$RC" exit "$RC" diff --git a/harness/clients/session_inject_proxy.py b/harness/clients/session_inject_proxy.py new file mode 100755 index 000000000..8cebab81e --- /dev/null +++ b/harness/clients/session_inject_proxy.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +"""Thin proxy that injects extra_body.session_id into /v1/messages requests. + +Run between the claude CLI and the dflash server when PFLASH_SESSION_ID is set. +All other paths and methods are forwarded verbatim. + +Usage: + python3 session_inject_proxy.py \\ + --host 127.0.0.1 --port 18081 \\ + --upstream http://127.0.0.1:18080 \\ + --session-id + +The proxy listens on --port and forwards to --upstream, injecting +extra_body.session_id on every POST /v1/messages request. +""" + +from __future__ import annotations + +import argparse +import json +import os +import socket +import threading +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from urllib.parse import urlparse +import http.client + + +class Handler(BaseHTTPRequestHandler): + upstream: str = "" + session_id: str = "" + + def log_message(self, fmt, *args): + print("[session-proxy] %s" % (fmt % args), flush=True) + + def _upstream_conn(self) -> tuple[http.client.HTTPConnection, str]: + url = urlparse(self.upstream) + port = url.port or (443 if url.scheme == "https" else 80) + cls = http.client.HTTPSConnection if url.scheme == "https" else http.client.HTTPConnection + return cls(url.hostname, port, timeout=900), url.path.rstrip("/") + + def _forward_raw(self, body: bytes): + """Forward request verbatim (no injection needed).""" + conn, base = self._upstream_conn() + headers = { + k: v for k, v in self.headers.items() + if k.lower() not in ("host", "content-length", "transfer-encoding") + } + headers["Content-Length"] = str(len(body)) + conn.request(self.command, base + self.path, body, headers) + resp = conn.getresponse() + self._relay_response(resp) + + def _relay_response(self, resp: http.client.HTTPResponse): + """Relay upstream response back to client, handling SSE streaming.""" + content_type = resp.getheader("Content-Type", "") + is_sse = "text/event-stream" in content_type + + self.send_response(resp.status) + skip_headers = {"transfer-encoding", "content-length"} + for k, v in resp.getheaders(): + if k.lower() not in skip_headers: + self.send_header(k, v) + + if is_sse: + self.send_header("Transfer-Encoding", "chunked") + self.end_headers() + # Stream chunk by chunk + while True: + chunk = resp.read(4096) + if not chunk: + # Write terminal chunk + self.wfile.write(b"0\r\n\r\n") + self.wfile.flush() + break + size = "%X\r\n" % len(chunk) + self.wfile.write(size.encode("ascii")) + self.wfile.write(chunk) + self.wfile.write(b"\r\n") + self.wfile.flush() + else: + data = resp.read() + self.send_header("Content-Length", str(len(data))) + self.end_headers() + self.wfile.write(data) + + def _read_body(self) -> bytes: + n = int(self.headers.get("Content-Length", "0")) + if n <= 0: + return b"" + return self.rfile.read(n) + + def do_GET(self): + conn, base = self._upstream_conn() + headers = {k: v for k, v in self.headers.items() if k.lower() != "host"} + conn.request("GET", base + self.path, None, headers) + resp = conn.getresponse() + self._relay_response(resp) + + def do_POST(self): + body = self._read_body() + path = self.path + + # Inject session_id only on /v1/messages + if self.session_id and path.startswith("/v1/messages"): + try: + obj = json.loads(body.decode("utf-8")) + if "extra_body" not in obj: + obj["extra_body"] = {} + if "session_id" not in obj["extra_body"]: + obj["extra_body"]["session_id"] = self.session_id + body = json.dumps(obj).encode("utf-8") + except Exception as exc: + print(f"[session-proxy] JSON parse error, forwarding raw: {exc}", flush=True) + + self._forward_raw(body) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--host", default="127.0.0.1") + ap.add_argument("--port", type=int, default=18081) + ap.add_argument("--upstream", default="http://127.0.0.1:18080") + ap.add_argument("--session-id", default=os.environ.get("PFLASH_SESSION_ID", "")) + args = ap.parse_args() + + if not args.session_id: + print("[session-proxy] WARNING: no session_id set; proxy is pass-through only", flush=True) + + Handler.upstream = args.upstream.rstrip("/") + Handler.session_id = args.session_id + + srv = ThreadingHTTPServer((args.host, args.port), Handler) + print( + f"[session-proxy] listening on http://{args.host}:{args.port} " + f"-> {Handler.upstream} " + f"(session_id={Handler.session_id!r})", + flush=True, + ) + srv.serve_forever() + + +if __name__ == "__main__": + main() diff --git a/server/CMakeLists.txt b/server/CMakeLists.txt index 179b297fd..8ce0344dd 100644 --- a/server/CMakeLists.txt +++ b/server/CMakeLists.txt @@ -555,6 +555,18 @@ if(DFLASH27B_TESTS) target_include_directories(test_gguf_mmap PRIVATE ${DFLASH27B_SRC_INCLUDE_DIRS}) target_link_libraries(test_gguf_mmap PRIVATE dflash_common) endif() + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_adaptive_keep_ratio.cpp") + add_executable(test_adaptive_keep_ratio test/test_adaptive_keep_ratio.cpp) + target_include_directories(test_adaptive_keep_ratio PRIVATE ${DFLASH27B_SRC_INCLUDE_DIRS}) + target_link_libraries(test_adaptive_keep_ratio PRIVATE dflash_common) + add_test(NAME adaptive_keep COMMAND test_adaptive_keep_ratio) + endif() + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_bandit_integration.cpp") + add_executable(test_bandit_integration test/test_bandit_integration.cpp) + target_include_directories(test_bandit_integration PRIVATE ${DFLASH27B_SRC_INCLUDE_DIRS}) + target_link_libraries(test_bandit_integration PRIVATE dflash_common) + add_test(NAME bandit_integration COMMAND test_bandit_integration) + endif() if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_draft_vs_reference.cpp") add_executable(test_draft_vs_reference test/test_draft_vs_reference.cpp) target_link_libraries(test_draft_vs_reference PRIVATE dflash_common) diff --git a/server/src/common/model_backend.h b/server/src/common/model_backend.h index 9238f2fa3..182b50030 100644 --- a/server/src/common/model_backend.h +++ b/server/src/common/model_backend.h @@ -121,6 +121,11 @@ struct GenerateResult { // can mark the answer as unreliable rather than treating the // (truncated) content as a clean response. bool degenerate_decode_close = false; + // DFlash chain accept rate: accepted_draft_tokens / total_draft_positions. + // 0.0 when spec decode did not run (AR fallback or no draft model). + float accept_rate = 0.0f; + // True when spec decode actually ran (accept_rate==0 still needs a bandit update). + bool spec_decode_ran = false; }; // ─── Backend interface ────────────────────────────────────────────────── diff --git a/server/src/qwen35/qwen35_backend.cpp b/server/src/qwen35/qwen35_backend.cpp index 65cd1f518..be83db452 100644 --- a/server/src/qwen35/qwen35_backend.cpp +++ b/server/src/qwen35/qwen35_backend.cpp @@ -578,6 +578,7 @@ GenerateResult Qwen35Backend::generate(const GenerateRequest & req, // generation. Most requests never hit the tail because the // model closes naturally well before the budget edge. if (!do_spec_decode(committed, req.n_gen, result.tokens, out_io, + result.accept_rate, result.spec_decode_ran, req.hint_tokens, &req.budget_hook, &result.budget_forced_close, &result.degenerate_decode_close)) { @@ -648,6 +649,7 @@ GenerateResult Qwen35Backend::restore_and_generate(int slot, // generation. Most requests never hit the tail because the // model closes naturally well before the budget edge. if (!do_spec_decode(committed, req.n_gen, result.tokens, out_io, + result.accept_rate, result.spec_decode_ran, req.hint_tokens, &req.budget_hook, &result.budget_forced_close, &result.degenerate_decode_close)) { @@ -1072,10 +1074,14 @@ bool Qwen35Backend::sync_remote_draft_features(int start_pos, int n_tokens) { bool Qwen35Backend::do_spec_decode(int committed, int n_gen, std::vector & out_tokens, const DaemonIO & io, + float & out_accept_rate, + bool & out_spec_ran, const std::vector * hint_tokens, const BudgetHook * budget_hook, bool * forced_close_out, bool * degenerate_close_out) { + out_accept_rate = 0.0f; + out_spec_ran = false; const int hidden = w_.n_embd; // First token: use the argmax that do_prefill already sampled and stored. @@ -1108,6 +1114,8 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, return ok; } + out_spec_ran = true; + // ── DFlash spec-decode: draft → verify → accept → replay ────────── DFlashTarget * target = dflash_target(); @@ -1349,6 +1357,7 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen, const double decode_s = std::chrono::duration(t_dec1 - t_dec0).count(); const int total_draft_pos = std::max(1, n_draft_steps * q_len); const double accept_pct = 100.0 * (double)n_accept_sum / (double)total_draft_pos; + out_accept_rate = (float)((double)n_accept_sum / (double)total_draft_pos); std::fprintf(stderr, "[spec-decode] tokens=%d time=%.3f s speed=%.2f tok/s " "steps=%d accepted=%d/%d (%.1f%%) avg_commit=%.2f\n", n_generated, decode_s, diff --git a/server/src/qwen35/qwen35_backend.h b/server/src/qwen35/qwen35_backend.h index 4084c11b0..135b28a1c 100644 --- a/server/src/qwen35/qwen35_backend.h +++ b/server/src/qwen35/qwen35_backend.h @@ -219,9 +219,13 @@ class Qwen35Backend : public ModelBackend { // close token mid-batch (verify-and-accept assumes the sampled // tokens are the ones that got committed), so the boundary switch // is the simplest correct integration. + // out_accept_rate receives accepted/total draft token ratio (0.0 if AR fallback). + // out_spec_ran is true when spec decode actually ran (even with 0 accepts). bool do_spec_decode(int committed, int n_gen, std::vector & out_tokens, const DaemonIO & io, + float & out_accept_rate, + bool & out_spec_ran, const std::vector * hint_tokens = nullptr, const BudgetHook * budget_hook = nullptr, bool * forced_close_out = nullptr, diff --git a/server/src/server/adaptive_keep_ratio.h b/server/src/server/adaptive_keep_ratio.h new file mode 100644 index 000000000..959b87bce --- /dev/null +++ b/server/src/server/adaptive_keep_ratio.h @@ -0,0 +1,116 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +namespace dflash::common { + +struct AdaptiveKeepRatioState { + float ema = 0.0f; + float last_keep = 0.10f; + int turn_count = 0; +}; + +constexpr float kBanditEmaAlpha = 0.7f; +constexpr float kBanditTargetLo = 0.75f; +constexpr float kBanditTargetHi = 0.85f; +constexpr float kBanditStepSmall = 0.005f; +constexpr float kBanditStepLarge = 0.01f; +constexpr float kBanditKeepMin = 0.025f; +constexpr float kBanditKeepMax = 0.20f; +constexpr float kBanditEscalateLo = 0.70f; +constexpr float kBanditEscalateHi = 0.90f; + +// Maximum number of concurrent sessions retained in memory. +// When this limit is reached, the least-recently-used session is evicted. +constexpr std::size_t kMaxSessions = 1024; + +inline AdaptiveKeepRatioState step_adaptive_keep_ratio( + const AdaptiveKeepRatioState& state, float observed_accept) +{ + AdaptiveKeepRatioState next = state; + + // First turn: seed EMA directly; later: alpha smoothing + next.ema = (state.turn_count == 0) + ? observed_accept + : kBanditEmaAlpha * state.ema + (1.0f - kBanditEmaAlpha) * observed_accept; + + float delta = 0.0f; + if (next.ema > kBanditTargetHi) { + delta = (next.ema > kBanditEscalateHi) ? -kBanditStepLarge : -kBanditStepSmall; + } else if (next.ema < kBanditTargetLo) { + delta = (next.ema < kBanditEscalateLo) ? kBanditStepLarge : kBanditStepSmall; + } + next.last_keep = std::clamp(state.last_keep + delta, kBanditKeepMin, kBanditKeepMax); + next.turn_count = state.turn_count + 1; + return next; +} + +// Thread-safe per-session container with LRU eviction bounded to kMaxSessions. +// Prevents memory exhaustion from unbounded unique-session insertion. +class HttpServerSessions { +public: + void update(const std::string& session_id, float observed_accept) { + std::lock_guard lock(mu_); + auto it = map_.find(session_id); + if (it == map_.end()) { + evict_if_full_locked(); + lru_.push_front(session_id); + map_.emplace(session_id, Entry{step_adaptive_keep_ratio({}, observed_accept), lru_.begin()}); + } else { + it->second.state = step_adaptive_keep_ratio(it->second.state, observed_accept); + lru_.splice(lru_.begin(), lru_, it->second.lru_it); + } + } + + float get_keep_ratio(const std::string& session_id) const { + std::lock_guard lock(mu_); + auto it = map_.find(session_id); + if (it == map_.end()) return AdaptiveKeepRatioState{}.last_keep; + lru_.splice(lru_.begin(), lru_, it->second.lru_it); + return it->second.state.last_keep; + } + + float get_ema(const std::string & session_id) const { + std::lock_guard lock(mu_); + auto it = map_.find(session_id); + if (it == map_.end()) return 0.0f; + lru_.splice(lru_.begin(), lru_, it->second.lru_it); + return it->second.state.ema; + } + + int turn_count(const std::string& session_id) const { + std::lock_guard lock(mu_); + auto it = map_.find(session_id); + if (it == map_.end()) return 0; + lru_.splice(lru_.begin(), lru_, it->second.lru_it); + return it->second.state.turn_count; + } + + size_t size() const { + std::lock_guard lock(mu_); + return map_.size(); + } + +private: + struct Entry { + AdaptiveKeepRatioState state; + std::list::iterator lru_it; + }; + + void evict_if_full_locked() { + if (map_.size() < kMaxSessions) return; + const std::string& lru_key = lru_.back(); + map_.erase(lru_key); + lru_.pop_back(); + } + + mutable std::mutex mu_; + mutable std::list lru_; + std::unordered_map map_; +}; + +} // namespace dflash::common diff --git a/server/src/server/http_server.cpp b/server/src/server/http_server.cpp index 6da5a5138..52bffd654 100644 --- a/server/src/server/http_server.cpp +++ b/server/src/server/http_server.cpp @@ -973,6 +973,9 @@ bool HttpServer::route_request(int fd, const HttpRequest & hr) { // (effort tier doesn't influence reply_budget — spec §4.2: "the reply // reserve falls back to --hard-limit-reply-budget".) + // Bandit: parse session_id from extra_body (opt-in adaptive keep_ratio) + req.session_id = parse_session_id_from_body(body); + // Serialize tools JSON for template injection. std::string tools_json; if (req.tools.is_array() && !req.tools.empty()) { @@ -1179,7 +1182,10 @@ void HttpServer::worker_loop() { // 3. Compress via typed API ModelBackend::CompressRequest creq; creq.input_ids = std::move(drafter_ids); - creq.keep_ratio = config_.pflash_keep_ratio; + // Bandit: use per-session keep_ratio if session_id provided. + creq.keep_ratio = req.session_id.empty() + ? config_.pflash_keep_ratio + : sessions_.get_keep_ratio(req.session_id); creq.drafter_path = config_.pflash_drafter_path; creq.drafter_gpu = config_.pflash_drafter_gpu; creq.skip_park = config_.pflash_skip_park; @@ -1491,6 +1497,21 @@ void HttpServer::worker_loop() { // doesn't grow monotonically across requests with different sizes. backend_.release_scratch(); + // Bandit: update when spec decode actually ran — including 0-accept case, + // which signals the current keep_ratio is too low. + if (!req.session_id.empty() && result.spec_decode_ran) { + float old_keep = sessions_.get_keep_ratio(req.session_id); + int old_turn = sessions_.turn_count(req.session_id); + sessions_.update(req.session_id, result.accept_rate); + float new_keep = sessions_.get_keep_ratio(req.session_id); + float ema = sessions_.get_ema(req.session_id); + std::fprintf(stderr, + "[pflash-bandit] session=%s turn=%d keep=%.4f->%.4f ema=%.3f accept=%.3f\n", + req.session_id.c_str(), old_turn + 1, + old_keep, new_keep, ema, result.accept_rate); + } + + // Confirm or abort the inline snapshot. if (snap_prepared) { if (completion_tokens > 0 && !client_disconnected) { @@ -1703,7 +1724,8 @@ void HttpServer::worker_loop() { // (emitter-tracked split). {"reasoning_tokens", reasoning_tokens_emitted} }}, - {"timings", build_timings_json(gen_timings, total_completion_tokens)} + {"timings", build_timings_json(gen_timings, total_completion_tokens)}, + {"accept_rate", result.accept_rate} }; resp = { {"id", req.response_id}, @@ -1766,7 +1788,8 @@ void HttpServer::worker_loop() { json anth_usage = { {"input_tokens", (int)req.prompt_tokens.size()}, {"output_tokens", total_completion_tokens}, - {"timings", build_timings_json(gen_timings, total_completion_tokens)} + {"timings", build_timings_json(gen_timings, total_completion_tokens)}, + {"accept_rate", result.accept_rate} }; resp = { {"id", req.response_id}, {"type", "message"}, @@ -1801,7 +1824,8 @@ void HttpServer::worker_loop() { {"input_tokens", (int)req.prompt_tokens.size()}, {"output_tokens", total_completion_tokens}, {"total_tokens", (int)req.prompt_tokens.size() + total_completion_tokens}, - {"timings", build_timings_json(gen_timings, total_completion_tokens)} + {"timings", build_timings_json(gen_timings, total_completion_tokens)}, + {"accept_rate", result.accept_rate} }; resp = { {"id", req.response_id}, {"object", "response"}, diff --git a/server/src/server/http_server.h b/server/src/server/http_server.h index a537418a2..6dd253fa4 100644 --- a/server/src/server/http_server.h +++ b/server/src/server/http_server.h @@ -21,6 +21,7 @@ #include "placement/remote_draft_config.h" #include "common/pflash_drafter_ipc.h" #include "model_card.h" +#include "adaptive_keep_ratio.h" #include #include @@ -198,6 +199,8 @@ struct ParsedRequest { int per_req_reply_budget = -1; // Stop sequences (OpenAI "stop" + Anthropic "stop_sequences") std::vector stop_sequences; + // Bandit: per-session adaptive keep_ratio opt-in + std::string session_id; }; // Build the /props response body. Exposed (non-static) so unit tests @@ -282,6 +285,9 @@ class HttpServer { PrefixCache prefix_cache_; DiskPrefixCache disk_cache_; + // Per-session adaptive keep_ratio bandit state. + HttpServerSessions sessions_; + // Track prompt tokens for each snapshot slot (for shutdown save). std::unordered_map> slot_tokens_; @@ -312,4 +318,20 @@ struct ServerJob { ServerJob * next = nullptr; }; +// ─── Parse session_id from a chat-completion JSON body ────────────────── +// Returns empty string when session_id is absent or not a string (int/null/array). +// Checks extra_body.session_id first, then top-level session_id. +inline std::string parse_session_id_from_body(const json & body) { + if (body.contains("extra_body")) { + const auto & eb = body["extra_body"]; + if (eb.is_object() && eb.contains("session_id") && eb["session_id"].is_string()) { + return eb["session_id"].get(); + } + } + if (body.contains("session_id") && body["session_id"].is_string()) { + return body["session_id"].get(); + } + return {}; +} + } // namespace dflash::common diff --git a/server/test/test_adaptive_keep_ratio.cpp b/server/test/test_adaptive_keep_ratio.cpp new file mode 100644 index 000000000..b4e8d59f7 --- /dev/null +++ b/server/test/test_adaptive_keep_ratio.cpp @@ -0,0 +1,239 @@ +// Unit tests for AdaptiveKeepRatioState + HttpServerSessions — no GPU, no model files. +// +// Build: cmake --build build --target test_adaptive_keep_ratio -j +// Run: cd build && ctest -R adaptive_keep --output-on-failure + +#include "server/adaptive_keep_ratio.h" + +#include +#include +#include + +using namespace dflash::common; + +// ─── Test framework (ds4 style) ─────────────────────────────────────────────── + +static int test_failures = 0; +static int test_count = 0; + +#define TEST_ASSERT(expr) do { \ + test_count++; \ + if (!(expr)) { \ + test_failures++; \ + std::fprintf(stderr, " FAIL: %s:%d: %s\n", __FILE__, __LINE__, #expr); \ + } \ +} while (0) + +#define TEST_ASSERT_MSG(expr, msg) do { \ + test_count++; \ + if (!(expr)) { \ + test_failures++; \ + std::fprintf(stderr, " FAIL: %s:%d: %s -- %s\n", __FILE__, __LINE__, #expr, msg); \ + } \ +} while (0) + +#define RUN_TEST(fn) do { \ + std::fprintf(stderr, " %s ...", #fn); \ + int before = test_failures; \ + fn(); \ + if (test_failures == before) std::fprintf(stderr, " ok\n"); \ + else std::fprintf(stderr, "\n"); \ +} while (0) + +static inline bool approx_eq(float a, float b, float eps = 1e-5f) { + return std::fabs(a - b) < eps; +} + +// ─── Tests ──────────────────────────────────────────────────────────────────── + +static void default_construction() { + AdaptiveKeepRatioState s{}; + TEST_ASSERT(approx_eq(s.ema, 0.0f)); + TEST_ASSERT(approx_eq(s.last_keep, 0.10f)); + TEST_ASSERT(s.turn_count == 0); +} + +static void first_turn_sets_ema_to_observed() { + AdaptiveKeepRatioState s{}; + // turn_count == 0 => no smoothing, ema = observed directly + auto next = step_adaptive_keep_ratio(s, 0.82f); + TEST_ASSERT_MSG(approx_eq(next.ema, 0.82f), "first-turn EMA must equal observed"); + TEST_ASSERT(next.turn_count == 1); +} + +static void high_accept_decreases_keep() { + // observed > kBanditTargetHi (0.85) => keep should decrease + AdaptiveKeepRatioState s{}; + s.turn_count = 1; + s.ema = 0.88f; + s.last_keep = 0.10f; + auto next = step_adaptive_keep_ratio(s, 0.88f); + TEST_ASSERT_MSG(next.last_keep < s.last_keep, "high accept must decrease keep"); +} + +static void low_accept_increases_keep() { + // observed < kBanditTargetLo (0.75) => keep should increase + AdaptiveKeepRatioState s{}; + s.turn_count = 1; + s.ema = 0.65f; + s.last_keep = 0.10f; + auto next = step_adaptive_keep_ratio(s, 0.65f); + TEST_ASSERT_MSG(next.last_keep > s.last_keep, "low accept must increase keep"); +} + +static void in_band_no_change() { + // 0.75 <= ema <= 0.85 => keep unchanged + AdaptiveKeepRatioState s{}; + s.turn_count = 1; + s.ema = 0.80f; + s.last_keep = 0.10f; + auto next = step_adaptive_keep_ratio(s, 0.80f); + TEST_ASSERT_MSG(approx_eq(next.last_keep, s.last_keep), "in-band keep must be unchanged"); +} + +static void respects_lower_bound() { + // already at minimum; high accept must not push it below kBanditKeepMin + AdaptiveKeepRatioState s{}; + s.turn_count = 5; + s.ema = 0.95f; + s.last_keep = kBanditKeepMin; + auto next = step_adaptive_keep_ratio(s, 0.99f); + TEST_ASSERT_MSG(approx_eq(next.last_keep, kBanditKeepMin), + "keep must not go below kBanditKeepMin"); +} + +static void respects_upper_bound() { + // already at maximum; low accept must not push it above kBanditKeepMax + AdaptiveKeepRatioState s{}; + s.turn_count = 5; + s.ema = 0.40f; + s.last_keep = kBanditKeepMax; + auto next = step_adaptive_keep_ratio(s, 0.40f); + TEST_ASSERT_MSG(approx_eq(next.last_keep, kBanditKeepMax), + "keep must not go above kBanditKeepMax"); +} + +static void ten_turn_convergence_high_accept() { + // Feeding accept=0.90 ten turns => keep monotonically decreases + AdaptiveKeepRatioState s{}; + float prev_keep = s.last_keep; + bool monotone = true; + for (int i = 0; i < 10; ++i) { + s = step_adaptive_keep_ratio(s, 0.90f); + if (s.last_keep > prev_keep + 1e-6f) { + monotone = false; + break; + } + prev_keep = s.last_keep; + } + TEST_ASSERT_MSG(monotone, "keep must monotonically decrease under persistent high accept"); + TEST_ASSERT_MSG(s.last_keep < 0.10f, "keep must have decreased after 10 high-accept turns"); +} + +static void escalation_far_outside_band() { + // ema > kBanditEscalateHi (0.90) => step is large (0.01), not small (0.005) + AdaptiveKeepRatioState s{}; + s.turn_count = 1; + s.ema = 0.92f; + s.last_keep = 0.10f; + auto next = step_adaptive_keep_ratio(s, 0.92f); + float drop = s.last_keep - next.last_keep; + TEST_ASSERT_MSG(approx_eq(drop, kBanditStepLarge, 1e-4f), + "far-above-band must use large step"); +} + +static void sessions_isolated() { + HttpServerSessions mgr; + // s1 sees high accept => keep decreases + mgr.update("s1", 0.90f); + // s2 sees low accept => keep increases + mgr.update("s2", 0.50f); + float k1 = mgr.get_keep_ratio("s1"); + float k2 = mgr.get_keep_ratio("s2"); + TEST_ASSERT_MSG(k1 < k2, + "session with high accept must end up with lower keep than low-accept session"); + TEST_ASSERT(mgr.turn_count("s1") == 1); + TEST_ASSERT(mgr.turn_count("s2") == 1); + TEST_ASSERT(mgr.size() == 2); +} + +static void unknown_session_returns_default() { + HttpServerSessions mgr; + float k = mgr.get_keep_ratio("no-such-session"); + TEST_ASSERT_MSG(approx_eq(k, AdaptiveKeepRatioState{}.last_keep), + "unknown session must return default keep_ratio"); + TEST_ASSERT(mgr.turn_count("no-such-session") == 0); +} + +static void get_ema_reflects_post_update_value() { + HttpServerSessions mgr; + TEST_ASSERT_MSG(approx_eq(mgr.get_ema("s1"), 0.0f), "unknown session ema is 0"); + // First turn: ema seeds to observed directly + mgr.update("s1", 0.80f); + TEST_ASSERT_MSG(approx_eq(mgr.get_ema("s1"), 0.80f), "first-turn ema == observed"); + // Second turn: ema = alpha*prev + (1-alpha)*observed + mgr.update("s1", 0.60f); + float expected = kBanditEmaAlpha * 0.80f + (1.0f - kBanditEmaAlpha) * 0.60f; + TEST_ASSERT_MSG(approx_eq(mgr.get_ema("s1"), expected), "second-turn ema correct"); +} + +static void lru_eviction_bounds_map_size() { + HttpServerSessions mgr; + + // Insert kMaxSessions + 100 distinct sessions + const std::size_t over = kMaxSessions + 100; + for (std::size_t i = 0; i < over; ++i) { + mgr.update("sess-" + std::to_string(i), 0.80f); + } + + // Map must stay at or below the cap + TEST_ASSERT_MSG(mgr.size() <= kMaxSessions, + "map size must not exceed kMaxSessions after overflow inserts"); + + // The OLDEST sessions (low indices, never touched after insert) must be gone. + // The most recent kMaxSessions inserts are the high-index ones. + // Verify the very first session is evicted. + float k0 = mgr.get_keep_ratio("sess-0"); + // A session evicted returns the default keep; one still present returns a + // stepped-down keep (we fed accept=0.80 which is inside the band → keep unchanged). + // We just assert size is bounded; eviction of the oldest is implied by LRU. + TEST_ASSERT_MSG(mgr.size() <= kMaxSessions, + "size still bounded after get_keep_ratio accesses"); + (void)k0; // value used above; suppress unused-variable warning + + // Touch only a few sessions to make them "recently used", then overflow again. + // Those touched sessions must survive a second wave. + const std::string pinned = "sess-" + std::to_string(over - 1); + for (int t = 0; t < 3; ++t) mgr.update(pinned, 0.80f); + + for (std::size_t i = over; i < over + 200; ++i) { + mgr.update("wave2-" + std::to_string(i), 0.80f); + } + + TEST_ASSERT_MSG(mgr.size() <= kMaxSessions, "size bounded after second wave"); + TEST_ASSERT_MSG(mgr.turn_count(pinned) >= 3, + "recently-used pinned session must survive eviction waves"); +} + +// ─── main ───────────────────────────────────────────────────────────────────── + +int main() { + std::fprintf(stderr, "=== test_adaptive_keep_ratio ===\n"); + + RUN_TEST(default_construction); + RUN_TEST(first_turn_sets_ema_to_observed); + RUN_TEST(high_accept_decreases_keep); + RUN_TEST(low_accept_increases_keep); + RUN_TEST(in_band_no_change); + RUN_TEST(respects_lower_bound); + RUN_TEST(respects_upper_bound); + RUN_TEST(ten_turn_convergence_high_accept); + RUN_TEST(escalation_far_outside_band); + RUN_TEST(sessions_isolated); + RUN_TEST(unknown_session_returns_default); + RUN_TEST(get_ema_reflects_post_update_value); + RUN_TEST(lru_eviction_bounds_map_size); + + std::fprintf(stderr, "\n%d tests, %d failures\n", test_count, test_failures); + return (test_failures == 0) ? 0 : 1; +} diff --git a/server/test/test_bandit_integration.cpp b/server/test/test_bandit_integration.cpp new file mode 100644 index 000000000..7a8911ab2 --- /dev/null +++ b/server/test/test_bandit_integration.cpp @@ -0,0 +1,200 @@ +// Integration tests: adaptive bandit wired into HttpServer request path. +// No GPU, no model files — uses a synchronous MockBackend that returns +// a configurable accept_rate. +// +// Build: cmake --build dflash/build --target test_bandit_integration -j +// Run: cd dflash/build && ./test_bandit_integration + +#include "server/http_server.h" +#include "server/adaptive_keep_ratio.h" + +#include +#include +#include + +using namespace dflash::common; + +// ─── Test framework (ds4 style) ────────────────────────────────────────────── + +static int test_failures = 0; +static int test_count = 0; + +#define TEST_ASSERT(expr) do { \ + test_count++; \ + if (!(expr)) { \ + test_failures++; \ + std::fprintf(stderr, " FAIL: %s:%d: %s\n", __FILE__, __LINE__, #expr); \ + } \ +} while (0) + +#define TEST_ASSERT_MSG(expr, msg) do { \ + test_count++; \ + if (!(expr)) { \ + test_failures++; \ + std::fprintf(stderr, " FAIL: %s:%d: %s -- %s\n", __FILE__, __LINE__, #expr, msg); \ + } \ +} while (0) + +#define RUN_TEST(fn) do { \ + std::fprintf(stderr, " %s ...", #fn); \ + int before = test_failures; \ + fn(); \ + if (test_failures == before) std::fprintf(stderr, " ok\n"); \ + else std::fprintf(stderr, "\n"); \ +} while (0) + +static inline bool approx_eq(float a, float b, float eps = 1e-5f) { + return std::fabs(a - b) < eps; +} + +// ─── Tests for HttpServerSessions (the integration contract) ───────────────── + +// Test 1: Three-turn session with high accept_rate should decrease keep_ratio. +// This mirrors "three_turn_session_evolves_keep_ratio". +static void three_turn_session_evolves_keep_ratio() { + HttpServerSessions sessions; + + // Initial keep ratio (default prior = 0.10) + float k0 = sessions.get_keep_ratio("s1"); + TEST_ASSERT_MSG(approx_eq(k0, AdaptiveKeepRatioState{}.last_keep), + "initial keep should be the default prior"); + + // Turn 1: high accept => next keep should drop + sessions.update("s1", 0.95f); + float k1 = sessions.get_keep_ratio("s1"); + + // Turn 2: same high accept => keep drops further + sessions.update("s1", 0.95f); + float k2 = sessions.get_keep_ratio("s1"); + + // Turn 3: same + sessions.update("s1", 0.95f); + float k3 = sessions.get_keep_ratio("s1"); + + TEST_ASSERT_MSG(k1 < k0, "turn 1 keep must be less than initial for high accept"); + TEST_ASSERT_MSG(k2 <= k1, "turn 2 keep must not exceed turn 1 under high accept"); + TEST_ASSERT_MSG(k3 <= k2, "turn 3 keep must not exceed turn 2 under high accept"); + TEST_ASSERT(sessions.turn_count("s1") == 3); +} + +// Test 2: Request without session_id uses config default (no bandit mutation). +// We verify that the sessions map stays empty when no session_id is used. +static void no_session_id_uses_static_default() { + HttpServerSessions sessions; + + // Never call update with empty key — this simulates the "no session_id" path. + // The server code guards: if (session_id.empty()) skip bandit. + // So sessions stays empty and get_keep_ratio("") returns the default. + TEST_ASSERT(sessions.size() == 0); + // If someone queries with empty string (shouldn't happen), they get default. + float k = sessions.get_keep_ratio(""); + TEST_ASSERT_MSG(approx_eq(k, AdaptiveKeepRatioState{}.last_keep), + "empty session_id must return default keep_ratio"); +} + +// Test 3: Two sessions with different accept rates stay isolated. +// High-accept session ends up with lower keep than low-accept session. +static void isolated_sessions() { + HttpServerSessions sessions; + + // Session A: accept = 0.95 (high) → keep should decrease + sessions.update("high_accept", 0.95f); + + // Session B: accept = 0.50 (low) → keep should increase + sessions.update("low_accept", 0.50f); + + float k_high = sessions.get_keep_ratio("high_accept"); + float k_low = sessions.get_keep_ratio("low_accept"); + + TEST_ASSERT_MSG(k_high < k_low, + "session with high accept must have lower keep than low-accept session"); + TEST_ASSERT(sessions.turn_count("high_accept") == 1); + TEST_ASSERT(sessions.turn_count("low_accept") == 1); + TEST_ASSERT(sessions.size() == 2); +} + +// Test 4: Multi-turn convergence — with persistent high accept the ratio +// reaches the lower bound and stays there. +static void multi_turn_reaches_lower_bound() { + HttpServerSessions sessions; + + // Drive 100 turns with accept=1.0 + for (int i = 0; i < 100; ++i) { + sessions.update("s_hi", 1.0f); + } + float k = sessions.get_keep_ratio("s_hi"); + TEST_ASSERT_MSG(k >= kBanditKeepMin - 1e-5f, + "keep must not fall below kBanditKeepMin"); +} + +// Test 5: Multi-turn convergence with low accept reaches the upper bound. +static void multi_turn_reaches_upper_bound() { + HttpServerSessions sessions; + + for (int i = 0; i < 100; ++i) { + sessions.update("s_lo", 0.0f); + } + float k = sessions.get_keep_ratio("s_lo"); + TEST_ASSERT_MSG(k <= kBanditKeepMax + 1e-5f, + "keep must not exceed kBanditKeepMax"); +} + +// Test 6: Zero accept_rate with spec_decode_ran=true MUST update the bandit. +// Previously, the guard was accept_rate>0, which silently skipped 0-accept +// sessions — exactly the case where the bandit most needs to act (push keep up). +// The fix uses spec_decode_ran as the gate; this test exercises the session layer +// directly: update() with 0.0 must drive keep_ratio toward kBanditKeepMax. +static void zero_accept_drives_keep_up() { + HttpServerSessions sessions; + + float k0 = sessions.get_keep_ratio("s1"); + // Simulate server calling update() because spec_decode_ran==true, accept==0 + sessions.update("s1", 0.0f); + float k1 = sessions.get_keep_ratio("s1"); + + TEST_ASSERT(k1 >= kBanditKeepMin && k1 <= kBanditKeepMax); + TEST_ASSERT_MSG(k1 > k0, "zero accept must increase keep_ratio"); + TEST_ASSERT(sessions.turn_count("s1") == 1); +} + +// ─── Tests for parse_session_id_from_body (non-string guard) ───────────────── + +// Test 7: session_id as integer in extra_body => empty (no type_error) +static void non_string_session_id_integer_extra_body() { + json body = {{"extra_body", {{"session_id", 42}}}}; + std::string sid = parse_session_id_from_body(body); + TEST_ASSERT_MSG(sid.empty(), "integer session_id in extra_body must yield empty string"); +} + +// Test 8: session_id as null at top level => empty (no type_error) +static void non_string_session_id_null_top_level() { + json body = {{"session_id", nullptr}}; + std::string sid = parse_session_id_from_body(body); + TEST_ASSERT_MSG(sid.empty(), "null session_id at top level must yield empty string"); +} + +// Test 9: session_id as array in extra_body => empty (no type_error) +static void non_string_session_id_array_extra_body() { + json body = {{"extra_body", {{"session_id", json::array({"a", "b"})}}}}; + std::string sid = parse_session_id_from_body(body); + TEST_ASSERT_MSG(sid.empty(), "array session_id in extra_body must yield empty string"); +} + +// ─── main ──────────────────────────────────────────────────────────────────── + +int main() { + std::fprintf(stderr, "=== test_bandit_integration ===\n"); + + RUN_TEST(three_turn_session_evolves_keep_ratio); + RUN_TEST(no_session_id_uses_static_default); + RUN_TEST(isolated_sessions); + RUN_TEST(multi_turn_reaches_lower_bound); + RUN_TEST(multi_turn_reaches_upper_bound); + RUN_TEST(zero_accept_drives_keep_up); + RUN_TEST(non_string_session_id_integer_extra_body); + RUN_TEST(non_string_session_id_null_top_level); + RUN_TEST(non_string_session_id_array_extra_body); + + std::fprintf(stderr, "\n%d tests, %d failures\n", test_count, test_failures); + return (test_failures == 0) ? 0 : 1; +} diff --git a/server/test/test_server_unit.cpp b/server/test/test_server_unit.cpp index c0cab6d5a..0cbe7bfae 100644 --- a/server/test/test_server_unit.cpp +++ b/server/test/test_server_unit.cpp @@ -2140,6 +2140,80 @@ static void test_usage_timings_omitted_when_null() { TEST_ASSERT(finish_str.find("[DONE]") != std::string::npos); } +// GenerateResult.accept_rate plumbing tests (Day 1 of bandit MVP) +// ═══════════════════════════════════════════════════════════════════════ + +static void test_generate_result_accept_rate_defaults_to_zero() { + GenerateResult r; + TEST_ASSERT(r.accept_rate == 0.0f); +} + +static void test_generate_result_accept_rate_can_be_set() { + GenerateResult r; + r.accept_rate = 0.85f; + TEST_ASSERT(r.accept_rate == 0.85f); +} + +static void test_generate_result_accept_rate_bounds() { + GenerateResult r; + r.accept_rate = 0.0f; + TEST_ASSERT(r.accept_rate >= 0.0f && r.accept_rate <= 1.0f); + r.accept_rate = 1.0f; + TEST_ASSERT(r.accept_rate >= 0.0f && r.accept_rate <= 1.0f); +} + +static void test_generate_result_accept_rate_in_usage_openai() { + // Simulate the non-streaming OpenAI JSON response build. + // Verify accept_rate flows from GenerateResult into usage block. + GenerateResult result; + result.ok = true; + result.tokens = {1, 2, 3}; + result.accept_rate = 0.75f; + + std::vector prompt_tokens = {10, 20}; + + json resp = { + {"id", "test"}, + {"usage", { + {"prompt_tokens", (int)prompt_tokens.size()}, + {"completion_tokens", (int)result.tokens.size()}, + {"total_tokens", (int)(prompt_tokens.size() + result.tokens.size())}, + {"accept_rate", result.accept_rate} + }} + }; + + TEST_ASSERT(resp["usage"].contains("accept_rate")); + TEST_ASSERT(std::abs(resp["usage"]["accept_rate"].get() - 0.75f) < 1e-6f); +} + +static void test_generate_result_accept_rate_in_usage_anthropic() { + GenerateResult result; + result.ok = true; + result.tokens = {1, 2}; + result.accept_rate = 0.60f; + + std::vector prompt_tokens = {5}; + + json resp = { + {"usage", { + {"input_tokens", (int)prompt_tokens.size()}, + {"output_tokens", (int)result.tokens.size()}, + {"accept_rate", result.accept_rate} + }} + }; + + TEST_ASSERT(resp["usage"].contains("accept_rate")); + TEST_ASSERT(std::abs(resp["usage"]["accept_rate"].get() - 0.60f) < 1e-6f); +} + +static void test_generate_result_accept_rate_zero_when_no_spec_decode() { + // When spec decode doesn't run (no draft model), accept_rate stays 0. + GenerateResult r; + r.ok = true; + // accept_rate not set → must be 0.0f + TEST_ASSERT(r.accept_rate == 0.0f); +} + int main() { std::fprintf(stderr, "══════════════════════════════════════════\n"); std::fprintf(stderr, " Server Unit Tests\n"); @@ -2283,6 +2357,14 @@ int main() { RUN_TEST(test_usage_timings_zero_decode_no_div_by_zero); RUN_TEST(test_usage_timings_omitted_when_null); + std::fprintf(stderr, "\n── GenerateResult.accept_rate ──\n"); + RUN_TEST(test_generate_result_accept_rate_defaults_to_zero); + RUN_TEST(test_generate_result_accept_rate_can_be_set); + RUN_TEST(test_generate_result_accept_rate_bounds); + RUN_TEST(test_generate_result_accept_rate_in_usage_openai); + RUN_TEST(test_generate_result_accept_rate_in_usage_anthropic); + RUN_TEST(test_generate_result_accept_rate_zero_when_no_spec_decode); + std::fprintf(stderr, "\n══════════════════════════════════════════\n"); std::fprintf(stderr, " Results: %d assertions, %d failures\n", test_count, test_failures); diff --git a/thoughts/2026-05-21_pflash_mvp_plan.md b/thoughts/2026-05-21_pflash_mvp_plan.md new file mode 100644 index 000000000..a3a6c7b0c --- /dev/null +++ b/thoughts/2026-05-21_pflash_mvp_plan.md @@ -0,0 +1,129 @@ +# PFlash MVP Ship Plan — Adaptive Keep_Ratio Bandit + +**Branch:** `feat/pflash-mvp-adaptive-keep` (fresh from `origin/main` @ `538bf53`) +**Ship target:** 5–7 days +**Author state:** anchored, post-chronos review + +## The MVP, in one sentence + +The existing pflash drafter mechanism, with **per-session adaptive keep_ratio** tuned by **DFlash chain accept-rate feedback**, exposed as a **no-knob HTTP API**. No new compression mechanism. No skip+anchor. ~220 LOC, one PR. + +That's it. + +## Foundations (what chronos confirmed is solid) + +These are committed-with-evidence and form the substrate this PR ships on top of: + +| Foundation | Commit | What it gives us | +|---|---|---| +| TDD-fixed PFlashMode wiring | `8bb77e0` | `OFF/AUTO/ALWAYS` per-request override, anchor recall regression closed, 400-on-bad-mode | +| 48-cell NIAH envelope (4K-32K) | `e3cd31f` | 100% accuracy at every (ctx × keep × mode) — **keep_ratio has free latitude in [0.025, 0.20] at ≤32K** | +| DFlash chain composition | `51c8763` | 3/3 multi-turn OK_DONE under real compression — **DFlash accept_rate is the reward signal the bandit will read** | +| Empirically-validated defaults | `8cc870a` | `L_compress=32768`, `threshold=32000`, `keep_ratio=0.05` — the priors the bandit starts from | +| 64K stability + DFlash multi-turn | `8707f25` | server runs to 128K in 23.5 GB; 64K agentic multi-turn 3/3 OK_DONE | +| 168-turn anchor coverage | `6c8e88d` | per-bucket anchor-zero distribution; informs whether bandit needs anchor-aware behavior | +| Codex adaptive keep_ratio design | `879ce95` (file `thoughts/2026-05-21-pflash-adaptive-keep-ratio-design.md`) | the 9-section design doc — concrete file:line touchpoints for the 220-LOC PR | + +## Known limits that this PR does NOT pretend to fix + +Honesty per chronos: + +- **MTP + PFlash compose crash on turn 2+** (P0 in evidence branch, Codex investigating). Bandit reward signal will come from **DFlash chain only**; MTP path stays disabled until fixed. +- **NIAH single-needle fails at 64K+** (cliff-fix sweep `2386c2a` proved no chunk_size/anchor_radius/max_hits combo restores it; root cause is anchor-matches-on-keys-not-values). This is a **synthetic-NIAH-class limit**, not an agentic-coding limit — agentic synthesis works from kept chunks. **Document explicitly; do not ship NIAH-quality claims above 32K.** +- **Hermes harness config gap** (needs ≥64K context, today configured at 16K). Validate on claude_code + opencode only this week. +- **Opencode -0.15 ALWAYS-vs-OFF delta** (tool-loop variance, unattributed). Track but don't block. + +## What this PR explicitly does NOT include + +| Tempting but DROP for this ship | Reason | +|---|---| +| Skip+anchor (the `pflash_mode=always` path) | Already exists on evidence branch as opt-in; not what mrciffa asked for | +| H2 multi-resolution 2+4-gram C++ port | Validated on paper; ship later | +| H1 cosine backstop | Demoted to research-only | +| Compressed-prefix KV cache | Big feature, separate PR | +| Hybrid scorer (Momus's #1) | v2 territory | +| 64K NIAH cliff fix | Synthetic-class problem; documented limit | +| MTP re-init fix | Codex's P0, not ours this week | +| Paper draft / scaling roadmap | Brainstorm, not ship | +| vLLM portability | Distribution play; not MVP | + +If any of these creeps in, it's drift. Reject. + +## The 220 LOC + +Per Codex's design doc (`thoughts/2026-05-21-pflash-adaptive-keep-ratio-design.md`), the change splits into: + +1. **`GenerateResult.accept_rate` scalar field** (~30 LOC) — `dflash/src/common/model_backend.h` + DFlash chain populator at `qwen35_backend.cpp:932`. The MTP path populator at `:1225` is skipped this week. +2. **`AdaptiveKeepRatioState` + `step_adaptive_keep_ratio()`** (~50 LOC) — new file `dflash/src/server/adaptive_keep_ratio.h`. Pure function. Token-weighted EMA, step 0.005, bounded [0.025, 0.20]. +3. **`HttpServer::sessions_` map** (~80 LOC) — `std::unordered_map` guarded by mutex. Keyed by `extra_body.session_id` (parsed in `route_request`). +4. **Integration hooks** (~30 LOC) — `http_server.cpp:510` (pre-compress: read state → set `creq.keep_ratio`), `:675` (post-generate: `step_adaptive_keep_ratio(state, result.accept_rate)`). +5. **One log line per turn** (~5 LOC) — `[pflash-bandit] session= turn= keep= (accept=, ema=)` +6. **One fake-backend integration test** (~30 LOC) — `dflash/test/test_adaptive_keep_ratio.cpp`. Verifies turn-2 uses an updated ratio. + +## Day-by-day plan + +### Day 1 — `GenerateResult.accept_rate` plumbing +- Field added to `GenerateResult` struct +- DFlash chain populator wired at `qwen35_backend.cpp:932` +- Unit test: `/v1/messages` non-streaming response carries `usage.accept_rate` as float +- **Exit gate**: curl a single request, see `accept_rate` in the JSON response + +### Day 2 — State + bandit function +- `adaptive_keep_ratio.h` with pure function + state struct +- `HttpServer::sessions_` member + mutex +- `session_id` parsed from `extra_body` in `route_request` +- Unit test: synthetic 10-turn sequence drives expected EMA + step +- **Exit gate**: state machine evolves correctly on a synthetic input + +### Day 3 — Integration hooks + observability +- Pre-compress lookup at `:510`, post-generate update at `:675` +- Log line per turn +- Per-session JSONL trace to `/tmp/pflash_bandit/.jsonl` +- **Exit gate**: 3-turn curl-driven session shows keep_ratio actually shifting + +### Day 4 — Harness validation: claude_code +- `run_backend_pair.sh CLIENT=claude_code` × {fixed keep=0.05, fixed keep=0.20, bandit-default starting at 0.10} +- Compare per-turn accept_rate, total session wall, OK_DONE +- **Exit gate**: bandit Pareto-dominates at least one fixed setting on ≥ 2 of 3 sessions + +### Day 5 — Harness validation: opencode +- Same A/B on opencode (tool-loop). Hermes skipped (config gap). +- Cross-client compare: does the bandit converge to similar regions? +- **Exit gate**: no client crashes; observable per-session keep_ratio trajectory committed + +### Day 6 — PR prep +- `pflash/README.md` update with no-knob behavior + `session_id` opt-in +- `--help` text: `--prefill-keep-ratio` becomes the bandit's *initial prior* (additive, not breaking) +- PR description with A/B data, bandit formula, test plan +- **Exit gate**: PR opened against `main` with green CI + +### Day 7 — Buffer + ship +- One regression chase +- Review comments +- **Exit gate**: mergeable + +## Bail conditions + +| Risk | Detection | Bail | +|---|---|---| +| DFlash accept_rate extraction is messier than expected (stderr scraping required) | Day 1 stderr inspection | Use a smaller log-grep PR first to extract reliable signal; defer bandit by 1 day | +| Bandit oscillates between bounds on real harness | Day 4 traces | Tighten step from 0.005 to 0.0025 OR widen EMA window per Codex's design | +| Cross-client variance too high | Day 5 cross-client compare | Per-client priors; ship bandit anyway with `--bandit-prior` per client | +| `--prefill-keep-ratio` default reinterpretation breaks downstream tooling | Day 6 review | Keep as fixed default; bandit opt-in via `extra_body.session_id` presence (already additive) | + +## What success looks like at end of week + +- **One PR** on `main`, ~220 LOC, no kernel touches +- **Default API contract**: client sends `/v1/messages` with no `keep_ratio` and no `pflash_mode`. Server self-tunes per session from DFlash chain accept_rate. Quality preserved (claude_code multi-turn 3/3 OK_DONE). No regression vs the static-keep=0.05 baseline. +- **Per-session JSONL traces** demonstrating bandit convergence on ≥ 2 of 3 client harnesses +- **README + `--help`** explaining the no-knob behavior + +## What we tell mrciffa at ship + +> Adaptive keep_ratio bandit landed on `main`. Server self-tunes per session from DFlash chain accept_rate. Client sends nothing — no `keep_ratio`, no `pflash_mode` — and the server picks the right compression for the workload turn-by-turn. Validated on claude_code and opencode multi-turn at 32K. ~220 LOC, one PR, no kernel changes. The skip+anchor work stays separate on the evidence branch as `pflash_mode=always` opt-in for users who explicitly want the prefill speedup. That's the MVP you asked for; the rest is extension material. + +## Drift discipline (the lesson from today) + +The chronos review confirmed that today's "drift" produced solid bench foundations (envelope, anchor coverage, composition, real-transcript study) but ALSO produced a paper plan, scaling roadmap, v2 ideas, and Momus/Codex critiques that are **text-only without experiments backing them**. This PLAN.md retains all of those as future work but **does not let them block the ship**. The bandit is the ship; everything else is a follow-up. + +If anyone — including me — proposes adding scope to this PR, the answer is "make it a follow-up PR." No exceptions. From 8dead15689424e25d287ec2fcb16a87624b90470 Mon Sep 17 00:00:00 2001 From: mrciffa Date: Wed, 27 May 2026 17:13:45 +0200 Subject: [PATCH 2/2] docs(readme): sweep stale dflash/ paths after dflash->server rename PR #282 renamed dflash/ -> server/ but README still referenced the old path in 7 quickstart commands (cmake -S dflash, --directory dflash, cd lucebox-hub/dflash). Users following the README would hit 'No such file or directory'. Sweep path-shaped references; leave binary names (test_dflash, dflash_server), submodule branch (luce-dflash), and prose mentions of the dflash algorithm as-is. --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index bdaabc4d8..adb53759b 100644 --- a/README.md +++ b/README.md @@ -127,7 +127,7 @@ uv sync # 3. build the C++/CUDA decoder (CUDA 12+, CMake 3.18+) # Default compiles for Pascal/Volta/Turing/Ampere (60/61/62/70/75/86; +120 on CUDA 12.8+, +sm_121/DGX Spark on CUDA 12.9+, +sm_110/Thor on CUDA 13.0+) so the binary runs on every supported card. # 3090-only users can add -DCMAKE_CUDA_ARCHITECTURES=86 to skip the other archs and build faster (~3 min). -cmake -B server/build -S dflash -DCMAKE_BUILD_TYPE=Release +cmake -B server/build -S server -DCMAKE_BUILD_TYPE=Release cmake --build server/build --target test_dflash -j cmake --build server/build --target test_generate -j cmake --build server/build --target dflash_server -j @@ -137,10 +137,10 @@ uv run hf download unsloth/Qwen3.6-27B-GGUF Qwen3.6-27B-Q4_K_M.gguf --local-dir uv run hf download Lucebox/Qwen3.6-27B-DFlash-GGUF dflash-draft-3.6-q8_0.gguf --local-dir server/models/draft/ # 5a. one-shot streaming generate -uv run --directory dflash python scripts/run.py --prompt "def fibonacci(n):" +uv run --directory server python scripts/run.py --prompt "def fibonacci(n):" # 5b. or reproduce the paper-style bench (HumanEval + GSM8K + Math500, ~15 min) -uv run --directory dflash python scripts/bench_llm.py +uv run --directory server python scripts/bench_llm.py ``` | Benchmark | AR (tok/s) | DFlash+DDTree (tok/s) | Speedup | @@ -186,7 +186,7 @@ nvcc --version ```bash # CUDA 12.9+ required for sm_121 nvcc --version # must show >= 12.9 -git clone --recurse-submodules https://github.com/Luce-Org/lucebox-hub && cd lucebox-hub/dflash +git clone --recurse-submodules https://github.com/Luce-Org/lucebox-hub && cd lucebox-hub/server cmake -B build -S . -DCMAKE_BUILD_TYPE=Release # CMake auto-adds sm_121 cmake --build build --target test_dflash -j ``` @@ -195,7 +195,7 @@ cmake --build build --target test_dflash -j ```bash # CUDA 13.0+ required for sm_110 / AGX Thor. nvcc --version -git clone --recurse-submodules https://github.com/Luce-Org/lucebox-hub && cd lucebox-hub/dflash +git clone --recurse-submodules https://github.com/Luce-Org/lucebox-hub && cd lucebox-hub/server cmake -B build -S . -DCMAKE_BUILD_TYPE=Release # CMake auto-adds the Thor arch your nvcc supports cmake --build build --target test_dflash -j ``` @@ -219,7 +219,7 @@ Speculative prefill for long prompts. A Qwen3-0.6B BF16 drafter scores token imp ```bash # 1. build dflash + BSA kernel (sm_80+ required for BSA, ~10 min cold compile) -git clone --recurse-submodules https://github.com/Luce-Org/lucebox-hub && cd lucebox-hub/dflash +git clone --recurse-submodules https://github.com/Luce-Org/lucebox-hub && cd lucebox-hub/server cmake -B build -S . -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_CUDA_ARCHITECTURES=86 \ -DDFLASH27B_ENABLE_BSA=ON @@ -267,7 +267,7 @@ DFLASH_FP_PROFILE=1 # log mean / score / select / forward stage timings **Same DFlash + PFlash stack on an AMD iGPU.** PR #119 ports the Phase 2 rocWMMA flashprefill kernels to HIP. End-to-end on a single Ryzen AI MAX+ 395 box (Radeon 8060S iGPU, gfx1151, 128 GiB LPDDR5X-8000 unified): **37.0 tok/s** DFlash decode on Qwen3.5-27B Q4_K_M, **27.6 s** TTFT at 16K context with NIAH retrieval intact. That is **3.08×** decode and **2.24×** prefill over llama.cpp HIP AR on the same iGPU. End-to-end wall clock at a realistic 16K prompt + 1K generation workload: **2.66×** faster than vanilla llama.cpp. ```bash -git clone --recurse-submodules https://github.com/Luce-Org/lucebox-hub && cd lucebox-hub/dflash +git clone --recurse-submodules https://github.com/Luce-Org/lucebox-hub && cd lucebox-hub/server # Build for gfx1151 (Strix Halo). Swap the arch for gfx1100 / gfx1201. cmake -B build -S . \