Skip to content

support decomposition of aten.adaptive_max_pool2d #3954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 82 additions & 128 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8038,105 +8038,17 @@ class DecomposeAtenToDeviceOp : public OpRewritePattern<AtenToDeviceOp> {
} // namespace

namespace {
// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d`.
// Decompose `aten.adaptive_max_pool1d` op into `aten.max_pool1d_with_indices`
// op.
class DecomposeAtenAdaptiveMaxPool1dOp
: public OpRewritePattern<AtenAdaptiveMaxPool1dOp> {
using OpRewritePattern<AtenAdaptiveMaxPool1dOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAdaptiveMaxPool1dOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op.getContext();

Value input = op.getSelf();
std::optional<unsigned> maybeRank = getTensorRank(input);
if (!maybeRank) {
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
}
unsigned rank = *maybeRank;
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rank - 1));
Value inputSize = rewriter.create<AtenSizeIntOp>(loc, input, sizeDim);

Value outputShape = op.getOutputSize();
SmallVector<Value> outputShapeSizesTorchInt;
getListConstructElements(outputShape, outputShapeSizesTorchInt);
Value outputSize = outputShapeSizesTorchInt[0];

Value constantOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);

int64_t outputSizeInt;
if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) {
return rewriter.notifyMatchFailure(
op, "the output size of adaptive_max_pool1d must be a constant int");
}

SmallVector<Value, 1> kernelSize;
if (outputSizeInt == 1) {
BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType());
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
kernelSize.push_back(
inputShape[rank - 1] == kUnknownSize
? inputSize
: rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputShape[rank - 1])));
} else {
if (!isAssumingStrictSymbolicShapes(rewriter)) {
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
rewriter.create<RuntimeAssertOp>(
loc, cond,
"unimplemented: only support cases where input and output size are "
"equal for non-unit output size");
}
kernelSize.push_back(constantOne);
}

Value kernelSizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
Value strideList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne});
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantZero});
Value dialationList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne});

if (op.getResult(1).use_empty()) {
auto maxPool = rewriter.create<AtenMaxPool1dOp>(
loc, op.getType(0), input, kernelSizeList, strideList,
paddingSizeList, dialationList,
/*ceil_mode=*/constantFalse);
rewriter.replaceOp(op, {maxPool.getResult(), Value()});
} else {
auto maxPool = rewriter.create<AtenMaxPool1dWithIndicesOp>(
loc, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
paddingSizeList, dialationList,
/*ceil_mode=*/constantFalse);
rewriter.replaceOp(op, maxPool.getResults());
}
return success();
}
};
} // namespace

namespace {
// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op.

// The logic of this decomposition is totally same with
// the DecomposeAtenAdaptiveAvgPool2dOp, that means currently only following two
// cases are supported:
// or `aten.max_pool1d`.
//
// Only following two cases are supported:
// 1. inputSize = outputSize
// 2. outputSize = 1
class DecomposeAtenAdaptiveAvgPool1dOp
: public OpRewritePattern<AtenAdaptiveAvgPool1dOp> {
using OpRewritePattern<AtenAdaptiveAvgPool1dOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAdaptiveAvgPool1dOp op,
template <typename AtenOpT>
class DecomposeAtenAdaptivePool1dOp : public OpRewritePattern<AtenOpT> {
using OpRewritePattern<AtenOpT>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenOpT op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op.getContext();
Expand All @@ -8149,11 +8061,10 @@ class DecomposeAtenAdaptiveAvgPool1dOp
unsigned rank = *maybeRank;
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rank - 1));
Value inputSize = rewriter.create<AtenSizeIntOp>(loc, input, sizeDim);
Value inputSize = rewriter.createOrFold<AtenSizeIntOp>(loc, input, sizeDim);

Value outputShape = op.getOutputSize();
SmallVector<Value> outputShapeSizesTorchInt;
getListConstructElements(outputShape, outputShapeSizesTorchInt);
getListConstructElements(op.getOutputSize(), outputShapeSizesTorchInt);
Value outputSize = outputShapeSizesTorchInt[0];

Value constantOne = rewriter.create<Torch::ConstantIntOp>(
Expand All @@ -8166,18 +8077,12 @@ class DecomposeAtenAdaptiveAvgPool1dOp
int64_t outputSizeInt;
if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) {
return rewriter.notifyMatchFailure(
op, "the output size of adaptive_pool_1d must be a constant int");
op, "the output size of adaptive pool1d must be a constant int");
}

SmallVector<Value, 1> kernelSize;
if (outputSizeInt == 1) {
BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType());
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
kernelSize.push_back(
inputShape[rank - 1] == kUnknownSize
? inputSize
: rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputShape[rank - 1])));
kernelSize.push_back(inputSize);
} else {
if (!isAssumingStrictSymbolicShapes(rewriter)) {
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
Expand All @@ -8198,16 +8103,40 @@ class DecomposeAtenAdaptiveAvgPool1dOp
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantZero});

rewriter.replaceOpWithNewOp<AtenAvgPool1dOp>(
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
/*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue);
return success();
if constexpr (std::is_same_v<AtenAdaptiveAvgPool1dOp, AtenOpT>) {
rewriter.replaceOpWithNewOp<AtenAvgPool1dOp>(
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
/*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue);
return success();
} else if constexpr (std::is_same_v<AtenAdaptiveMaxPool1dOp, AtenOpT>) {
Value dilationList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne});
if (op.getResult(1).use_empty()) {
auto maxPool = rewriter.create<AtenMaxPool1dOp>(
loc, op.getType(0), input, kernelSizeList, strideList,
paddingSizeList, dilationList,
/*ceil_mode=*/constantFalse);
rewriter.replaceOp(op, {maxPool.getResult(), Value()});
} else {
auto maxPool = rewriter.create<AtenMaxPool1dWithIndicesOp>(
loc, op.getType(0), op.getType(1), input, kernelSizeList,
strideList, paddingSizeList, dilationList,
/*ceil_mode=*/constantFalse);
rewriter.replaceOp(op, maxPool.getResults());
}
return success();
}
return rewriter.notifyMatchFailure(
op, "unimplemented: unsupported template op");
}
};
} // namespace

namespace {
// Decompose `aten.adaptiveAvgPool2d` op into `aten.avgPool2d` op.
// Decompose `aten.adaptive_avg_pool2d` op into `aten.avg_pool2d` op.
// Decompose `aten.adaptive_max_pool2d` op into `aten.max_pool2d` or
// `aten.max_pool2d_with_indices` op.
//
// For AdaptiveAvgPool2d op, when the input size is an integer multiple of
// output size the kernelSize, stride and padding is calculated as follows:
Expand All @@ -8217,10 +8146,10 @@ namespace {
// kernelW = inW - [(outW - 1) * strideW] = strideW
// paddingH = 0, paddingW = 0
//
class DecomposeAtenAdaptiveAvgPool2dOp
: public OpRewritePattern<AtenAdaptiveAvgPool2dOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAdaptiveAvgPool2dOp op,
template <typename AtenOpT>
class DecomposeAtenAdaptivePool2dOp : public OpRewritePattern<AtenOpT> {
using OpRewritePattern<AtenOpT>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenOpT op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
Expand All @@ -8236,15 +8165,14 @@ class DecomposeAtenAdaptiveAvgPool2dOp
Value dimH = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rank - 2));
inputHW.push_back(
/*inH=*/rewriter.create<AtenSizeIntOp>(loc, input, dimH));
/*inH=*/rewriter.createOrFold<AtenSizeIntOp>(loc, input, dimH));
Value dimW = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rank - 1));
inputHW.push_back(
/*inW=*/rewriter.create<AtenSizeIntOp>(loc, input, dimW));
/*inW=*/rewriter.createOrFold<AtenSizeIntOp>(loc, input, dimW));

Value outputShape = op.getOutputSize();
SmallVector<Value> outputShapeSizesTorchInt;
getListConstructElements(outputShape, outputShapeSizesTorchInt);
getListConstructElements(op.getOutputSize(), outputShapeSizesTorchInt);

// TODO: Add support for cases other than:
// inH % outH != 0 or inW % outW != 0 where
Expand Down Expand Up @@ -8325,11 +8253,32 @@ class DecomposeAtenAdaptiveAvgPool2dOp
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantZero, constantZero});

rewriter.replaceOpWithNewOp<AtenAvgPool2dOp>(
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
/*ceilMode=*/constantFalse, /*countIncludePad=*/constantTrue,
/*divisorOverride=*/constantNone);
return success();
if constexpr (std::is_same_v<AtenOpT, AtenAdaptiveAvgPool2dOp>) {
rewriter.replaceOpWithNewOp<AtenAvgPool2dOp>(
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
/*ceilMode=*/constantFalse, /*countIncludePad=*/constantTrue,
/*divisorOverride=*/constantNone);
return success();
} else if constexpr (std::is_same_v<AtenOpT, AtenAdaptiveMaxPool2dOp>) {
Value dilationList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne, constantOne});
if (op.getResult(1).use_empty()) {
auto maxPool = rewriter.create<AtenMaxPool2dOp>(
loc, op.getType(0), input, kernelSizeList, strideList,
paddingSizeList, dilationList, /*ceil_mode=*/constantFalse);
rewriter.replaceOp(op, {maxPool.getResult(), Value()});
} else {
auto maxPool = rewriter.create<AtenMaxPool2dWithIndicesOp>(
loc, op.getType(0), op.getType(1), input, kernelSizeList,
strideList, paddingSizeList, dilationList,
/*ceil_mode=*/constantFalse);
rewriter.replaceOp(op, maxPool.getResults());
}
return success();
}
return rewriter.notifyMatchFailure(
op, "unimplemented: unsupported template op");
}
};
} // namespace
Expand Down Expand Up @@ -11760,9 +11709,14 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToPrimDeviceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveMaxPool1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAdaptivePool1dOp<AtenAdaptiveMaxPool1dOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAdaptivePool1dOp<AtenAdaptiveAvgPool1dOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAdaptivePool2dOp<AtenAdaptiveMaxPool2dOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAdaptivePool2dOp<AtenAdaptiveAvgPool2dOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMaxOp>(patterns);
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenToPrimDeviceOp>();
target.addIllegalOp<AtenAdaptiveAvgPool1dOp>();
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
target.addIllegalOp<AtenAdaptiveMaxPool1dOp>();
target.addIllegalOp<AtenAdaptiveMaxPool2dOp>();
target.addIllegalOp<AtenClampMinOp>();
target.addIllegalOp<AtenClampMinTensorOp>();
target.addIllegalOp<AtenClampMaxOp>();
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2702,6 +2702,7 @@
"AdaptiveMaxPool2dDynamicNoBatch_basic",
"AdaptiveMaxPool2dDynamicWithIndices_basic",
"AdaptiveMaxPool2dDynamic_basic",
"AdaptiveMaxPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveMaxPool2dStaticWithIndices_basic",
"AdaptiveMaxPool2dStatic_basic",
"AdaptiveMaxPool3dDynamicNoBatch_basic",
Expand Down
39 changes: 1 addition & 38 deletions projects/pt1/python/torch_mlir/torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,34 +145,6 @@ def _get_for_tracing(
return result


# The set of ops that are considered legal for each backend.
# These are currently quite load-bearing, since different backends might be
# missing patterns for decomposed forms of certain ops.
# TODO: Tighten up the definition of these "conditionally legal for backends"
# ops in the backend contract, and move these lists somewhere deeper in the
# compiler where each backend can "own" its set of legal ops.
BACKEND_LEGAL_OPS = {
OutputType.TOSA: [
"aten.flatten.using_ints",
"aten.native_layer_norm",
"aten.linear",
],
OutputType.LINALG_ON_TENSORS: [
"aten.flatten.using_ints",
"aten.adaptive_avg_pool1d",
"aten.adaptive_avg_pool2d",
"aten.unflatten.int",
],
OutputType.STABLEHLO: [
"aten.amax",
"aten.amin",
"aten.randn.generator",
"aten.normal_functional",
"aten.fmod.Tensor",
],
}


def _canon_extra_library(
extra_library, extra_library_file_name="custom_op_extra_library.mlir"
):
Expand Down Expand Up @@ -249,19 +221,10 @@ def compile(
if ignore_traced_shapes and not use_tracing:
raise Exception("`ignore_traced_shapes` requires `use_tracing`")

# We only allow `backend_legal_ops` to be specified for the `"torch"`
# output type because the other output types actually invoke their
# respective backends (Linalg, TOSA, or STABLEHLO), and those backends have
# very specific requirements about the ops which are legal.
# See `BACKEND_LEGAL_OPS` for more details.
if backend_legal_ops is not None:
if output_type != OutputType.TORCH:
raise Exception(
"`backend_legal_ops` is only valid with the " "`torch` output type"
)
backend_legal_ops = list(sorted(set(backend_legal_ops)))
else:
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
backend_legal_ops = []

# For FX-based models, automatically strip overloads.
if isinstance(model, torch.fx.GraphModule):
Expand Down
Loading
Loading