Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
- uses: actions/checkout@v4
with:
submodules: recursive
token: ${{ secrets.SUBMODULE_PAT || secrets.GITHUB_TOKEN }}

- uses: Jimver/cuda-toolkit@v0.2.35
with:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ env/
*.qdrep
*.sqlite
bench-out/
dflash/bench/results/
profile-out/

# Model weights and caches (pull fresh from HF)
Expand Down
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ uv sync
# 3. build the C++/CUDA decoder (CUDA 12+, CMake 3.18+)
# Default compiles for Pascal/Volta/Turing/Ampere (60/61/62/70/75/86; +120 on CUDA 12.8+, +sm_121/DGX Spark on CUDA 12.9+, +sm_110/Thor on CUDA 13.0+) so the binary runs on every supported card.
# 3090-only users can add -DCMAKE_CUDA_ARCHITECTURES=86 to skip the other archs and build faster (~3 min).
cmake -B server/build -S dflash -DCMAKE_BUILD_TYPE=Release
cmake -B server/build -S server -DCMAKE_BUILD_TYPE=Release
cmake --build server/build --target test_dflash -j
cmake --build server/build --target test_generate -j
cmake --build server/build --target dflash_server -j
Expand All @@ -137,10 +137,10 @@ uv run hf download unsloth/Qwen3.6-27B-GGUF Qwen3.6-27B-Q4_K_M.gguf --local-dir
uv run hf download Lucebox/Qwen3.6-27B-DFlash-GGUF dflash-draft-3.6-q8_0.gguf --local-dir server/models/draft/

# 5a. one-shot streaming generate
uv run --directory dflash python scripts/run.py --prompt "def fibonacci(n):"
uv run --directory server python scripts/run.py --prompt "def fibonacci(n):"

# 5b. or reproduce the paper-style bench (HumanEval + GSM8K + Math500, ~15 min)
uv run --directory dflash python scripts/bench_llm.py
uv run --directory server python scripts/bench_llm.py
```

| Benchmark | AR (tok/s) | DFlash+DDTree (tok/s) | Speedup |
Expand Down Expand Up @@ -186,7 +186,7 @@ nvcc --version
```bash
# CUDA 12.9+ required for sm_121
nvcc --version # must show >= 12.9
git clone --recurse-submodules https://github.com/Luce-Org/lucebox-hub && cd lucebox-hub/dflash
git clone --recurse-submodules https://github.com/Luce-Org/lucebox-hub && cd lucebox-hub/server
cmake -B build -S . -DCMAKE_BUILD_TYPE=Release # CMake auto-adds sm_121
cmake --build build --target test_dflash -j
```
Expand All @@ -195,7 +195,7 @@ cmake --build build --target test_dflash -j
```bash
# CUDA 13.0+ required for sm_110 / AGX Thor.
nvcc --version
git clone --recurse-submodules https://github.com/Luce-Org/lucebox-hub && cd lucebox-hub/dflash
git clone --recurse-submodules https://github.com/Luce-Org/lucebox-hub && cd lucebox-hub/server
cmake -B build -S . -DCMAKE_BUILD_TYPE=Release # CMake auto-adds the Thor arch your nvcc supports
cmake --build build --target test_dflash -j
```
Expand All @@ -219,7 +219,7 @@ Speculative prefill for long prompts. A Qwen3-0.6B BF16 drafter scores token imp

```bash
# 1. build dflash + BSA kernel (sm_80+ required for BSA, ~10 min cold compile)
git clone --recurse-submodules https://github.com/Luce-Org/lucebox-hub && cd lucebox-hub/dflash
git clone --recurse-submodules https://github.com/Luce-Org/lucebox-hub && cd lucebox-hub/server
cmake -B build -S . -DCMAKE_BUILD_TYPE=Release \
-DCMAKE_CUDA_ARCHITECTURES=86 \
-DDFLASH27B_ENABLE_BSA=ON
Expand Down Expand Up @@ -267,7 +267,7 @@ DFLASH_FP_PROFILE=1 # log mean / score / select / forward stage timings
**Same DFlash + PFlash stack on an AMD iGPU.** PR #119 ports the Phase 2 rocWMMA flashprefill kernels to HIP. End-to-end on a single Ryzen AI MAX+ 395 box (Radeon 8060S iGPU, gfx1151, 128 GiB LPDDR5X-8000 unified): **37.0 tok/s** DFlash decode on Qwen3.5-27B Q4_K_M, **27.6 s** TTFT at 16K context with NIAH retrieval intact. That is **3.08×** decode and **2.24×** prefill over llama.cpp HIP AR on the same iGPU. End-to-end wall clock at a realistic 16K prompt + 1K generation workload: **2.66×** faster than vanilla llama.cpp.

```bash
git clone --recurse-submodules https://github.com/Luce-Org/lucebox-hub && cd lucebox-hub/dflash
git clone --recurse-submodules https://github.com/Luce-Org/lucebox-hub && cd lucebox-hub/server

# Build for gfx1151 (Strix Halo). Swap the arch for gfx1100 / gfx1201.
cmake -B build -S . \
Expand Down
5 changes: 5 additions & 0 deletions harness/clients/prompts/logic_check.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Answer these logic puzzles. End your answer with OK_DONE.

1. If all roses are flowers and some flowers fade quickly, can we conclude that some roses fade quickly?
2. A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost?
3. If you have a 3-litre jug and a 5-litre jug, how can you measure exactly 4 litres of water?
5 changes: 5 additions & 0 deletions harness/clients/prompts/math_check.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Solve the following math problems. End your answer with OK_DONE.

1. What is 17 * 23?
2. What is the sum of the first 10 prime numbers?
3. If a rectangle has width 7 and height 11, what is its area?
43 changes: 41 additions & 2 deletions harness/clients/run_claude_code.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,45 @@ start_lucebox_server
trap stop_lucebox_server EXIT
wait_lucebox_server

# When PFLASH_SESSION_ID is set, start a thin proxy that injects
# extra_body.session_id into every /v1/messages request. The claude CLI
# cannot inject extra_body natively, so the proxy does it transparently.
PROXY_PID=""
CLIENT_BASE_URL="$BASE_URL"
if [[ -n "${PFLASH_SESSION_ID:-}" ]]; then
PROXY_PORT="${PFLASH_PROXY_PORT:-18082}"
python3 "$SCRIPT_DIR/session_inject_proxy.py" \
--host "$HOST" \
--port "$PROXY_PORT" \
--upstream "$BASE_URL" \
--session-id "$PFLASH_SESSION_ID" \
>> "$LOG_DIR/proxy.log" 2>&1 &
PROXY_PID=$!
_proxy_ready=0
for _i in $(seq 1 10); do
if curl -fsS "http://$HOST:$PROXY_PORT/health" >/dev/null 2>&1; then _proxy_ready=1; break; fi
sleep 1
if ! kill -0 "$PROXY_PID" 2>/dev/null; then
echo "session-inject proxy exited early; log: $LOG_DIR/proxy.log" >&2
cat "$LOG_DIR/proxy.log" >&2 || true
exit 1
fi
done
if [[ "$_proxy_ready" -eq 0 ]]; then
echo "session-inject proxy did not become ready after 10s; log: $LOG_DIR/proxy.log" >&2
cat "$LOG_DIR/proxy.log" >&2 || true
kill "$PROXY_PID" 2>/dev/null || true
exit 1
fi
CLIENT_BASE_URL="http://$HOST:$PROXY_PORT"
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
echo "[run_claude_code] session-inject proxy up on $CLIENT_BASE_URL (session=$PFLASH_SESSION_ID)"
fi

set +e
HOME="$HOME_DIR" \
ANTHROPIC_API_KEY="$API_KEY" \
ANTHROPIC_BASE_URL="$BASE_URL" \
CLAUDE_CODE_API_BASE_URL="$BASE_URL" \
ANTHROPIC_BASE_URL="$CLIENT_BASE_URL" \
CLAUDE_CODE_API_BASE_URL="$CLIENT_BASE_URL" \
CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1 \
CLAUDE_CODE_DISABLE_TELEMETRY=1 \
CLAUDE_CODE_DISABLE_NONSTREAMING_FALLBACK=1 \
Expand All @@ -42,5 +76,10 @@ timeout "${CLAUDE_TIMEOUT}s" "$CLAUDE_BIN" \
RC=$?
set -e

if [[ -n "$PROXY_PID" ]] && kill -0 "$PROXY_PID" 2>/dev/null; then
kill "$PROXY_PID" 2>/dev/null || true
wait "$PROXY_PID" 2>/dev/null || true
fi

finish_report "$CLIENT_OUT" "$RC"
exit "$RC"
144 changes: 144 additions & 0 deletions harness/clients/session_inject_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#!/usr/bin/env python3
"""Thin proxy that injects extra_body.session_id into /v1/messages requests.

Run between the claude CLI and the dflash server when PFLASH_SESSION_ID is set.
All other paths and methods are forwarded verbatim.

Usage:
python3 session_inject_proxy.py \\
--host 127.0.0.1 --port 18081 \\
--upstream http://127.0.0.1:18080 \\
--session-id <id>

The proxy listens on --port and forwards to --upstream, injecting
extra_body.session_id on every POST /v1/messages request.
"""

from __future__ import annotations

import argparse
import json
import os
import socket
import threading
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from urllib.parse import urlparse
import http.client


class Handler(BaseHTTPRequestHandler):
upstream: str = ""
session_id: str = ""

def log_message(self, fmt, *args):
print("[session-proxy] %s" % (fmt % args), flush=True)

def _upstream_conn(self) -> tuple[http.client.HTTPConnection, str]:
url = urlparse(self.upstream)
port = url.port or (443 if url.scheme == "https" else 80)
cls = http.client.HTTPSConnection if url.scheme == "https" else http.client.HTTPConnection
return cls(url.hostname, port, timeout=900), url.path.rstrip("/")

def _forward_raw(self, body: bytes):
"""Forward request verbatim (no injection needed)."""
conn, base = self._upstream_conn()
headers = {
k: v for k, v in self.headers.items()
if k.lower() not in ("host", "content-length", "transfer-encoding")
}
headers["Content-Length"] = str(len(body))
conn.request(self.command, base + self.path, body, headers)
resp = conn.getresponse()
self._relay_response(resp)

def _relay_response(self, resp: http.client.HTTPResponse):
"""Relay upstream response back to client, handling SSE streaming."""
content_type = resp.getheader("Content-Type", "")
is_sse = "text/event-stream" in content_type

self.send_response(resp.status)
skip_headers = {"transfer-encoding", "content-length"}
for k, v in resp.getheaders():
if k.lower() not in skip_headers:
self.send_header(k, v)

if is_sse:
self.send_header("Transfer-Encoding", "chunked")
self.end_headers()
# Stream chunk by chunk
while True:
chunk = resp.read(4096)
if not chunk:
# Write terminal chunk
self.wfile.write(b"0\r\n\r\n")
self.wfile.flush()
break
size = "%X\r\n" % len(chunk)
self.wfile.write(size.encode("ascii"))
self.wfile.write(chunk)
self.wfile.write(b"\r\n")
self.wfile.flush()
else:
data = resp.read()
self.send_header("Content-Length", str(len(data)))
self.end_headers()
self.wfile.write(data)

def _read_body(self) -> bytes:
n = int(self.headers.get("Content-Length", "0"))
if n <= 0:
return b""
return self.rfile.read(n)

def do_GET(self):
conn, base = self._upstream_conn()
headers = {k: v for k, v in self.headers.items() if k.lower() != "host"}
conn.request("GET", base + self.path, None, headers)
resp = conn.getresponse()
self._relay_response(resp)

def do_POST(self):
body = self._read_body()
path = self.path

# Inject session_id only on /v1/messages
if self.session_id and path.startswith("/v1/messages"):
try:
obj = json.loads(body.decode("utf-8"))
if "extra_body" not in obj:
obj["extra_body"] = {}
if "session_id" not in obj["extra_body"]:
obj["extra_body"]["session_id"] = self.session_id
body = json.dumps(obj).encode("utf-8")
except Exception as exc:
print(f"[session-proxy] JSON parse error, forwarding raw: {exc}", flush=True)

self._forward_raw(body)


def main():
ap = argparse.ArgumentParser()
ap.add_argument("--host", default="127.0.0.1")
ap.add_argument("--port", type=int, default=18081)
ap.add_argument("--upstream", default="http://127.0.0.1:18080")
ap.add_argument("--session-id", default=os.environ.get("PFLASH_SESSION_ID", ""))
args = ap.parse_args()

if not args.session_id:
print("[session-proxy] WARNING: no session_id set; proxy is pass-through only", flush=True)

Handler.upstream = args.upstream.rstrip("/")
Handler.session_id = args.session_id

srv = ThreadingHTTPServer((args.host, args.port), Handler)
print(
f"[session-proxy] listening on http://{args.host}:{args.port} "
f"-> {Handler.upstream} "
f"(session_id={Handler.session_id!r})",
flush=True,
)
srv.serve_forever()


if __name__ == "__main__":
main()
12 changes: 12 additions & 0 deletions server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,18 @@ if(DFLASH27B_TESTS)
target_include_directories(test_gguf_mmap PRIVATE ${DFLASH27B_SRC_INCLUDE_DIRS})
target_link_libraries(test_gguf_mmap PRIVATE dflash_common)
endif()
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_adaptive_keep_ratio.cpp")
add_executable(test_adaptive_keep_ratio test/test_adaptive_keep_ratio.cpp)
target_include_directories(test_adaptive_keep_ratio PRIVATE ${DFLASH27B_SRC_INCLUDE_DIRS})
target_link_libraries(test_adaptive_keep_ratio PRIVATE dflash_common)
add_test(NAME adaptive_keep COMMAND test_adaptive_keep_ratio)
endif()
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_bandit_integration.cpp")
add_executable(test_bandit_integration test/test_bandit_integration.cpp)
target_include_directories(test_bandit_integration PRIVATE ${DFLASH27B_SRC_INCLUDE_DIRS})
target_link_libraries(test_bandit_integration PRIVATE dflash_common)
add_test(NAME bandit_integration COMMAND test_bandit_integration)
endif()
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_draft_vs_reference.cpp")
add_executable(test_draft_vs_reference test/test_draft_vs_reference.cpp)
target_link_libraries(test_draft_vs_reference PRIVATE dflash_common)
Expand Down
5 changes: 5 additions & 0 deletions server/src/common/model_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ struct GenerateResult {
// can mark the answer as unreliable rather than treating the
// (truncated) content as a clean response.
bool degenerate_decode_close = false;
// DFlash chain accept rate: accepted_draft_tokens / total_draft_positions.
// 0.0 when spec decode did not run (AR fallback or no draft model).
float accept_rate = 0.0f;
// True when spec decode actually ran (accept_rate==0 still needs a bandit update).
bool spec_decode_ran = false;
};

// ─── Backend interface ──────────────────────────────────────────────────
Expand Down
9 changes: 9 additions & 0 deletions server/src/qwen35/qwen35_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ GenerateResult Qwen35Backend::generate(const GenerateRequest & req,
// generation. Most requests never hit the tail because the
// model closes </think> naturally well before the budget edge.
if (!do_spec_decode(committed, req.n_gen, result.tokens, out_io,
result.accept_rate, result.spec_decode_ran,
req.hint_tokens, &req.budget_hook,
&result.budget_forced_close,
&result.degenerate_decode_close)) {
Expand Down Expand Up @@ -648,6 +649,7 @@ GenerateResult Qwen35Backend::restore_and_generate(int slot,
// generation. Most requests never hit the tail because the
// model closes </think> naturally well before the budget edge.
if (!do_spec_decode(committed, req.n_gen, result.tokens, out_io,
result.accept_rate, result.spec_decode_ran,
req.hint_tokens, &req.budget_hook,
&result.budget_forced_close,
&result.degenerate_decode_close)) {
Expand Down Expand Up @@ -1072,10 +1074,14 @@ bool Qwen35Backend::sync_remote_draft_features(int start_pos, int n_tokens) {
bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
std::vector<int32_t> & out_tokens,
const DaemonIO & io,
float & out_accept_rate,
bool & out_spec_ran,
const std::vector<int32_t> * hint_tokens,
const BudgetHook * budget_hook,
bool * forced_close_out,
bool * degenerate_close_out) {
out_accept_rate = 0.0f;
out_spec_ran = false;
const int hidden = w_.n_embd;

// First token: use the argmax that do_prefill already sampled and stored.
Expand Down Expand Up @@ -1108,6 +1114,8 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
return ok;
}

out_spec_ran = true;

// ── DFlash spec-decode: draft → verify → accept → replay ──────────

DFlashTarget * target = dflash_target();
Expand Down Expand Up @@ -1349,6 +1357,7 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
const double decode_s = std::chrono::duration<double>(t_dec1 - t_dec0).count();
const int total_draft_pos = std::max(1, n_draft_steps * q_len);
const double accept_pct = 100.0 * (double)n_accept_sum / (double)total_draft_pos;
out_accept_rate = (float)((double)n_accept_sum / (double)total_draft_pos);
std::fprintf(stderr, "[spec-decode] tokens=%d time=%.3f s speed=%.2f tok/s "
"steps=%d accepted=%d/%d (%.1f%%) avg_commit=%.2f\n",
n_generated, decode_s,
Expand Down
4 changes: 4 additions & 0 deletions server/src/qwen35/qwen35_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,13 @@ class Qwen35Backend : public ModelBackend {
// close token mid-batch (verify-and-accept assumes the sampled
// tokens are the ones that got committed), so the boundary switch
// is the simplest correct integration.
// out_accept_rate receives accepted/total draft token ratio (0.0 if AR fallback).
// out_spec_ran is true when spec decode actually ran (even with 0 accepts).
bool do_spec_decode(int committed, int n_gen,
std::vector<int32_t> & out_tokens,
const DaemonIO & io,
float & out_accept_rate,
bool & out_spec_ran,
const std::vector<int32_t> * hint_tokens = nullptr,
const BudgetHook * budget_hook = nullptr,
bool * forced_close_out = nullptr,
Expand Down
Loading
Loading