Skip to content
Closed
23 changes: 6 additions & 17 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2550,15 +2550,7 @@ def aten_ops_cdist_forward(


def avg_pool_param_validator(pool_node: Node) -> bool:
ceil_mode = args_bounds_check(pool_node.args, 4, False)
divisor_override = args_bounds_check(pool_node.args, 6)

if ceil_mode is not False:
_LOGGER.debug(
f"Currently we don't support specifying ceil_mode, got ceil_mode={ceil_mode}."
)
return False

if divisor_override is not None:
_LOGGER.debug(
f"Currently we don't support divisor_override, got divisor_override={divisor_override}."
Expand Down Expand Up @@ -2694,17 +2686,14 @@ def topk_sort_validator(k: int) -> bool:

def max_pool_param_validator(pool_node: Node) -> bool:
dilation = args_bounds_check(pool_node.args, 4, 1)
ceil_mode = args_bounds_check(pool_node.args, 5, False)

if dilation != 1:
_LOGGER.debug(f"Currently we don't support dilation, got dilation={dilation}.")
return False
if not isinstance(dilation, (list, tuple)):
dilation = (dilation,)

if ceil_mode is not False:
_LOGGER.debug(
f"Currently we don't support specifying ceil_mode, got ceil_mode={ceil_mode}."
)
return False
for dil in dilation:
if dil != 1:
_LOGGER.debug("Currently we don't support dilation > 1 at any dimension.")
return False

return True

Expand Down
15 changes: 8 additions & 7 deletions py/torch_tensorrt/dynamo/conversion/impl/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ def avg_poolNd(
count_include_pad: bool = True,
divisor_override: Optional[int] = None,
) -> TRTTensor:
if ceil_mode is not False:
raise RuntimeError("ceil_mode is not yet supported!")
padding_mode = trt.PaddingMode.EXPLICIT_ROUND_DOWN
if ceil_mode:
padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP

if divisor_override is not None:
raise RuntimeError("divisor_override is not yet supported!")
Expand All @@ -57,6 +58,7 @@ def avg_poolNd(
pool_layer.stride_nd = stride
pool_layer.padding_nd = padding
pool_layer.average_count_excludes_padding = not count_include_pad
pool_layer.padding_mode = padding_mode

set_layer_name(pool_layer, target, name, source_ir)
return pool_layer.get_output(0)
Expand All @@ -77,11 +79,9 @@ def max_poolNd(
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for pooling."

if dilation != 1:
raise RuntimeError("dilation is not yet supported!")

if ceil_mode is not False:
raise RuntimeError("ceil_mode is not yet supported!")
padding_mode = trt.PaddingMode.EXPLICIT_ROUND_DOWN
if ceil_mode:
padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP

dim = len(kernel_size)

Expand All @@ -103,6 +103,7 @@ def max_poolNd(

pool_layer.stride_nd = stride
pool_layer.padding_nd = padding
pool_layer.padding_mode = padding_mode

set_layer_name(pool_layer, target, name, source_ir)
return pool_layer.get_output(0)
Expand Down
24 changes: 23 additions & 1 deletion tests/py/dynamo/conversion/test_pool_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class TestPoolConverter(DispatchTestCase):
((4,), (1,), (1,)),
((5,), (2,), (0,)),
((7,), (2,), (1,)),
((3,), (1,), (1,), 0, True),
((7,), (2,), (1,), 0, True),
]
)
def test_avg_pool1d(
Expand Down Expand Up @@ -44,8 +46,11 @@ def forward(self, x):
(3, 1, 1),
((2, 2), [], (1, 0)),
((4, 3), (1, 1), (1, 1)),
((4, 3), (1, 1), (1, 1), True),
((5, 4), (2, 1), (1, 0)),
((5, 4), (2, 1), (1, 0), True),
((7, 7), (1, 2), (0, 1)),
((7, 7), (1, 2), (0, 1), True),
]
)
def test_avg_pool2d(
Expand All @@ -70,7 +75,7 @@ def forward(self, x):
)

inputs = [torch.randn(1, 3, 32, 32)]
self.run_test(TestModule(), inputs, use_dynamo_tracer=True)
self.run_test(TestModule(), inputs, rtol=5e-03, atol=5e-03, use_dynamo_tracer=True)

@parameterized.expand(
[
Expand All @@ -80,6 +85,8 @@ def forward(self, x):
((4, 3, 2), (1, 1, 1), (1, 1, 0)),
((5, 4, 3), (2, 1, 2), (1, 0, 1)),
((7, 7, 7), (1, 2, 1), (0, 1, 1)),
((7, 7, 7), (1, 2, 1), (0, 1, 1), True),
((5, 4, 3), (2, 1, 2), (1, 0, 1), True),
]
)
def test_avg_pool3d(
Expand Down Expand Up @@ -168,6 +175,16 @@ def forward(self, x):
(1, 1),
(1, 1),
),
(
(1, 1, 1, 1),
(2, 2, 2, 2),
(3, 3, 3, 3),
torch.float,
(3, 3),
(1, 1),
(1, 1),
True
),
]
)
def test_dynamic_shape_pool2d(
Expand Down Expand Up @@ -258,6 +275,7 @@ def forward(self, x):
((4,), (1,), (1,)),
((5,), (2,), (0,)),
((7,), (2,), (1,)),
((7,), (2,), (1,), 1, True),
]
)
def test_max_pool1d(
Expand Down Expand Up @@ -290,6 +308,9 @@ def forward(self, x):
((4, 3), (1, 1), (1, 1)),
((5, 4), (2, 1), (1, 0)),
((7, 7), (1, 2), (0, 1)),
((4, 3), (1, 1), (1, 1), 1, True),
((5, 4), (2, 1), (1, 0), 1, True),
((7, 7), (1, 2), (0, 1), 1, True),
]
)
def test_max_pool2d(
Expand Down Expand Up @@ -322,6 +343,7 @@ def forward(self, x):
((4, 3, 2), (1, 1, 1), (1, 1, 0)),
((5, 4, 3), (2, 1, 2), (1, 0, 1)),
((7, 7, 7), (1, 2, 1), (0, 1, 1)),
((7, 7, 7), (1, 2, 1), (0, 1, 1), 1, True),
]
)
def test_max_pool3d(
Expand Down