From 356b2947ad4931414ed0de55cbbede4df42602d9 Mon Sep 17 00:00:00 2001 From: Rachit Gupta Date: Thu, 16 Oct 2025 00:21:52 -0500 Subject: [PATCH 1/2] [AIESW-14742] merge pad op into avgpool pad attribute for xcompiler models --- src/Compiler/OnnxToMlirPasses.cpp | 3 + src/Compiler/OnnxToMlirPasses.hpp | 2 +- src/Dialect/ONNX/Transforms/CMakeLists.txt | 1 + .../ONNX/Transforms/FusePadAvgpool.cpp | 141 ++++++++++++++++++ src/Pass/Passes.hpp | 2 +- src/Tools/onnx-mlir-opt/RegisterPasses.cpp | 4 + test/mlir/onnx/onnx_fuse_pad_in_avgpool.mlir | 25 ++++ 7 files changed, 176 insertions(+), 2 deletions(-) create mode 100644 src/Dialect/ONNX/Transforms/FusePadAvgpool.cpp create mode 100644 test/mlir/onnx/onnx_fuse_pad_in_avgpool.mlir diff --git a/src/Compiler/OnnxToMlirPasses.cpp b/src/Compiler/OnnxToMlirPasses.cpp index 291ca8e464..d9d3b07501 100644 --- a/src/Compiler/OnnxToMlirPasses.cpp +++ b/src/Compiler/OnnxToMlirPasses.cpp @@ -80,6 +80,9 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU, } } + if (opts.enableFusePadIntoAvgpool) + pm.addPass(createFusePadIntoAvgpoolPass()); + // Simplify shape-related ops. pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass( opts.enableQuarkQuantizedLegalization)); diff --git a/src/Compiler/OnnxToMlirPasses.hpp b/src/Compiler/OnnxToMlirPasses.hpp index a7532ec926..f2ba7197a2 100644 --- a/src/Compiler/OnnxToMlirPasses.hpp +++ b/src/Compiler/OnnxToMlirPasses.hpp @@ -19,7 +19,7 @@ struct OnnxToMlirOptions { bool enableRemoveDqQOp = true; bool enableRemoveDqQAroundOp = true; bool enableRemoveBinary = false; - + bool enableFusePadIntoAvgpool = false; bool disableRecomposeOption = false; bool enableONNXHybridPass = true; bool enableConvOptPass = true; diff --git a/src/Dialect/ONNX/Transforms/CMakeLists.txt b/src/Dialect/ONNX/Transforms/CMakeLists.txt index 7862443935..fa35be8927 100644 --- a/src/Dialect/ONNX/Transforms/CMakeLists.txt +++ b/src/Dialect/ONNX/Transforms/CMakeLists.txt @@ -44,6 +44,7 @@ add_onnx_mlir_library(OMONNXRewrite ConstProp.cpp QDQAroundOpOpt.cpp QDQOpt.cpp + FusePadAvgPool.cpp DQBinaryQOpt.cpp ConvOpt.cpp Decompose.cpp diff --git a/src/Dialect/ONNX/Transforms/FusePadAvgpool.cpp b/src/Dialect/ONNX/Transforms/FusePadAvgpool.cpp new file mode 100644 index 0000000000..af925009aa --- /dev/null +++ b/src/Dialect/ONNX/Transforms/FusePadAvgpool.cpp @@ -0,0 +1,141 @@ +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" +#include "src/Pass/Passes.hpp" + +using namespace mlir; +using namespace onnx_mlir; + +namespace { + +struct FusePadIntoAveragePoolPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ONNXAveragePoolOp avgOp, PatternRewriter &rewriter) const override { + + Value input = avgOp.getX(); + auto padOp = input.getDefiningOp(); + if (!padOp) + return failure(); + + StringAttr modeAttr = padOp.getModeAttr(); + StringRef mode = "constant"; + if (modeAttr) + mode = modeAttr.getValue(); + if (mode != "constant") + return failure(); + float padValue = 0.0f; + + Value padsInput = padOp.getPads(); + Value constantValInput = padOp.getConstantValue(); + + auto padsConstOp = + dyn_cast_or_null(padsInput.getDefiningOp()); + if (!padsConstOp) + return failure(); + auto padsAttr = dyn_cast_or_null(padsConstOp.getValueAttr()); + if (!padsAttr) + return failure(); + + auto constOp = + dyn_cast_or_null(constantValInput.getDefiningOp()); + if (!constOp) + return failure(); + auto constAttr = dyn_cast_or_null(constOp.getValueAttr()); + + if (!constAttr) + return failure(); + + auto firstAttr = *constAttr.getValues().begin(); + if (auto fAttr = mlir::dyn_cast(firstAttr)) + padValue = fAttr.getValueAsDouble(); + + if (padValue != 0.0f) + return failure(); + + SmallVector padsVals; + for (auto val : padsAttr.getValues()) { + if (auto iAttr = mlir::dyn_cast(val)) { + auto pad = iAttr.getInt(); + padsVals.push_back(pad); + } else { + padsVals.push_back(0); + } + } + + SmallVector mergedPads; + if (auto existingPadsAttr = avgOp.getPadsAttr()) { + for (Attribute v : existingPadsAttr) { + mergedPads.push_back(cast(v).getInt()); + } + } else { + mergedPads.resize(padsVals.size() / 2, 0); + } + + if (mergedPads.size() != padsVals.size() / 2) + return failure(); + + mergedPads[0] += padsVals[2]; + mergedPads[1] += padsVals[3]; + mergedPads[2] += padsVals[6]; + mergedPads[3] += padsVals[7]; + + auto mergedPadsAttr = + rewriter.getI64ArrayAttr(llvm::ArrayRef(mergedPads)); + + SmallVector operands; + operands.push_back(padOp.getData()); + + NamedAttrList attrs; + attrs.set(avgOp.getKernelShapeAttrName(), avgOp.getKernelShapeAttr()); + attrs.set(avgOp.getPadsAttrName(), mergedPadsAttr); + attrs.set(avgOp.getStridesAttrName(), avgOp.getStridesAttr()); + attrs.set(avgOp.getCeilModeAttrName(), avgOp.getCeilModeAttr()); + attrs.set( + avgOp.getCountIncludePadAttrName(), avgOp.getCountIncludePadAttr()); + + auto newAvgOp = rewriter.create( + avgOp.getLoc(), avgOp->getResultTypes(), operands, attrs); + + rewriter.replaceOp(avgOp, newAvgOp->getResults()); + rewriter.eraseOp(padOp); + + return success(); + } +}; + +struct FusePadIntoAveragePoolPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FusePadIntoAveragePoolPass) + + StringRef getArgument() const override { return "fuse-pad-into-avgpool"; } + StringRef getDescription() const override { + return "Fuse ONNXPadOp into ONNXAveragePoolOp when mode=constant and " + "pad=0."; + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +namespace onnx_mlir { +std::unique_ptr createFusePadIntoAvgpoolPass() { + return std::make_unique(); +} +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 8e4855e6dc..e2b88ec4b9 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -58,7 +58,7 @@ std::unique_ptr createQDQAroundOpOptONNXToONNXPass(); std::unique_ptr createQDQOptONNXToONNXPass(); std::unique_ptr createFoldDQBinaryQPass(); - +std::unique_ptr createFusePadIntoAvgpoolPass(); /// Pass for instrument the ops in specific stage. std::unique_ptr createInstrumentPass(); std::unique_ptr createInstrumentPass( diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp index ba6635b466..c7ce02a418 100644 --- a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp +++ b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp @@ -86,6 +86,10 @@ void registerOMPasses(int optLevel) { mlir::registerPass( []() -> std::unique_ptr { return createInstrumentPass(); }); + mlir::registerPass([]() -> std::unique_ptr { + return createFusePadIntoAvgpoolPass(); + }); + mlir::registerPass([]() -> std::unique_ptr { return createInstrumentCleanupPass(); }); diff --git a/test/mlir/onnx/onnx_fuse_pad_in_avgpool.mlir b/test/mlir/onnx/onnx_fuse_pad_in_avgpool.mlir new file mode 100644 index 0000000000..76bcb25c47 --- /dev/null +++ b/test/mlir/onnx/onnx_fuse_pad_in_avgpool.mlir @@ -0,0 +1,25 @@ +// RUN: onnx-mlir-opt --fuse-pad-into-avgpool %s | FileCheck %s + +func.func @test_fuse_pad_avgpool(%arg0: tensor<1x1x4x4xf32>) -> tensor<1x1x8x8xf32> { + %0 = onnx.Constant dense<[0, 0, 1, 1, 0, 0, 2, 2]> : tensor<8xi64> + %1 = onnx.Constant dense<0.000000e+00> : tensor + %2 = "onnx.NoValue"() {value} : () -> none + %3 = "onnx.Pad"(%arg0, %0, %1, %2) {mode = "constant"} : (tensor<1x1x4x4xf32>, tensor<8xi64>, tensor, none) -> tensor<1x1x7x7xf32> + %4 = "onnx.AveragePool"(%3) { + auto_pad = "NOTSET", + ceil_mode = 0 : si64, + count_include_pad = 1 : si64, + kernel_shape = [2, 2], + pads = [1, 1, 1, 1], + strides = [1, 1]} : (tensor<1x1x7x7xf32>) -> tensor<1x1x8x8xf32> + return %4 : tensor<1x1x8x8xf32> + } + + +// CHECK-LABEL: func.func @test_fuse_pad_avgpool +// CHECK-NOT: onnx.Pad +// CHECK: %[[POOL:.*]] = "onnx.AveragePool"(%arg0) +// CHECK-SAME: kernel_shape = [2, 2] +// CHECK-SAME: pads = [2, 2, 3, 3] +// CHECK-SAME: strides = [1, 1] +// CHECK: return %[[POOL]] \ No newline at end of file From 3a4d6360033844d0770a12f65f72d8a455b7aad5 Mon Sep 17 00:00:00 2001 From: Rachit Gupta Date: Thu, 16 Oct 2025 00:41:01 -0500 Subject: [PATCH 2/2] file name update --- src/Dialect/ONNX/Transforms/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Dialect/ONNX/Transforms/CMakeLists.txt b/src/Dialect/ONNX/Transforms/CMakeLists.txt index fa35be8927..87d6185408 100644 --- a/src/Dialect/ONNX/Transforms/CMakeLists.txt +++ b/src/Dialect/ONNX/Transforms/CMakeLists.txt @@ -44,7 +44,7 @@ add_onnx_mlir_library(OMONNXRewrite ConstProp.cpp QDQAroundOpOpt.cpp QDQOpt.cpp - FusePadAvgPool.cpp + FusePadAvgpool.cpp DQBinaryQOpt.cpp ConvOpt.cpp Decompose.cpp