Summary
src/prime_rl/inference/patches.py:1029-1037 (PR #2441) computes fp32 logits and slices the padded vocab dim:
logits = torch.mm(flat, lm_head.weight.t(), out_dtype=torch.float32)
...
logits = logits[..., : self.org_vocab_size]
When padded_vocab > org_vocab_size, the slice returns a non-contiguous view with stride (padded_vocab, 1). vLLM's native Triton top-k/top-p kernel (_topk_topp_kernel) indexes rows as row_id * VOCAB_SIZE, not row_id * stride(0). The kernel reads the wrong physical row, can mask a logical row to all -inf, and processed_logprobs then computes log_softmax(all -inf) = NaN.
Filing upstream vLLM separately for the kernel-side fix. PrimeRL should also guard at this boundary because other logits processors can produce non-contiguous views too, and the upstream fix may not land soon.
Observed in production
- vLLM 0.20.1, V1 engine,
logprobs_mode=processed_logprobs
- Olmo3 (org vocab 100278, padded 100288), bf16 model, fp32 lm-head
- chunked prefill, async scheduling, high-concurrency multi-replica serving
top_p=0.95, no top-k
Captured at TopKTopPSampler.forward_native:
{
"shape": [192, 100278],
"stride": [100288, 1],
"is_contiguous": false,
"pre_bad_finite": [100278],
"post_all_neginf": [true],
"post_neginf_count": [100278]
}
Pre-top-p row was finite fp32. Post-top-p row was all -inf. JSON serialization then failed with ValueError: Out of range float values are not JSON compliant: nan.
Controls
| Config |
Out of range |
BadRequest |
OUTPUT_NAN |
| fp32 lm-head + native Triton (default) |
479 |
478 |
many |
| fp32 lm-head + PyTorch top-p/top-k |
0 |
0 |
0 |
fp32 lm-head + logits.contiguous() before Triton |
0 |
0 |
0 |
bf16 lm-head did not reproduce over one comparable canary. Not proof bf16 is safe; it implicates the fp32 path as the layout source.
Observer confirmed the contiguous mitigation did real work:
FORCE_CONTIGUOUS shape=(192, 100278) stride=(100288, 1)
empty_from_nonempty=0 out_nan_rows=0
Suggested fix
One-line guard on the fp32 lm-head path:
logits = logits[..., : self.org_vocab_size]
if not logits.is_contiguous():
logits = logits.contiguous()
Doesn't change logits values. Survived repeated live canaries. Cost: one [batch, vocab] fp32 copy per step.
Cross-references
Summary
src/prime_rl/inference/patches.py:1029-1037(PR #2441) computes fp32 logits and slices the padded vocab dim:When
padded_vocab > org_vocab_size, the slice returns a non-contiguous view with stride(padded_vocab, 1). vLLM's native Triton top-k/top-p kernel (_topk_topp_kernel) indexes rows asrow_id * VOCAB_SIZE, notrow_id * stride(0). The kernel reads the wrong physical row, can mask a logical row to all-inf, andprocessed_logprobsthen computeslog_softmax(all -inf) = NaN.Filing upstream vLLM separately for the kernel-side fix. PrimeRL should also guard at this boundary because other logits processors can produce non-contiguous views too, and the upstream fix may not land soon.
Observed in production
logprobs_mode=processed_logprobstop_p=0.95, no top-kCaptured at
TopKTopPSampler.forward_native:{ "shape": [192, 100278], "stride": [100288, 1], "is_contiguous": false, "pre_bad_finite": [100278], "post_all_neginf": [true], "post_neginf_count": [100278] }Pre-top-p row was finite fp32. Post-top-p row was all
-inf. JSON serialization then failed withValueError: Out of range float values are not JSON compliant: nan.Controls
logits.contiguous()before Tritonbf16 lm-head did not reproduce over one comparable canary. Not proof bf16 is safe; it implicates the fp32 path as the layout source.
Observer confirmed the contiguous mitigation did real work:
Suggested fix
One-line guard on the fp32 lm-head path:
Doesn't change logits values. Survived repeated live canaries. Cost: one
[batch, vocab]fp32 copy per step.Cross-references