Skip to content

Commit df354bf

Browse files
hsharma35facebook-github-bot
authored andcommitted
Enable transpose-quantized_relu-transpose fusion. (#10337)
Summary: Add quantized_relu support when fusing transpose pairs. Reviewed By: mcremon-meta Differential Revision: D73300693
1 parent 647e1f1 commit df354bf

File tree

3 files changed

+50
-32
lines changed

3 files changed

+50
-32
lines changed

backends/cadence/aot/TARGETS

+20-22
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66

77
load("@fbcode_macros//build_defs:export_files.bzl", "export_file")
88
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
9+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
910
load(
1011
"@fbsource//tools/build_defs:default_platform_defs.bzl",
1112
"CXX",
1213
)
1314
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib")
14-
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
1515

1616
oncall("odai_jarvis")
1717

@@ -36,18 +36,18 @@ python_library(
3636
"compiler.py",
3737
],
3838
deps = [
39-
":passes",
40-
":utils",
39+
":memory_planning",
4140
":ops_registrations",
41+
":passes",
4242
":replace_ops",
43-
":memory_planning",
43+
":utils",
4444
"//caffe2:torch",
4545
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
4646
"//executorch/backends/cadence/aot/quantizer:quantizer",
4747
"//executorch/backends/transforms:decompose_sdpa",
4848
"//executorch/backends/transforms:remove_clone_ops",
49-
"//executorch/exir:lib",
5049
"//executorch/devtools:lib",
50+
"//executorch/exir:lib",
5151
],
5252
)
5353

@@ -57,19 +57,19 @@ python_library(
5757
"export_example.py",
5858
],
5959
deps = [
60-
":passes",
61-
":utils",
6260
":ops_registrations",
61+
":passes",
6362
":replace_ops",
63+
":utils",
6464
"//caffe2:torch",
6565
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
66-
"//executorch/backends/cadence/runtime:runtime",
6766
"//executorch/backends/cadence/aot/quantizer:quantizer",
68-
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
67+
"//executorch/backends/cadence/runtime:runtime",
6968
"//executorch/backends/transforms:decompose_sdpa",
7069
"//executorch/backends/transforms:remove_clone_ops",
71-
"//executorch/exir:lib",
70+
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
7271
"//executorch/devtools:lib",
72+
"//executorch/exir:lib",
7373
],
7474
)
7575

@@ -94,12 +94,12 @@ python_library(
9494
"passes.py",
9595
],
9696
deps = [
97-
":utils",
9897
":fuse_ops",
99-
":simplify_ops",
100-
":replace_ops",
101-
":reorder_ops",
10298
":remove_ops",
99+
":reorder_ops",
100+
":replace_ops",
101+
":simplify_ops",
102+
":utils",
103103
"//caffe2:torch",
104104
"//executorch/exir:pass_base",
105105
"//executorch/exir/dialects:lib",
@@ -131,7 +131,6 @@ python_library(
131131
],
132132
)
133133

134-
135134
export_file(name = "functions.yaml")
136135

137136
executorch_generated_lib(
@@ -191,9 +190,9 @@ python_library(
191190
],
192191
typing = True,
193192
deps = [
194-
"//caffe2:torch",
195-
":ops_registrations",
196193
":compiler_utils",
194+
":ops_registrations",
195+
"//caffe2:torch",
197196
"//executorch/backends/cadence/aot:pass_utils",
198197
"//executorch/backends/cadence/aot:utils",
199198
"//executorch/exir:pass_base",
@@ -228,11 +227,11 @@ python_library(
228227
"//caffe2:torch",
229228
"//executorch/backends/cadence/aot:pass_utils",
230229
"//executorch/backends/cadence/aot:simplify_ops",
230+
"//executorch/backends/transforms:remove_clone_ops",
231231
"//executorch/exir:pass_base",
232232
"//executorch/exir/dialects:lib",
233233
"//executorch/exir/dialects/edge:lib",
234234
"//executorch/exir/passes:spec_prop_pass",
235-
"//executorch/backends/transforms:remove_clone_ops"
236235
],
237236
)
238237

@@ -283,13 +282,13 @@ python_unittest(
283282
],
284283
typing = True,
285284
deps = [
285+
":ops_registrations",
286286
"//caffe2:torch",
287287
"//executorch/backends/cadence/aot:graph_builder",
288288
"//executorch/backends/cadence/aot:pass_utils",
289289
"//executorch/exir:pass_base",
290290
"//executorch/exir/dialects:lib",
291291
"//later:lib",
292-
":ops_registrations"
293292
],
294293
)
295294

@@ -319,8 +318,10 @@ python_unittest(
319318
srcs = [
320319
"tests/test_fusion_ops_passes.py",
321320
],
321+
supports_static_listing = False,
322322
typing = True,
323323
deps = [
324+
"fbsource//third-party/pypi/parameterized:parameterized",
324325
":compiler",
325326
"//caffe2:torch",
326327
"//executorch/backends/cadence/aot:compiler",
@@ -391,7 +392,6 @@ python_unittest(
391392
],
392393
)
393394

394-
395395
python_library(
396396
name = "memory_planning",
397397
srcs = [
@@ -409,7 +409,6 @@ python_library(
409409
],
410410
)
411411

412-
413412
python_library(
414413
name = "memory_constraints",
415414
srcs = [
@@ -425,7 +424,6 @@ python_library(
425424
],
426425
)
427426

428-
429427
python_unittest(
430428
name = "test_memory_passes",
431429
srcs = [

backends/cadence/aot/fuse_ops.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -901,9 +901,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
901901
@register_cadence_pass(CadencePassAttribute(opt_level=1))
902902
class FuseTransposeOpPairsPass(FuseOpPairsAcrossBranchesPass):
903903
"""
904-
Fuse dequantize-quantize op pairs to a single requantize op.
905-
For the special case where quant params match, this will remove
906-
both dequant and quant ops.
904+
Fuse transpose op pairs to a single view op.
907905
"""
908906

909907
# A list of ops that can be bypassed when looking for a
@@ -915,6 +913,7 @@ class FuseTransposeOpPairsPass(FuseOpPairsAcrossBranchesPass):
915913
exir_ops.edge.cadence.dequantize_per_tensor.default,
916914
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
917915
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
916+
exir_ops.edge.cadence.quantized_relu.per_tensor,
918917
}
919918

920919
def can_fuse_for_chain(

backends/cadence/aot/tests/test_fusion_ops_passes.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
2424
from executorch.exir.dialects._ops import ops as exir_ops
2525
from executorch.exir.dialects.edge._ops import EdgeOpOverload
26+
from executorch.exir.pass_base import ProxyValue
27+
from parameterized import parameterized
2628
from torch import nn
2729

2830

@@ -485,18 +487,37 @@ def test_fuse_then_transpose_pass(self):
485487

486488

487489
class TestFuseTransposeOpPairsPass(TestFusionPassesBase):
488-
def test_fuse_transpose_pairs(self):
490+
def _create_operator(
491+
self, builder: GraphBuilder, op: torch._ops.OpOverload, x: ProxyValue
492+
) -> ProxyValue:
493+
if op == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default:
494+
return builder.call_operator(
495+
op=op,
496+
args=(x, 1.2, 3, 0, 127, torch.int8),
497+
)
498+
elif op == exir_ops.edge.cadence.quantized_relu.per_tensor:
499+
return builder.call_operator(
500+
op=op,
501+
args=(x, 0, 0, 0, 0),
502+
)
503+
else:
504+
raise ValueError(f"Unsupported op: {op}")
505+
506+
@parameterized.expand(
507+
[
508+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
509+
exir_ops.edge.cadence.quantized_relu.per_tensor,
510+
],
511+
)
512+
def test_fuse_transpose_pairs(self, op: torch._ops.OpOverload):
489513
# Create a graph with transpose -> quant -> transpose.
490514
builder = GraphBuilder()
491515
x = builder.placeholder("x", torch.randn(2, 3))
492516
transpose_node = builder.call_operator(
493517
op=exir_ops.edge.aten.transpose_copy.int,
494518
args=(x, 0, 1),
495519
)
496-
quant_node = builder.call_operator(
497-
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
498-
args=(transpose_node, 1.2, 3, 0, 127, torch.int8),
499-
)
520+
quant_node = self._create_operator(builder, op, transpose_node)
500521
transpose_node = builder.call_operator(
501522
op=exir_ops.edge.aten.transpose_copy.int,
502523
args=(quant_node, 0, 1),
@@ -507,7 +528,7 @@ def test_fuse_transpose_pairs(self):
507528
gm,
508529
expected_op_counts={
509530
exir_ops.edge.aten.transpose_copy.int: 2,
510-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
531+
op: 1,
511532
},
512533
)
513534

@@ -517,7 +538,7 @@ def test_fuse_transpose_pairs(self):
517538
gm_after_pass,
518539
expected_op_counts={
519540
exir_ops.edge.aten.transpose_copy.int: 0,
520-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
541+
op: 1,
521542
},
522543
)
523544

0 commit comments

Comments
 (0)