Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/Compiler/OnnxToMlirPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
}
}

if (opts.enableFusePadIntoAvgpool)
pm.addPass(createFusePadIntoAvgpoolPass());

// Simplify shape-related ops.
pm.addPass(onnx_mlir::createSimplifyShapeRelatedOpsPass(
opts.enableQuarkQuantizedLegalization));
Expand Down
2 changes: 1 addition & 1 deletion src/Compiler/OnnxToMlirPasses.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct OnnxToMlirOptions {
bool enableRemoveDqQOp = true;
bool enableRemoveDqQAroundOp = true;
bool enableRemoveBinary = false;

bool enableFusePadIntoAvgpool = false;
bool disableRecomposeOption = false;
bool enableONNXHybridPass = true;
bool enableConvOptPass = true;
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ add_onnx_mlir_library(OMONNXRewrite
ConstProp.cpp
QDQAroundOpOpt.cpp
QDQOpt.cpp
FusePadAvgpool.cpp
DQBinaryQOpt.cpp
ConvOpt.cpp
Decompose.cpp
Expand Down
141 changes: 141 additions & 0 deletions src/Dialect/ONNX/Transforms/FusePadAvgpool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
#include "src/Pass/Passes.hpp"

using namespace mlir;
using namespace onnx_mlir;

namespace {

struct FusePadIntoAveragePoolPattern
: public OpRewritePattern<ONNXAveragePoolOp> {
using OpRewritePattern<ONNXAveragePoolOp>::OpRewritePattern;

LogicalResult matchAndRewrite(
ONNXAveragePoolOp avgOp, PatternRewriter &rewriter) const override {

Value input = avgOp.getX();
auto padOp = input.getDefiningOp<ONNXPadOp>();
if (!padOp)
return failure();

StringAttr modeAttr = padOp.getModeAttr();
StringRef mode = "constant";
if (modeAttr)
mode = modeAttr.getValue();
if (mode != "constant")
return failure();
float padValue = 0.0f;

Value padsInput = padOp.getPads();
Value constantValInput = padOp.getConstantValue();

auto padsConstOp =
dyn_cast_or_null<ONNXConstantOp>(padsInput.getDefiningOp());
if (!padsConstOp)
return failure();
auto padsAttr = dyn_cast_or_null<ElementsAttr>(padsConstOp.getValueAttr());
if (!padsAttr)
return failure();

auto constOp =
dyn_cast_or_null<ONNXConstantOp>(constantValInput.getDefiningOp());
if (!constOp)
return failure();
auto constAttr = dyn_cast_or_null<ElementsAttr>(constOp.getValueAttr());

if (!constAttr)
return failure();

auto firstAttr = *constAttr.getValues<Attribute>().begin();
if (auto fAttr = mlir::dyn_cast<FloatAttr>(firstAttr))
padValue = fAttr.getValueAsDouble();

if (padValue != 0.0f)
return failure();

SmallVector<int64_t> padsVals;
for (auto val : padsAttr.getValues<Attribute>()) {
if (auto iAttr = mlir::dyn_cast<IntegerAttr>(val)) {
auto pad = iAttr.getInt();
padsVals.push_back(pad);
} else {
padsVals.push_back(0);
}
}

SmallVector<int64_t> mergedPads;
if (auto existingPadsAttr = avgOp.getPadsAttr()) {
for (Attribute v : existingPadsAttr) {
mergedPads.push_back(cast<IntegerAttr>(v).getInt());
}
} else {
mergedPads.resize(padsVals.size() / 2, 0);
}

if (mergedPads.size() != padsVals.size() / 2)
return failure();

mergedPads[0] += padsVals[2];
mergedPads[1] += padsVals[3];
mergedPads[2] += padsVals[6];
mergedPads[3] += padsVals[7];

auto mergedPadsAttr =
rewriter.getI64ArrayAttr(llvm::ArrayRef<int64_t>(mergedPads));

SmallVector<Value, 1> operands;
operands.push_back(padOp.getData());

NamedAttrList attrs;
attrs.set(avgOp.getKernelShapeAttrName(), avgOp.getKernelShapeAttr());
attrs.set(avgOp.getPadsAttrName(), mergedPadsAttr);
attrs.set(avgOp.getStridesAttrName(), avgOp.getStridesAttr());
attrs.set(avgOp.getCeilModeAttrName(), avgOp.getCeilModeAttr());
attrs.set(
avgOp.getCountIncludePadAttrName(), avgOp.getCountIncludePadAttr());

auto newAvgOp = rewriter.create<ONNXAveragePoolOp>(
avgOp.getLoc(), avgOp->getResultTypes(), operands, attrs);

rewriter.replaceOp(avgOp, newAvgOp->getResults());
rewriter.eraseOp(padOp);

return success();
}
};

struct FusePadIntoAveragePoolPass
: public PassWrapper<FusePadIntoAveragePoolPass,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FusePadIntoAveragePoolPass)

StringRef getArgument() const override { return "fuse-pad-into-avgpool"; }
StringRef getDescription() const override {
return "Fuse ONNXPadOp into ONNXAveragePoolOp when mode=constant and "
"pad=0.";
}

void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
patterns.add<FusePadIntoAveragePoolPattern>(ctx);

if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};

} // namespace

namespace onnx_mlir {
std::unique_ptr<Pass> createFusePadIntoAvgpoolPass() {
return std::make_unique<FusePadIntoAveragePoolPass>();
}
} // namespace onnx_mlir
2 changes: 1 addition & 1 deletion src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ std::unique_ptr<mlir::Pass> createQDQAroundOpOptONNXToONNXPass();

std::unique_ptr<mlir::Pass> createQDQOptONNXToONNXPass();
std::unique_ptr<mlir::Pass> createFoldDQBinaryQPass();

std::unique_ptr<mlir::Pass> createFusePadIntoAvgpoolPass();
/// Pass for instrument the ops in specific stage.
std::unique_ptr<mlir::Pass> createInstrumentPass();
std::unique_ptr<mlir::Pass> createInstrumentPass(
Expand Down
4 changes: 4 additions & 0 deletions src/Tools/onnx-mlir-opt/RegisterPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ void registerOMPasses(int optLevel) {
mlir::registerPass(
[]() -> std::unique_ptr<mlir::Pass> { return createInstrumentPass(); });

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return createFusePadIntoAvgpoolPass();
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return createInstrumentCleanupPass();
});
Expand Down
25 changes: 25 additions & 0 deletions test/mlir/onnx/onnx_fuse_pad_in_avgpool.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: onnx-mlir-opt --fuse-pad-into-avgpool %s | FileCheck %s

func.func @test_fuse_pad_avgpool(%arg0: tensor<1x1x4x4xf32>) -> tensor<1x1x8x8xf32> {
%0 = onnx.Constant dense<[0, 0, 1, 1, 0, 0, 2, 2]> : tensor<8xi64>
%1 = onnx.Constant dense<0.000000e+00> : tensor<f32>
%2 = "onnx.NoValue"() {value} : () -> none
%3 = "onnx.Pad"(%arg0, %0, %1, %2) {mode = "constant"} : (tensor<1x1x4x4xf32>, tensor<8xi64>, tensor<f32>, none) -> tensor<1x1x7x7xf32>
%4 = "onnx.AveragePool"(%3) {
auto_pad = "NOTSET",
ceil_mode = 0 : si64,
count_include_pad = 1 : si64,
kernel_shape = [2, 2],
pads = [1, 1, 1, 1],
strides = [1, 1]} : (tensor<1x1x7x7xf32>) -> tensor<1x1x8x8xf32>
return %4 : tensor<1x1x8x8xf32>
}


// CHECK-LABEL: func.func @test_fuse_pad_avgpool
// CHECK-NOT: onnx.Pad
// CHECK: %[[POOL:.*]] = "onnx.AveragePool"(%arg0)
// CHECK-SAME: kernel_shape = [2, 2]
// CHECK-SAME: pads = [2, 2, 3, 3]
// CHECK-SAME: strides = [1, 1]
// CHECK: return %[[POOL]]