-
Notifications
You must be signed in to change notification settings - Fork 527
/
Copy pathdynamic_shape_prop_pass.py
301 lines (263 loc) · 11.9 KB
/
dynamic_shape_prop_pass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import sympy
import torch
import torch.utils._pytree as pytree
from executorch.exir.delegate import LoweredBackendModule
from executorch.exir.dynamic_shape import (
calculate_dynamic_shape_spec,
DynamicMemoryPlanningMode,
)
from executorch.exir.pass_base import Argument, ExportPass
from executorch.exir.pass_infra.node_metadata import NodeMetadata
from executorch.exir.pass_infra.proxy_value import ProxyValue
from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
from executorch.exir.schema import TensorShapeDynamism
from executorch.exir.sym_util import collect_free_symbols, eval_expr
from executorch.exir.tensor import TensorSpec
from torch._subclasses import FakeTensor
from torch.fx import GraphModule
@dataclass
class DSInfo:
"""
Dynamic shape information we are tracking for each dynamic shape symbol.
"""
# the output of format_node() for the node introducing the symbol
node_debug_str: str
# upper bound value or None for fully dynamic memory planning
ubval: Optional[int]
class DynamicShapePropPass(ExportPass):
"""
In general, for each op, this pass propagate dynamic shape information from
op inputs to op outputs.
For cond/map nodes, we need pass dynamic shape information to submodules'
placeholder nodes, propagate the dynamic shape information thru the graphs
of the submodules, and finally set the node's dynamic shape info based on
submodules' output nodes' dynamic shape info.
"""
def __init__(
self, mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND
):
"""
mode controls how we do memory planning for dynamic shape tensors.
In UPPER_BOUND mode, we plan dynamic shape tensors' memory based on
its upper bound shape;
In FULL_DYNAMIC mdoe, the compiler does not allocata memory for
dynamic shape tensors, the runtime will do the allocation.
"""
super().__init__()
self.mode = mode
self.sym_to_dsinfo = {}
self.shape_env = None
@contextmanager
def apply_upper_bounds(self):
"""
Context manager to use upper bound value to evaluate expressions.
"""
try:
if self.shape_env:
old_var_to_val = dict(self.shape_env.var_to_val)
for sym, dsinfo in self.sym_to_dsinfo.items():
assert dsinfo.ubval is not None
self.shape_env.var_to_val[sym] = sympy.Integer(dsinfo.ubval)
yield
finally:
if self.shape_env:
self.shape_env.var_to_val = old_var_to_val
def copy_dsinfo_btw_specs(self, src_spec: TensorSpec, dst_spec: TensorSpec):
dst_spec.shape_dynamism = src_spec.shape_dynamism
dst_spec._upper_bound_shape = src_spec._upper_bound_shape
def inject_dsinfo_to_graph(
self,
subgm: GraphModule,
inputs: Union[List[ProxyValue], Tuple[ProxyValue, ...]],
ignore_first_ph: bool = False,
):
"""
ignore_first_ph: This argument is added for map node. For map node,
the first placeholder is special and we need ignore it here
and handle it specially.
"""
phs = [n for n in subgm.graph.nodes if n.op == "placeholder"]
if ignore_first_ph:
phs = phs[1:]
assert len(phs) == len(inputs)
for ph, inp in zip(phs, inputs):
dst_spec: TensorSpec = ph.meta["spec"]
src_spec: TensorSpec = inp.node.meta["spec"]
self.copy_dsinfo_btw_specs(src_spec, dst_spec)
def inject_xs_dsinfo_to_graph(self, subgm: GraphModule, xs: ProxyValue):
"""
xs means the first argument for the map node.
Even if xs is a upper bound tensor, it's possible that the first placeholder
of subgm is still a static shape tensor if only the first dimension of xs
is dynamic. But we don't have this optimization yet. If xs is dynamic,
we treat the first placeholder of subgm as dynamic.
"""
ph = next(n for n in subgm.graph.nodes if n.op == "placeholder")
src_spec: TensorSpec = xs.node.meta["spec"]
dst_spec = ph.meta["spec"]
self.copy_dsinfo_btw_specs(src_spec, dst_spec)
# update dst_spec to remove the highest dimesion
if dst_spec._upper_bound_shape:
dst_spec._upper_bound_shape = dst_spec._upper_bound_shape[1:]
def verify_dsinfo_from_both_branches(
self, true_gm: GraphModule, false_gm: GraphModule
):
"""
For cond node, true and false branch should return outputs with the
same shape.
"""
*_, true_out = true_gm.graph.nodes
*_, false_out = false_gm.graph.nodes
true_out = pytree.tree_flatten(true_out)[0]
false_out = pytree.tree_flatten(false_out)[0]
assert len(true_out) == len(false_out)
for true_out_item, false_out_item in zip(true_out, false_out):
true_spec = true_out_item.meta["spec"]
false_spec = false_out_item.meta["spec"]
assert true_spec.shape_dynamism == false_spec.shape_dynamism
assert true_spec._upper_bound_shape == false_spec._upper_bound_shape
def extract_dsinfo_from_graph(self, subgm: GraphModule, meta: NodeMetadata):
*_, out_node = subgm.graph.nodes
dst_spec_list = pytree.tree_flatten(meta["spec"])[0]
src_spec_list = pytree.tree_flatten(out_node.meta["spec"])[0]
for src_spec, dst_spec in zip(src_spec_list, dst_spec_list):
self.copy_dsinfo_btw_specs(src_spec, dst_spec)
def call_cond(
self,
pred: ProxyValue,
true_fn: torch.fx.GraphModule,
false_fn: torch.fx.GraphModule,
inputs: List[Any],
meta: NodeMetadata,
) -> ProxyValue:
self.inject_dsinfo_to_graph(true_fn, inputs)
self.inject_dsinfo_to_graph(false_fn, inputs)
retval = super().call_cond(pred, true_fn, false_fn, inputs, meta)
self.verify_dsinfo_from_both_branches(true_fn, false_fn)
# Note: 'meta' will override the metadata in retval.
# so we update 'meta' rather than 'retval' here.
self.extract_dsinfo_from_graph(true_fn, meta)
return retval
def call_map(
self,
f: torch.fx.GraphModule,
xs: ProxyValue,
args: Tuple[ProxyValue, ...],
meta: NodeMetadata,
) -> ProxyValue:
self.inject_dsinfo_to_graph(f, args, True)
self.inject_xs_dsinfo_to_graph(f, xs)
retval = super().call_map(f, xs, args, meta)
# We are being a bit conservative that if xs of f's output are dynamic
# shape, we decide the output of map node as dynamic shape.
xs_spec = xs.node.meta["spec"]
*_, subgm_out = f.graph.nodes
subgm_out_spec = subgm_out.meta["spec"]
# Take advantage that the static TensorShapeDynamsim is miminal
result_spec = meta["spec"]
result_spec.shape_dynamism = max(
spec.shape_dynamism
for spec in pytree.tree_flatten((xs_spec, subgm_out_spec))[0]
)
if result_spec.shape_dynamism == TensorShapeDynamism.DYNAMIC_BOUND:
# on the right hand side of the assignment we use 'upper_bound_shape'
# rather than '_upper_bound_shape'. The former return the static shape
# for static tensor which is what we want.
result_spec._upper_bound_shape = (
xs_spec.upper_bound_shape[:1] + subgm_out_spec.upper_bound_shape
)
return retval
def add_symint_upperbound(
self, node_debug_str: str, symint: torch.SymInt, ubval: int
):
if not isinstance(symint, torch.SymInt):
return
expr = symint.node.expr
if isinstance(expr, sympy.Symbol):
self.sym_to_dsinfo[expr] = DSInfo(node_debug_str, ubval)
if self.shape_env is None:
self.shape_env = symint.node.shape_env
else:
assert symint.node.shape_env is self.shape_env
def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue:
output = super().placeholder(name, arg, meta)
# TODO: handle full dynamic
if (
self.mode == DynamicMemoryPlanningMode.UPPER_BOUND
and meta.data.get("spec", None) is not None
and meta.data.get("val", None) is not None
):
spec = meta.data["spec"]
val = meta.data["val"]
if not isinstance(val, FakeTensor):
return output
if spec.shape_dynamism != TensorShapeDynamism.DYNAMIC_BOUND:
return output
for sym, ubval in zip(val.shape, spec._upper_bound_shape):
assert self.node_debug_str is not None
self.add_symint_upperbound(self.node_debug_str, sym, ubval)
return output
def eval_symint_to_ubval(self, symint: torch.SymInt) -> int:
return eval_expr(symint)
def decide_upper_bound_from_symbols(self, meta):
with self.apply_upper_bounds():
meta = meta.data
if meta.get("val", None) is None or meta.get("spec", None) is None:
return
vallist, _ = pytree.tree_flatten(meta["val"])
speclist, _ = pytree.tree_flatten(meta["spec"])
for val, spec in zip(vallist, speclist):
if not isinstance(val, FakeTensor) or not isinstance(spec, TensorSpec):
continue
free_symbols = collect_free_symbols(val.shape)
if len(free_symbols & set(self.sym_to_dsinfo.keys())) == 0:
spec.shape_dynamism = TensorShapeDynamism.STATIC
spec._upper_bound_shape = None
continue
spec.shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND
# evaluate the upper bound shape
spec._upper_bound_shape = [
self.eval_symint_to_ubval(s) for s in val.shape
]
def call_delegate(
self,
lowered_module: LoweredBackendModule,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
"""
Override this method so we can properly calculate the dynamic shape
information for the output of delegate.
"""
if self.mode == DynamicMemoryPlanningMode.UPPER_BOUND:
self.decide_upper_bound_from_symbols(meta)
else:
raise RuntimeError("NYI: delegatoin supporting in full dynamic mode")
return super().call_delegate(lowered_module, args, kwargs, meta)
def call_operator(self, op, args, kwargs, meta):
"""
If any of the arguments has dynamic shape, mark the output as dynamic shape.
"""
# no need to do dynamic shape propagation for these ops
if op.target in _EXECUTORCH_SYM_OPS:
return super().call_operator(op, args, kwargs, meta)
if self.mode == DynamicMemoryPlanningMode.UPPER_BOUND:
self.decide_upper_bound_from_symbols(meta)
return super().call_operator(op, args, kwargs, meta)
ds_spec = calculate_dynamic_shape_spec(self.mode, op.target, args, kwargs)
out_tensor_spec = meta["spec"]
for ds_spec_item, tensor_spec_item in zip(
pytree.tree_flatten(ds_spec)[0], pytree.tree_flatten(out_tensor_spec)[0]
):
tensor_spec_item.shape_dynamism = ds_spec_item.shape_dynamism
tensor_spec_item._upper_bound_shape = ds_spec_item.upper_bound_shape
return super().call_operator(op, args, kwargs, meta)