Skip to content

Commit 92575a0

Browse files
add aten.topk implementation (#2841)
1 parent 2b4d699 commit 92575a0

File tree

4 files changed

+126
-5
lines changed

4 files changed

+126
-5
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch_tensorrt.dynamo.conversion.converter_utils import (
1717
dynamic_unsupported_with_args,
1818
enforce_tensor_types,
19+
get_positive_dim,
1920
is_only_operator_on_placeholder,
2021
)
2122
from torch_tensorrt.fx.types import TRTTensor
@@ -2411,6 +2412,28 @@ def aten_ops_adaptive_avg_poolNd(
24112412
)
24122413

24132414

2415+
def topk_validator(node: Node) -> bool:
2416+
k = node.args[1]
2417+
return topk_sort_validator(k)
2418+
2419+
2420+
def sort_validator(node: Node) -> bool:
2421+
shape = node.args[0].meta.get("tensor_meta").shape
2422+
dim = node.args[1]
2423+
dim = get_positive_dim(dim, len(shape))
2424+
k = shape[dim]
2425+
return topk_sort_validator(k)
2426+
2427+
2428+
def topk_sort_validator(k: int) -> bool:
2429+
if k > 3840:
2430+
_LOGGER.debug(
2431+
f"Currently only topk values up to 3840 are supported, got k={k}."
2432+
)
2433+
return False
2434+
return True
2435+
2436+
24142437
def max_pool_param_validator(pool_node: Node) -> bool:
24152438
dilation = args_bounds_check(pool_node.args, 4, 1)
24162439
ceil_mode = args_bounds_check(pool_node.args, 5, False)
@@ -2792,7 +2815,37 @@ def upsample_bilinear2d(
27922815
)
27932816

27942817

2795-
@dynamo_tensorrt_converter(torch.ops.aten.sort.default)
2818+
@dynamo_tensorrt_converter(
2819+
torch.ops.aten.topk.default, capability_validator=topk_validator
2820+
)
2821+
@enforce_tensor_types(
2822+
{
2823+
0: (TRTTensor,),
2824+
}
2825+
)
2826+
def aten_ops_topk(
2827+
ctx: ConversionContext,
2828+
target: Target,
2829+
args: Tuple[Argument, ...],
2830+
kwargs: Dict[str, Argument],
2831+
name: str,
2832+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2833+
return impl.topk.topk(
2834+
ctx,
2835+
target,
2836+
SourceIR.ATEN,
2837+
name,
2838+
args[0],
2839+
k=args[1],
2840+
dim=args_bounds_check(args, 2, -1),
2841+
largest=args_bounds_check(args, 3, True),
2842+
sorted=args_bounds_check(args, 4, True),
2843+
)
2844+
2845+
2846+
@dynamo_tensorrt_converter(
2847+
torch.ops.aten.sort.default, capability_validator=sort_validator
2848+
)
27962849
@enforce_tensor_types(
27972850
{
27982851
0: (TRTTensor,),

py/torch_tensorrt/dynamo/conversion/impl/topk.py

+33-4
Original file line numberDiff line numberDiff line change
@@ -113,21 +113,50 @@ def sort(
113113
descending: bool,
114114
return_indices: bool = True,
115115
) -> Union[TRTTensor, Tuple[TRTTensor, TRTTensor]]:
116-
if descending:
116+
dim = get_positive_dim(dim, len(input.shape))
117+
k = input.shape[dim]
118+
return topk(
119+
ctx,
120+
target,
121+
source_ir,
122+
name,
123+
input,
124+
k,
125+
dim,
126+
descending,
127+
sorted=None,
128+
return_indices=return_indices,
129+
)
130+
131+
132+
def topk(
133+
ctx: ConversionContext,
134+
target: Target,
135+
source_ir: Optional[SourceIR],
136+
name: str,
137+
input: TRTTensor,
138+
k: int,
139+
dim: int,
140+
largest: bool,
141+
sorted: Optional[bool],
142+
return_indices: bool = True,
143+
) -> Union[TRTTensor, Tuple[TRTTensor, TRTTensor]]:
144+
if largest:
117145
topk_layer = ctx.net.add_topk(
118146
input,
119147
trt.TopKOperation.MAX,
120-
input.shape[dim],
148+
k,
121149
get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))),
122150
)
123151
else:
124152
topk_layer = ctx.net.add_topk(
125153
input,
126154
trt.TopKOperation.MIN,
127-
input.shape[dim],
155+
k,
128156
get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))),
129157
)
130-
158+
# TensorRT ITopKLayer does not have a sorted flag, it is always returning the sorted topk elements
159+
# so here no matter sorted is True or False the returned the topk Tensor object is always sorted
131160
set_layer_name(topk_layer, target, name, source_ir)
132161

133162
if return_indices:

tests/py/dynamo/conversion/test_sort_aten.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def forward(self, x):
2727
self.run_test(
2828
Sort(),
2929
inputs,
30+
enable_passes=True,
3031
)
3132

3233

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestSortConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((3, 2, 4), 1, 0, True, True),
13+
((3, 3, 4), 2, -1, True, True),
14+
((3, 3, 4), 2, -1, False, True),
15+
((3850, 2), 3840, 0, False, True),
16+
((3, 3), 2, 0, True, True),
17+
((3, 3), 2, 1, True, False),
18+
((5, 3), 2, 1, False, False),
19+
((6, 4), 2, 1, False, False),
20+
# default dim:-1 largest:True, sorted:True
21+
((3, 5, 12), 3),
22+
]
23+
)
24+
def test_topk(self, input_shape, k, dim=-1, largest=True, sorted=True):
25+
class Topk(nn.Module):
26+
def forward(self, x):
27+
return torch.ops.aten.topk.default(x, k, dim, largest, sorted)
28+
29+
inputs = [torch.randn(*input_shape)]
30+
self.run_test(
31+
Topk(),
32+
inputs,
33+
enable_passes=True,
34+
)
35+
36+
37+
if __name__ == "__main__":
38+
run_tests()

0 commit comments

Comments
 (0)