Skip to content

Custom topk-logprobs kernel + Remove a redundant Device to Host logits copy.#434

Open
pramodith wants to merge 2 commits into
Luce-Org:mainfrom
GeometricAGI:geometric-ai/optimize_topk
Open

Custom topk-logprobs kernel + Remove a redundant Device to Host logits copy.#434
pramodith wants to merge 2 commits into
Luce-Org:mainfrom
GeometricAGI:geometric-ai/optimize_topk

Conversation

@pramodith

@pramodith pramodith commented Jun 22, 2026

Copy link
Copy Markdown

What does the PR do?

The extraction of top-k vocab indices from the logit scores of the draft model currently happens on the cpu. This involves a costly transfer of the entire logit scores from the GPU to CPU i.e. Device to Host. In this PR we move the top-k computation to the GPU entirely.

This PR also adds a verify-argmax fastpath that skips the
logits transfer from Device to Host and re-uses the argmax indices stored as an attribute. It can be toggled via DFLASH_GPU_VERIFY_ARGMAX. It re-uses the same CPU fallback on any out-of-range index.

A comment stated that previous builds had issues with GPU based verify-argmax, this PR adds a check to detect failures and resort to using the prior CPU method for running the top-k computation in this scenario.

The GPU kernel and verify-argmax is enabled by default but can be disabled via the DFLASH_GPU_DRAFT_TOPK and DFLASH_GPU_VERIFY_ARGMAX flags.

Files edited

  • Minor bug fixes to bench_llm.py to use the right tokenizer + correct the naming conventions of datasets being used.
  • geometric_draft_topk_cuda.cu/h: Creates a new cuda kernel used by the draft model to identify the topk (where k is in the range 1-8) vocab indices from the logit scores.
  • qwen35_dflash_target.cpp: Invokes the new .cu kernel and falls back to the cpu execution of topk incase there are errors.
  • test_dflash.cpp: Invokes the .cu kernel if the DFLASH_GPU_DRAFT_TOPK flag is enabled. Uses the fast-path verification of draft tokens if DFLASH_GPU_VERIFY_ARGMAX=1, if it is set to 2 it identifies/reports any mismatches between cpu and gpu based verification.
  • test_draft_topk_cuda.cpp: Correctness tests for the cuda kernel by comparing it against the previously used cpu function.
  • CMakeLists.txt: Register the new cuda kernel.

Results

Results in an ~27% increase in tok/s. (105.63-82.80)/(82.8)
Numbers can be reproduced by running bench_llm.py after re-building with the changes.
The benchmark was run on an RTX 3090.

Results before

[bench] ==== HumanEval (n=10, n_gen=256) ====
  [01/10] n_tok=  92  AR= 33.96  DFlash=  98.69  AL= 6.74
  [02/10] n_tok= 146  AR= 33.93  DFlash=  76.72  AL= 5.33
  [03/10] n_tok= 142  AR= 33.55  DFlash=  99.51  AL= 7.11
  [04/10] n_tok= 128  AR= 34.39  DFlash=  75.50  AL= 5.57
  [05/10] n_tok= 180  AR= 33.09  DFlash=  87.92  AL= 6.74
  [06/10] n_tok= 126  AR= 33.31  DFlash=  71.80  AL= 5.45
  [07/10] n_tok=  59  AR= 33.48  DFlash=  71.72  AL= 5.45
  [08/10] n_tok= 149  AR= 33.33  DFlash=  75.89  AL= 5.69
  [09/10] n_tok= 133  AR= 33.19  DFlash=  80.56  AL= 6.10
  [10/10] n_tok= 103  AR= 33.10  DFlash=  89.65  AL= 6.56
  HumanEval mean: AR=33.53  DFlash=82.80  AL=6.07  2.47x

[bench] === SUMMARY ===
Task                AR    DFlash      AL   Speedup     Score
HumanEval        33.53     82.80    6.07     2.47x          

Results after

[bench] ==== HumanEval (n=10, n_gen=256) ====
  [01/10] n_tok=  92  AR= 33.66  DFlash= 118.42  AL= 6.74
  [02/10] n_tok= 146  AR= 34.09  DFlash=  93.11  AL= 5.33
  [03/10] n_tok= 142  AR= 34.33  DFlash= 123.95  AL= 7.11
  [04/10] n_tok= 128  AR= 33.73  DFlash=  97.60  AL= 5.57
  [05/10] n_tok= 180  AR= 34.15  DFlash= 117.04  AL= 6.74
  [06/10] n_tok= 126  AR= 34.01  DFlash=  94.94  AL= 5.45
  [07/10] n_tok=  59  AR= 34.60  DFlash=  94.64  AL= 5.45
  [08/10] n_tok= 149  AR= 33.02  DFlash=  97.45  AL= 5.69
  [09/10] n_tok= 133  AR= 33.90  DFlash= 104.93  AL= 6.10
  [10/10] n_tok= 103  AR= 34.06  DFlash= 114.25  AL= 6.56
  HumanEval mean: AR=33.95  DFlash=105.63  AL=6.07  3.11x

[bench] === SUMMARY ===
Task                AR    DFlash      AL   Speedup     Score
HumanEval        33.95    105.63    6.07     3.11x          

Reproducing Results

Baseline: DFLASH_GPU_DRAFT_TOPK=0 DFLASH_GPU_VERIFY_ARGMAX=0 python server/scripts/bench_llm.py --bench HumanEval

Latest Results: DFLASH_GPU_DRAFT_TOPK=1 DFLASH_GPU_VERIFY_ARGMAX=1 python server/scripts/bench_llm.py --bench HumanEval

Review in cubic

Move the DDTree draft top-K + log-prob extraction off the CPU and onto the
GPU, operating directly on the draft logits' device buffer instead of D2H-ing
the full [vocab x n_positions] logits and running the OpenMP heap top-K. New
split-K kernel (draft_topk_cuda.cu) with register-resident per-thread top-K
and online-logsumexp; ~10x faster kernel, ~10% end-to-end tok/s.

Also add a GPU verify-argmax shortcut: read the in-graph batched per-node
argmax (sg_.argmax_tokens) to skip the verify-step vocab x N logits D2H + CPU
argmax, validating each row and falling back to CPU argmax on any bad index.

Both paths fall back to the existing CPU code and are runtime-toggleable:
  - DFLASH_GPU_DRAFT_TOPK=0     disables the GPU top-K
  - DFLASH_GPU_VERIFY_ARGMAX=0  forces the legacy CPU verify argmax
The GPU top-K is compiled in for CUDA builds (DFLASH27B_HAVE_DRAFT_TOPK_CUDA).

Applied in both the library (qwen35_dflash_target.cpp) and the standalone
end-to-end harness (test_dflash.cpp), which keep separate copies of the
decode/verify loop. Adds test_draft_topk_cuda.cpp checking the kernel
bit-for-bit against the CPU reference, registered in CMake.

bench_llm.py: fix tokenizer (Qwen3.6-27B) and dataset names (openai/...).

@cubic-dev-ai cubic-dev-ai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 issues found across 7 files

Prompt for AI agents (unresolved issues)

Check if these issues are valid — if so, understand the root cause of each and fix them. If appropriate, use sub-agents to investigate and fix each issue separately.


<file name="server/src/common/geometric_draft_topk_cuda.h">

<violation number="1" location="server/src/common/geometric_draft_topk_cuda.h:28">
P2: Failure-signaling GPU top-k API should be marked `[[nodiscard]]` so callers cannot silently ignore fallback-critical errors.</violation>
</file>

<file name="server/src/qwen35/qwen35_dflash_target.cpp">

<violation number="1" location="server/src/qwen35/qwen35_dflash_target.cpp:308">
P2: GPU verify-argmax fast path returns early after only an in-range bounds check, skipping CPU verification that would catch silent in-range argmax mismatches.</violation>
</file>

Reply with feedback, questions, or to request a fix.

Re-trigger cubic

// d_logits: device pointer to row-major [n_positions][vocab] f32 logits (the
// position stride is `vocab` floats — pass an offset pointer to skip
// leading positions). out_* are HOST buffers of size n_positions*K.
bool geometric_extract_draft_topk_cuda(const void * d_logits,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: Failure-signaling GPU top-k API should be marked [[nodiscard]] so callers cannot silently ignore fallback-critical errors.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At server/src/common/geometric_draft_topk_cuda.h, line 28:

<comment>Failure-signaling GPU top-k API should be marked `[[nodiscard]]` so callers cannot silently ignore fallback-critical errors.</comment>

<file context>
@@ -0,0 +1,34 @@
+// d_logits: device pointer to row-major [n_positions][vocab] f32 logits (the
+//           position stride is `vocab` floats — pass an offset pointer to skip
+//           leading positions). out_* are HOST buffers of size n_positions*K.
+bool geometric_extract_draft_topk_cuda(const void * d_logits,
+                             int n_positions, int vocab, int K,
+                             float * out_log_probs,
</file context>

for (int i = 0; i < N_actual; i++) {
if (posterior_out[i] < 0 || posterior_out[i] >= vocab) { ok = false; break; }
}
if (ok) return true; // fast path; otherwise fall through to CPU argmax

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: GPU verify-argmax fast path returns early after only an in-range bounds check, skipping CPU verification that would catch silent in-range argmax mismatches.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At server/src/qwen35/qwen35_dflash_target.cpp, line 308:

<comment>GPU verify-argmax fast path returns early after only an in-range bounds check, skipping CPU verification that would catch silent in-range argmax mismatches.</comment>

<file context>
@@ -278,13 +279,38 @@ bool Qwen35DFlashTarget::verify_tree(
+        for (int i = 0; i < N_actual; i++) {
+            if (posterior_out[i] < 0 || posterior_out[i] >= vocab) { ok = false; break; }
+        }
+        if (ok) return true;  // fast path; otherwise fall through to CPU argmax
+    }
+
</file context>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant