11# Copyright 2024-2025 Arm Limited and/or its affiliates.
2- # All rights reserved.
32#
43# This source code is licensed under the BSD-style license found in the
54# LICENSE file in the root directory of this source tree.
65
7- # pyre-unsafe
6+ from math import prod
87
98import torch
109from executorch .backends .arm ._passes import ArmPass
@@ -28,42 +27,111 @@ def get_meandim_decomposition(op) -> tuple:
2827 raise RuntimeError (f"Can't get meandim decomposition for op { op } " )
2928
3029
30+ def get_avgpool (op ):
31+ if op == exir_ops .edge .aten .mean .dim :
32+ return exir_ops .edge .aten .avg_pool2d .default
33+ if op == torch .ops .aten .mean .dim :
34+ return torch .ops .aten .avg_pool2d .default
35+ raise RuntimeError (f"Can't get meandim decomposition for op { op } " )
36+
37+
38+ def get_view (op ):
39+ if op == exir_ops .edge .aten .mean .dim :
40+ return exir_ops .edge .aten .view_copy .default
41+ if op == torch .ops .aten .mean .dim :
42+ return torch .ops .aten .view_copy .default
43+ raise RuntimeError (f"Can't get meandim decomposition for op { op } " )
44+
45+
3146class DecomposeMeanDimPass (ArmPass ):
3247 """
33- This pass decomposes meandim into a sum and mul node.
48+ Decomposes a meandim into avg_pool and/or sum + mul (1/N) depending on which dims the mean is taken for:
49+ h,w -> avg_pool
50+ n,c -> sum + mul(1/N)
51+ For rank < 4, the input is first reshaped to 4D by padding with dim=1 from the left.
3452
3553 Example:
36- y = mean_dim(x, dim, keepdim)
54+ x = mean_dim(x, (0,2), keepdim=False) # x = (c,h,w )
3755 Becomes:
38- sum = sum.dim_IntList(x, dim, keepdim)
39- y = mul(sum, 1/N)
56+ x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to work with avg_pool
57+ x = avg_pool2d.default(x, kernel=(1,w), stride=(1,1)) # Reduce w with avg_pool
58+ x = sum.dim_IntList(x, dim=1, keepdims=True) # Reduce c with sum
59+ x = mul.Tensor(x, 1/c) # Divide by number of channels to get mean
60+ x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False
4061 """
4162
4263 def call_operator (self , op , args , kwargs , meta ):
4364 if op not in (exir_ops .edge .aten .mean .dim , torch .ops .aten .mean .dim ):
4465 return super ().call_operator (op , args , kwargs , meta )
4566
4667 x = get_node_arg (args , 0 )
47- dim = get_node_arg (args , 1 )
48- keepdim = get_node_arg (args , 2 , False )
49-
50- # if dim == [-1, -2], mean.dim can be
51- # decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool.
52- if dim == [- 1 , - 2 ]:
53- # Simply return the mean.dim operator for future decomposition.
54- return super ().call_operator (op , args , kwargs , meta )
68+ input_shape = x .data .size ()
69+ output_shape = meta ["val" ].size ()
70+ dims_to_reduce = get_node_arg (args , 1 )
71+ dims_to_reduce = [dim % len (input_shape ) for dim in dims_to_reduce ]
5572
56- shape = meta ["val" ].size ()
5773 dtype = meta ["val" ].dtype
58- input_shape = x .data .size ()
59- N = 1
60- for d in dim :
61- N *= input_shape [d ]
74+ view_op = get_view (op )
6275
76+ if len (input_shape ) > 4 :
77+ raise NotImplementedError (
78+ f"{ op } with rank > 4 is currently not supported for the TOSA backend."
79+ )
80+
81+ # Unsqueeze to 4D
82+ if len (input_shape ) < 4 :
83+ pad_n = 4 - len (input_shape )
84+ new_shape = [1 ] * pad_n + list (input_shape )
85+ dims_to_reduce = [dim + pad_n for dim in dims_to_reduce ]
86+
87+ x = super ().call_operator (view_op , (x , new_shape ), {}, meta , True )
88+
89+ # Reduce (h,w) by avg pool
90+ dims_to_reduce_by_avgpool = [dim for dim in dims_to_reduce if dim >= 2 ]
91+ x = self ._reduce_by_average_pool (op , x , dims_to_reduce_by_avgpool , meta )
92+
93+ # Reduce (n, c) by reduce sum
94+ dims_to_reduce_by_sum = [dim for dim in dims_to_reduce if dim < 2 ]
95+ x = self ._reduce_by_sum (op , x , dims_to_reduce_by_sum , meta , dtype )
96+
97+ # Reshape to correct output shape if necessary
98+ if x .data .size () != output_shape :
99+ x = super ().call_operator (view_op , (x , output_shape ), {}, meta , True )
100+
101+ return x
102+
103+ def _reduce_by_sum (self , op , input_node , dims , meta , dtype ):
104+ if len (dims ) == 0 :
105+ return input_node
106+
107+ input_shape = input_node .data .size ()
108+ output_shape = meta ["val" ].size ()
109+ N = prod ((n for i , n in enumerate (input_shape ) if i in dims ))
63110 sum_op , full_op , mul_op = get_meandim_decomposition (op )
64111
65- sum = super ().call_operator (sum_op , (x , dim , keepdim ), {}, meta , True )
112+ sum = super ().call_operator (sum_op , (input_node , dims , True ), {}, meta , True )
66113 full = super ().call_operator (
67- full_op , ([1 ] * len (shape ), 1 / N ), {"dtype" : dtype }, meta , True
114+ full_op , ([1 ] * len (output_shape ), 1 / N ), {"dtype" : dtype }, meta , True
68115 )
69116 return super ().call_operator (mul_op , (sum , full ), {}, meta , True )
117+
118+ def _reduce_by_average_pool (self , op , input_node , dims , meta ):
119+ if len (dims ) == 0 :
120+ return input_node
121+
122+ avgpool_op = get_avgpool (op )
123+ input_shape = input_node .data .size ()
124+
125+ stride = [1 , 1 ]
126+ if dims in ([2 , 3 ], [3 , 2 ]):
127+ kernel_size = [input_shape [2 ], input_shape [3 ]]
128+ elif dims == [3 ]:
129+ kernel_size = [1 , input_shape [3 ]]
130+ elif dims == [2 ]:
131+ kernel_size = [input_shape [2 ], 1 ]
132+ else :
133+ raise RuntimeError (f"Bad dims { dims } for { op } decomposition of mean_dim." )
134+
135+ return super ().call_operator (
136+ avgpool_op , (input_node , kernel_size , stride ), {}, meta , True
137+ )
0 commit comments