Skip to content

Arm backend: Merge decompose/convert meandim pass #10844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from .keep_dims_false_to_squeeze_pass import KeepDimsFalseToSqueezePass # noqa
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa
from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
from .remove_clone_pass import RemoveClonePass # noqa
from .replace_scalar_with_tensor_pass import ( # noqa
Expand Down
4 changes: 1 addition & 3 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
ConvertAnyDefaultDimDimsPass,
ConvertExpandCopyToRepeatPass,
ConvertFullLikeToFullPass,
ConvertMeanDimToAveragePoolPass,
ConvertMinMaxPass,
ConvertMmToBmmPass,
ConvertSplitToSlicePass,
Expand Down Expand Up @@ -87,7 +86,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(ConvertMmToBmmPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(ConvertMeanDimToAveragePoolPass())
self.add_pass(DecomposeMeanDimPass())
self.add_pass(ConvertFullLikeToFullPass())
self.add_pass(ConvertToClampPass())
self.add_pass(ConvertMinMaxPass())
Expand Down Expand Up @@ -140,7 +139,6 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(DecomposeVarPass())
self.add_pass(DecomposeMeanDimPass())
self.add_pass(DecomposeNotEqualPass())
self.add_pass(ConvertMeanDimToAveragePoolPass())
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeSoftmaxPass())
self.add_pass(DecomposeGeluPass())
Expand Down
110 changes: 89 additions & 21 deletions backends/arm/_passes/decompose_meandim_pass.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from math import prod

import torch
from executorch.backends.arm._passes import ArmPass
Expand All @@ -28,42 +27,111 @@ def get_meandim_decomposition(op) -> tuple:
raise RuntimeError(f"Can't get meandim decomposition for op {op}")


def get_avgpool(op):
if op == exir_ops.edge.aten.mean.dim:
return exir_ops.edge.aten.avg_pool2d.default
if op == torch.ops.aten.mean.dim:
return torch.ops.aten.avg_pool2d.default
raise RuntimeError(f"Can't get meandim decomposition for op {op}")


def get_view(op):
if op == exir_ops.edge.aten.mean.dim:
return exir_ops.edge.aten.view_copy.default
if op == torch.ops.aten.mean.dim:
return torch.ops.aten.view_copy.default
raise RuntimeError(f"Can't get meandim decomposition for op {op}")


class DecomposeMeanDimPass(ArmPass):
"""
This pass decomposes meandim into a sum and mul node.
Decomposes a meandim into avg_pool and/or sum + mul (1/N) depending on which dims the mean is taken for:
h,w -> avg_pool
n,c -> sum + mul(1/N)
For rank < 4, the input is first reshaped to 4D by padding with dim=1 from the left.

Example:
y = mean_dim(x, dim, keepdim)
x = mean_dim(x, (0,2), keepdim=False) # x = (c,h,w)
Becomes:
sum = sum.dim_IntList(x, dim, keepdim)
y = mul(sum, 1/N)
x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to work with avg_pool
x = avg_pool2d.default(x, kernel=(1,w), stride=(1,1)) # Reduce w with avg_pool
x = sum.dim_IntList(x, dim=1, keepdims=True) # Reduce c with sum
x = mul.Tensor(x, 1/c) # Divide by number of channels to get mean
x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False
"""

def call_operator(self, op, args, kwargs, meta):
if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim):
return super().call_operator(op, args, kwargs, meta)

x = get_node_arg(args, 0)
dim = get_node_arg(args, 1)
keepdim = get_node_arg(args, 2, False)

# if dim == [-1, -2], mean.dim can be
# decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool.
if dim == [-1, -2]:
# Simply return the mean.dim operator for future decomposition.
return super().call_operator(op, args, kwargs, meta)
input_shape = x.data.size()
output_shape = meta["val"].size()
dims_to_reduce = get_node_arg(args, 1)
dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce]

shape = meta["val"].size()
dtype = meta["val"].dtype
input_shape = x.data.size()
N = 1
for d in dim:
N *= input_shape[d]
view_op = get_view(op)

if len(input_shape) > 4:
raise NotImplementedError(
f"{op} with rank > 4 is currently not supported for the TOSA backend."
)

# Unsqueeze to 4D
if len(input_shape) < 4:
pad_n = 4 - len(input_shape)
new_shape = [1] * pad_n + list(input_shape)
dims_to_reduce = [dim + pad_n for dim in dims_to_reduce]

x = super().call_operator(view_op, (x, new_shape), {}, meta, True)

# Reduce (h,w) by avg pool
dims_to_reduce_by_avgpool = [dim for dim in dims_to_reduce if dim >= 2]
x = self._reduce_by_average_pool(op, x, dims_to_reduce_by_avgpool, meta)

# Reduce (n, c) by reduce sum
dims_to_reduce_by_sum = [dim for dim in dims_to_reduce if dim < 2]
x = self._reduce_by_sum(op, x, dims_to_reduce_by_sum, meta, dtype)

# Reshape to correct output shape if necessary
if x.data.size() != output_shape:
x = super().call_operator(view_op, (x, output_shape), {}, meta, True)

return x

def _reduce_by_sum(self, op, input_node, dims, meta, dtype):
if len(dims) == 0:
return input_node

input_shape = input_node.data.size()
output_shape = meta["val"].size()
N = prod((n for i, n in enumerate(input_shape) if i in dims))
sum_op, full_op, mul_op = get_meandim_decomposition(op)

sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta, True)
sum = super().call_operator(sum_op, (input_node, dims, True), {}, meta, True)
full = super().call_operator(
full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta, True
full_op, ([1] * len(output_shape), 1 / N), {"dtype": dtype}, meta, True
)
return super().call_operator(mul_op, (sum, full), {}, meta, True)

def _reduce_by_average_pool(self, op, input_node, dims, meta):
if len(dims) == 0:
return input_node

avgpool_op = get_avgpool(op)
input_shape = input_node.data.size()

stride = [1, 1]
if dims in ([2, 3], [3, 2]):
kernel_size = [input_shape[2], input_shape[3]]
elif dims == [3]:
kernel_size = [1, input_shape[3]]
elif dims == [2]:
kernel_size = [input_shape[2], 1]
else:
raise RuntimeError(f"Bad dims {dims} for {op} decomposition of mean_dim.")

return super().call_operator(
avgpool_op, (input_node, kernel_size, stride), {}, meta, True
)
54 changes: 0 additions & 54 deletions backends/arm/_passes/meandim_to_averagepool_pass.py

This file was deleted.

39 changes: 17 additions & 22 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,28 +262,23 @@ def is_node_supported(

if node.op != "call_function":
return True
if node.target == exir_ops.edge.aten.mean.dim:
dim = node.args[1]
needs_decomp = dim != [-1, -2]
else:
needs_decomp = node.target in [
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.native_layer_norm.default,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten._log_softmax.default,
exir_ops.edge.aten.var.correction,
exir_ops.edge.aten.var.dim,
exir_ops.edge.aten.add.Scalar,
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.sub.Scalar,
exir_ops.edge.aten.mul.Scalar,
exir_ops.edge.aten.ne.Tensor,
exir_ops.edge.aten.ne.Scalar,
exir_ops.edge.aten.div.Scalar,
exir_ops.edge.aten.leaky_relu.default,
]
needs_decomp = node.target in [
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.native_layer_norm.default,
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten._log_softmax.default,
exir_ops.edge.aten.var.correction,
exir_ops.edge.aten.var.dim,
exir_ops.edge.aten.add.Scalar,
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.sub.Scalar,
exir_ops.edge.aten.mul.Scalar,
exir_ops.edge.aten.ne.Tensor,
exir_ops.edge.aten.ne.Scalar,
exir_ops.edge.aten.div.Scalar,
exir_ops.edge.aten.leaky_relu.default,
]
if needs_decomp:
self.reporter.report_reject(node, "Needs to be decomposed.")
return False
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/test/ops/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def test_native_layer_norm_tosa_BI(test_data):
model,
test_data,
"torch.ops.aten.sub.Tensor", # Just check for sub op included in the layernorm decomposition
symmetric_io_quantization=True,
)
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
pipeline.run()


Expand All @@ -95,8 +95,8 @@ def test_native_layer_norm_u55_BI(test_data):
test_data,
"torch.ops.aten.sub.Tensor", # Just check for sub op included in the layernorm decomposition
run_on_fvp=True,
symmetric_io_quantization=True,
)
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
pipeline.run()


Expand All @@ -109,6 +109,6 @@ def test_native_layer_norm_u85_BI(test_data):
test_data,
"torch.ops.aten.sub.Tensor", # Just check for sub op included in the layernorm decomposition
run_on_fvp=True,
symmetric_io_quantization=True,
)
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
pipeline.run()
Loading
Loading