Skip to content

Commit f8f5097

Browse files
Add Metal backend Python preprocessing, partitioning, and tests
Implement Python-side Metal backend interface including: - MetalBackend: AOT compilation with MPS device support - MetalPartitioner: Graph partitioning for Metal delegation Add comprehensive unit tests for partitioner and backend preprocessing. ghstack-source-id: dfe25a4 ghstack-comment-id: 3392174694 Pull-Request: #15015
1 parent 6fbb843 commit f8f5097

File tree

6 files changed

+626
-0
lines changed

6 files changed

+626
-0
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import contextlib
8+
import os
9+
import typing
10+
from enum import Enum
11+
12+
from typing import Any, Dict, final, List, Optional, Set
13+
14+
import torch
15+
from executorch.backends.apple.metal.replace_slice_copy_with_slice import (
16+
ReplaceSliceCopyWithSlicePass,
17+
)
18+
from executorch.exir._serialize._named_data_store import NamedDataStore
19+
from executorch.exir._warnings import experimental
20+
from executorch.exir.backend.backend_details import (
21+
BackendDetails,
22+
ExportedProgram,
23+
PreprocessResult,
24+
)
25+
from executorch.exir.backend.compile_spec_schema import CompileSpec
26+
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
27+
from torch.export.passes import move_to_device_pass
28+
29+
30+
# exist fallback operators in et namespace;
31+
supported_fallback_kernels: Dict[str, Any] = {
32+
"aoti_torch_mps_addmm_out": None,
33+
"aoti_torch_mps_convolution": None,
34+
"aoti_torch_mps_mm_out": None,
35+
"at::_ops::_scaled_dot_product_attention_math_for_mps::call": None,
36+
}
37+
38+
# required fallback kernels but not supported
39+
missing_fallback_kernels: Set[str] = set()
40+
41+
42+
class COMPILE_SPEC_KEYS(Enum):
43+
METHOD_NAME = "method_name"
44+
45+
46+
# context manager for non-fallback guarantee
47+
# it will raise exception when generating fallback kernels during aoti compile
48+
@contextlib.contextmanager
49+
def collect_unsupported_fallback_kernels():
50+
original_generate_c_shim_extern_kernel_call = (
51+
CppWrapperCpu.generate_c_shim_extern_kernel_call
52+
)
53+
54+
def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels(
55+
self,
56+
kernel: str,
57+
args: list[str],
58+
device: str,
59+
*,
60+
debug_args: Optional[list[str]] = None,
61+
debug_handle: Optional[int] = None,
62+
):
63+
if kernel not in supported_fallback_kernels:
64+
missing_fallback_kernels.add(kernel)
65+
66+
original_generate_c_shim_extern_kernel_call(
67+
self, kernel, args, device, debug_args=debug_args, debug_handle=debug_handle
68+
)
69+
70+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
71+
generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels
72+
)
73+
try:
74+
yield
75+
finally:
76+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
77+
original_generate_c_shim_extern_kernel_call
78+
)
79+
80+
81+
@final
82+
@experimental(
83+
"This API and all of Metal backend related functionality are experimental."
84+
)
85+
class MetalBackend(BackendDetails):
86+
@staticmethod
87+
def preprocess(
88+
edge_program: ExportedProgram,
89+
compile_specs: List[CompileSpec],
90+
) -> PreprocessResult:
91+
print("entering the lowerable parts in MetalBackend.preprocess....")
92+
# Move the edge_program from CPU to MPS for aoti compile
93+
mps_edge_program = move_to_device_pass(edge_program, "mps")
94+
95+
# replace slice_copy with slice
96+
ReplaceSliceCopyWithSlicePass()(mps_edge_program.graph_module)
97+
98+
edge_program_module = mps_edge_program.module()
99+
100+
# Grab all input placeholders from the graph
101+
user_input_names = mps_edge_program.graph_signature.user_inputs
102+
user_input_placeholders = []
103+
for node in mps_edge_program.graph.nodes:
104+
if node.op == "placeholder" and node.name in user_input_names:
105+
user_input_placeholders.append(node.meta["val"])
106+
107+
# Base options for all devices
108+
options: dict[str, typing.Any] = {
109+
# Do not link against the full PyTorch/libtorch library
110+
"aot_inductor.link_libtorch": False,
111+
# Package model constants and other generated files directly in the shared object (.so) file
112+
"aot_inductor.package_constants_in_so": True,
113+
# Enable maximum automatic tuning for optimal performance
114+
"max_autotune": True,
115+
# "aot_inductor.debug_compile": True,
116+
# "aot_inductor.force_mmap_weights": False,
117+
}
118+
119+
with collect_unsupported_fallback_kernels():
120+
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
121+
if len(missing_fallback_kernels) > 0:
122+
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
123+
raise RuntimeError(
124+
f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
125+
"Please add them to the AOTI backend."
126+
)
127+
128+
# pyre-ignorep[6]: Incompatible parameter type
129+
with open(so_path, "rb") as f:
130+
so_data = f.read()
131+
132+
named_data_store = NamedDataStore()
133+
method_name = MetalBackend.method_name_from_compile_specs(compile_specs)
134+
named_data_store.add_named_data(
135+
method_name + "_so_blob", so_data, 1, "aoti_metal_blob"
136+
)
137+
138+
# Clean up the generated so file; it has been packaged into the NamdeDataStore
139+
# pyre-ignorep[6]: Incompatible parameter type
140+
os.remove(so_path)
141+
142+
return PreprocessResult(
143+
processed_bytes=b"",
144+
debug_handle_map={},
145+
data_store_output=named_data_store.get_named_data_store_output(),
146+
)
147+
148+
@staticmethod
149+
def generate_method_name_compile_spec(
150+
method_name: str,
151+
) -> CompileSpec:
152+
"""
153+
Returns the compile spec representing the model compute precision, for additional details
154+
please refer to the documentation for ``coremltools.precision``.
155+
"""
156+
return CompileSpec(
157+
COMPILE_SPEC_KEYS.METHOD_NAME.value,
158+
method_name.encode("utf-8"),
159+
)
160+
161+
@staticmethod
162+
def method_name_from_compile_specs(
163+
compile_specs: List[CompileSpec],
164+
) -> str:
165+
"""
166+
Returns the method name from the compile specs.
167+
"""
168+
for spec in compile_specs:
169+
if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value:
170+
return spec.value.decode("utf-8")
171+
raise RuntimeError(
172+
f"Could not find method name in compile specs: {compile_specs}"
173+
)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Callable, Dict, final, List, Optional, Tuple
8+
9+
import torch
10+
from executorch.backends.apple.metal.metal_backend import MetalBackend # usort: skip
11+
from executorch.exir._warnings import experimental
12+
from executorch.exir.backend.compile_spec_schema import CompileSpec
13+
from executorch.exir.backend.partitioner import (
14+
DelegationSpec,
15+
Partitioner,
16+
PartitionResult,
17+
)
18+
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
19+
from torch.export.exported_program import ExportedProgram
20+
21+
22+
@final
23+
@experimental(
24+
"This API and all of Metal backend related functionality are experimental."
25+
)
26+
class MetalPartitioner(Partitioner):
27+
"""
28+
Metal partitioner for AOTInductor backend integration.
29+
30+
This partitioner creates a single partition containing all operators from the input graph.
31+
It skips core ATen decomposition, allowing the Metal backend to handle decomposition using
32+
AOTInductor's MPS-specific decomposition table.
33+
34+
Only operators that cannot be handled by the aoti-mps library will be excluded from
35+
the partition and fall back to ExecuTorch's default or custom handling.
36+
"""
37+
38+
def __init__(self, compile_spec: List[CompileSpec]) -> None:
39+
self.delegation_spec = DelegationSpec(MetalBackend.__name__, compile_spec)
40+
41+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
42+
"""
43+
Fully delegate the graph to AOTInductor by tagging all nodes as a single partition.
44+
"""
45+
46+
partition_tags: Dict[str, DelegationSpec] = {}
47+
tag = "tag0"
48+
49+
for node in exported_program.graph.nodes:
50+
if node.op != "call_function":
51+
continue
52+
node.meta["delegation_tag"] = tag
53+
54+
partition_tags[tag] = self.delegation_spec
55+
56+
tag_constant_data(exported_program)
57+
tag_mutated_buffer(exported_program)
58+
59+
return PartitionResult(
60+
tagged_exported_program=exported_program, partition_tags=partition_tags
61+
)
62+
63+
def ops_to_not_decompose(
64+
self, ep: ExportedProgram
65+
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
66+
"""
67+
Return a list of operations that should not be decomposed and let the AOT compiler handle them.
68+
Currently we skip ATen decompositon for all ops, and let the Metal backend handle them.
69+
"""
70+
do_not_decompose = set()
71+
72+
for node in ep.graph.nodes:
73+
if node.op == "call_function" and isinstance(
74+
node.target, torch._ops.OpOverload
75+
):
76+
do_not_decompose.add(node.target)
77+
return list(do_not_decompose), None
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Dict, Iterable, Tuple
10+
11+
import torch
12+
from executorch.exir.dialects._ops import ops
13+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
from torch import fx
16+
17+
18+
_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = (
19+
torch.ops.aten.slice_copy.Tensor,
20+
ops.edge.aten.slice_copy.Tensor,
21+
)
22+
23+
_SLICE_TARGETS: Dict[
24+
torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload
25+
] = {
26+
torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor,
27+
ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor,
28+
}
29+
30+
31+
class ReplaceSliceCopyWithSlicePass(ExportPass):
32+
"""Replace non-mutated ``slice_copy`` results with ``slice`` views."""
33+
34+
def call(self, graph_module: fx.GraphModule) -> PassResult:
35+
graph_changed = False
36+
37+
for node in graph_module.graph.nodes:
38+
if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS:
39+
continue
40+
41+
if self._has_blocking_user(node, node.users.keys()):
42+
continue
43+
44+
node.target = _SLICE_TARGETS[node.target]
45+
graph_changed = True
46+
47+
if graph_changed:
48+
graph_module.graph.lint()
49+
graph_module.recompile()
50+
51+
return PassResult(graph_module, graph_changed)
52+
53+
def _has_blocking_user(self, node: fx.Node, users: Iterable[fx.Node]) -> bool:
54+
for user in users:
55+
if self._is_mutating_user(node, user) or self._is_view_user(node, user):
56+
return True
57+
return False
58+
59+
def _is_mutating_user(self, node: fx.Node, user: fx.Node) -> bool:
60+
if user.op == "call_method":
61+
# Treat in-place tensor methods conservatively as mutations only when the
62+
# method name ends with ``_`` which is the PyTorch convention for mutation.
63+
return isinstance(user.target, str) and user.target.endswith("_")
64+
65+
if user.op != "call_function":
66+
return False
67+
68+
target = user.target
69+
if not hasattr(target, "_schema"):
70+
return False
71+
72+
schema = target._schema # pyre-ignore[16]
73+
# Positional arguments
74+
for index, arg in enumerate(user.args):
75+
if arg is node and self._argument_mutates(schema, index):
76+
return True
77+
78+
# Keyword arguments
79+
for name, arg in user.kwargs.items():
80+
if arg is node and self._argument_mutates(schema, name):
81+
return True
82+
83+
return False
84+
85+
def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool:
86+
if user.op == "call_method":
87+
# Treat tensor methods conservatively and assume they may be view-producing.
88+
return True
89+
90+
if user.op != "call_function":
91+
return False
92+
93+
target = user.target
94+
if getattr(target, "is_view", False):
95+
for arg in user.args:
96+
if arg is node:
97+
return True
98+
for arg in user.kwargs.values():
99+
if arg is node:
100+
return True
101+
102+
return False
103+
104+
def _argument_mutates(
105+
self, schema: torch._C.FunctionSchema, key: int | str
106+
) -> bool:
107+
arguments = schema.arguments
108+
if isinstance(key, int):
109+
if key >= len(arguments):
110+
return False
111+
argument = arguments[key]
112+
else:
113+
argument = next((arg for arg in arguments if arg.name == key), None)
114+
if argument is None:
115+
return False
116+
117+
alias_info = argument.alias_info
118+
return bool(alias_info and alias_info.is_write)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+

0 commit comments

Comments
 (0)