Skip to content

Commit d43cde5

Browse files
Add option in memory planning to put shared state on same location across entry points
Differential Revision: D82250153 Pull Request resolved: #14230
1 parent 5b99d4d commit d43cde5

File tree

5 files changed

+244
-10
lines changed

5 files changed

+244
-10
lines changed

exir/emit/_emitter.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@
9393
from executorch.exir.types import LeafValueSpec, ValueSpec
9494
from torch._subclasses.fake_tensor import FakeTensor
9595

96-
from torch.export.exported_program import ExportedProgram
96+
from torch.export.exported_program import ExportedProgram, ExportGraphSignature
97+
from torch.fx.node import Node
9798
from torch.utils import _pytree as pytree
9899

99100
from typing_extensions import TypeAlias
@@ -209,11 +210,11 @@ class _AbstractValue:
209210
]
210211

211212

212-
# pyre-ignore[13]: Attribute `node` is never initialized.
213213
class _Emitter(torch.fx.Interpreter):
214214
"""An abstract interpreter (https://wiki.mozilla.org/Abstract_Interpretation) used to emit the
215215
given traced torch.fx.GraphModule to the flatbuffer schema."""
216216

217+
# pyre-ignore[13]: Attribute `node` is never initialized.
217218
node: torch.fx.Node
218219

219220
def __init__(
@@ -1633,6 +1634,28 @@ def placeholder( # noqa: C901
16331634
if isinstance(target, str) and isinstance(spec, TensorSpec):
16341635
fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)
16351636

1637+
def _is_buffer(node: Node, graph_signature: ExportGraphSignature) -> bool:
1638+
"""
1639+
Check if the node is buffer according to the provided graph signature.
1640+
If it is one return its fqn as well
1641+
"""
1642+
if node.op == "placeholder":
1643+
if isinstance(node.target, str):
1644+
if node.target in graph_signature.inputs_to_buffers:
1645+
return True
1646+
return False
1647+
1648+
# If the spec does not appear in the mutable section of the graph signature it still might
1649+
# overall be considered a mutable buffer if it has already been memory planned. This would
1650+
# suggest that the same abstract buffer is mutable in another entry point so we should
1651+
# compel it to be considered mutable in all entry points at emission just as the user did with
1652+
# memory planning.
1653+
is_mutable_buffer |= (
1654+
_is_buffer(self.node, self.exported_program.graph_signature)
1655+
and spec.mem_id is not None
1656+
and spec.mem_offset is not None
1657+
)
1658+
16361659
# If the placeholder has a constant_tag, it is external to the PTE file
16371660
# and requires a fqn and location=TensorDataLocation.EXTERNAL
16381661
if constant_tag is not None:

exir/memory_planning.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ def verify_graph_input_output(self) -> None:
245245
assert len(specs) > 0, "Expect tensor specs"
246246
specs = list(filter(lambda spec: not spec.const, specs))
247247
if len(specs) == 0:
248+
# all outputs are const so no need to allocate memory just say we suceeded
249+
graph_output_allocated = self.alloc_graph_output
248250
continue
249251
allocated = any(
250252
spec is None or spec.mem_offset is not None for spec in specs
@@ -408,6 +410,7 @@ def collect_specs_from_nodes( # noqa: C901
408410
ignore_graph_input: bool = False,
409411
ignore_graph_output: bool = False,
410412
ignore_mutable_buffers: bool = False,
413+
share_mutable_buffers: bool = False,
411414
ignore_const: bool = True,
412415
ignore_out_var_node: bool = True,
413416
dedup: bool = True,

exir/passes/memory_planning_pass.py

Lines changed: 152 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import itertools
78
import logging
89
import warnings
10+
from dataclasses import dataclass, field
911
from functools import partial
10-
from typing import Any, Callable, List, Optional
12+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
1113

1214
import torch
1315
from executorch.exir._warnings import deprecated
@@ -16,14 +18,18 @@
1618
from executorch.exir.memory_planning import (
1719
_is_out_var_node,
1820
apply_algo,
21+
collect_specs_from_nodes,
22+
filter_nodes,
1923
get_node_tensor_specs,
2024
MemoryPlanningAlgorithmSuite,
2125
Verifier,
2226
)
2327
from executorch.exir.operator.convert import get_out_args_from_opoverload
2428
from executorch.exir.pass_base import PassBase, PassResult
25-
from executorch.exir.tensor import ALIGNMENT
29+
from executorch.exir.tensor import ALIGNMENT, TensorSpec
30+
from torch import fx
2631
from torch.export.exported_program import ExportGraphSignature
32+
from torch.fx import Node
2733

2834

2935
# copied from https://stackoverflow.com/questions/75582932/python-how-can-i-print-the-function-name-of-a-partial-function
@@ -37,6 +43,106 @@ def _callable_name(any_callable: Callable[..., Any]) -> str:
3743
return str(any_callable)
3844

3945

46+
def _is_buffer(
47+
node: Node, graph_signature: ExportGraphSignature
48+
) -> Tuple[bool, Optional[str]]:
49+
"""
50+
Check if the node is buffer according to the provided graph signature.
51+
If it is one return its fqn as well
52+
"""
53+
if node.op == "placeholder":
54+
if isinstance(node.target, str):
55+
if node.target in graph_signature.inputs_to_buffers:
56+
fqn = graph_signature.inputs_to_buffers[node.target]
57+
return (True, fqn)
58+
return (False, None)
59+
60+
61+
def _is_mutable_buffer(
62+
node: Node, graph_signature: ExportGraphSignature
63+
) -> Tuple[bool, Optional[str]]:
64+
"""
65+
Check if the node is mutable buffer according to the provided graph signature.
66+
If it is one return its fqn as well
67+
"""
68+
if node.op == "placeholder":
69+
if isinstance(node.target, str):
70+
if node.target in graph_signature.inputs_to_buffers:
71+
fqn = graph_signature.inputs_to_buffers[node.target]
72+
# if the buffer is mutated then record that
73+
if fqn in graph_signature.buffers_to_mutate.values():
74+
return True, fqn
75+
return False, None
76+
77+
78+
def _get_spec_from_node(node: fx.Node) -> TensorSpec:
79+
specs = get_node_tensor_specs(node)
80+
return specs[0]
81+
82+
83+
def _insert_mutable_buffer_specs(
84+
state: "_MemoryPlanningState", gm: torch.fx.GraphModule, gs: ExportGraphSignature
85+
):
86+
for node in gm.graph.nodes:
87+
is_mutable, fqn = _is_mutable_buffer(node, gs)
88+
if is_mutable:
89+
assert fqn
90+
spec = _get_spec_from_node(node)
91+
if (
92+
getattr(spec, "mem_id", None) is not None
93+
or getattr(spec, "mem_offset", None) is not None
94+
):
95+
raise ValueError(
96+
"Cannot share mutable buffers if they already have a mem_id or mem_offset assigned"
97+
)
98+
if fqn not in state.mutable_buffers.keys():
99+
state.mutable_buffers[fqn] = set()
100+
state.mutable_buffers[fqn].add(spec)
101+
continue
102+
is_buffer, fqn = _is_buffer(node, gs)
103+
# If it is not a mutable buffer it might just appear to be a buffer in this entry point. Think model.get_state()
104+
# So cache it and later double check that this buffer never appears mutable
105+
if is_buffer:
106+
assert fqn
107+
spec = _get_spec_from_node(node)
108+
if (
109+
getattr(spec, "mem_id", None) is not None
110+
or getattr(spec, "mem_offset", None) is not None
111+
):
112+
raise ValueError(
113+
"Cannot share mutable buffers if they already have a mem_id or mem_offset assigned"
114+
)
115+
if fqn not in state.maybe_mutable_buffers.keys():
116+
state.maybe_mutable_buffers[fqn] = set()
117+
state.maybe_mutable_buffers[fqn].add(spec)
118+
119+
120+
def _check_default_mem_ids(gm: torch.fx.GraphModule):
121+
for node in gm.graph.nodes:
122+
for spec in collect_specs_from_nodes(
123+
filter_nodes(itertools.chain([node], node.args, node.kwargs.values())),
124+
None,
125+
ignore_graph_input=False,
126+
ignore_const=False,
127+
ignore_out_var_node=False,
128+
dedup=False,
129+
do_assertion=False,
130+
ignore_dynamic_unbound_tensor=False,
131+
):
132+
mem_id = getattr(spec, "mem_id", None)
133+
if mem_id is not None and mem_id != 1:
134+
raise ValueError(
135+
"Cannot share mutable buffers if all other tensors are not on the default mem_id of 1"
136+
)
137+
138+
139+
@dataclass
140+
class _MemoryPlanningState:
141+
mutable_buffers: Dict[str, Set[TensorSpec]] = field(default_factory=dict)
142+
maybe_mutable_buffers: Dict[str, Set[TensorSpec]] = field(default_factory=dict)
143+
graph_modules: List[torch.fx.GraphModule] = field(default_factory=list)
144+
145+
40146
class MemoryPlanningPass(PassBase):
41147
def __init__(
42148
self,
@@ -45,6 +151,7 @@ def __init__(
45151
alloc_graph_input: bool = True,
46152
alloc_graph_output: bool = True,
47153
alloc_mutable_buffers: bool = True,
154+
share_mutable_buffers: bool = False,
48155
alignment: int = ALIGNMENT,
49156
) -> None:
50157
r"""
@@ -55,12 +162,18 @@ def __init__(
55162
"""
56163
if memory_planning_algo is None:
57164
memory_planning_algo = MemoryPlanningAlgorithmSuite()
165+
if share_mutable_buffers and not alloc_mutable_buffers:
166+
raise ValueError(
167+
"share_mutable_buffers is only meaningful when alloc_mutable_buffers is True"
168+
)
58169
self.memory_planning_algo: Callable[..., List[int]] = memory_planning_algo
59170
self.allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap
60171
self.alloc_graph_input = alloc_graph_input
61172
self.alloc_graph_output = alloc_graph_output
62173
self.alloc_mutable_buffers = alloc_mutable_buffers
174+
self.share_mutable_buffers = share_mutable_buffers
63175
self.alignment = alignment
176+
self.state = _MemoryPlanningState()
64177

65178
def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
66179
"""
@@ -134,9 +247,17 @@ def run(
134247
graph_signature,
135248
self.alloc_graph_input,
136249
self.alloc_graph_output,
137-
self.alloc_mutable_buffers,
250+
# If we are sharing the mutable buffers then do not allocate them in
251+
# memory planning algo, instead collect all of the specs over all the entry
252+
# points and then allocate them directly in the run_multimethod name call
253+
self.alloc_mutable_buffers and not self.share_mutable_buffers,
138254
)
139255

256+
if self.share_mutable_buffers and graph_signature is not None:
257+
self.state.graph_modules.append(graph_module)
258+
_check_default_mem_ids(graph_module)
259+
_insert_mutable_buffer_specs(self.state, graph_module, graph_signature)
260+
140261
# TODO: make the verifier do the work recursively to handle
141262
# control flow
142263
verifier = Verifier(
@@ -164,3 +285,31 @@ def run(
164285
# I dont know if that is a valid thing but if it is we should adjust verify_storage_reuse function
165286
verifier.verify_storage_reuse()
166287
return PassResult(graph_module, True)
288+
289+
def run_multimethod(self):
290+
"Resolve any memory planning done across entry points"
291+
if self.share_mutable_buffers:
292+
arena: int = 0
293+
294+
# Every spec that shares an fqn is the same tensor! So we give it the same id and offset
295+
# anywhere it appears.
296+
for fqn, specs_set in self.state.mutable_buffers.items():
297+
specs = list(specs_set)
298+
# If the same buffer appears in mutable and maybe mutable then we know it is in fact mutable.
299+
if fqn in self.state.maybe_mutable_buffers.keys():
300+
specs.extend(self.state.maybe_mutable_buffers[fqn])
301+
for spec in specs:
302+
# Assume a default memory planning placed all activations on 1, place shared state on 2.
303+
spec.mem_id = 2
304+
spec.realign(self.alignment)
305+
# State is persistent, so the memory never overlaps.
306+
spec.mem_offset = arena
307+
# They should all be the same size since they are the same tensor, so just bump off the first.
308+
arena += specs[0].allocated_memory
309+
310+
for graph_module in self.state.graph_modules:
311+
if len(graph_module.meta["non_const_buffer_sizes"]) != 2:
312+
raise ValueError(
313+
"Cannot share mutable state if not using default memory ids"
314+
)
315+
graph_module.meta["non_const_buffer_sizes"].append(arena)

exir/program/_program.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,7 +1681,7 @@ def to_backend(
16811681
return epm
16821682

16831683
@et_logger("to_executorch")
1684-
def to_executorch(
1684+
def to_executorch( # noqa (FLAKE8) C901
16851685
self,
16861686
config: Optional[ExecutorchBackendConfig] = None,
16871687
) -> "ExecutorchProgramManager":
@@ -1745,11 +1745,9 @@ def to_executorch(
17451745
memory_planning_pass = config.memory_planning_pass
17461746
# TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work
17471747
if hasattr(memory_planning_pass, "run"):
1748-
new_gm_res = memory_planning_pass.run( # pyre-ignore[16]
1749-
new_gm, new_signature
1750-
)
1748+
new_gm_res = memory_planning_pass.run(new_gm, new_signature)
17511749
else:
1752-
new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29]
1750+
new_gm_res = memory_planning_pass(new_gm)
17531751

17541752
# WARNING: DO NOT ADD ANY MORE PASSES AFTER MEMORY PLANNING PASS.
17551753
# THERE ARE A LOT OF ASSUMPTIONS IN THE STACK THAT MEMORY PLANNING IS THE LAST PASS BEFORE THE EMITTER.
@@ -1758,6 +1756,15 @@ def to_executorch(
17581756

17591757
_copy_module(program.graph_module, new_gm)
17601758
execution_programs[name] = program
1759+
# After running memory planning on all entry points we can run the cross entry point memory planning
1760+
if isinstance(config.memory_planning_pass, dict):
1761+
for memory_planning_pass in config.memory_planning_pass.values():
1762+
if hasattr(memory_planning_pass, "run_multimethod"):
1763+
memory_planning_pass.run_multimethod()
1764+
else:
1765+
memory_planning_pass = config.memory_planning_pass
1766+
if hasattr(memory_planning_pass, "run_multimethod"):
1767+
memory_planning_pass.run_multimethod()
17611768

17621769
et_pm = ExecutorchProgramManager(
17631770
execution_programs,

exir/tests/test_memory_planning.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
from executorch.exir import ExecutorchBackendConfig, to_edge
17+
from executorch.exir.capture._capture import patch_forward
1718
from executorch.exir.dialects._ops import ops as exir_ops
1819
from executorch.exir.memory_planning import (
1920
_do_user_inputs_exist,
@@ -93,6 +94,24 @@ def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
9394
return (torch.randn(10), torch.randn(10))
9495

9596

97+
class MultiEntryPointStatefulModel(torch.nn.Module):
98+
def __init__(self) -> None:
99+
super().__init__()
100+
self.register_buffer("state", torch.zeros(2, 2))
101+
102+
def forward(self, x: torch.Tensor) -> torch.Tensor:
103+
return self.state.add_(x).view(-1) * 2
104+
105+
def set_state(self, state: torch.Tensor) -> None:
106+
self.state.copy_(state)
107+
108+
def get_state(self) -> torch.Tensor:
109+
return self.state
110+
111+
def get_example_inputs(self) -> Tuple[torch.Tensor, ...]:
112+
return (torch.ones(1),)
113+
114+
96115
class ModelWithDifferentTensorSizes(torch.nn.Module):
97116
def __init__(self) -> None:
98117
super(ModelWithDifferentTensorSizes, self).__init__()
@@ -1081,3 +1100,36 @@ def test_multi_map(self) -> None:
10811100
verifier.storage_overlap(outer_spec, inner_spec),
10821101
f"Outer spec {outer_spec.shape=} {outer_spec.dtype=} {outer_spec.lifetime=} and inner spec {inner_spec} have storage overlap",
10831102
)
1103+
1104+
def test_multi_state_plan(self) -> None:
1105+
eager_module = MultiEntryPointStatefulModel().eval()
1106+
forward = export(eager_module, eager_module.get_example_inputs())
1107+
with patch_forward(eager_module, eager_module.get_state):
1108+
get_state = export(eager_module, ())
1109+
with patch_forward(eager_module, eager_module.set_state):
1110+
set_state = export(eager_module, (torch.zeros(1),))
1111+
edge = to_edge(
1112+
{"forward": forward, "set_state": set_state, "get_state": get_state}
1113+
)
1114+
et = edge.to_executorch(
1115+
ExecutorchBackendConfig(
1116+
memory_planning_pass=MemoryPlanningPass(share_mutable_buffers=True),
1117+
emit_mutable_buffer_names=True,
1118+
)
1119+
)
1120+
et_prog = et.executorch_program
1121+
count = 0
1122+
for plan in et_prog.execution_plan:
1123+
for value in plan.values:
1124+
if (
1125+
hasattr(value.val, "allocation_info")
1126+
and value.val.allocation_info is not None
1127+
and value.val.allocation_info.memory_id == 2
1128+
):
1129+
count += 1
1130+
self.assertEqual(value.val.allocation_info.memory_offset_low, 0)
1131+
self.assertTrue(value.val.extra_tensor_info is not None)
1132+
self.assertEqual(
1133+
value.val.extra_tensor_info.fully_qualified_name, "state"
1134+
)
1135+
self.assertEqual(count, 3)

0 commit comments

Comments
 (0)