Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support decomposition of aten.adaptive_max_pool2d #3954

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 82 additions & 128 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8034,105 +8034,17 @@ class DecomposeAtenToDeviceOp : public OpRewritePattern<AtenToDeviceOp> {
} // namespace

namespace {
// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d`.
// Decompose `aten.adaptive_max_pool1d` op into `aten.max_pool1d_with_indices`
// op.
class DecomposeAtenAdaptiveMaxPool1dOp
: public OpRewritePattern<AtenAdaptiveMaxPool1dOp> {
using OpRewritePattern<AtenAdaptiveMaxPool1dOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAdaptiveMaxPool1dOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op.getContext();

Value input = op.getSelf();
std::optional<unsigned> maybeRank = getTensorRank(input);
if (!maybeRank) {
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
}
unsigned rank = *maybeRank;
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rank - 1));
Value inputSize = rewriter.create<AtenSizeIntOp>(loc, input, sizeDim);

Value outputShape = op.getOutputSize();
SmallVector<Value> outputShapeSizesTorchInt;
getListConstructElements(outputShape, outputShapeSizesTorchInt);
Value outputSize = outputShapeSizesTorchInt[0];

Value constantOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);

int64_t outputSizeInt;
if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) {
return rewriter.notifyMatchFailure(
op, "the output size of adaptive_max_pool1d must be a constant int");
}

SmallVector<Value, 1> kernelSize;
if (outputSizeInt == 1) {
BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType());
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
kernelSize.push_back(
inputShape[rank - 1] == kUnknownSize
? inputSize
: rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputShape[rank - 1])));
} else {
if (!isAssumingStrictSymbolicShapes(rewriter)) {
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
rewriter.create<RuntimeAssertOp>(
loc, cond,
"unimplemented: only support cases where input and output size are "
"equal for non-unit output size");
}
kernelSize.push_back(constantOne);
}

Value kernelSizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
Value strideList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne});
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantZero});
Value dialationList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne});

if (op.getResult(1).use_empty()) {
auto maxPool = rewriter.create<AtenMaxPool1dOp>(
loc, op.getType(0), input, kernelSizeList, strideList,
paddingSizeList, dialationList,
/*ceil_mode=*/constantFalse);
rewriter.replaceOp(op, {maxPool.getResult(), Value()});
} else {
auto maxPool = rewriter.create<AtenMaxPool1dWithIndicesOp>(
loc, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
paddingSizeList, dialationList,
/*ceil_mode=*/constantFalse);
rewriter.replaceOp(op, maxPool.getResults());
}
return success();
}
};
} // namespace

namespace {
// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op.

// The logic of this decomposition is totally same with
// the DecomposeAtenAdaptiveAvgPool2dOp, that means currently only following two
// cases are supported:
// or `aten.max_pool1d`.
//
// Only following two cases are supported:
// 1. inputSize = outputSize
// 2. outputSize = 1
class DecomposeAtenAdaptiveAvgPool1dOp
: public OpRewritePattern<AtenAdaptiveAvgPool1dOp> {
using OpRewritePattern<AtenAdaptiveAvgPool1dOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAdaptiveAvgPool1dOp op,
template <typename AtenOpT>
class DecomposeAtenAdaptivePool1dOp : public OpRewritePattern<AtenOpT> {
using OpRewritePattern<AtenOpT>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenOpT op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op.getContext();
Expand All @@ -8145,11 +8057,10 @@ class DecomposeAtenAdaptiveAvgPool1dOp
unsigned rank = *maybeRank;
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rank - 1));
Value inputSize = rewriter.create<AtenSizeIntOp>(loc, input, sizeDim);
Value inputSize = rewriter.createOrFold<AtenSizeIntOp>(loc, input, sizeDim);

Value outputShape = op.getOutputSize();
SmallVector<Value> outputShapeSizesTorchInt;
getListConstructElements(outputShape, outputShapeSizesTorchInt);
getListConstructElements(op.getOutputSize(), outputShapeSizesTorchInt);
Value outputSize = outputShapeSizesTorchInt[0];

Value constantOne = rewriter.create<Torch::ConstantIntOp>(
Expand All @@ -8162,18 +8073,12 @@ class DecomposeAtenAdaptiveAvgPool1dOp
int64_t outputSizeInt;
if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) {
return rewriter.notifyMatchFailure(
op, "the output size of adaptive_pool_1d must be a constant int");
op, "the output size of adaptive pool1d must be a constant int");
}

SmallVector<Value, 1> kernelSize;
if (outputSizeInt == 1) {
BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType());
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
kernelSize.push_back(
inputShape[rank - 1] == kUnknownSize
? inputSize
: rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputShape[rank - 1])));
kernelSize.push_back(inputSize);
} else {
if (!isAssumingStrictSymbolicShapes(rewriter)) {
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
Expand All @@ -8194,16 +8099,40 @@ class DecomposeAtenAdaptiveAvgPool1dOp
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantZero});

rewriter.replaceOpWithNewOp<AtenAvgPool1dOp>(
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
/*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue);
return success();
if constexpr (std::is_same_v<AtenAdaptiveAvgPool1dOp, AtenOpT>) {
rewriter.replaceOpWithNewOp<AtenAvgPool1dOp>(
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
/*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue);
return success();
} else if constexpr (std::is_same_v<AtenAdaptiveMaxPool1dOp, AtenOpT>) {
Value dilationList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne});
if (op.getResult(1).use_empty()) {
auto maxPool = rewriter.create<AtenMaxPool1dOp>(
loc, op.getType(0), input, kernelSizeList, strideList,
paddingSizeList, dilationList,
/*ceil_mode=*/constantFalse);
rewriter.replaceOp(op, {maxPool.getResult(), Value()});
} else {
auto maxPool = rewriter.create<AtenMaxPool1dWithIndicesOp>(
loc, op.getType(0), op.getType(1), input, kernelSizeList,
strideList, paddingSizeList, dilationList,
/*ceil_mode=*/constantFalse);
rewriter.replaceOp(op, maxPool.getResults());
}
return success();
}
return rewriter.notifyMatchFailure(
op, "unimplemented: unsupported template op");
}
};
} // namespace

namespace {
// Decompose `aten.adaptiveAvgPool2d` op into `aten.avgPool2d` op.
// Decompose `aten.adaptive_avg_pool2d` op into `aten.avg_pool2d` op.
// Decompose `aten.adaptive_max_pool2d` op into `aten.max_pool2d` or
// `aten.max_pool2d_with_indices` op.
//
// For AdaptiveAvgPool2d op, when the input size is an integer multiple of
// output size the kernelSize, stride and padding is calculated as follows:
Expand All @@ -8213,10 +8142,10 @@ namespace {
// kernelW = inW - [(outW - 1) * strideW] = strideW
// paddingH = 0, paddingW = 0
//
class DecomposeAtenAdaptiveAvgPool2dOp
: public OpRewritePattern<AtenAdaptiveAvgPool2dOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAdaptiveAvgPool2dOp op,
template <typename AtenOpT>
class DecomposeAtenAdaptivePool2dOp : public OpRewritePattern<AtenOpT> {
using OpRewritePattern<AtenOpT>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenOpT op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
Expand All @@ -8232,15 +8161,14 @@ class DecomposeAtenAdaptiveAvgPool2dOp
Value dimH = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rank - 2));
inputHW.push_back(
/*inH=*/rewriter.create<AtenSizeIntOp>(loc, input, dimH));
/*inH=*/rewriter.createOrFold<AtenSizeIntOp>(loc, input, dimH));
Value dimW = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rank - 1));
inputHW.push_back(
/*inW=*/rewriter.create<AtenSizeIntOp>(loc, input, dimW));
/*inW=*/rewriter.createOrFold<AtenSizeIntOp>(loc, input, dimW));

Value outputShape = op.getOutputSize();
SmallVector<Value> outputShapeSizesTorchInt;
getListConstructElements(outputShape, outputShapeSizesTorchInt);
getListConstructElements(op.getOutputSize(), outputShapeSizesTorchInt);

// TODO: Add support for cases other than:
// inH % outH != 0 or inW % outW != 0 where
Expand Down Expand Up @@ -8321,11 +8249,32 @@ class DecomposeAtenAdaptiveAvgPool2dOp
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantZero, constantZero});

rewriter.replaceOpWithNewOp<AtenAvgPool2dOp>(
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
/*ceilMode=*/constantFalse, /*countIncludePad=*/constantTrue,
/*divisorOverride=*/constantNone);
return success();
if constexpr (std::is_same_v<AtenOpT, AtenAdaptiveAvgPool2dOp>) {
rewriter.replaceOpWithNewOp<AtenAvgPool2dOp>(
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
/*ceilMode=*/constantFalse, /*countIncludePad=*/constantTrue,
/*divisorOverride=*/constantNone);
return success();
} else if constexpr (std::is_same_v<AtenOpT, AtenAdaptiveMaxPool2dOp>) {
Value dilationList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne, constantOne});
if (op.getResult(1).use_empty()) {
auto maxPool = rewriter.create<AtenMaxPool2dOp>(
loc, op.getType(0), input, kernelSizeList, strideList,
paddingSizeList, dilationList, /*ceil_mode=*/constantFalse);
rewriter.replaceOp(op, {maxPool.getResult(), Value()});
} else {
auto maxPool = rewriter.create<AtenMaxPool2dWithIndicesOp>(
loc, op.getType(0), op.getType(1), input, kernelSizeList,
strideList, paddingSizeList, dilationList,
/*ceil_mode=*/constantFalse);
rewriter.replaceOp(op, maxPool.getResults());
}
return success();
}
return rewriter.notifyMatchFailure(
op, "unimplemented: unsupported template op");
}
};
} // namespace
Expand Down Expand Up @@ -11640,9 +11589,14 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToPrimDeviceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveMaxPool1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAdaptivePool1dOp<AtenAdaptiveMaxPool1dOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAdaptivePool1dOp<AtenAdaptiveAvgPool1dOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAdaptivePool2dOp<AtenAdaptiveMaxPool2dOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAdaptivePool2dOp<AtenAdaptiveAvgPool2dOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMaxOp>(patterns);
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenToPrimDeviceOp>();
target.addIllegalOp<AtenAdaptiveAvgPool1dOp>();
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
target.addIllegalOp<AtenAdaptiveMaxPool1dOp>();
target.addIllegalOp<AtenAdaptiveMaxPool2dOp>();
target.addIllegalOp<AtenClampMinOp>();
target.addIllegalOp<AtenClampMinTensorOp>();
target.addIllegalOp<AtenClampMaxOp>();
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2802,6 +2802,7 @@
"AdaptiveMaxPool2dDynamicNoBatch_basic",
"AdaptiveMaxPool2dDynamicWithIndices_basic",
"AdaptiveMaxPool2dDynamic_basic",
"AdaptiveMaxPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveMaxPool2dStaticWithIndices_basic",
"AdaptiveMaxPool2dStatic_basic",
"AdaptiveMaxPool3dDynamicNoBatch_basic",
Expand Down
46 changes: 46 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1921,6 +1921,52 @@ def AdaptiveMaxPool2dStaticWithIndices_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 512, 10, 16))


class AdaptiveMaxPool2dFixedKernelStrideSizeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.amp2d = torch.nn.AdaptiveMaxPool2d((2, 2))

@export
@annotate_args(
[
None,
([1, 3, 7, 7], torch.float32, True),
]
)
def forward(self, x):
return self.amp2d(x)


@register_test_case(
module_factory=lambda: AdaptiveMaxPool2dFixedKernelStrideSizeStaticModule()
)
def AdaptiveMaxPool2dFixedKernelStrideSizeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 3, 7, 7))


class AdaptiveMaxPool2dUnitOutputSizeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.amp2d = torch.nn.AdaptiveMaxPool2d((1, 1))

@export
@annotate_args(
[
None,
([1, 512, 7, 7], torch.float32, True),
]
)
def forward(self, x):
return self.amp2d(x)


@register_test_case(
module_factory=lambda: AdaptiveMaxPool2dUnitOutputSizeStaticModule()
)
def AdaptiveMaxPool2dUnitOutputSizeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 512, 7, 7))


# AdaptiveMaxPool3d


Expand Down
Loading