Skip to content

Commit 7c8ba71

Browse files
authored
[TRTLLM-8832][feat] fully async _select_generated_logits with tests (#8628)
Signed-off-by: ixlmar <[email protected]>
1 parent 4fd5813 commit 7c8ba71

File tree

7 files changed

+495
-17
lines changed

7 files changed

+495
-17
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,7 @@ def _select_generated_logits(
13661366
req_num_generation_steps: torch.Tensor,
13671367
num_context_logits_prefix_sum: list[int],
13681368
generation_requests_total_steps: int,
1369+
num_logits_to_keep: int,
13691370
) -> torch.Tensor:
13701371
# raw_logits should contain only the generated logits.
13711372
# If return context logits is requested, select only the generated logits.
@@ -1394,9 +1395,10 @@ def _select_generated_logits(
13941395
req_num_steps_fictitious_cuda = req_num_generation_steps_cuda[
13951396
: (len(scheduled_requests.context_requests) + 1)
13961397
].clone()
1397-
req_num_steps_fictitious_cuda[-1] = generation_requests_total_steps
1398-
next_context_req_offsets_cuda[-1] = (
1399-
next_context_req_offsets_cuda[-2] + req_num_steps_fictitious_cuda[-1]
1398+
req_num_steps_fictitious_cuda[-1].fill_(generation_requests_total_steps)
1399+
next_context_req_offsets_cuda[-1].copy_(
1400+
next_context_req_offsets_cuda[-2] + req_num_steps_fictitious_cuda[-1],
1401+
non_blocking=True,
14001402
)
14011403
else:
14021404
req_num_steps_fictitious_cuda = req_num_generation_steps_cuda[
@@ -1412,6 +1414,7 @@ def _select_generated_logits(
14121414
indices_to_keep_cuda = torch_multi_arange(
14131415
starts=(next_context_req_offsets_cuda - req_num_steps_fictitious_cuda),
14141416
ends=next_context_req_offsets_cuda,
1417+
output_length=num_logits_to_keep,
14151418
)
14161419

14171420
raw_logits_cuda = raw_logits_cuda[indices_to_keep_cuda]
@@ -1455,6 +1458,7 @@ def _process_requests(
14551458
if scheduled_requests.generation_requests
14561459
else 0
14571460
),
1461+
num_logits_to_keep=sum_steps,
14581462
)
14591463

14601464
# Handle embedding bias

tensorrt_llm/_torch/pyexecutor/sampling_utils.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -343,25 +343,43 @@ def sample_grouped_strategies(
343343
)
344344

345345

346+
class _AcceptSyncCompute:
347+
pass
348+
349+
350+
ACCEPT_SYNC_COMPUTE = _AcceptSyncCompute()
351+
352+
346353
# Inspired by https://github.com/pytorch/pytorch/issues/80577; note also the
347354
# suggestion to consider torch.nested.
348355
def torch_multi_arange(
349356
ends: torch.Tensor,
350357
*,
358+
output_length: int | _AcceptSyncCompute,
351359
starts: Optional[torch.Tensor] = None,
352360
steps: Optional[torch.Tensor] = None,
353361
) -> torch.Tensor:
354362
"""Efficiently compute torch.cat([torch.arange(b, e, d) for b, e, d in zip(starts, ends, steps)]).
355363
356364
Starts, ends, steps need to share dtype and shape. Invalid ranges like range(1, 2, -1) are
357365
silently discarded. 'steps' defaults to 1 and 'starts' defaults to 0.
366+
367+
Provide 'output_length' to avoid synchronization when using device tensors or pass
368+
`ACCEPT_SYNC_COMPUTE` to explicitly accept the possibility of a device sync (for device tensors)
369+
or when tensors are known to reside on the host.
358370
"""
359371
if steps is not None:
360372
assert ends.dtype == steps.dtype
361373
assert ends.shape == steps.shape
374+
assert ends.device == steps.device
362375
if starts is not None:
363376
assert ends.dtype == starts.dtype
364377
assert ends.shape == starts.shape
378+
assert ends.device == starts.device
379+
output_length_arg = None if isinstance(output_length, _AcceptSyncCompute) else output_length
380+
381+
if ends.numel() == 0:
382+
return ends.clone()
365383

366384
# This algorithm combines torch.repeat_interleaved() and torch.cumsum() to
367385
# construct the result.
@@ -378,29 +396,37 @@ def torch_multi_arange(
378396
repeats = repeats.clone()
379397
repeats -= starts
380398
if steps is not None:
381-
repeats = (repeats + steps - 1).div(steps, rounding_mode="floor")
382-
repeats = repeats.clip(0) # ignore invalid ranges
399+
repeats *= steps.sign()
400+
steps_abs = steps.abs()
401+
repeats = (repeats + steps_abs - 1).div(steps_abs, rounding_mode="floor")
402+
repeats = repeats.clip(min=0) # ignore invalid ranges
383403
range_ends = repeats - 1 # last element in each range
384404
if steps is not None:
385405
range_ends *= steps
386406
if starts is not None:
387407
range_ends += starts
388408
prev_range_ends = range_ends.roll(1) # last element in preceding range (or 0)
389-
prev_range_ends[0] = 0
390-
ones = (
391-
torch.tensor(1, dtype=ends.dtype, pin_memory=True)
392-
.to(device=ends.device, non_blocking=True)
393-
.broadcast_to(ends.shape)
394-
)
409+
prev_range_ends[0].fill_(0)
410+
ones = torch.ones((), dtype=ends.dtype, device=ends.device)
411+
zeros = torch.zeros((), dtype=ends.dtype, device=ends.device)
395412
if steps is None:
396-
steps = ones
413+
steps = ones.broadcast_to(ends.shape)
397414
jumps = -prev_range_ends # delta from one range to the next
398415
if starts is not None:
399416
jumps += starts
417+
# NB: Apply correction for empty ranges
418+
jumps_corrections = torch.where(repeats == 0, jumps, zeros).cumsum(0, dtype=ends.dtype)
419+
jumps += jumps_corrections
400420
seq = torch.cat((jumps.unsqueeze(-1), steps.unsqueeze(-1)), dim=1).view(-1)
401421
#
402422
# 2. Construct output via torch.repeat_interleave() and torch.cumsum()
403-
seq_repeats = torch.cat((ones.unsqueeze(-1), (repeats - 1).unsqueeze(-1)), dim=1).view(-1)
404-
seq = seq.repeat_interleave(seq_repeats)
405-
seq = seq.cumsum(0)
423+
# NB: For a resulting empty range, repeats - 1 == -1. In this case, we
424+
# should set repeats for delta and increment both to 0 instead.
425+
jump_repeats = torch.where(repeats == 0, zeros, ones)
426+
step_repeats = torch.where(repeats == 0, zeros, repeats - 1)
427+
seq_repeats = torch.cat((jump_repeats.unsqueeze(-1), step_repeats.unsqueeze(-1)), dim=1).view(
428+
-1
429+
)
430+
seq = seq.repeat_interleave(seq_repeats, output_size=output_length_arg)
431+
seq = seq.cumsum(0, dtype=ends.dtype)
406432
return seq

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ l0_a10:
1515
tests:
1616
# ------------- PyTorch tests ---------------
1717
- unittest/_torch/sampler/test_torch_sampler.py
18+
- unittest/_torch/sampler/test_torch_multi_arange.py
19+
- unittest/utils/test_util.py
1820
- unittest/_torch/modeling/test_modeling_mistral.py
1921
- unittest/_torch/modeling/test_modeling_pixtral.py
2022
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from contextlib import nullcontext
16+
from itertools import product
17+
from typing import Iterable, Optional
18+
19+
import numpy as np
20+
import pytest
21+
import torch
22+
from utils.util import assert_no_cuda_sync, force_ampere
23+
24+
from tensorrt_llm._torch.pyexecutor.sampling_utils import (ACCEPT_SYNC_COMPUTE,
25+
torch_multi_arange)
26+
27+
BASE_CASES = [
28+
(None, [], None, []),
29+
([], [], None, []),
30+
(None, [], [], []),
31+
([], [], [], []),
32+
(None, [1], None, [0]),
33+
(None, [-1], None, []),
34+
(None, [3], None, [0, 1, 2]),
35+
(None, [-3], None, []),
36+
([-5], [-3], None, [-5, -4]),
37+
([-5], [-2], [2], [-5, -3]),
38+
([-5], [-1], [2], [-5, -3]),
39+
([-5], [-3], [3], [-5]),
40+
([-3], [-5], None, []),
41+
([-3], [-5], [-1], [-3, -4]),
42+
([-3], [-5], [-3], [-3]),
43+
([-3], [-5], [1], []),
44+
([-5], [-3], [-2], []),
45+
([-3], [2], None, [-3, -2, -1, 0, 1]),
46+
([-3], [2], [2], [-3, -1, 1]),
47+
([-3], [3], [2], [-3, -1, 1]),
48+
([2], [5], None, [2, 3, 4]),
49+
([2], [5], [2], [2, 4]),
50+
([2], [6], [2], [2, 4]),
51+
]
52+
53+
54+
def _build_multi_arange_case() -> tuple[Iterable, Iterable, Iterable, Iterable]:
55+
gen = np.random.default_rng(seed=42)
56+
cases = [
57+
BASE_CASES[i] for i in gen.choice(len(BASE_CASES), 128)
58+
if len(BASE_CASES[i][3]) > 0
59+
]
60+
starts = [
61+
val for case in cases
62+
for val in (case[0] if case[0] is not None else [0] * len(case[1]))
63+
]
64+
ends = [val for case in cases for val in case[1]]
65+
steps = [
66+
val for case in cases
67+
for val in (case[2] if case[2] is not None else [1] * len(case[1]))
68+
]
69+
expected = [val for case in cases for val in case[3]]
70+
return starts, ends, steps, expected
71+
72+
73+
@force_ampere
74+
@pytest.mark.parametrize(
75+
"device, allow_sync, dtype, starts, ends, steps, expected",
76+
[
77+
pytest.param(device, allow_sync, dtype, starts, ends, steps, expected)
78+
for (dtype,
79+
(starts, ends, steps, expected), device, allow_sync) in product(
80+
[
81+
torch.int32,
82+
torch.int64,
83+
],
84+
BASE_CASES + [_build_multi_arange_case()],
85+
[
86+
"cpu",
87+
"cuda",
88+
],
89+
[False, True],
90+
) if device == "cuda" or allow_sync
91+
],
92+
)
93+
def test_torch_multi_arange(
94+
device: str,
95+
allow_sync: bool,
96+
dtype: torch.dtype,
97+
starts: Optional[Iterable],
98+
ends: Iterable,
99+
steps: Optional[Iterable],
100+
expected: Iterable,
101+
):
102+
torch_device = torch.device(device)
103+
104+
def _make_tensor(data: Iterable) -> torch.Tensor:
105+
return torch.tensor(data, device=torch_device, dtype=dtype)
106+
107+
def _maybe_make_tensor(data: Optional[Iterable]) -> Optional[torch.Tensor]:
108+
if data is None:
109+
return None
110+
return _make_tensor(data)
111+
112+
starts_tensor = _maybe_make_tensor(starts)
113+
ends_tensor = _make_tensor(ends)
114+
steps_tensor = _maybe_make_tensor(steps)
115+
expected_tensor = _make_tensor(expected)
116+
117+
extra_args = {}
118+
extra_args["output_length"] = ACCEPT_SYNC_COMPUTE
119+
if device != "cpu":
120+
# Pre-allocates a large chunk of memory, because PyTorch caching memory allocator
121+
# can sync otherwise.
122+
buf = torch.ones((2**30, ), device=device)
123+
del buf
124+
if not allow_sync:
125+
extra_args["output_length"] = expected_tensor.numel()
126+
# Warmup to avoid syncs due to lazy loading of kernels
127+
_ = torch_multi_arange(
128+
ends_tensor,
129+
starts=starts_tensor,
130+
steps=steps_tensor,
131+
**extra_args,
132+
)
133+
134+
with torch.cuda.Stream():
135+
with assert_no_cuda_sync() if not allow_sync else nullcontext():
136+
result = torch_multi_arange(
137+
ends_tensor,
138+
starts=starts_tensor,
139+
steps=steps_tensor,
140+
**extra_args,
141+
)
142+
143+
torch.testing.assert_close(result, expected_tensor)

0 commit comments

Comments
 (0)