🚀 The feature, motivation and pitch
In some cases (eg. cudagraph capture for attention), it is important to determine the number of prefill and decode requests in the batch. This information is also useful in launching kernels specialized for prefill and decode. To extract this information from a batch requires operations that are usually not compatible with graph capture. eg:
prefill_mask = seq_len > 1
num_prefill = int(prefill_mask.sum().item())
Instead, the logic to compute number of prefill and decode requests can be executed ahead of graph launch and stored in sequence info object.
Alternatives
No response
Additional context
No response
Before submitting a new issue...