-
Notifications
You must be signed in to change notification settings - Fork 525
/
Copy pathaddmm_mm_to_linear.py
170 lines (154 loc) · 7.27 KB
/
addmm_mm_to_linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.sym_util import eval_shape, eval_shape_upper_bound
_int64_max_dim_val = torch.iinfo(torch.int64).max - 1
def get_shape(input_node: torch.fx.Node):
"""
If shape is symbolic then evaluate shape, otherwise if it has upperbound
shape, then return upperbound shape.
Note that we must check for upperbound because by default upperbound is int64_max
"""
input_val = input_node.meta["val"]
upper_bound_shape = eval_shape_upper_bound(input_val.shape)
for i in range(len(input_val.shape)):
# Unbounded shape get int64 max values assigned to it.
# This is just hacking around it when export with dynamic shape
# does not use constraint api but instead just traces the
# modelw with tensors of the max size
if upper_bound_shape[i] >= _int64_max_dim_val:
return eval_shape(input_val.shape)
return upper_bound_shape
def get_dqlinear_input(node: torch.fx.Node):
ops = exir_ops.edge
node_to_backtrack = node
# First find the activation input
# Then trace it backwards through all view copies
# Until you find dequant node.
# if any of the nodes, during backtracking, is not view_copy
# then break
while node_to_backtrack.op != "placeholder":
if (
node_to_backtrack.op == "call_function"
and node_to_backtrack.target
== ops.quantized_decomposed.dequantize_per_tensor.tensor
):
return node_to_backtrack
if (
node_to_backtrack.op == "call_function"
and node_to_backtrack.target == ops.aten.view_copy.default
):
node_to_backtrack = node_to_backtrack.args[0]
else:
return None
return None
def replace_linear_view_copy_input_output(graph: torch.fx.Graph) -> torch.fx.Graph:
"""
Replaces pattern: x -> view_copy -> view_copy -> linear -> view_copy -> y
with
x -> linear -> y
Linear nodes can handle input tensor with > 2 dimensions.
"""
ops = exir_ops.edge
for node in graph.nodes:
if node.op == "call_function" and (node.target == ops.aten.linear.default):
input_node = node.args[0]
dqlinear_input = get_dqlinear_input(input_node)
if dqlinear_input is not None and dqlinear_input != input_node:
if len(input_node.args[0].users) == 1:
input_node.replace_all_uses_with(dqlinear_input)
else:
print(
f"{input_node} has more than one user. Users: {input_node.users}"
)
if len(node.users) == 1:
users = list(node.users)
maybe_view_copy = users[0]
if maybe_view_copy.op == "call_function" and (
maybe_view_copy.target == ops.aten.view_copy.default
):
# Must update the input node since replaced the original node
input_node = node.args[0]
input_shape = list(get_shape(input_node))
weight_node = node.args[1]
if "val" not in weight_node.meta:
raise ValueError(f"Val not found meta of node {weight_node}")
weight_val = weight_node.meta["val"]
output_channels = weight_val.shape[0]
output_shape = input_shape
output_shape[-1] = output_channels
view_copy_out_shape = list(get_shape(maybe_view_copy))
if output_shape == view_copy_out_shape:
maybe_view_copy.replace_all_uses_with(node)
graph.eliminate_dead_code()
return graph
def replace_addmm_mm_with_linear(graph: torch.fx.Graph) -> torch.fx.Graph:
"""
Replace calls to addmm/mm with linear node
Reason is that it simplifies the downstream logic of lowering to just linear node.
Furthermore it also removes various view_copy nodes. These nodes have been absorbed
by delegated by ignoring them entirely.
Furthermore, removing view_copy nodes has the advantage of not having to match
against view copies which simplifies the pattern that has to be matched.
Simplified patterns will be less brittle since symbolic ints and sizes creeping into
the graph was making them harder to match.
"""
ops = exir_ops.edge
for node in graph.nodes:
if node.op == "call_function" and (
node.target == ops.aten.mm.default or node.target == ops.aten.addmm.default
):
with graph.inserting_after(node):
if node.target == ops.aten.addmm.default:
weight_t_node = node.args[2]
if weight_t_node.target not in [
ops.aten.t_copy.default,
ops.aten.permute_copy.default,
]:
# Skip this node as it appears to be a standalone `addmm`
continue
weight_node = weight_t_node.args[0]
args = (node.args[1], weight_node, node.args[0])
linear_node = graph.create_node(
"call_function", ops.aten.linear.default, args
)
node.replace_all_uses_with(linear_node)
output_val = linear_node.target( # pyre-fixme[29]
args[0].meta["val"], args[1].meta["val"], args[2].meta["val"]
)
else:
weight_t_node = node.args[1]
if weight_t_node.target not in [
ops.aten.t_copy.default,
ops.aten.permute_copy.default,
]:
# Skip this node as it appears to be a standalone `mm`
continue
weight_node = weight_t_node.args[0]
args = (node.args[0], weight_node)
linear_node = graph.create_node(
"call_function", ops.aten.linear.default, args
)
node.replace_all_uses_with(linear_node)
output_val = linear_node.target( # pyre-fixme[29]
args[0].meta["val"], args[1].meta["val"]
)
linear_node.meta = node.meta
# Val contain in this meta and corresponding shape will not be accurate
# Sub
linear_node.meta["val"] = output_val
graph.eliminate_dead_code()
return graph
def apply_addmm_mm_to_linear_transform(graph: torch.fx.Graph) -> torch.fx.Graph:
graph = replace_addmm_mm_with_linear(graph)
graph = replace_linear_view_copy_input_output(graph)
return graph
class AddmmToLinearTransform(ExportPass):
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph_module.graph = apply_addmm_mm_to_linear_transform(graph_module.graph)
return PassResult(graph_module, True)