Skip to content

fp32 lm-head returns non-contiguous logits, triggers NaN in vLLM processed_logprobs #2497

@joanvelja

Description

@joanvelja

Summary

src/prime_rl/inference/patches.py:1029-1037 (PR #2441) computes fp32 logits and slices the padded vocab dim:

logits = torch.mm(flat, lm_head.weight.t(), out_dtype=torch.float32)
...
logits = logits[..., : self.org_vocab_size]

When padded_vocab > org_vocab_size, the slice returns a non-contiguous view with stride (padded_vocab, 1). vLLM's native Triton top-k/top-p kernel (_topk_topp_kernel) indexes rows as row_id * VOCAB_SIZE, not row_id * stride(0). The kernel reads the wrong physical row, can mask a logical row to all -inf, and processed_logprobs then computes log_softmax(all -inf) = NaN.

Filing upstream vLLM separately for the kernel-side fix. PrimeRL should also guard at this boundary because other logits processors can produce non-contiguous views too, and the upstream fix may not land soon.

Observed in production

  • vLLM 0.20.1, V1 engine, logprobs_mode=processed_logprobs
  • Olmo3 (org vocab 100278, padded 100288), bf16 model, fp32 lm-head
  • chunked prefill, async scheduling, high-concurrency multi-replica serving
  • top_p=0.95, no top-k

Captured at TopKTopPSampler.forward_native:

{
  "shape": [192, 100278],
  "stride": [100288, 1],
  "is_contiguous": false,
  "pre_bad_finite": [100278],
  "post_all_neginf": [true],
  "post_neginf_count": [100278]
}

Pre-top-p row was finite fp32. Post-top-p row was all -inf. JSON serialization then failed with ValueError: Out of range float values are not JSON compliant: nan.

Controls

Config Out of range BadRequest OUTPUT_NAN
fp32 lm-head + native Triton (default) 479 478 many
fp32 lm-head + PyTorch top-p/top-k 0 0 0
fp32 lm-head + logits.contiguous() before Triton 0 0 0

bf16 lm-head did not reproduce over one comparable canary. Not proof bf16 is safe; it implicates the fp32 path as the layout source.

Observer confirmed the contiguous mitigation did real work:

FORCE_CONTIGUOUS shape=(192, 100278) stride=(100288, 1)
empty_from_nonempty=0 out_nan_rows=0

Suggested fix

One-line guard on the fp32 lm-head path:

logits = logits[..., : self.org_vocab_size]
if not logits.is_contiguous():
    logits = logits.contiguous()

Doesn't change logits values. Survived repeated live canaries. Cost: one [batch, vocab] fp32 copy per step.

Cross-references

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions