Custom topk-logprobs kernel + Remove a redundant Device to Host logits copy.#434
Open
pramodith wants to merge 2 commits into
Open
Custom topk-logprobs kernel + Remove a redundant Device to Host logits copy.#434pramodith wants to merge 2 commits into
pramodith wants to merge 2 commits into
Conversation
Move the DDTree draft top-K + log-prob extraction off the CPU and onto the GPU, operating directly on the draft logits' device buffer instead of D2H-ing the full [vocab x n_positions] logits and running the OpenMP heap top-K. New split-K kernel (draft_topk_cuda.cu) with register-resident per-thread top-K and online-logsumexp; ~10x faster kernel, ~10% end-to-end tok/s. Also add a GPU verify-argmax shortcut: read the in-graph batched per-node argmax (sg_.argmax_tokens) to skip the verify-step vocab x N logits D2H + CPU argmax, validating each row and falling back to CPU argmax on any bad index. Both paths fall back to the existing CPU code and are runtime-toggleable: - DFLASH_GPU_DRAFT_TOPK=0 disables the GPU top-K - DFLASH_GPU_VERIFY_ARGMAX=0 forces the legacy CPU verify argmax The GPU top-K is compiled in for CUDA builds (DFLASH27B_HAVE_DRAFT_TOPK_CUDA). Applied in both the library (qwen35_dflash_target.cpp) and the standalone end-to-end harness (test_dflash.cpp), which keep separate copies of the decode/verify loop. Adds test_draft_topk_cuda.cpp checking the kernel bit-for-bit against the CPU reference, registered in CMake. bench_llm.py: fix tokenizer (Qwen3.6-27B) and dataset names (openai/...).
Contributor
There was a problem hiding this comment.
2 issues found across 7 files
Prompt for AI agents (unresolved issues)
Check if these issues are valid — if so, understand the root cause of each and fix them. If appropriate, use sub-agents to investigate and fix each issue separately.
<file name="server/src/common/geometric_draft_topk_cuda.h">
<violation number="1" location="server/src/common/geometric_draft_topk_cuda.h:28">
P2: Failure-signaling GPU top-k API should be marked `[[nodiscard]]` so callers cannot silently ignore fallback-critical errors.</violation>
</file>
<file name="server/src/qwen35/qwen35_dflash_target.cpp">
<violation number="1" location="server/src/qwen35/qwen35_dflash_target.cpp:308">
P2: GPU verify-argmax fast path returns early after only an in-range bounds check, skipping CPU verification that would catch silent in-range argmax mismatches.</violation>
</file>
Reply with feedback, questions, or to request a fix.
Re-trigger cubic
| // d_logits: device pointer to row-major [n_positions][vocab] f32 logits (the | ||
| // position stride is `vocab` floats — pass an offset pointer to skip | ||
| // leading positions). out_* are HOST buffers of size n_positions*K. | ||
| bool geometric_extract_draft_topk_cuda(const void * d_logits, |
Contributor
There was a problem hiding this comment.
P2: Failure-signaling GPU top-k API should be marked [[nodiscard]] so callers cannot silently ignore fallback-critical errors.
Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At server/src/common/geometric_draft_topk_cuda.h, line 28:
<comment>Failure-signaling GPU top-k API should be marked `[[nodiscard]]` so callers cannot silently ignore fallback-critical errors.</comment>
<file context>
@@ -0,0 +1,34 @@
+// d_logits: device pointer to row-major [n_positions][vocab] f32 logits (the
+// position stride is `vocab` floats — pass an offset pointer to skip
+// leading positions). out_* are HOST buffers of size n_positions*K.
+bool geometric_extract_draft_topk_cuda(const void * d_logits,
+ int n_positions, int vocab, int K,
+ float * out_log_probs,
</file context>
| for (int i = 0; i < N_actual; i++) { | ||
| if (posterior_out[i] < 0 || posterior_out[i] >= vocab) { ok = false; break; } | ||
| } | ||
| if (ok) return true; // fast path; otherwise fall through to CPU argmax |
Contributor
There was a problem hiding this comment.
P2: GPU verify-argmax fast path returns early after only an in-range bounds check, skipping CPU verification that would catch silent in-range argmax mismatches.
Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At server/src/qwen35/qwen35_dflash_target.cpp, line 308:
<comment>GPU verify-argmax fast path returns early after only an in-range bounds check, skipping CPU verification that would catch silent in-range argmax mismatches.</comment>
<file context>
@@ -278,13 +279,38 @@ bool Qwen35DFlashTarget::verify_tree(
+ for (int i = 0; i < N_actual; i++) {
+ if (posterior_out[i] < 0 || posterior_out[i] >= vocab) { ok = false; break; }
+ }
+ if (ok) return true; // fast path; otherwise fall through to CPU argmax
+ }
+
</file context>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does the PR do?
The extraction of top-k vocab indices from the logit scores of the draft model currently happens on the cpu. This involves a costly transfer of the entire logit scores from the GPU to CPU i.e. Device to Host. In this PR we move the top-k computation to the GPU entirely.
This PR also adds a verify-argmax fastpath that skips the
logits transfer from Device to Host and re-uses the argmax indices stored as an attribute. It can be toggled via DFLASH_GPU_VERIFY_ARGMAX. It re-uses the same CPU fallback on any out-of-range index.
A comment stated that previous builds had issues with GPU based verify-argmax, this PR adds a check to detect failures and resort to using the prior CPU method for running the top-k computation in this scenario.
The GPU kernel and verify-argmax is enabled by default but can be disabled via the
DFLASH_GPU_DRAFT_TOPKandDFLASH_GPU_VERIFY_ARGMAXflags.Files edited
bench_llm.pyto use the right tokenizer + correct the naming conventions of datasets being used.geometric_draft_topk_cuda.cu/h: Creates a new cuda kernel used by the draft model to identify the topk (where k is in the range 1-8) vocab indices from the logit scores.qwen35_dflash_target.cpp: Invokes the new.cukernel and falls back to the cpu execution of topk incase there are errors.test_dflash.cpp: Invokes the.cukernel if theDFLASH_GPU_DRAFT_TOPKflag is enabled. Uses the fast-path verification of draft tokens ifDFLASH_GPU_VERIFY_ARGMAX=1, if it is set to 2 it identifies/reports any mismatches between cpu and gpu based verification.test_draft_topk_cuda.cpp: Correctness tests for the cuda kernel by comparing it against the previously used cpu function.CMakeLists.txt: Register the new cuda kernel.Results
Results in an ~27% increase in tok/s. (105.63-82.80)/(82.8)
Numbers can be reproduced by running bench_llm.py after re-building with the changes.
The benchmark was run on an RTX 3090.
Results before
Results after
Reproducing Results
Baseline:
DFLASH_GPU_DRAFT_TOPK=0 DFLASH_GPU_VERIFY_ARGMAX=0 python server/scripts/bench_llm.py --bench HumanEvalLatest Results:
DFLASH_GPU_DRAFT_TOPK=1 DFLASH_GPU_VERIFY_ARGMAX=1 python server/scripts/bench_llm.py --bench HumanEval