diff --git a/backends/arm/test/passes/test_broadcast_args_pass.py b/backends/arm/test/passes/test_broadcast_args_pass.py index 719a0ddd622..3df27454f90 100644 --- a/backends/arm/test/passes/test_broadcast_args_pass.py +++ b/backends/arm/test/passes/test_broadcast_args_pass.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import operator -from typing import Tuple +from typing import Callable, Tuple import torch from executorch.backends.arm._passes import BroadcastArgsPass @@ -12,17 +12,19 @@ from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline -input_t = Tuple[torch.Tensor] # Input x +input_t = Tuple[torch.Tensor, torch.Tensor] class NeedsMultipleBroadcastsModel(torch.nn.Module): test_data = (torch.rand(1, 10), torch.rand(10, 1)) - def __init__(self, op: operator): + def __init__( + self, op: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + ) -> None: self.op = op super().__init__() - def forward(self, x: torch.Tensor, y: torch.Tensor): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return self.op(x, y) diff --git a/backends/arm/test/passes/test_cast_int64_pass.py b/backends/arm/test/passes/test_cast_int64_pass.py index 7832fd87ed9..afcc0d1db36 100644 --- a/backends/arm/test/passes/test_cast_int64_pass.py +++ b/backends/arm/test/passes/test_cast_int64_pass.py @@ -21,7 +21,7 @@ class Int64Model(torch.nn.Module): "rand": (torch.rand(4),), } - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x + 3 diff --git a/backends/arm/test/passes/test_convert_expand_copy_to_repeat.py b/backends/arm/test/passes/test_convert_expand_copy_to_repeat.py index aa877c355bd..899472b2e8a 100644 --- a/backends/arm/test/passes/test_convert_expand_copy_to_repeat.py +++ b/backends/arm/test/passes/test_convert_expand_copy_to_repeat.py @@ -20,17 +20,17 @@ class Expand(torch.nn.Module): Basic expand model using torch.Tensor.expand function """ - def __init__(self): - super(Expand, self).__init__() + def __init__(self) -> None: + super().__init__() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.expand(3, 4) def get_inputs(self) -> input_t: return (torch.rand(3, 1),) -def test_expand_to_repeat_tosa_INT(): +def test_expand_to_repeat_tosa_INT() -> None: module = Expand() pipeline = PassPipeline[input_t]( module, diff --git a/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py b/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py index ddb31625849..7c7ad984e4c 100644 --- a/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py +++ b/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple, Union +from typing import Callable, ClassVar, Dict, Tuple, Union import pytest @@ -22,6 +22,10 @@ input_t1 = Tuple[torch.Tensor] # Input x input_t2 = Tuple[torch.Tensor, torch.Tensor] # Input x, y +Scalar = Union[bool, float, int] +ArangeNoneParam = Tuple[Callable[[], input_t1], Tuple[Scalar, Scalar, Scalar]] +FullNoneParam = Tuple[Callable[[], input_t1], Tuple[Tuple[int, ...], Scalar]] + ##################################################### ## Test arange(dtype=int64) -> arange(dtype=int32) ## @@ -29,11 +33,10 @@ class ArangeDefaultIncrementViewLessThan(torch.nn.Module): - - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return (torch.arange(10, dtype=torch.int64) + 1).view(-1, 1) < x - test_data = { + test_data: ClassVar[Dict[str, input_t1]] = { "randint": ( torch.randint( 0, @@ -46,7 +49,9 @@ def forward(self, x: torch.Tensor): @common.parametrize("test_data", ArangeDefaultIncrementViewLessThan.test_data) -def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1): +def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_FP( + test_data: input_t1, +) -> None: module = ArangeDefaultIncrementViewLessThan() aten_ops_checks = [ "torch.ops.aten.lt.Tensor", @@ -67,7 +72,9 @@ def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_FP(test_data: inp @common.parametrize("test_data", ArangeDefaultIncrementViewLessThan.test_data) -def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1): +def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT( + test_data: input_t1, +) -> None: module = ArangeDefaultIncrementViewLessThan() aten_ops_checks = [ "torch.ops.aten.lt.Tensor", @@ -88,11 +95,10 @@ def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT(test_data: in class ArangeStartIncrementViewLessThan(torch.nn.Module): - - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return (torch.arange(0, 10, dtype=torch.int64) + 1).view(-1, 1) < x - test_data = { + test_data: ClassVar[Dict[str, input_t1]] = { "randint": ( torch.randint( 0, @@ -105,7 +111,9 @@ def forward(self, x: torch.Tensor): @common.parametrize("test_data", ArangeStartIncrementViewLessThan.test_data) -def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1): +def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_FP( + test_data: input_t1, +) -> None: module = ArangeStartIncrementViewLessThan() aten_ops_checks = [ "torch.ops.aten.lt.Tensor", @@ -126,7 +134,9 @@ def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_FP(test_data: input @common.parametrize("test_data", ArangeStartIncrementViewLessThan.test_data) -def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1): +def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT( + test_data: input_t1, +) -> None: module = ArangeStartIncrementViewLessThan() aten_ops_checks = [ "torch.ops.aten.lt.Tensor", @@ -147,11 +157,10 @@ def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT(test_data: inpu class ArangeStartStepIncrementViewLessThan(torch.nn.Module): - - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return (torch.arange(0, 10, 2, dtype=torch.int64) + 1).view(-1, 1) < x - test_data = { + test_data: ClassVar[Dict[str, input_t1]] = { "randint": ( torch.randint( 0, @@ -166,7 +175,7 @@ def forward(self, x: torch.Tensor): @common.parametrize("test_data", ArangeStartStepIncrementViewLessThan.test_data) def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_FP( test_data: input_t1, -): +) -> None: module = ArangeStartStepIncrementViewLessThan() aten_ops_checks = [ "torch.ops.aten.lt.Tensor", @@ -189,7 +198,7 @@ def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_FP( @common.parametrize("test_data", ArangeStartStepIncrementViewLessThan.test_data) def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_INT( test_data: input_t1, -): +) -> None: module = ArangeStartStepIncrementViewLessThan() aten_ops_checks = [ "torch.ops.aten.lt.Tensor", @@ -225,7 +234,7 @@ def __init__(self, start: float, stop: float, step: float): def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.arange(*self.args) + x - test_data = { + test_data: ClassVar[Dict[str, ArangeNoneParam]] = { "int64": (lambda: (torch.randn(10, 1),), (0, 10, 1)), "float32_start": (lambda: (torch.randn(10, 1),), (0.0, 10, 1)), "float32_stop": (lambda: (torch.randn(10, 1),), (0, 10.0, 1)), @@ -238,11 +247,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @common.parametrize("test_data", ArangeAddDtypeNone.test_data) -def test_arange_dtype_none_tosa_FP(test_data): - input_data, init_data = test_data +def test_arange_dtype_none_tosa_FP(test_data: ArangeNoneParam) -> None: + input_factory, init_data = test_data pipeline = TosaPipelineFP[input_t1]( ArangeAddDtypeNone(*init_data), - input_data(), + input_factory(), ArangeAddDtypeNone.aten_op, ArangeAddDtypeNone.exir_op, ) @@ -250,11 +259,11 @@ def test_arange_dtype_none_tosa_FP(test_data): @common.parametrize("test_data", ArangeAddDtypeNone.test_data) -def test_arange_dtype_none_tosa_INT(test_data): - input_data, init_data = test_data +def test_arange_dtype_none_tosa_INT(test_data: ArangeNoneParam) -> None: + input_factory, init_data = test_data pipeline = TosaPipelineINT[input_t1]( ArangeAddDtypeNone(*init_data), - input_data(), + input_factory(), ArangeAddDtypeNone.aten_op, ArangeAddDtypeNone.exir_op, ) @@ -268,8 +277,7 @@ def test_arange_dtype_none_tosa_INT(test_data): class FullIncrementViewMulXLessThanY(torch.nn.Module): - - def forward(self, x: torch.Tensor, y: torch.Tensor): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return ( ( torch.full( @@ -286,7 +294,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): * x ) < y - test_data = { + test_data: ClassVar[Dict[str, input_t2]] = { "randint": ( torch.randint( 0, @@ -305,7 +313,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): @common.parametrize("test_data", FullIncrementViewMulXLessThanY.test_data) -def test_convert_full_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1): +def test_convert_full_int64_dtype_to_int32_pass_tosa_FP( + test_data: input_t2, +) -> None: """ There are four int64 placeholders in the original graph: 1. _lifted_tensor_constant0: 1 @@ -347,7 +357,9 @@ def test_convert_full_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1): @common.parametrize("test_data", FullIncrementViewMulXLessThanY.test_data) -def test_convert_full_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1): +def test_convert_full_int64_dtype_to_int32_pass_tosa_INT( + test_data: input_t2, +) -> None: """ For INT profile, _lifted_tensor_constant0 is still int64 after applying ConvertInt64ConstOpsToInt32Pass(). And an int64->int32 cast is inserted at the beginning of the graph. @@ -380,8 +392,7 @@ def test_convert_full_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1): class RejectFullIncrementViewMulXLessThanY(torch.nn.Module): - - def forward(self, x: torch.Tensor, y: torch.Tensor): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return ( ( torch.full( @@ -398,7 +409,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): * x ) < y - test_data = { + test_data: ClassVar[Dict[str, input_t2]] = { "randint": ( torch.randint( 0, @@ -420,7 +431,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): @pytest.mark.xfail( reason="MLETORCH-1254: Add operator support check for aten.arange and aten.full" ) -def test_reject_convert_full_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1): +def test_reject_convert_full_int64_dtype_to_int32_pass_tosa_FP( + test_data: input_t2, +) -> None: module = RejectFullIncrementViewMulXLessThanY() aten_ops_checks = [ "torch.ops.aten.full.default", @@ -469,11 +482,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @common.parametrize("test_data", AddConstFullDtypeNone.test_data) -def test_full_dtype_none_tosa_FP(test_data): - input_data, init_data = test_data +def test_full_dtype_none_tosa_FP(test_data: FullNoneParam) -> None: + input_factory, init_data = test_data pipeline = TosaPipelineFP[input_t1]( AddConstFullDtypeNone(*init_data), - input_data(), + input_factory(), aten_op=[], exir_op=AddConstFullDtypeNone.exir_op, ) @@ -481,11 +494,11 @@ def test_full_dtype_none_tosa_FP(test_data): @common.parametrize("test_data", AddConstFullDtypeNone.test_data_bool) -def test_full_dtype_none_tosa_FP_bool(test_data): - input_data, init_data = test_data +def test_full_dtype_none_tosa_FP_bool(test_data: FullNoneParam) -> None: + input_factory, init_data = test_data pipeline = TosaPipelineFP[input_t1]( AddConstFullDtypeNone(*init_data), - input_data(), + input_factory(), aten_op=[], exir_op=AddConstFullDtypeNone.exir_op, ) @@ -501,9 +514,10 @@ def test_full_dtype_none_tosa_FP_bool(test_data): ) def test_full_dtype_none_tosa_INT(test_data): input_data, init_data = test_data + input_factory, init_data = test_data pipeline = TosaPipelineINT[input_t1]( AddConstFullDtypeNone(*init_data), - input_data(), + input_factory(), aten_op=[], exir_op=AddConstFullDtypeNone.exir_op, ) diff --git a/backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py b/backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py index ea7e03f8e21..bc7f8218183 100644 --- a/backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py +++ b/backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import Callable, Dict, Tuple import torch from executorch.backends.arm._passes import ConvertInt64OutputOpsToInt32Pass @@ -21,20 +21,20 @@ class CastingToInt64Model(torch.nn.Module): - def __init__(self, target_dtype): + def __init__(self, target_dtype: torch.dtype) -> None: super().__init__() self.target_dtype = target_dtype - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.to(dtype=self.target_dtype) -test_data_suite_convert = { +test_data_suite_convert: Dict[str, Callable[[], Tuple[torch.Tensor, torch.dtype]]] = { "fp32_input": lambda: (torch.rand((1, 2, 3, 4), dtype=torch.float32), torch.int64), "fp16_input": lambda: (torch.rand((1, 2, 3, 4), dtype=torch.float16), torch.int64), } -test_data_suite_remove = { +test_data_suite_remove: Dict[str, Callable[[], Tuple[torch.Tensor, torch.dtype]]] = { "int32_input": lambda: ( torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int32), torch.int64, @@ -42,8 +42,13 @@ def forward(self, x: torch.Tensor): } +TestDataFactory = Callable[[], Tuple[torch.Tensor, torch.dtype]] + + @common.parametrize("test_data", test_data_suite_convert) -def test_convert_or_remove_casting_to_int64_convert_tosa_FP(test_data: Tuple): +def test_convert_or_remove_casting_to_int64_convert_tosa_FP( + test_data: TestDataFactory, +) -> None: test_tensor, target_dtype = test_data() module = CastingToInt64Model(target_dtype) @@ -61,7 +66,9 @@ def test_convert_or_remove_casting_to_int64_convert_tosa_FP(test_data: Tuple): @common.parametrize("test_data", test_data_suite_remove) -def test_convert_or_remove_casting_to_int64_remove_tosa_FP(test_data: Tuple): +def test_convert_or_remove_casting_to_int64_remove_tosa_FP( + test_data: TestDataFactory, +) -> None: test_tensor, target_dtype = test_data() module = CastingToInt64Model(target_dtype) @@ -86,7 +93,7 @@ def test_convert_or_remove_casting_to_int64_remove_tosa_FP(test_data: Tuple): class Int64OutputModel(torch.nn.Module): - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: # return torch.argmax(x) # RuntimeError: Int did not match Long; But this is expected as we expect _argmax_i32 to generate int32 output # return (10 * torch.argmax(x) + 10).to(dtype=torch.int32) # [1]. This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (function _resize_output_check) return (10 * torch.argmax(x, dim=-1) + 10) + 1.5 diff --git a/backends/arm/test/passes/test_convert_int_pow_to_muls.py b/backends/arm/test/passes/test_convert_int_pow_to_muls.py index 4eeff845749..bccde782f55 100644 --- a/backends/arm/test/passes/test_convert_int_pow_to_muls.py +++ b/backends/arm/test/passes/test_convert_int_pow_to_muls.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, Protocol, Tuple import torch from executorch.backends.arm._passes import ConvertIntPowToMuls @@ -12,7 +12,14 @@ from executorch.backends.arm.test.tester.test_pipeline import PassPipeline -input_t = Tuple[torch.nn.Module, int] # Input x +input_t = Tuple[torch.Tensor] # Inputs to the module + + +class ModuleWithInputs(Protocol): + def get_inputs(self) -> input_t: ... + + +TestParam = Tuple[ModuleWithInputs, int] class Square(torch.nn.Module): @@ -20,7 +27,7 @@ class Square(torch.nn.Module): Basic squaring """ - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.square() def get_inputs(self) -> input_t: @@ -32,18 +39,18 @@ class Pow(torch.nn.Module): Basic squaring """ - def __init__(self, exponent): + def __init__(self, exponent: int) -> None: super().__init__() self.exponent = exponent - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.pow(self.exponent) def get_inputs(self) -> input_t: return (torch.rand(4, 4),) -test_data = { +test_data: Dict[str, TestParam] = { "square": (Square(), 1), "pow_2": (Pow(2), 1), "pow_3": (Pow(3), 2), @@ -53,12 +60,12 @@ def get_inputs(self) -> input_t: @common.parametrize("data", test_data) -def test_convert_pow_to_muls(data): - module = data[0] - nbr_muls = data[1] +def test_convert_pow_to_muls(data: TestParam) -> None: + module_with_inputs, nbr_muls = data + module = cast(torch.nn.Module, module_with_inputs) pipeline = PassPipeline[input_t]( module, - module.get_inputs(), + module_with_inputs.get_inputs(), quantize=False, ops_before_pass={ "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 1, diff --git a/backends/arm/test/passes/test_convert_split_to_slice.py b/backends/arm/test/passes/test_convert_split_to_slice.py index fba52308ff0..3321693babd 100644 --- a/backends/arm/test/passes/test_convert_split_to_slice.py +++ b/backends/arm/test/passes/test_convert_split_to_slice.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, Protocol, Tuple import torch from executorch.backends.arm._passes.convert_split_to_slice import ( @@ -17,6 +17,10 @@ input_t = Tuple[torch.Tensor] # Input x +class ModuleWithInputs(Protocol): + def get_inputs(self) -> input_t: ... + + class Split(torch.nn.Module): """ Basic split model using torch.split function @@ -25,7 +29,7 @@ class Split(torch.nn.Module): def get_inputs(self) -> input_t: return (torch.rand(10),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]: return torch.split(x, 2) @@ -37,17 +41,21 @@ class SplitTensor(torch.nn.Module): def get_inputs(self) -> input_t: return (torch.rand(10),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]: return x.split(2) -modules = {"split_basic": Split(), "split_tensor": SplitTensor()} +modules: Dict[str, ModuleWithInputs] = { + "split_basic": Split(), + "split_tensor": SplitTensor(), +} @common.parametrize("module", modules) -def test_split_to_slice_tosa_INT(module): +def test_split_to_slice_tosa_INT(module: ModuleWithInputs) -> None: + nn_module = cast(torch.nn.Module, module) pipeline = PassPipeline[input_t]( - module, + nn_module, module.get_inputs(), quantize=True, ops_before_pass={ diff --git a/backends/arm/test/passes/test_convert_to_clamp.py b/backends/arm/test/passes/test_convert_to_clamp.py index cc854eeacd7..5072af000b0 100644 --- a/backends/arm/test/passes/test_convert_to_clamp.py +++ b/backends/arm/test/passes/test_convert_to_clamp.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import ClassVar, Dict, Tuple import torch from executorch.backends.arm._passes.convert_to_clamp import ConvertToClampPass @@ -16,26 +16,26 @@ class HardTanh(torch.nn.Module): - test_data = {"rand": (torch.rand(1, 64, 64, 3),)} + test_data: ClassVar[Dict[str, input_t]] = {"rand": (torch.rand(1, 64, 64, 3),)} def __init__(self): super().__init__() self.hardtanh = torch.nn.Hardtanh() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.hardtanh(x) class ReLU(torch.nn.Module): - test_data = {"rand": (torch.rand(1, 64, 64, 3),)} + test_data: ClassVar[Dict[str, input_t]] = {"rand": (torch.rand(1, 64, 64, 3),)} def __init__(self): super().__init__() self.relu = torch.nn.ReLU() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.relu(x) @@ -45,7 +45,7 @@ def forward(self, x): @common.parametrize("test_data", HardTanh.test_data) -def test_tosa_FP_hardtahn(test_data: input_t): +def test_tosa_FP_hardtahn(test_data: input_t) -> None: module = HardTanh() op_checks_before_pass = { "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, @@ -69,7 +69,7 @@ def test_tosa_FP_hardtahn(test_data: input_t): @common.parametrize("test_data", ReLU.test_data) -def test_tosa_FP_relu(test_data: input_t): +def test_tosa_FP_relu(test_data: input_t) -> None: module = ReLU() op_checks_before_pass = { "executorch_exir_dialects_edge__ops_aten_relu_default": 1, diff --git a/backends/arm/test/passes/test_decompose_avg_pool2d_pass.py b/backends/arm/test/passes/test_decompose_avg_pool2d_pass.py index 4d686039456..405c3d7ca8f 100644 --- a/backends/arm/test/passes/test_decompose_avg_pool2d_pass.py +++ b/backends/arm/test/passes/test_decompose_avg_pool2d_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, Protocol, Tuple import torch from executorch.backends.arm._passes.decompose_avg_pool2d import DecomposeAvgPool2d @@ -13,6 +13,10 @@ input_t = Tuple[torch.Tensor] # Input x +class ModuleWithInputs(Protocol): + def get_inputs(self) -> input_t: ... + + class AvgPool2dWithStride(torch.nn.Module): """ avg_pool2d model with explicit stride parameter @@ -21,7 +25,7 @@ class AvgPool2dWithStride(torch.nn.Module): def get_inputs(self) -> input_t: return (torch.rand(1, 3, 8, 8),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) @@ -33,7 +37,7 @@ class AvgPool2dWithoutStride(torch.nn.Module): def get_inputs(self) -> input_t: return (torch.rand(1, 3, 8, 8),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.avg_pool2d(x, kernel_size=3) @@ -45,11 +49,11 @@ class AvgPool2dListKernel(torch.nn.Module): def get_inputs(self) -> input_t: return (torch.rand(1, 3, 8, 8),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.avg_pool2d(x, kernel_size=[2, 3]) -modules = { +modules: Dict[str, ModuleWithInputs] = { "avg_pool2d_with_stride": AvgPool2dWithStride(), "avg_pool2d_without_stride": AvgPool2dWithoutStride(), "avg_pool2d_list_kernel": AvgPool2dListKernel(), @@ -57,10 +61,11 @@ def forward(self, x): @common.parametrize("module", modules) -def test_decompose_avg_pool2d_tosa_MI(module): +def test_decompose_avg_pool2d_tosa_MI(module: ModuleWithInputs) -> None: """Test that DecomposeAvgPool2d pass works correctly with and without stride parameters.""" + nn_module = cast(torch.nn.Module, module) pipeline = PassPipeline[input_t]( - module, + nn_module, module.get_inputs(), quantize=False, ops_before_pass={ diff --git a/backends/arm/test/passes/test_decompose_cosine_similarity_pass.py b/backends/arm/test/passes/test_decompose_cosine_similarity_pass.py index 80a328f39c6..8dec8408584 100644 --- a/backends/arm/test/passes/test_decompose_cosine_similarity_pass.py +++ b/backends/arm/test/passes/test_decompose_cosine_similarity_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, Protocol, Tuple import torch @@ -16,6 +16,10 @@ input_t = Tuple[torch.Tensor, torch.Tensor] +class ModuleWithInputs(Protocol): + def get_inputs(self) -> input_t: ... + + class CosineSimilarityModel(torch.nn.Module): def get_inputs(self) -> input_t: return (torch.rand(2, 3, 4), torch.rand(2, 3, 4)) @@ -24,11 +28,11 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: return torch.cosine_similarity(x1, x2, dim=1, eps=1e-6) -modules = {"cosine_basic": CosineSimilarityModel()} +modules: Dict[str, ModuleWithInputs] = {"cosine_basic": CosineSimilarityModel()} @common.parametrize("module", modules) -def test_decompose_cosine_similarity_tosa_INT(module): +def test_decompose_cosine_similarity_tosa_INT(module: ModuleWithInputs) -> None: ops_after_pass = { "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 5, @@ -40,8 +44,9 @@ def test_decompose_cosine_similarity_tosa_INT(module): "executorch_exir_dialects_edge__ops_aten_reciprocal_default": 1, } + nn_module = cast(torch.nn.Module, module) pipeline = PassPipeline[input_t]( - module, + nn_module, module.get_inputs(), ops_before_pass=None, ops_not_before_pass=None, diff --git a/backends/arm/test/passes/test_decompose_div_pass.py b/backends/arm/test/passes/test_decompose_div_pass.py index b52e264bf11..3d6293b2194 100644 --- a/backends/arm/test/passes/test_decompose_div_pass.py +++ b/backends/arm/test/passes/test_decompose_div_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, Protocol, Tuple import torch from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass @@ -15,6 +15,10 @@ input_t = Tuple[torch.Tensor] # Input x +class ModuleWithInputs(Protocol): + def get_inputs(self) -> input_t: ... + + class Div(torch.nn.Module): """ Basic div model using torch.div @@ -23,7 +27,7 @@ class Div(torch.nn.Module): def get_inputs(self) -> input_t: return (torch.rand(10),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.div(x, 2) @@ -35,17 +39,18 @@ class DivTensor(torch.nn.Module): def get_inputs(self) -> input_t: return (torch.rand(10),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.div(2) -modules = {"div_basic": Div(), "div_tensor": DivTensor()} +modules: Dict[str, ModuleWithInputs] = {"div_basic": Div(), "div_tensor": DivTensor()} @common.parametrize("module", modules) -def test_decompose_div_tosa_FP(module): +def test_decompose_div_tosa_FP(module: ModuleWithInputs) -> None: + nn_module = cast(torch.nn.Module, module) pipeline = PassPipeline[input_t]( - module, + nn_module, module.get_inputs(), quantize=False, ops_before_pass={ diff --git a/backends/arm/test/passes/test_decompose_layernorm_pass.py b/backends/arm/test/passes/test_decompose_layernorm_pass.py index d3c2cd6efd7..02fed874765 100644 --- a/backends/arm/test/passes/test_decompose_layernorm_pass.py +++ b/backends/arm/test/passes/test_decompose_layernorm_pass.py @@ -24,7 +24,7 @@ def __init__(self): super(LayerNorm, self).__init__() self.layer_norm = torch.nn.LayerNorm(10) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.layer_norm(x) return x diff --git a/backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py b/backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py index 5b4c84edbfd..bd83bfc9a22 100644 --- a/backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py +++ b/backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, Protocol, Tuple import torch @@ -16,6 +16,12 @@ input_t = Tuple[torch.Tensor] +class ModuleWithInputs(Protocol): + ord: float | None + + def get_inputs(self) -> input_t: ... + + class VectorNormModel(torch.nn.Module): """ A test module with torch.linalg.vector_norm. @@ -24,7 +30,9 @@ class VectorNormModel(torch.nn.Module): We support only order 1 or 2. """ - def __init__(self, ord: float = None, dim=None, keepdim: bool = False): + def __init__( + self, ord: float | None = None, dim=None, keepdim: bool = False + ) -> None: super().__init__() self.ord = ord self.dim = dim @@ -55,7 +63,7 @@ def get_inputs(self) -> input_t: @common.parametrize("module", modules) -def test_decompose_vector_norm_tosa_INT(module): +def test_decompose_vector_norm_tosa_INT(module: ModuleWithInputs) -> None: """ This test creates a PassPipeline that applies the DecomposeLinearVectorNormPass. The expected primitive ops vary depending on the norm order: @@ -65,6 +73,7 @@ def test_decompose_vector_norm_tosa_INT(module): """ ord_val = module.ord if module.ord is not None else 2.0 + ops_after_pass: Dict[str, int] if ord_val == 1: ops_after_pass = { "executorch_exir_dialects_edge__ops_aten_abs_default": 1, @@ -75,9 +84,16 @@ def test_decompose_vector_norm_tosa_INT(module): "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 2, "executorch_exir_dialects_edge__ops_aten_sum_dim_IntList": 1, } + else: + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_abs_default": 1, + "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 2, + "executorch_exir_dialects_edge__ops_aten_sum_dim_IntList": 1, + } + nn_module = cast(torch.nn.Module, module) pipeline = PassPipeline[input_t]( - module, + nn_module, module.get_inputs(), # The op is decomposed in legalization aten -> edge, so we are not able to check ops before ops_before_pass=None, diff --git a/backends/arm/test/passes/test_decompose_meandim_pass.py b/backends/arm/test/passes/test_decompose_meandim_pass.py index e771d74b5c4..ac7f3f883c4 100644 --- a/backends/arm/test/passes/test_decompose_meandim_pass.py +++ b/backends/arm/test/passes/test_decompose_meandim_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, Protocol, Tuple import torch @@ -17,6 +17,15 @@ input_t = Tuple[torch.Tensor] # Input x +class ModuleWithMeanAttrs(Protocol): + ops_after_pass: Dict[str, int] + ops_not_after_pass: list[str] + u55_ops_after_pass: Dict[str, int] + u55_ops_not_after_pass: list[str] + + def get_inputs(self) -> input_t: ... + + class MeanDim(torch.nn.Module): """ Basic mean model using torch.mean with keepdim = True @@ -36,7 +45,7 @@ class MeanDim(torch.nn.Module): def __init__(self): super(MeanDim, self).__init__() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.mean(x, (0, 1), True) def get_inputs(self) -> input_t: @@ -73,25 +82,25 @@ class MeanDimTensor(torch.nn.Module): def __init__(self): super(MeanDimTensor, self).__init__() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mean((0, 2), False) def get_inputs(self) -> input_t: return (torch.rand(4, 4, 4, 4),) -modules = {"meandim_basic": MeanDim(), "meandim_tensor": MeanDimTensor()} +modules: Dict[str, ModuleWithMeanAttrs] = { + "meandim_basic": MeanDim(), + "meandim_tensor": MeanDimTensor(), +} @common.parametrize("module", modules) -def test_decompose_meandim_tosa_INT(module): +def test_decompose_meandim_tosa_INT(module: ModuleWithMeanAttrs) -> None: # Decompose meandim_pass requires initiating the pas with args, which is not supported # by RunPasses in the arm_tester -> PassPipeline cannot be used. - pipeline = TosaPipelineINT[input_t]( - module, - module.get_inputs(), - [], - ) + nn_module = cast(torch.nn.Module, module) + pipeline = TosaPipelineINT[input_t](nn_module, module.get_inputs(), []) pipeline.pop_stage("check_not.exir") pipeline.pop_stage("check_count.exir") pipeline.pop_stage("to_executorch") @@ -106,11 +115,12 @@ def test_decompose_meandim_tosa_INT(module): @common.parametrize("module", modules) -def test_decompose_meandim_u55_INT(module): +def test_decompose_meandim_u55_INT(module: ModuleWithMeanAttrs) -> None: # Decompose meandim_pass requires initiating the pas with args, which is not supported # by RunPasses in the arm_tester -> PassPipeline cannot be used. + nn_module = cast(torch.nn.Module, module) pipeline = EthosU55PipelineINT[input_t]( - module, module.get_inputs(), [], run_on_fvp=False + nn_module, module.get_inputs(), [], run_on_fvp=False ) pipeline.pop_stage("check_not.exir") pipeline.pop_stage("check_count.exir") diff --git a/backends/arm/test/passes/test_decompose_softmax_pass.py b/backends/arm/test/passes/test_decompose_softmax_pass.py index 3af1976e3f3..28d7bbb7fdf 100644 --- a/backends/arm/test/passes/test_decompose_softmax_pass.py +++ b/backends/arm/test/passes/test_decompose_softmax_pass.py @@ -22,7 +22,7 @@ def __init__(self): super(Softmax, self).__init__() self.softmax = torch.nn.Softmax(dim=1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.softmax(x) return x @@ -39,7 +39,7 @@ def __init__(self): super(SoftmaxLog, self).__init__() self.softmax = torch.nn.LogSoftmax(dim=1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.softmax(x) return x diff --git a/backends/arm/test/passes/test_decompose_var_pass.py b/backends/arm/test/passes/test_decompose_var_pass.py index c347a2f667c..2e31c9de817 100644 --- a/backends/arm/test/passes/test_decompose_var_pass.py +++ b/backends/arm/test/passes/test_decompose_var_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, Protocol, Tuple import torch from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass @@ -15,6 +15,10 @@ input_t = Tuple[torch.Tensor] # Input x +class ModuleWithInputs(Protocol): + def get_inputs(self) -> input_t: ... + + class VarDim(torch.nn.Module): """ Basic variance model using torch.Tensor.var function. @@ -24,7 +28,7 @@ def __init__(self, keepdim): super(VarDim, self).__init__() self.keepdim = keepdim - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.var(dim=-1, keepdim=self.keepdim) def get_inputs(self) -> input_t: @@ -40,14 +44,14 @@ def __init__(self, keepdim): super(VarCorrection, self).__init__() self.keepdim = keepdim - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.var(x, -1, keepdim=self.keepdim) def get_inputs(self) -> input_t: return (torch.rand(4, 4),) -modules = { +modules: Dict[str, ModuleWithInputs] = { "vardim_keepdim": VarDim(True), "vardim_no_keepdim": VarDim(False), "varcorrection_keepdim": VarCorrection(True), @@ -56,9 +60,10 @@ def get_inputs(self) -> input_t: @common.parametrize("module", modules) -def test_decompose_var_tosa_FP(module): +def test_decompose_var_tosa_FP(module: ModuleWithInputs) -> None: + nn_module = cast(torch.nn.Module, module) pipeline = PassPipeline[input_t]( - module, + nn_module, module.get_inputs(), quantize=False, ops_before_pass={ diff --git a/backends/arm/test/passes/test_decorate_fp32_to_int32_casting_pass.py b/backends/arm/test/passes/test_decorate_fp32_to_int32_casting_pass.py index 84573878aef..588428aa31b 100644 --- a/backends/arm/test/passes/test_decorate_fp32_to_int32_casting_pass.py +++ b/backends/arm/test/passes/test_decorate_fp32_to_int32_casting_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import Callable, Dict, Tuple import torch from executorch.backends.arm.test import common, conftest @@ -17,15 +17,15 @@ class FP32ToINT32Casting(torch.nn.Module): - def __init__(self, target_dtype): + def __init__(self, target_dtype: torch.dtype) -> None: super().__init__() self.target_dtype = target_dtype - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.to(self.target_dtype) -test_data_fp32_input = { +test_data_fp32_input: Dict[str, Callable[[], Tuple[torch.Tensor, torch.dtype]]] = { "fp32_input_rank1": lambda: ( torch.rand((4), dtype=torch.float32), torch.int32, @@ -46,7 +46,9 @@ def forward(self, x: torch.Tensor): @common.parametrize("test_data", test_data_fp32_input) -def test_decorate_fp32_to_int32_casting_tosa_FP(test_data: Tuple): +def test_decorate_fp32_to_int32_casting_tosa_FP( + test_data: Callable[[], Tuple[torch.Tensor, torch.dtype]] +) -> None: test_tensor, target_dtype = test_data() module = FP32ToINT32Casting(target_dtype) @@ -61,7 +63,9 @@ def test_decorate_fp32_to_int32_casting_tosa_FP(test_data: Tuple): @common.parametrize("test_data", test_data_fp32_input) -def test_decorate_fp32_to_int32_casting_tosa_INT(test_data: Tuple): +def test_decorate_fp32_to_int32_casting_tosa_INT( + test_data: Callable[[], Tuple[torch.Tensor, torch.dtype]] +) -> None: """ Casting operation involving floating-point dtypes will be rejected in INT/INT profile. Therefore, the DecorateFp32toInt32CastingPass is not required in this profile. diff --git a/backends/arm/test/passes/test_fold_qdq_pass.py b/backends/arm/test/passes/test_fold_qdq_pass.py index 994676ff442..dcf945d5bb4 100644 --- a/backends/arm/test/passes/test_fold_qdq_pass.py +++ b/backends/arm/test/passes/test_fold_qdq_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import ClassVar, Dict, Tuple import torch from executorch.backends.arm._passes import FoldAndAnnotateQParamsPass @@ -15,16 +15,16 @@ class SimpleQuantizeModel(torch.nn.Module): - test_data = { + test_data: ClassVar[Dict[str, input_t]] = { "rand": (torch.rand(1, 1280, 7, 7), torch.rand(1, 1280, 7, 7)), } - def forward(self, x, y): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + torch.max((x + x), (y + y)) @common.parametrize("test_data", SimpleQuantizeModel.test_data) -def test_fold_qdq_pass_tosa_INT(test_data: input_t): +def test_fold_qdq_pass_tosa_INT(test_data: input_t) -> None: """ Tests the FoldAndAnnotateQParamsPass which folds dq/q nodes into the node and stores the quantization parameters in meta. diff --git a/backends/arm/test/passes/test_fuse_batchnorm_pass.py b/backends/arm/test/passes/test_fuse_batchnorm_pass.py index 59fae7cafbd..08bf960da7d 100644 --- a/backends/arm/test/passes/test_fuse_batchnorm_pass.py +++ b/backends/arm/test/passes/test_fuse_batchnorm_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, ClassVar, Dict, Protocol, Tuple import torch from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass @@ -13,12 +13,19 @@ input_t = Tuple[torch.Tensor] # Input x +class ModuleWithBatchNormAttrs(Protocol): + ops_before_pass: Dict[str, int] + ops_after_pass: Dict[str, int] + + def get_inputs(self) -> input_t: ... + + class MergeOneOfTwoBN(torch.nn.Module): - ops_before_pass = { + ops_before_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2, "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, } - ops_after_pass = { + ops_after_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 0, "executorch_exir_dialects_edge__ops_aten_convolution_default": 2, } @@ -39,7 +46,7 @@ def __init__(self, affine: bool): def get_inputs(self) -> input_t: return (torch.randn(1, 3, 256, 256),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv2d(x) x = self.batch_norm2d(x) x = self.relu6(x) @@ -48,11 +55,11 @@ def forward(self, x): class MergeTwosOfTwoBN(torch.nn.Module): - ops_before_pass = { + ops_before_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2, "executorch_exir_dialects_edge__ops_aten_convolution_default": 2, } - ops_after_pass = { + ops_after_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 0, "executorch_exir_dialects_edge__ops_aten_convolution_default": 2, } @@ -76,7 +83,7 @@ def __init__(self, affine: bool): def get_inputs(self) -> input_t: return (torch.randn(1, 3, 256, 256),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv2d(x) x = self.batch_norm2d(x) x = self.relu6(x) @@ -86,11 +93,11 @@ def forward(self, x): class MergeMultipleUsersBN(torch.nn.Module): - ops_before_pass = { + ops_before_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2, "executorch_exir_dialects_edge__ops_aten_convolution_default": 3, } - ops_after_pass = { + ops_after_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 0, "executorch_exir_dialects_edge__ops_aten_convolution_default": 4, } @@ -114,7 +121,7 @@ def __init__(self, affine: bool): def get_inputs(self) -> input_t: return (torch.randn(1, 3, 256, 256),) - def forward(self, x): + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x1 = self.conv2d(x) x = self.batch_norm2d( x1 @@ -129,20 +136,25 @@ def forward(self, x): return z, a -modules = { - "merge_one_of_two_bn_affine": MergeOneOfTwoBN(True), - "merge_one_of_two_bn": MergeOneOfTwoBN(False), - "merge_two_of_two_bn_affine": MergeTwosOfTwoBN(True), - "merge_multiple_users_bn_affine": MergeMultipleUsersBN(True), +modules: Dict[str, ModuleWithBatchNormAttrs] = { + "merge_one_of_two_bn_affine": cast(ModuleWithBatchNormAttrs, MergeOneOfTwoBN(True)), + "merge_one_of_two_bn": cast(ModuleWithBatchNormAttrs, MergeOneOfTwoBN(False)), + "merge_two_of_two_bn_affine": cast( + ModuleWithBatchNormAttrs, MergeTwosOfTwoBN(True) + ), + "merge_multiple_users_bn_affine": cast( + ModuleWithBatchNormAttrs, MergeMultipleUsersBN(True) + ), } @common.parametrize("module", modules) -def test_fuse_batchnorm_tosa_FP(module: torch.nn.Module): +def test_fuse_batchnorm_tosa_FP(module: ModuleWithBatchNormAttrs) -> None: """Test various cases where the batchnorm should either be fused with a previous conv, or converted to a new conv.""" + nn_module = cast(torch.nn.Module, module) pipeline = PassPipeline[input_t]( - module, + nn_module, module.get_inputs(), quantize=False, ops_before_pass=module.ops_before_pass, diff --git a/backends/arm/test/passes/test_fuse_constant_ops_pass.py b/backends/arm/test/passes/test_fuse_constant_ops_pass.py index 417ad7bff2a..95492075c0d 100644 --- a/backends/arm/test/passes/test_fuse_constant_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_constant_ops_pass.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import operator -from typing import Tuple +from typing import cast, ClassVar, Dict, Protocol, Tuple import torch from executorch.backends.arm._passes.fuse_constant_ops_pass import ( @@ -22,16 +22,26 @@ input_t2 = Tuple[torch.Tensor, torch.Tensor] +class ModuleWithFuseAttrs(Protocol): + ops_before_pass: Dict[str, int] + ops_after_pass: Dict[str, int] + ops_not_after_pass: list[str] + + def get_inputs(self) -> input_t: ... + + class FuseParameter(torch.nn.Module): - ops_before_pass = { + ops_before_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten_full_default": 1, "executorch_exir_dialects_edge__ops_aten_view_copy_default": 2, "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, "executorch_exir_dialects_edge__ops_aten_addmm_default": 1, "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, } - ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1} - ops_not_after_pass = [ + ops_after_pass: ClassVar[Dict[str, int]] = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1 + } + ops_not_after_pass: ClassVar[list[str]] = [ "executorch_exir_dialects_edge__ops_aten_full_default", "executorch_exir_dialects_edge__ops_aten_view_copy_default", "executorch_exir_dialects_edge__ops_aten_permute_copy_default", @@ -51,34 +61,38 @@ def __init__( bias=bias, ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc(torch.ones(1)) + x class FuseBuffer(torch.nn.Module): - ops_before_pass = { + ops_before_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, } - ops_after_pass = { + ops_after_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, } - ops_not_after_pass = [ + ops_not_after_pass: ClassVar[list[str]] = [ "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" ] - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return (x + 1) * 2 class FuseLiftedTensor(torch.nn.Module): - ops_before_pass = { + ops_before_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten_select_copy_int": 1, "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, } - ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1} - ops_not_after_pass = ["executorch_exir_dialects_edge__ops_aten_select_copy_int"] + ops_after_pass: ClassVar[Dict[str, int]] = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1 + } + ops_not_after_pass: ClassVar[list[str]] = [ + "executorch_exir_dialects_edge__ops_aten_select_copy_int" + ] def __init__( self, @@ -92,18 +106,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class CatConst(torch.nn.Module): - ops_before_pass = { + ops_before_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten_cat_default": 1, } - ops_after_pass = { + ops_after_pass: ClassVar[Dict[str, int]] = { "executorch_exir_dialects_edge__ops_aten_cat_default": 1, } - ops_not_after_pass = [] + ops_not_after_pass: ClassVar[list[str]] = [] def __init__(self): super().__init__() - def forward(self, a, b): + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return torch.cat((a, b), dim=0) @@ -115,29 +129,29 @@ def __init__(self, in_out_features: int = 3, bias: bool = True): self.linear = torch.nn.Linear(in_out_features, in_out_features, bias=bias) self.example_input = torch.rand(in_out_features, in_out_features) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: y = torch.full_like(x, 1.0) return self.linear(y) + x - def get_example_input(self): + def get_example_input(self) -> torch.Tensor: return self.example_input -modules = { - "fuse_parameter": FuseParameter(), - "fuse_buffer": FuseBuffer(), - "fuse_const_tensor": FuseLiftedTensor(), +modules: Dict[str, ModuleWithFuseAttrs] = { + "fuse_parameter": cast(ModuleWithFuseAttrs, FuseParameter()), + "fuse_buffer": cast(ModuleWithFuseAttrs, FuseBuffer()), + "fuse_const_tensor": cast(ModuleWithFuseAttrs, FuseLiftedTensor()), } -cat_module = { - "fuse_cat": CatConst(), +cat_module: Dict[str, ModuleWithFuseAttrs] = { + "fuse_cat": cast(ModuleWithFuseAttrs, CatConst()), } @common.parametrize("module", modules) -def test_fuse_const_ops_tosa_FP(module: torch.nn.Module): +def test_fuse_const_ops_tosa_FP(module: ModuleWithFuseAttrs) -> None: pipeline = PassPipeline[input_t]( - module=module, + module=cast(torch.nn.Module, module), test_data=(torch.rand(1),), quantize=False, ops_before_pass=module.ops_before_pass, @@ -149,9 +163,9 @@ def test_fuse_const_ops_tosa_FP(module: torch.nn.Module): @common.parametrize("module", modules) -def test_fuse_const_ops_tosa_INT(module: torch.nn.Module): +def test_fuse_const_ops_tosa_INT(module: ModuleWithFuseAttrs) -> None: pipeline = PassPipeline[input_t]( - module, + cast(torch.nn.Module, module), (torch.rand(10, 10),), quantize=True, ops_before_pass=module.ops_before_pass, @@ -162,9 +176,9 @@ def test_fuse_const_ops_tosa_INT(module: torch.nn.Module): @common.parametrize("module", cat_module) -def test_fuse_const_ops_tosa_BI_cat(module: torch.nn.Module): +def test_fuse_const_ops_tosa_BI_cat(module: ModuleWithFuseAttrs) -> None: pipeline = PassPipeline[input_t2]( - module, + cast(torch.nn.Module, module), (torch.rand(3), torch.rand(2)), quantize=True, ops_before_pass=module.ops_before_pass, diff --git a/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py b/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py index f6e437ba034..22c4630d628 100644 --- a/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py @@ -4,12 +4,14 @@ # LICENSE file in the root directory of this source tree. from copy import deepcopy -from typing import Tuple +from typing import Callable, cast, ClassVar, Dict, Protocol, Tuple, TypeVar import torch from executorch.backends.arm._passes.fuse_equal_placeholders_pass import ( FuseEqualPlaceholdersPass, ) + +from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( PassPipeline, TosaPipelineFP, @@ -18,10 +20,26 @@ input_t = Tuple[torch.Tensor] # Input x +class ModuleWithEqualPlaceholderAttrs(Protocol): + ops_before_pass: Dict[str, int] + ops_after_pass: Dict[str, int] + ops_not_after_pass: list[str] + + def get_inputs(self) -> input_t: ... + + +T = TypeVar("T") +TestDecorator = Callable[[Callable[[T], None]], Callable[[T], None]] + + +def _typed_parametrize(test_data: Dict[str, T]) -> TestDecorator: + return cast(TestDecorator, common.parametrize("module", test_data)) + + class FuseWeightsConstants(torch.nn.Module): - ops_before_pass = {} - ops_after_pass = {} - ops_not_after_pass = [] + ops_before_pass: ClassVar[Dict[str, int]] = {} + ops_after_pass: ClassVar[Dict[str, int]] = {} + ops_not_after_pass: ClassVar[list[str]] = [] def __init__( self, @@ -33,18 +51,21 @@ def __init__( self.bias2 = deepcopy(self.bias1) self.bias3 = deepcopy(self.bias1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return ( torch.conv1d(x, self.weights1, self.bias1) + torch.conv1d(x, self.weights2, self.bias2) + self.bias3 ) + def get_inputs(self) -> input_t: + return (torch.rand(1, 2, 8),) + class FuseWeightsStateDict(torch.nn.Module): - ops_before_pass = {} - ops_after_pass = {} - ops_not_after_pass = [] + ops_before_pass: ClassVar[Dict[str, int]] = {} + ops_after_pass: ClassVar[Dict[str, int]] = {} + ops_not_after_pass: ClassVar[list[str]] = [] def __init__( self, @@ -53,15 +74,18 @@ def __init__( self.fc1 = torch.nn.Linear(in_features=8, out_features=2, bias=True) self.fc2 = deepcopy(self.fc1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc1(x) + self.fc2(x) + def get_inputs(self) -> input_t: + return (torch.rand(1, 2, 8),) + class NotFuseTensorWithDifferentType(torch.nn.Module): - ops_before_pass = {} - ops_after_pass = {} - ops_not_after_pass = [] + ops_before_pass: ClassVar[Dict[str, int]] = {} + ops_after_pass: ClassVar[Dict[str, int]] = {} + ops_not_after_pass: ClassVar[list[str]] = [] def forward(self, x: torch.Tensor, y: torch.Tensor): """ @@ -76,12 +100,20 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): return m, n -def test_fuse_equal_placeholders_constants_tosa_FP(): - module = FuseWeightsConstants() - data = (torch.rand(1, 2, 8),) +constants_modules: Dict[str, ModuleWithEqualPlaceholderAttrs] = { + "fuse_constants": cast(ModuleWithEqualPlaceholderAttrs, FuseWeightsConstants()), +} + +parametrize_constants = _typed_parametrize(constants_modules) + + +@parametrize_constants +def test_fuse_equal_placeholders_constants_tosa_FP( + module: ModuleWithEqualPlaceholderAttrs, +) -> None: pipeline = PassPipeline[input_t]( - module, - data, + cast(torch.nn.Module, module), + module.get_inputs(), quantize=False, ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, @@ -97,12 +129,11 @@ def test_fuse_equal_placeholders_constants_tosa_FP(): assert "_common" in constant_keys[1], "FuseEqualPlaceholders constants failed" -def test_fuse_equal_placeholders_state_dict_tosa_FP(): +def test_fuse_equal_placeholders_state_dict_tosa_FP() -> None: module = FuseWeightsStateDict() - data = (torch.rand(1, 2, 8),) pipeline = PassPipeline[input_t]( module, - data, + module.get_inputs(), quantize=False, ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, diff --git a/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py b/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py index efc1bebb610..7c32cee8534 100644 --- a/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py +++ b/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py @@ -10,7 +10,7 @@ from executorch.backends.arm.test.tester.test_pipeline import PassPipeline -input_t = Tuple[torch.Tensor] # Input x +input_t = Tuple[torch.Tensor, torch.Tensor] # weights, indices class Int64InputModel(torch.nn.Module): diff --git a/backends/arm/test/passes/test_insert_table_ops_pass.py b/backends/arm/test/passes/test_insert_table_ops_pass.py index 5e695c237a0..00ff0c96de1 100644 --- a/backends/arm/test/passes/test_insert_table_ops_pass.py +++ b/backends/arm/test/passes/test_insert_table_ops_pass.py @@ -3,8 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - -from typing import Tuple +from typing import ClassVar, Dict, Tuple import torch from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( @@ -18,16 +17,16 @@ class Sigmoid(torch.nn.Module): - test_data = { + test_data: ClassVar[Dict[str, input_t]] = { "rand": (torch.rand(4),), } - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.sigmoid() @common.parametrize("test_data", Sigmoid.test_data) -def test_insert_table_tosa_INT(test_data: input_t): +def test_insert_table_tosa_INT(test_data: input_t) -> None: module = Sigmoid() pipeline = PassPipeline[input_t]( module, diff --git a/backends/arm/test/passes/test_int32_cast_embedding_pass.py b/backends/arm/test/passes/test_int32_cast_embedding_pass.py index 7adca527d75..30e84fadde3 100644 --- a/backends/arm/test/passes/test_int32_cast_embedding_pass.py +++ b/backends/arm/test/passes/test_int32_cast_embedding_pass.py @@ -10,12 +10,12 @@ from executorch.backends.arm.test.tester.test_pipeline import PassPipeline -input_t = Tuple[torch.Tensor] # Input x +input_t = Tuple[torch.Tensor, torch.Tensor] class Int32Embedding(torch.nn.Module): - def forward(self, weights: torch.Tensor, indices: torch.Tensor): + def forward(self, weights: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: return torch.embedding(weights, indices) def get_inputs(self) -> input_t: diff --git a/backends/arm/test/passes/test_ioquantization_pass.py b/backends/arm/test/passes/test_ioquantization_pass.py index da3b81aa096..fc57e8fa5b0 100644 --- a/backends/arm/test/passes/test_ioquantization_pass.py +++ b/backends/arm/test/passes/test_ioquantization_pass.py @@ -14,7 +14,7 @@ from executorch.exir.passes.quantize_io_pass import QuantizeInputs, QuantizeOutputs -input_t = Tuple[torch.Tensor] +input_t = Tuple[torch.Tensor, torch.Tensor] class SimpleModel(torch.nn.Module): diff --git a/backends/arm/test/passes/test_rescale_pass.py b/backends/arm/test/passes/test_rescale_pass.py index ecd1deadf4f..1ab4f5b6a03 100644 --- a/backends/arm/test/passes/test_rescale_pass.py +++ b/backends/arm/test/passes/test_rescale_pass.py @@ -174,7 +174,7 @@ def test_quantized_rescale_tosa_bi(test_data: tuple[torch.Tensor, torch.Tensor]) @common.parametrize("test_data", RescaleNetwork.test_data) @common.XfailIfNoCorstone300 -def test_quantized_rescale_u55(test_data: tuple[torch.Tensor, torch.Tensor]): +def test_quantized_rescale_u55(test_data: input_t): """Tests a model with many ops that requires rescales. As more ops are quantized to int32 and need the InsertRescalesPass, make sure that they play nicely together.""" module = RescaleNetwork() @@ -189,7 +189,7 @@ def test_quantized_rescale_u55(test_data: tuple[torch.Tensor, torch.Tensor]): @common.parametrize("test_data", RescaleNetwork.test_data) @common.XfailIfNoCorstone320 -def test_quantized_rescale_u85(test_data: tuple[torch.Tensor, torch.Tensor]): +def test_quantized_rescale_u85(test_data: input_t): """Tests a model with many ops that requires rescales. As more ops are quantized to int32 and need the InsertRescalesPass, make sure that they play nicely together.""" module = RescaleNetwork() diff --git a/backends/arm/test/passes/test_to_tosa_memory_format.py b/backends/arm/test/passes/test_to_tosa_memory_format.py index aed87c05799..486a906a0ff 100644 --- a/backends/arm/test/passes/test_to_tosa_memory_format.py +++ b/backends/arm/test/passes/test_to_tosa_memory_format.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import cast, Dict, List, Protocol, Tuple import torch from executorch.backends.arm._passes import ( @@ -21,19 +21,30 @@ input_t = Tuple[torch.Tensor] # Input x +class ModuleMetadata(Protocol): + ops_before_pass: Dict[str, int] + ops_after_pass: Dict[str, int] + ops_not_after_pass: List[str] + + def get_inputs(self) -> input_t: ... + + class NoNHWC(torch.nn.Module): """ Test-module with no ops requiring NHWC mermory format. """ - ops_after_pass = {"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2} - ops_not_after_pass = [] + ops_before_pass: Dict[str, int] = {} + ops_after_pass: Dict[str, int] = { + "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2 + } + ops_not_after_pass: List[str] = [] - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + x return x - def get_inputs(self): + def get_inputs(self) -> input_t: return (torch.rand(1, 2, 2, 2),) @@ -42,8 +53,11 @@ class ParallelClusters(torch.nn.Module): Test-module with multiple parallel clusters of nodes requiring different memory formats. """ - ops_after_pass = {"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2} - ops_not_after_pass = [] + ops_before_pass: Dict[str, int] = {} + ops_after_pass: Dict[str, int] = { + "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 2 + } + ops_not_after_pass: List[str] = [] def __init__(self): super().__init__() @@ -56,14 +70,14 @@ def __init__(self): self.maxpool = torch.nn.MaxPool2d(1, 1) self.avgpool = torch.nn.AvgPool2d(1, 1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x1 = self.conv(x) x2 = self.maxpool(x) x3 = self.avgpool(x) x4 = x * x return x1 + x2 + x3 + x4 - def get_inputs(self): + def get_inputs(self) -> input_t: return (torch.rand(1, 2, 2, 2),) @@ -72,9 +86,11 @@ class SerialClusters(torch.nn.Module): Test-module with multiple serial clusters of nodes requring different memory formats. """ - ops_before_pass = {} - ops_after_pass = {"executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 4} - ops_not_after_pass = [] + ops_before_pass: Dict[str, int] = {} + ops_after_pass: Dict[str, int] = { + "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 4 + } + ops_not_after_pass: List[str] = [] def __init__(self): super().__init__() @@ -90,7 +106,7 @@ def __init__(self): bias=True, ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = x * x x = self.conv(x) @@ -100,7 +116,7 @@ def forward(self, x): x = self.conv(x) return x - def get_inputs(self): + def get_inputs(self) -> input_t: return (torch.rand(2, 2, 2, 2),) @@ -109,17 +125,17 @@ class Reshapes(torch.nn.Module): Test-module with different configurations of views requiring different memory formats. """ - ops_before_pass = {} - ops_after_pass = { + ops_before_pass: Dict[str, int] = {} + ops_after_pass: Dict[str, int] = { "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 16 } - ops_not_after_pass = [] + ops_not_after_pass: List[str] = [] def __init__(self): super().__init__() self.maxpool = torch.nn.MaxPool2d(1, 1) # Use maxpool to force NHWC format - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.maxpool(x) x = x.view((2, 2, 4, 16, 1)) # N-C-HW-invariant intact, no transposes needed @@ -159,11 +175,11 @@ def forward(self, x): return x - def get_inputs(self): + def get_inputs(self) -> input_t: return (torch.rand(4, 4, 4, 4),) -modules = { +modules: Dict[str, ModuleMetadata] = { "no_nhwc": NoNHWC(), "parallel_clusters": ParallelClusters(), "serial_clusters": SerialClusters(), @@ -172,10 +188,11 @@ def get_inputs(self): @common.parametrize("module", modules) -def test_to_tosa_memory_format_tosa_INT(module): +def test_to_tosa_memory_format_tosa_INT(module: ModuleMetadata) -> None: # We cannot check op counts after a specific pass with the full pipeline + module_nn = cast(torch.nn.Module, module) pipeline = PassPipeline[input_t]( - module, + module_nn, module.get_inputs(), ops_after_pass=module.ops_after_pass, ops_not_after_pass=module.ops_not_after_pass, @@ -189,7 +206,8 @@ def test_to_tosa_memory_format_tosa_INT(module): @common.parametrize("module", modules) -def test_to_tosa_memory_format_tosa_INT_functional(module): +def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> None: # Also run the actual pass pipeline to ensure functional correctness. - pipeline = TosaPipelineINT[input_t](module, module.get_inputs(), []) + module_nn = cast(torch.nn.Module, module) + pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), []) pipeline.run() diff --git a/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py b/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py index fc405e21f2a..f6ff8b8c0bb 100644 --- a/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py +++ b/backends/arm/test/passes/test_unsqueeze_before_repeat_pass.py @@ -3,16 +3,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, Tuple +from typing import Dict, List, Tuple import torch from executorch.backends.arm._passes import UnsqueezeBeforeRepeatPass from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline -input_t = Tuple[ - torch.Tensor, Dict[str, int], list[str] -] # Input x, ops_after_pass, ops_not_after_pass +pipeline_input_t = Tuple[torch.Tensor, ...] +test_case_t = Tuple[ + pipeline_input_t, + Dict[str, int], + List[str], +] class Repeat(torch.nn.Module): @@ -20,10 +23,10 @@ class Repeat(torch.nn.Module): Basic repeat model. """ - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.repeat(2, 2, 2, 2) - test_data: Dict[str, input_t] = { + test_data: Dict[str, test_case_t] = { "insert_view": ( (torch.rand((2, 3, 4)),), {"aten_repeat_default": 3, "aten_view_copy_default": 4}, @@ -38,14 +41,14 @@ def forward(self, x: torch.Tensor): @common.parametrize("test_data", Repeat.test_data) -def test_unsqueeze_before_repeat_tosa_FP(test_data: input_t): +def test_unsqueeze_before_repeat_tosa_FP(test_data: test_case_t): """ When rank(input) != number of repeated dimensions (=4 in Repeat module), insert view. """ module = Repeat() data, ops_after_pass, ops_not_after_pass = test_data - pipeline = PassPipeline( + pipeline = PassPipeline[pipeline_input_t]( module, data, quantize=False,