-
Notifications
You must be signed in to change notification settings - Fork 527
/
Copy pathfuse_conv_with_clamp.py
77 lines (63 loc) · 3.03 KB
/
fuse_conv_with_clamp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import sys
import executorch.backends.vulkan.custom_ops_lib # noqa
import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
class FuseClampPass(ExportPass):
"""
Some activations like ReLU and hardtanh can be fused with certain operators (e.g. convolution) preceding it.
"""
FUSEABLE_OPS = [
exir_ops.edge.aten.convolution.default,
]
FUSEABLE_ACTIVATIONS = [
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.hardtanh.default,
]
def get_output_min_max_from_activation(self, activation_node):
if activation_node.target == exir_ops.edge.aten.relu.default:
output_min = 0.0
output_max = sys.float_info.max
elif activation_node.target == exir_ops.edge.aten.hardtanh.default:
output_min = -1.0
output_max = 1.0
if len(activation_node.args) > 1:
output_min = activation_node.args[1]
output_max = activation_node.args[2]
return output_min, output_max
def call(self, graph_module: torch.fx.GraphModule):
for activation_node in graph_module.graph.nodes:
if activation_node.op == "call_function":
if activation_node.target in self.FUSEABLE_ACTIVATIONS:
preceding_op = activation_node.args[0]
if (
preceding_op.op == "call_function"
and preceding_op.target in self.FUSEABLE_OPS
):
# Delete activation
output_min_max = self.get_output_min_max_from_activation(
activation_node
)
new_args = list(preceding_op.args)
new_args.append(output_min_max[0])
new_args.append(output_min_max[1])
new_args = tuple(new_args)
activation_node.replace_all_uses_with(preceding_op)
graph_module.graph.erase_node(activation_node)
# Create and insert node of custom op `conv_with_clamp`
with graph_module.graph.inserting_before(preceding_op):
conv_activation_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.et_vk.conv_with_clamp.default,
new_args,
)
preceding_op.replace_all_uses_with(conv_activation_node)
graph_module.graph.erase_node(preceding_op)
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)