Skip to content

Commit 647e1f1

Browse files
authored
Replace split_with_sizes_copy with slice_copy
Differential Revision: D73312379 Pull Request resolved: #10318
1 parent 7c150d4 commit 647e1f1

File tree

2 files changed

+108
-2
lines changed

2 files changed

+108
-2
lines changed

backends/cadence/aot/replace_ops.py

+80-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
# pyre-unsafe
1818

1919
import math
20+
import operator
2021
from operator import neg
21-
from typing import cast, Dict, Iterable, Sequence, Set, Tuple
22+
from typing import cast, Dict, Iterable, Optional, Sequence, Set, Tuple
2223

2324
import torch
2425
import torch.fx
@@ -2182,6 +2183,82 @@ def call_operator(
21822183
)
21832184

21842185

2186+
# Adapted from fbcode/pyspeech/opt_passes/replace_ops.py
2187+
@register_cadence_pass(CadencePassAttribute(opt_level=2))
2188+
class ReplaceSplitWithSlicePass(ExportPass):
2189+
"""
2190+
split_with_sizes() delegates to slice() op, so perform this replacement here.
2191+
This avoids the expense of delegation from ATen.
2192+
"""
2193+
2194+
# For split_with_sizes, return the slice dim and extent for each split.
2195+
def get_split_sizes(
2196+
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
2197+
) -> Optional[list[tuple[int, ...]]]:
2198+
# Parse the args of the split_with_sizes op
2199+
tensor_arg, split_sizes = node.args[0:2]
2200+
assert isinstance(tensor_arg, torch.fx.Node)
2201+
in_shape = get_shape(graph_module, tensor_arg)
2202+
split_dim = 0 if len(node.args) < 3 else node.args[2]
2203+
if in_shape is None:
2204+
return None
2205+
2206+
# Canonicalize the split dimension
2207+
assert isinstance(split_dim, int)
2208+
split_dim = split_dim if split_dim >= 0 else len(in_shape) + split_dim
2209+
2210+
# Create the slice op args corresponding to each split
2211+
slice_ops = []
2212+
split_start = 0
2213+
assert isinstance(split_sizes, list)
2214+
for split_size in split_sizes:
2215+
split_end = split_start + split_size
2216+
slice_args = (split_dim, split_start, split_end)
2217+
slice_ops.append(slice_args)
2218+
split_start = split_end
2219+
2220+
return slice_ops
2221+
2222+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2223+
graph = graph_module.graph
2224+
for node in graph.nodes:
2225+
if not isinstance(node.target, EdgeOpOverload):
2226+
continue
2227+
if (
2228+
get_edge_overload_packet(node.target)
2229+
!= exir_ops.edge.aten.split_with_sizes_copy
2230+
):
2231+
continue
2232+
# All the users of this split_with_sizes op must be getitem ops
2233+
if any(user.target != operator.getitem for user in node.users):
2234+
continue
2235+
2236+
# Get the slice dim and extent for each split
2237+
slice_ops = self.get_split_sizes(graph_module, node)
2238+
if slice_ops is None:
2239+
continue
2240+
2241+
# Go over each getitem user, and replace it with slice op
2242+
for user in list(node.users.keys()):
2243+
assert user.target == operator.getitem
2244+
item_idx = user.args[1]
2245+
assert item_idx < len(slice_ops)
2246+
cur_slice = slice_ops[item_idx]
2247+
with graph.inserting_before(user):
2248+
cur_slice_node = graph.call_function(
2249+
exir_ops.edge.aten.slice_copy.Tensor,
2250+
(node.args[0], cur_slice[0], cur_slice[1], cur_slice[2], 1),
2251+
)
2252+
user.replace_all_uses_with(cur_slice_node)
2253+
graph.erase_node(user)
2254+
2255+
graph.erase_node(node)
2256+
2257+
graph_module.recompile()
2258+
result = super().call(graph_module)
2259+
return result
2260+
2261+
21852262
# This class encapsulates all the functions that replace/switch one op in the
21862263
# graph with another.
21872264
class CadenceReplaceOpsInGraph:
@@ -2220,5 +2297,6 @@ class CadenceReplaceOpsInGraph:
22202297
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
22212298
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
22222299
ReplaceWhereWithFullArgsWithWhereScalar,
2223-
# ReplaceGeluWithApproximateGeluPass,
2300+
ReplaceGeluWithApproximateGeluPass,
2301+
ReplaceSplitWithSlicePass,
22242302
]

backends/cadence/aot/tests/test_replace_ops_passes.py

+28
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-unsafe
88

9+
import operator
910
import unittest
1011
from typing import Any, Callable, cast, List, Optional, Sequence, Tuple, Union
1112

@@ -40,6 +41,7 @@
4041
ReplaceScalarWithTensorArgPass,
4142
ReplaceSelectWithViewOpPass,
4243
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
44+
ReplaceSplitWithSlicePass,
4345
ReplaceSqueezeAndUnsqueezeWithViewPass,
4446
ReplaceTCopyWithTransposePass,
4547
ReplaceTransposedConvWithLinearPass,
@@ -1306,6 +1308,32 @@ def forward(self, input):
13061308
6,
13071309
)
13081310

1311+
def test_replace_split_with_sizes_with_slice(self):
1312+
builder = GraphBuilder()
1313+
x = builder.placeholder("x", torch.randn(1, 16, 8, 4))
1314+
split = builder.call_operator(
1315+
exir_ops.edge.aten.split_with_sizes_copy.default, (x, [8, 8], 1)
1316+
)
1317+
# We need the outputs to be gathered by getitem ops
1318+
out0 = builder.call_operator(operator.getitem, (split, 0))
1319+
out1 = builder.call_operator(operator.getitem, (split, 1))
1320+
builder.output([out0, out1])
1321+
graph_module = builder.get_graph_module()
1322+
1323+
p = ReplaceSplitWithSlicePass()
1324+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1325+
1326+
self.assertEqual(
1327+
count_node(
1328+
graph_after_passes, exir_ops.edge.aten.split_with_sizes_copy.default
1329+
),
1330+
0,
1331+
)
1332+
self.assertEqual(
1333+
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor),
1334+
2,
1335+
)
1336+
13091337

13101338
class TestReplaceIm2rowWithViewPass(unittest.TestCase):
13111339
def test_no_replacement_for_conv(self):

0 commit comments

Comments
 (0)