Skip to content

Commit 717a216

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
generate shared state test program (#14407)
Summary: Had to hack around some legacy code here that assumes all entry points use the same input Differential Revision: D82329519
1 parent 4d0961e commit 717a216

File tree

5 files changed

+68
-22
lines changed

5 files changed

+68
-22
lines changed

exir/passes/memory_planning_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def run(
287287
return PassResult(graph_module, True)
288288

289289
def run_multimethod(self):
290-
"Resolve any memory planning done across entry points"
290+
"""Resolve any memory planning done across entry points, called after run is called on all entry points."""
291291
if self.share_mutable_buffers:
292292
arena: int = 0
293293

extension/module/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def define_common_targets(is_fbcode=False):
1919
"ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])",
2020
"ET_MODULE_ADD_MUL_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.pte])",
2121
"ET_MODULE_ADD_MUL_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.ptd])",
22+
"ET_MODULE_SHARED_STATE": "$(location fbcode//executorch/test/models:exported_programs[ModuleSharedState.pte])",
2223
}
2324

2425
for aten_mode in get_aten_mode_options():

test/end2end/exported_module.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import executorch.exir as exir
1616
import torch
1717
from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager, to_edge
18+
from executorch.exir.capture._capture import patch_forward
1819
from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode
1920
from executorch.exir.passes import (
2021
DebugPass,
@@ -70,6 +71,7 @@ def export(
7071
export_joint_graph: bool = False,
7172
external_constants: bool = False,
7273
export_state_names: bool = False,
74+
share_mutable_buffers: bool = False,
7375
) -> "ExportedModule":
7476
"""
7577
Creates a new ExportedModule for the specified module class.
@@ -134,10 +136,13 @@ def return_wrapper():
134136
# all exported methods must have the same signature so just pick the first one.
135137
methods[0],
136138
)
137-
trace_inputs: Sequence = get_trace_inputs()
139+
inputs: Sequence = get_trace_inputs()
138140
method_name_to_args = {}
139141
for method in methods:
140-
method_name_to_args[method] = trace_inputs
142+
if hasattr(eager_module, "get_random_inputs_per_method"):
143+
# pyre-ignore
144+
inputs = eager_module.get_random_inputs_per_method()[method]
145+
method_name_to_args[method] = inputs
141146

142147
method_name_to_dynamic_shapes = None
143148
if hasattr(eager_module, "get_dynamic_shapes"):
@@ -149,23 +154,18 @@ def return_wrapper():
149154
method_name_to_dynamic_shapes[method] = trace_dynamic_shapes
150155

151156
memory_planning_pass = MemoryPlanningPass(
152-
alloc_mutable_buffers=not export_state_names
157+
alloc_mutable_buffers=not export_state_names,
158+
share_mutable_buffers=share_mutable_buffers,
153159
)
154160
if hasattr(eager_module, "get_memory_planning_pass"):
155161
memory_planning_pass = eager_module.get_memory_planning_pass() # type: ignore[operator]
156162

157-
class WrapperModule(nn.Module):
158-
def __init__(self, method):
159-
super().__init__()
160-
self.forward = method
161-
162163
exported_methods = {}
163164
# These cleanup passes are required to convert the `add` op to its out
164165
# variant, along with some other transformations.
165166
for method_name, method_input in method_name_to_args.items():
166167
# if not isinstance(eager_module, torch.nn.Module):
167168
if export_joint_graph:
168-
# _export was having issues with WrapperModule.
169169
assert method_name == "forward"
170170
ep = _export(
171171
eager_module,
@@ -179,15 +179,16 @@ def __init__(self, method):
179179
)
180180
exported_methods[method_name] = _export_forward_backward(ep)
181181
else:
182-
exported_methods[method_name] = export(
183-
eager_module,
184-
method_input, # type: ignore[arg-type]
185-
dynamic_shapes=(
186-
method_name_to_dynamic_shapes[method_name]
187-
if method_name_to_dynamic_shapes
188-
else None
189-
),
190-
)
182+
with patch_forward(eager_module, getattr(eager_module, method_name)):
183+
exported_methods[method_name] = export(
184+
eager_module,
185+
method_input, # type: ignore[arg-type]
186+
dynamic_shapes=(
187+
method_name_to_dynamic_shapes[method_name]
188+
if method_name_to_dynamic_shapes
189+
else None
190+
),
191+
)
191192

192193
exec_prog = to_edge(
193194
exported_methods,
@@ -229,6 +230,6 @@ def __init__(self, method):
229230
methods=methods,
230231
executorch_program=exec_prog,
231232
exported_program=exported_program,
232-
trace_inputs=trace_inputs,
233+
trace_inputs=inputs,
233234
get_random_inputs_fn=get_random_inputs_fn,
234235
)

test/models/export_program.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,42 @@ def get_random_inputs(self):
262262
return (torch.randint(100, [1, 3], dtype=torch.long),)
263263

264264

265+
class ModuleSharedState(torch.nn.Module):
266+
def __init__(self):
267+
super().__init__()
268+
self.register_buffer("state", torch.ones(1))
269+
270+
def forward(self, x):
271+
return self.state.add_(1) + x
272+
273+
def get_state(self):
274+
return self.state
275+
276+
def set_state(self, x):
277+
self.state.copy_(x)
278+
279+
# Including this is tech debt since we will immediately override it with the per method one.
280+
# ExportedModule is really old infra though from before multiple methods were supported. So
281+
# its really obnoxious to change.
282+
def get_random_inputs(self):
283+
return (torch.ones(1),)
284+
285+
def get_random_inputs_per_method(self):
286+
return {
287+
"forward": (torch.ones(1),),
288+
"get_state": (),
289+
"set_state": (torch.ones(1),),
290+
}
291+
292+
@staticmethod
293+
def get_method_names_to_export() -> List[str]:
294+
return ["forward", "get_state", "set_state"]
295+
296+
@staticmethod
297+
def share_mutable_buffers():
298+
return True
299+
300+
265301
#
266302
# Main logic.
267303
#
@@ -280,21 +316,28 @@ def export_module_to_program(
280316
export_kwargs = module_class.get_export_kwargs()
281317
export_joint = False
282318
export_state_names = False
319+
share_mutable_buffers = False
283320
if hasattr(module_class, "export_joint"):
284-
export_joint = module_class.export_joint() # pyre-ignore
321+
# pyre-ignore[16]: pyre just cant figure it out
322+
export_joint = module_class.export_joint()
285323
if hasattr(module_class, "export_state_names"):
324+
# pyre-ignore[16]: pyre just cant figure it out
286325
export_state_names = module_class.export_state_names()
287326
if hasattr(module_class, "get_method_names_to_export"):
288-
# pyre-ignore[16]: pyre doesn't know about get_export_kwargs.
327+
# pyre-ignore[16]: pyre just cant figure it out
289328
methods = module_class.get_method_names_to_export()
290329
else:
291330
methods = ["forward"]
331+
if hasattr(module_class, "share_mutable_buffers"):
332+
# pyre-ignore[16]: pyre just cant figure it out
333+
share_mutable_buffers = module_class.share_mutable_buffers()
292334
module = ExportedModule.export(
293335
module_class,
294336
methods,
295337
export_joint_graph=export_joint,
296338
external_constants=external_constants,
297339
export_state_names=export_state_names,
340+
share_mutable_buffers=share_mutable_buffers,
298341
**export_kwargs,
299342
)
300343
return module.executorch_program

test/models/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def define_common_targets():
7171
"ModuleDynamicCatUnallocatedIO",
7272
"ModuleSimpleTrain",
7373
"ModuleStateful",
74+
"ModuleSharedState",
7475
]
7576

7677
# Generates Executorch .pte program files for various modules at build time.

0 commit comments

Comments
 (0)