Skip to content

Enable transpose-quantized_relu-transpose fusion. #10337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 20 additions & 22 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

load("@fbcode_macros//build_defs:export_files.bzl", "export_file")
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
load(
"@fbsource//tools/build_defs:default_platform_defs.bzl",
"CXX",
)
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib")
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")

oncall("odai_jarvis")

Expand All @@ -36,18 +36,18 @@ python_library(
"compiler.py",
],
deps = [
":passes",
":utils",
":memory_planning",
":ops_registrations",
":passes",
":replace_ops",
":memory_planning",
":utils",
"//caffe2:torch",
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
"//executorch/backends/cadence/aot/quantizer:quantizer",
"//executorch/backends/transforms:decompose_sdpa",
"//executorch/backends/transforms:remove_clone_ops",
"//executorch/exir:lib",
"//executorch/devtools:lib",
"//executorch/exir:lib",
],
)

Expand All @@ -57,19 +57,19 @@ python_library(
"export_example.py",
],
deps = [
":passes",
":utils",
":ops_registrations",
":passes",
":replace_ops",
":utils",
"//caffe2:torch",
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
"//executorch/backends/cadence/runtime:runtime",
"//executorch/backends/cadence/aot/quantizer:quantizer",
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
"//executorch/backends/cadence/runtime:runtime",
"//executorch/backends/transforms:decompose_sdpa",
"//executorch/backends/transforms:remove_clone_ops",
"//executorch/exir:lib",
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
"//executorch/devtools:lib",
"//executorch/exir:lib",
],
)

Expand All @@ -94,12 +94,12 @@ python_library(
"passes.py",
],
deps = [
":utils",
":fuse_ops",
":simplify_ops",
":replace_ops",
":reorder_ops",
":remove_ops",
":reorder_ops",
":replace_ops",
":simplify_ops",
":utils",
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
Expand Down Expand Up @@ -131,7 +131,6 @@ python_library(
],
)


export_file(name = "functions.yaml")

executorch_generated_lib(
Expand Down Expand Up @@ -191,9 +190,9 @@ python_library(
],
typing = True,
deps = [
"//caffe2:torch",
":ops_registrations",
":compiler_utils",
":ops_registrations",
"//caffe2:torch",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:utils",
"//executorch/exir:pass_base",
Expand Down Expand Up @@ -228,11 +227,11 @@ python_library(
"//caffe2:torch",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:simplify_ops",
"//executorch/backends/transforms:remove_clone_ops",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//executorch/exir/dialects/edge:lib",
"//executorch/exir/passes:spec_prop_pass",
"//executorch/backends/transforms:remove_clone_ops"
],
)

Expand Down Expand Up @@ -283,13 +282,13 @@ python_unittest(
],
typing = True,
deps = [
":ops_registrations",
"//caffe2:torch",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//later:lib",
":ops_registrations"
],
)

Expand Down Expand Up @@ -319,8 +318,10 @@ python_unittest(
srcs = [
"tests/test_fusion_ops_passes.py",
],
supports_static_listing = False,
typing = True,
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
":compiler",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
Expand Down Expand Up @@ -391,7 +392,6 @@ python_unittest(
],
)


python_library(
name = "memory_planning",
srcs = [
Expand All @@ -409,7 +409,6 @@ python_library(
],
)


python_library(
name = "memory_constraints",
srcs = [
Expand All @@ -425,7 +424,6 @@ python_library(
],
)


python_unittest(
name = "test_memory_passes",
srcs = [
Expand Down
5 changes: 2 additions & 3 deletions backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,9 +901,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
@register_cadence_pass(CadencePassAttribute(opt_level=1))
class FuseTransposeOpPairsPass(FuseOpPairsAcrossBranchesPass):
"""
Fuse dequantize-quantize op pairs to a single requantize op.
For the special case where quant params match, this will remove
both dequant and quant ops.
Fuse transpose op pairs to a single view op.
"""

# A list of ops that can be bypassed when looking for a
Expand All @@ -915,6 +913,7 @@ class FuseTransposeOpPairsPass(FuseOpPairsAcrossBranchesPass):
exir_ops.edge.cadence.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.cadence.quantized_relu.per_tensor,
}

def can_fuse_for_chain(
Expand Down
41 changes: 32 additions & 9 deletions backends/cadence/aot/tests/test_fusion_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ProxyValue
from parameterized import parameterized
from torch import nn


Expand Down Expand Up @@ -485,39 +487,60 @@ def test_fuse_then_transpose_pass(self):


class TestFuseTransposeOpPairsPass(TestFusionPassesBase):
def test_fuse_transpose_pairs(self):
def _create_operator(
self, builder: GraphBuilder, op: torch._ops.OpOverload, x: ProxyValue
) -> ProxyValue:
if op == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default:
return builder.call_operator(
op=op,
args=(x, 1.2, 3, 0, 127, torch.int8),
)
elif op == exir_ops.edge.cadence.quantized_relu.per_tensor:
return builder.call_operator(
op=op,
args=(x, 0, 0, 0, 0),
)
else:
raise ValueError(f"Unsupported op: {op}")

@parameterized.expand(
[
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.cadence.quantized_relu.per_tensor,
],
)
def test_fuse_transpose_pairs(self, op: torch._ops.OpOverload):
# Create a graph with transpose -> quant -> transpose.
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 3))
transpose_node = builder.call_operator(
op=exir_ops.edge.aten.transpose_copy.int,
args=(x, 0, 1),
)
quant_node = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(transpose_node, 1.2, 3, 0, 127, torch.int8),
)
quant_node = self._create_operator(builder, op, transpose_node)
transpose_node = builder.call_operator(
op=exir_ops.edge.aten.transpose_copy.int,
args=(quant_node, 0, 1),
)
builder.output(transpose_node)
builder.output([transpose_node])
gm = builder.get_graph_module()
self.check_op_counts(
gm,
expected_op_counts={
exir_ops.edge.aten.transpose_copy.int: 2,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
op: 1,
},
)

# Check that the pass fuses the two transpose ops.
gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module
fusion_pass_result = FuseTransposeOpPairsPass()(gm)
self.assertIsNotNone(fusion_pass_result)
gm_after_pass = fusion_pass_result.graph_module
self.check_op_counts(
gm_after_pass,
expected_op_counts={
exir_ops.edge.aten.transpose_copy.int: 0,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
op: 1,
},
)

Expand Down
Loading