diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 35879d5026c..364d4bdf329 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -47,7 +47,6 @@ from .keep_dims_false_to_squeeze_pass import KeepDimsFalseToSqueezePass # noqa from .match_arg_ranks_pass import MatchArgRanksPass # noqa from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa -from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa from .remove_clone_pass import RemoveClonePass # noqa from .replace_scalar_with_tensor_pass import ( # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index c5ebace2834..1d9c2231b2f 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -17,7 +17,6 @@ ConvertAnyDefaultDimDimsPass, ConvertExpandCopyToRepeatPass, ConvertFullLikeToFullPass, - ConvertMeanDimToAveragePoolPass, ConvertMinMaxPass, ConvertMmToBmmPass, ConvertSplitToSlicePass, @@ -87,7 +86,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertSplitToSlicePass()) self.add_pass(ConvertMmToBmmPass()) self.add_pass(DecomposeLinearPass()) - self.add_pass(ConvertMeanDimToAveragePoolPass()) + self.add_pass(DecomposeMeanDimPass()) self.add_pass(ConvertFullLikeToFullPass()) self.add_pass(ConvertToClampPass()) self.add_pass(ConvertMinMaxPass()) @@ -140,7 +139,6 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeVarPass()) self.add_pass(DecomposeMeanDimPass()) self.add_pass(DecomposeNotEqualPass()) - self.add_pass(ConvertMeanDimToAveragePoolPass()) self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeSoftmaxPass()) self.add_pass(DecomposeGeluPass()) diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index 6af6caf0c3f..0e5fe03ab0f 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -1,10 +1,9 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +from math import prod import torch from executorch.backends.arm._passes import ArmPass @@ -28,15 +27,37 @@ def get_meandim_decomposition(op) -> tuple: raise RuntimeError(f"Can't get meandim decomposition for op {op}") +def get_avgpool(op): + if op == exir_ops.edge.aten.mean.dim: + return exir_ops.edge.aten.avg_pool2d.default + if op == torch.ops.aten.mean.dim: + return torch.ops.aten.avg_pool2d.default + raise RuntimeError(f"Can't get meandim decomposition for op {op}") + + +def get_view(op): + if op == exir_ops.edge.aten.mean.dim: + return exir_ops.edge.aten.view_copy.default + if op == torch.ops.aten.mean.dim: + return torch.ops.aten.view_copy.default + raise RuntimeError(f"Can't get meandim decomposition for op {op}") + + class DecomposeMeanDimPass(ArmPass): """ - This pass decomposes meandim into a sum and mul node. + Decomposes a meandim into avg_pool and/or sum + mul (1/N) depending on which dims the mean is taken for: + h,w -> avg_pool + n,c -> sum + mul(1/N) + For rank < 4, the input is first reshaped to 4D by padding with dim=1 from the left. Example: - y = mean_dim(x, dim, keepdim) + x = mean_dim(x, (0,2), keepdim=False) # x = (c,h,w) Becomes: - sum = sum.dim_IntList(x, dim, keepdim) - y = mul(sum, 1/N) + x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to work with avg_pool + x = avg_pool2d.default(x, kernel=(1,w), stride=(1,1)) # Reduce w with avg_pool + x = sum.dim_IntList(x, dim=1, keepdims=True) # Reduce c with sum + x = mul.Tensor(x, 1/c) # Divide by number of channels to get mean + x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False """ def call_operator(self, op, args, kwargs, meta): @@ -44,26 +65,73 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator(op, args, kwargs, meta) x = get_node_arg(args, 0) - dim = get_node_arg(args, 1) - keepdim = get_node_arg(args, 2, False) - - # if dim == [-1, -2], mean.dim can be - # decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool. - if dim == [-1, -2]: - # Simply return the mean.dim operator for future decomposition. - return super().call_operator(op, args, kwargs, meta) + input_shape = x.data.size() + output_shape = meta["val"].size() + dims_to_reduce = get_node_arg(args, 1) + dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce] - shape = meta["val"].size() dtype = meta["val"].dtype - input_shape = x.data.size() - N = 1 - for d in dim: - N *= input_shape[d] + view_op = get_view(op) + if len(input_shape) > 4: + raise NotImplementedError( + f"{op} with rank > 4 is currently not supported for the TOSA backend." + ) + + # Unsqueeze to 4D + if len(input_shape) < 4: + pad_n = 4 - len(input_shape) + new_shape = [1] * pad_n + list(input_shape) + dims_to_reduce = [dim + pad_n for dim in dims_to_reduce] + + x = super().call_operator(view_op, (x, new_shape), {}, meta, True) + + # Reduce (h,w) by avg pool + dims_to_reduce_by_avgpool = [dim for dim in dims_to_reduce if dim >= 2] + x = self._reduce_by_average_pool(op, x, dims_to_reduce_by_avgpool, meta) + + # Reduce (n, c) by reduce sum + dims_to_reduce_by_sum = [dim for dim in dims_to_reduce if dim < 2] + x = self._reduce_by_sum(op, x, dims_to_reduce_by_sum, meta, dtype) + + # Reshape to correct output shape if necessary + if x.data.size() != output_shape: + x = super().call_operator(view_op, (x, output_shape), {}, meta, True) + + return x + + def _reduce_by_sum(self, op, input_node, dims, meta, dtype): + if len(dims) == 0: + return input_node + + input_shape = input_node.data.size() + output_shape = meta["val"].size() + N = prod((n for i, n in enumerate(input_shape) if i in dims)) sum_op, full_op, mul_op = get_meandim_decomposition(op) - sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta, True) + sum = super().call_operator(sum_op, (input_node, dims, True), {}, meta, True) full = super().call_operator( - full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta, True + full_op, ([1] * len(output_shape), 1 / N), {"dtype": dtype}, meta, True ) return super().call_operator(mul_op, (sum, full), {}, meta, True) + + def _reduce_by_average_pool(self, op, input_node, dims, meta): + if len(dims) == 0: + return input_node + + avgpool_op = get_avgpool(op) + input_shape = input_node.data.size() + + stride = [1, 1] + if dims in ([2, 3], [3, 2]): + kernel_size = [input_shape[2], input_shape[3]] + elif dims == [3]: + kernel_size = [1, input_shape[3]] + elif dims == [2]: + kernel_size = [input_shape[2], 1] + else: + raise RuntimeError(f"Bad dims {dims} for {op} decomposition of mean_dim.") + + return super().call_operator( + avgpool_op, (input_node, kernel_size, stride), {}, meta, True + ) diff --git a/backends/arm/_passes/meandim_to_averagepool_pass.py b/backends/arm/_passes/meandim_to_averagepool_pass.py deleted file mode 100644 index 9a755191504..00000000000 --- a/backends/arm/_passes/meandim_to_averagepool_pass.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Any, cast, Dict, Tuple - -import torch.fx - -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue - -Argument = Any - - -class ConvertMeanDimToAveragePoolPass(ExportPass): - """ - Replace a mean operation with dim = [-1, -2] and keep_dim = True with an average pool operation. - """ - - def call_operator( - self, - op: torch.fx.node.Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.mean.dim: - return super().call_operator(op, args, kwargs, meta) - - input_value = cast(ProxyValue, args[0]) - dim = cast(list, args[1]) - keep_dim = cast(bool, args[2]) if len(args) > 2 else False - - # averagepool2d gets converted to a mean operation with dim = [-1, -2] and keep_dim = True - # so check the dim argument for this case - if dim == [-1, -2] and keep_dim is True: - # Given the shape format of input is (N, C, H, W) - kernel_size = [ - input_value.to_tensor().size()[2], - input_value.to_tensor().size()[3], - ] - stride = [1, 1] - return super().call_operator( - exir_ops.edge.aten.avg_pool2d.default, - (input_value, kernel_size, stride), - {}, - meta, - ) - else: - return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 89a87b2637a..c732c91a20a 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -262,28 +262,23 @@ def is_node_supported( if node.op != "call_function": return True - if node.target == exir_ops.edge.aten.mean.dim: - dim = node.args[1] - needs_decomp = dim != [-1, -2] - else: - needs_decomp = node.target in [ - exir_ops.edge.aten.div.Tensor, - exir_ops.edge.aten._native_batch_norm_legit_no_training.default, - exir_ops.edge.aten.native_layer_norm.default, - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten._softmax.default, - exir_ops.edge.aten._log_softmax.default, - exir_ops.edge.aten.var.correction, - exir_ops.edge.aten.var.dim, - exir_ops.edge.aten.add.Scalar, - exir_ops.edge.aten.sqrt.default, - exir_ops.edge.aten.sub.Scalar, - exir_ops.edge.aten.mul.Scalar, - exir_ops.edge.aten.ne.Tensor, - exir_ops.edge.aten.ne.Scalar, - exir_ops.edge.aten.div.Scalar, - exir_ops.edge.aten.leaky_relu.default, - ] + needs_decomp = node.target in [ + exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten._native_batch_norm_legit_no_training.default, + exir_ops.edge.aten.native_layer_norm.default, + exir_ops.edge.aten._softmax.default, + exir_ops.edge.aten._log_softmax.default, + exir_ops.edge.aten.var.correction, + exir_ops.edge.aten.var.dim, + exir_ops.edge.aten.add.Scalar, + exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.sub.Scalar, + exir_ops.edge.aten.mul.Scalar, + exir_ops.edge.aten.ne.Tensor, + exir_ops.edge.aten.ne.Scalar, + exir_ops.edge.aten.div.Scalar, + exir_ops.edge.aten.leaky_relu.default, + ] if needs_decomp: self.reporter.report_reject(node, "Needs to be decomposed.") return False diff --git a/backends/arm/test/ops/test_layer_norm.py b/backends/arm/test/ops/test_layer_norm.py index d2d9aa0bc14..8d31ef992cb 100644 --- a/backends/arm/test/ops/test_layer_norm.py +++ b/backends/arm/test/ops/test_layer_norm.py @@ -81,8 +81,8 @@ def test_native_layer_norm_tosa_BI(test_data): model, test_data, "torch.ops.aten.sub.Tensor", # Just check for sub op included in the layernorm decomposition + symmetric_io_quantization=True, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() @@ -95,8 +95,8 @@ def test_native_layer_norm_u55_BI(test_data): test_data, "torch.ops.aten.sub.Tensor", # Just check for sub op included in the layernorm decomposition run_on_fvp=True, + symmetric_io_quantization=True, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() @@ -109,6 +109,6 @@ def test_native_layer_norm_u85_BI(test_data): test_data, "torch.ops.aten.sub.Tensor", # Just check for sub op included in the layernorm decomposition run_on_fvp=True, + symmetric_io_quantization=True, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() diff --git a/backends/arm/test/ops/test_mean_dim.py b/backends/arm/test/ops/test_mean_dim.py index 43063058805..b512d6d13bc 100644 --- a/backends/arm/test/ops/test_mean_dim.py +++ b/backends/arm/test/ops/test_mean_dim.py @@ -53,6 +53,7 @@ def test_adaptive_avg_pool2d_tosa_BI(test_data): test_data(), AdaptiveAveragePool2d.aten_op, AdaptiveAveragePool2d.exir_op, + symmetric_io_quantization=True, ).run() @@ -65,6 +66,7 @@ def test_adaptive_avg_pool2d_u55_BI(test_data): AdaptiveAveragePool2d.aten_op, AdaptiveAveragePool2d.exir_op, run_on_fvp=True, + symmetric_io_quantization=True, ).run() @@ -77,21 +79,120 @@ def test_adaptive_avg_pool2d_u85_BI(test_data): AdaptiveAveragePool2d.aten_op, AdaptiveAveragePool2d.exir_op, run_on_fvp=True, + symmetric_io_quantization=True, ).run() class MeanDim(torch.nn.Module): test_data_suite: dict[str, tuple] = { - "zeros": lambda: (torch.zeros(1, 1280, 7, 7), -1, True), - "ones": lambda: (torch.ones(1, 1280, 7, 7), (-1, 2), False), - "rand": lambda: ( - torch.rand(1, 1280, 7, 7), - (-1), + "rank_1_keepdim": lambda: ( + torch.rand(7), + (0), + True, + ), + "rank_2_keepdim": lambda: ( + torch.rand(7, 7), + (0, 1), + True, + ), + "rank_3_keepdim": lambda: ( + torch.rand(7, 7, 7), + (0, 1, 2), + True, + ), + "rand_1_keepdim": lambda: ( + torch.rand(1, 7, 7, 7), + (1), + True, + ), + "rand_2_keepdim": lambda: ( + torch.rand(1, 7, 7, 7), + (2), + True, + ), + "rand_3_keepdim": lambda: ( + torch.rand(1, 7, 7, 7), + (3), True, ), - "randn": lambda: ( - torch.randn(1, 1280, 7, 7), - (-1, -2, -3), + "rand_12_keepdim": lambda: ( + torch.rand(1, 7, 7, 7), + (1, 2), + True, + ), + "rand_13_keepdim": lambda: ( + torch.rand(1, 7, 7, 7), + (1, 3), + True, + ), + "rand_23_keepdim": lambda: ( + torch.rand(1, 7, 7, 7), + (2, 3), + True, + ), + "rand_123_keepdim": lambda: ( + torch.rand(1, 7, 7, 7), + (1, 2, 3), + True, + ), + "rand_0123_keepdim": lambda: ( + torch.rand(1, 7, 7, 7), + (0, 1, 2, 3), + True, + ), + "rank_1": lambda: ( + torch.rand(7), + (-1), + False, + ), + "rank_2": lambda: ( + torch.rand(7, 7), + (-2, -1), + False, + ), + "rank_3": lambda: ( + torch.rand(7, 7, 7), + (-3, -2, -1), + False, + ), + "rand_1": lambda: ( + torch.rand(1, 7, 7, 7), + (-3), + False, + ), + "rand_2": lambda: ( + torch.rand(1, 7, 7, 7), + (-2), + False, + ), + "rand_3": lambda: ( + torch.rand(1, 7, 7, 7), + (-1), + False, + ), + "rand_12": lambda: ( + torch.rand(1, 7, 7, 7), + (-3, -2), + False, + ), + "rand_13": lambda: ( + torch.rand(1, 7, 7, 7), + (-3, -1), + False, + ), + "rand_23": lambda: ( + torch.rand(1, 7, 7, 7), + (-2, -1), + False, + ), + "rand_123": lambda: ( + torch.rand(1, 7, 7, 7), + (-3, -2, -1), + False, + ), + "rand_0123": lambda: ( + torch.rand(1, 7, 7, 7), + (-4, -3, -2, -1), False, ), } @@ -124,9 +225,9 @@ def test_mean_dim_tosa_BI(test_data): pipeline = TosaPipelineBI[input_t]( MeanDim(dim, keep_dim), (test_data,), - "torch.ops.aten.sum.dim_IntList", # Just check for sum op included in the mean decomposition + [], # Might be sum, avgpool, or both + symmetric_io_quantization=True, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() @@ -137,10 +238,10 @@ def test_mean_dim_u55_BI(test_data): pipeline = EthosU55PipelineBI[input_t]( MeanDim(dim, keep_dim), (test_data,), - "torch.ops.aten.sum.dim_IntList", # Just check for sum op included in the mean decomposition + [], # Might be sum, avgpool, or both run_on_fvp=True, - ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) + symmetric_io_quantization=True, + ).dump_artifact("export") pipeline.run() @@ -151,8 +252,8 @@ def test_mean_dim_u85_BI(test_data): pipeline = EthosU85PipelineBI[input_t]( MeanDim(dim, keep_dim), (test_data,), - "torch.ops.aten.sum.dim_IntList", # Just check for sum op included in the mean decomposition + [], # Might be sum, avgpool, or both run_on_fvp=True, + symmetric_io_quantization=True, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1) pipeline.run() diff --git a/backends/arm/test/passes/test_decompose_meandim_pass.py b/backends/arm/test/passes/test_decompose_meandim_pass.py index 511959e36cf..fe953198527 100644 --- a/backends/arm/test/passes/test_decompose_meandim_pass.py +++ b/backends/arm/test/passes/test_decompose_meandim_pass.py @@ -17,32 +17,68 @@ class MeanDim(torch.nn.Module): """ - Basic mean model using torch.mean function making sure keepdim=True (keepdim=False doesnt work for this pass for some reason) + Basic mean model using torch.mean with keepdim = True """ + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_mean_dim": 1, + } + ops_not_before_pass = [ + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default", + "executorch_exir_dialects_edge__ops_aten_sum_dim_IntList", + "executorch_exir_dialects_edge__ops_aten_mul_Tensor", + ] + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_sum_dim_IntList": 1, + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + } + + ops_not_after_pass = [ + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default", + "executorch_exir_dialects_edge__ops_aten_mean_dim", + ] + def __init__(self): super(MeanDim, self).__init__() def forward(self, x): - return torch.mean(x, 1, True) + return torch.mean(x, (0, 1), True) def get_inputs(self) -> input_t: - return (torch.rand(4, 4),) + return (torch.rand(4, 4, 4, 4),) class MeanDimTensor(torch.nn.Module): """ - Basic mean model using torch.Tensor.mean function making sure keepdim=True (keepdim=False doesnt work for this pass for some reason) + Basic mean model using torch.Tensor.mean with keepdim = False """ + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_mean_dim": 1, + } + ops_not_before_pass = [ + "executorch_exir_dialects_edge__ops_aten_mul_Tensor", + "executorch_exir_dialects_edge__ops_aten_full_default", + "executorch_exir_dialects_edge__ops_aten_sum_dim_IntList", + ] + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_full_default": 1, + "executorch_exir_dialects_edge__ops_aten_sum_dim_IntList": 1, + } + + ops_not_after_pass = ["executorch_exir_dialects_edge__ops_aten_mean_dim"] + def __init__(self): super(MeanDimTensor, self).__init__() def forward(self, x): - return x.mean(1, True) + return x.mean((0, 1), False) def get_inputs(self) -> input_t: - return (torch.rand(4, 4),) + return (torch.rand(4, 4, 4),) modules = {"meandim_basic": MeanDim(), "meandim_tensor": MeanDimTensor()} @@ -53,21 +89,10 @@ def test_decompose_meandim_tosa_MI(module): pipeline = PassPipeline[input_t]( module, module.get_inputs(), - quantize=False, - ops_before_pass={ - "executorch_exir_dialects_edge__ops_aten_mean_dim": 1, - }, - ops_not_before_pass=[ - "executorch_exir_dialects_edge__ops_aten_mul_Tensor", - "executorch_exir_dialects_edge__ops_aten_full_default", - "executorch_exir_dialects_edge__ops_aten_sum_dim_IntList", - ], - ops_after_pass={ - "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, - "executorch_exir_dialects_edge__ops_aten_full_default": 1, - "executorch_exir_dialects_edge__ops_aten_sum_dim_IntList": 1, - }, - ops_not_after_pass=["executorch_exir_dialects_edge__ops_aten_mean_dim"], + ops_before_pass=module.ops_before_pass, + ops_not_before_pass=module.ops_not_before_pass, + ops_after_pass=module.ops_after_pass, + ops_not_after_pass=module.ops_not_after_pass, pass_list=[DecomposeMeanDimPass], ) pipeline.run() diff --git a/backends/arm/test/passes/test_meandim_to_averagepool2d.py b/backends/arm/test/passes/test_meandim_to_averagepool2d.py deleted file mode 100644 index fbcb26d2542..00000000000 --- a/backends/arm/test/passes/test_meandim_to_averagepool2d.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import Tuple - -import torch -from executorch.backends.arm._passes import ConvertMeanDimToAveragePoolPass -from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.test_pipeline import PassPipeline - - -input_t = Tuple[torch.Tensor, torch.Tensor] # Input x - - -class MeanDim(torch.nn.Module): - def forward(self, x): - return torch.mean(x, dim=[-1, -2], keepdim=True) - - def get_inputs(self) -> input_t: - return (torch.rand(1, 1280, 7, 7),) - - ops_before_pass = {"executorch_exir_dialects_edge__ops_aten_mean_dim": 1} - ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1} - ops_not_after_pass = [ - "aten_sum_dim_int_list", - "aten_full_default", - "aten_mul_tensor", - ] - - -class MeanDim2(torch.nn.Module): - def forward(self, x): - return torch.mean(x, dim=1) - - def get_inputs(self) -> input_t: - return (torch.rand(1, 1280, 7, 7),) - - ops_before_pass = { - "aten_sum_dim_int_list": 3, - "aten_full_default": 4, - "aten_mul_tensor": 3, - } - ops_after_pass = { - "aten_sum_dim_int_list": 3, - "aten_full_default": 4, - "aten_mul_tensor": 3, - } - ops_not_after_pass = ["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"] - - -modules = { - "meandim_to_averagepool": MeanDim(), - "meandim_no_modification": MeanDim2(), -} - - -@common.parametrize("module", modules) -def test_meandim_to_avgpool_tosa_BI(module: torch.nn.Module): - """ - Tests the MeanDimToAveragePool2dPass which converts mean.dim to average_pool2d - for the special case where dim is [-1, -2] and keepdim is True. - """ - pipeline = PassPipeline[input_t]( - module, - module.get_inputs(), - quantize=True, - ops_before_pass=module.ops_before_pass, - ops_after_pass=module.ops_after_pass, - ops_not_after_pass=module.ops_not_after_pass, - pass_list=[ConvertMeanDimToAveragePoolPass], - ) - pipeline.pop_stage(-1) # Do not compare output - pipeline.run()