4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import itertools
7
8
import logging
8
9
import warnings
10
+ from dataclasses import dataclass , field
9
11
from functools import partial
10
- from typing import Any , Callable , List , Optional
12
+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple
11
13
12
14
import torch
13
15
from executorch .exir ._warnings import deprecated
16
18
from executorch .exir .memory_planning import (
17
19
_is_out_var_node ,
18
20
apply_algo ,
21
+ collect_specs_from_nodes ,
22
+ filter_nodes ,
19
23
get_node_tensor_specs ,
20
24
MemoryPlanningAlgorithmSuite ,
21
25
Verifier ,
22
26
)
23
27
from executorch .exir .operator .convert import get_out_args_from_opoverload
24
28
from executorch .exir .pass_base import PassBase , PassResult
25
- from executorch .exir .tensor import ALIGNMENT
29
+ from executorch .exir .tensor import ALIGNMENT , TensorSpec
30
+ from torch import fx
26
31
from torch .export .exported_program import ExportGraphSignature
32
+ from torch .fx import Node
27
33
28
34
29
35
# copied from https://stackoverflow.com/questions/75582932/python-how-can-i-print-the-function-name-of-a-partial-function
@@ -37,6 +43,106 @@ def _callable_name(any_callable: Callable[..., Any]) -> str:
37
43
return str (any_callable )
38
44
39
45
46
+ def _is_buffer (
47
+ node : Node , graph_signature : ExportGraphSignature
48
+ ) -> Tuple [bool , Optional [str ]]:
49
+ """
50
+ Check if the node is buffer according to the provided graph signature.
51
+ If it is one return its fqn as well
52
+ """
53
+ if node .op == "placeholder" :
54
+ if isinstance (node .target , str ):
55
+ if node .target in graph_signature .inputs_to_buffers :
56
+ fqn = graph_signature .inputs_to_buffers [node .target ]
57
+ return (True , fqn )
58
+ return (False , None )
59
+
60
+
61
+ def _is_mutable_buffer (
62
+ node : Node , graph_signature : ExportGraphSignature
63
+ ) -> Tuple [bool , Optional [str ]]:
64
+ """
65
+ Check if the node is mutable buffer according to the provided graph signature.
66
+ If it is one return its fqn as well
67
+ """
68
+ if node .op == "placeholder" :
69
+ if isinstance (node .target , str ):
70
+ if node .target in graph_signature .inputs_to_buffers :
71
+ fqn = graph_signature .inputs_to_buffers [node .target ]
72
+ # if the buffer is mutated then record that
73
+ if fqn in graph_signature .buffers_to_mutate .values ():
74
+ return True , fqn
75
+ return False , None
76
+
77
+
78
+ def _get_spec_from_node (node : fx .Node ) -> TensorSpec :
79
+ specs = get_node_tensor_specs (node )
80
+ return specs [0 ]
81
+
82
+
83
+ def _insert_mutable_buffer_specs (
84
+ state : "_MemoryPlanningState" , gm : torch .fx .GraphModule , gs : ExportGraphSignature
85
+ ):
86
+ for node in gm .graph .nodes :
87
+ is_mutable , fqn = _is_mutable_buffer (node , gs )
88
+ if is_mutable :
89
+ assert fqn
90
+ spec = _get_spec_from_node (node )
91
+ if (
92
+ getattr (spec , "mem_id" , None ) is not None
93
+ or getattr (spec , "mem_offset" , None ) is not None
94
+ ):
95
+ raise ValueError (
96
+ "Cannot share mutable buffers if they already have a mem_id or mem_offset assigned"
97
+ )
98
+ if fqn not in state .mutable_buffers .keys ():
99
+ state .mutable_buffers [fqn ] = set ()
100
+ state .mutable_buffers [fqn ].add (spec )
101
+ continue
102
+ is_buffer , fqn = _is_buffer (node , gs )
103
+ # If it is not a mutable buffer it might just appear to be a buffer in this entry point. Think model.get_state()
104
+ # So cache it and later double check that this buffer never appears mutable
105
+ if is_buffer :
106
+ assert fqn
107
+ spec = _get_spec_from_node (node )
108
+ if (
109
+ getattr (spec , "mem_id" , None ) is not None
110
+ or getattr (spec , "mem_offset" , None ) is not None
111
+ ):
112
+ raise ValueError (
113
+ "Cannot share mutable buffers if they already have a mem_id or mem_offset assigned"
114
+ )
115
+ if fqn not in state .maybe_mutable_buffers .keys ():
116
+ state .maybe_mutable_buffers [fqn ] = set ()
117
+ state .maybe_mutable_buffers [fqn ].add (spec )
118
+
119
+
120
+ def _check_default_mem_ids (gm : torch .fx .GraphModule ):
121
+ for node in gm .graph .nodes :
122
+ for spec in collect_specs_from_nodes (
123
+ filter_nodes (itertools .chain ([node ], node .args , node .kwargs .values ())),
124
+ None ,
125
+ ignore_graph_input = False ,
126
+ ignore_const = False ,
127
+ ignore_out_var_node = False ,
128
+ dedup = False ,
129
+ do_assertion = False ,
130
+ ignore_dynamic_unbound_tensor = False ,
131
+ ):
132
+ mem_id = getattr (spec , "mem_id" , None )
133
+ if mem_id is not None and mem_id != 1 :
134
+ raise ValueError (
135
+ "Cannot share mutable buffers if all other tensors are not on the default mem_id of 1"
136
+ )
137
+
138
+
139
+ @dataclass
140
+ class _MemoryPlanningState :
141
+ mutable_buffers : Dict [str , Set [TensorSpec ]] = field (default_factory = dict )
142
+ maybe_mutable_buffers : Dict [str , Set [TensorSpec ]] = field (default_factory = dict )
143
+ graph_modules : List [torch .fx .GraphModule ] = field (default_factory = list )
144
+
145
+
40
146
class MemoryPlanningPass (PassBase ):
41
147
def __init__ (
42
148
self ,
@@ -45,6 +151,7 @@ def __init__(
45
151
alloc_graph_input : bool = True ,
46
152
alloc_graph_output : bool = True ,
47
153
alloc_mutable_buffers : bool = True ,
154
+ share_mutable_buffers : bool = False ,
48
155
alignment : int = ALIGNMENT ,
49
156
) -> None :
50
157
r"""
@@ -55,12 +162,18 @@ def __init__(
55
162
"""
56
163
if memory_planning_algo is None :
57
164
memory_planning_algo = MemoryPlanningAlgorithmSuite ()
165
+ if share_mutable_buffers and not alloc_mutable_buffers :
166
+ raise ValueError (
167
+ "share_mutable_buffers is only meaningful when alloc_mutable_buffers is True"
168
+ )
58
169
self .memory_planning_algo : Callable [..., List [int ]] = memory_planning_algo
59
170
self .allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap
60
171
self .alloc_graph_input = alloc_graph_input
61
172
self .alloc_graph_output = alloc_graph_output
62
173
self .alloc_mutable_buffers = alloc_mutable_buffers
174
+ self .share_mutable_buffers = share_mutable_buffers
63
175
self .alignment = alignment
176
+ self .state = _MemoryPlanningState ()
64
177
65
178
def _set_alloc_node_spec (self , graph_module : torch .fx .GraphModule ) -> None :
66
179
"""
@@ -134,9 +247,17 @@ def run(
134
247
graph_signature ,
135
248
self .alloc_graph_input ,
136
249
self .alloc_graph_output ,
137
- self .alloc_mutable_buffers ,
250
+ # If we are sharing the mutable buffers then do not allocate them in
251
+ # memory planning algo, instead collect all of the specs over all the entry
252
+ # points and then allocate them directly in the run_multimethod name call
253
+ self .alloc_mutable_buffers and not self .share_mutable_buffers ,
138
254
)
139
255
256
+ if self .share_mutable_buffers and graph_signature is not None :
257
+ self .state .graph_modules .append (graph_module )
258
+ _check_default_mem_ids (graph_module )
259
+ _insert_mutable_buffer_specs (self .state , graph_module , graph_signature )
260
+
140
261
# TODO: make the verifier do the work recursively to handle
141
262
# control flow
142
263
verifier = Verifier (
@@ -164,3 +285,31 @@ def run(
164
285
# I dont know if that is a valid thing but if it is we should adjust verify_storage_reuse function
165
286
verifier .verify_storage_reuse ()
166
287
return PassResult (graph_module , True )
288
+
289
+ def run_multimethod (self ):
290
+ "Resolve any memory planning done across entry points"
291
+ if self .share_mutable_buffers :
292
+ arena : int = 0
293
+
294
+ # Every spec that shares an fqn is the same tensor! So we give it the same id and offset
295
+ # anywhere it appears.
296
+ for fqn , specs_set in self .state .mutable_buffers .items ():
297
+ specs = list (specs_set )
298
+ # If the same buffer appears in mutable and maybe mutable then we know it is in fact mutable.
299
+ if fqn in self .state .maybe_mutable_buffers .keys ():
300
+ specs .extend (self .state .maybe_mutable_buffers [fqn ])
301
+ for spec in specs :
302
+ # Assume a default memory planning placed all activations on 1, place shared state on 2.
303
+ spec .mem_id = 2
304
+ spec .realign (self .alignment )
305
+ # State is persistent, so the memory never overlaps.
306
+ spec .mem_offset = arena
307
+ # They should all be the same size since they are the same tensor, so just bump off the first.
308
+ arena += specs [0 ].allocated_memory
309
+
310
+ for graph_module in self .state .graph_modules :
311
+ if len (graph_module .meta ["non_const_buffer_sizes" ]) != 2 :
312
+ raise ValueError (
313
+ "Cannot share mutable state if not using default memory ids"
314
+ )
315
+ graph_module .meta ["non_const_buffer_sizes" ].append (arena )
0 commit comments