Skip to content

Average pooling clamped divisor should be done on all conditions where the kernel can go out of bounds #4144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
168 changes: 108 additions & 60 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ namespace {
// used in the divisor of the average pooling operator.
template <int NumOfDims> class PoolSizeCalculator {
public:
PoolSizeCalculator(Value self, Value sumPool,
PoolSizeCalculator(Value self, Value sumPool, bool countIncludePad,
ConversionPatternRewriter &rewriter, Location loc);

// The algorithm for computing the divisor with
Expand All @@ -871,18 +871,19 @@ template <int NumOfDims> class PoolSizeCalculator {
SmallVectorImpl<int64_t> &paddingInts);

private:
int64_t DimSizeFromSumPoolType[NumOfDims];
Value InputSpatialDimValues[NumOfDims];
int64_t SumPoolTypeDimIndex[NumOfDims];
Value InputSpatialDimSizes[NumOfDims];
Location location;
bool isCountIncludePad;
};

} // namespace

template <int NumOfDims>
PoolSizeCalculator<NumOfDims>::PoolSizeCalculator(
Value self, Value sumPool, ConversionPatternRewriter &rewriter,
Location loc)
: location(loc) {
Value self, Value sumPool, bool countIncludePad,
ConversionPatternRewriter &rewriter, Location loc)
: location(loc), isCountIncludePad(countIncludePad) {
auto selfType = cast<RankedTensorType>(self.getType());
const int64_t selfRank = selfType.getRank();
RankedTensorType sumPoolType = cast<RankedTensorType>(sumPool.getType());
Expand All @@ -891,16 +892,16 @@ PoolSizeCalculator<NumOfDims>::PoolSizeCalculator(
// Store dimensions in this order:
// 0 => width, 1 => height, 2 => depth
for (int i = 0; i < NumOfDims; ++i) {
int64_t DimSizeFromSelfType = toPositiveDim(-(i + 1), selfRank);
InputSpatialDimValues[i] =
getDimOp(rewriter, location, self, DimSizeFromSelfType);
DimSizeFromSumPoolType[i] = toPositiveDim(-(i + 1), rank);
int64_t inputSpatialDimIndex = toPositiveDim(-(i + 1), selfRank);
InputSpatialDimSizes[i] =
getDimOp(rewriter, location, self, inputSpatialDimIndex);
SumPoolTypeDimIndex[i] = toPositiveDim(-(i + 1), rank);
}
}

template <int NumOfDims>
Value PoolSizeCalculator<NumOfDims>::getPoolSize(
OpBuilder &b, SmallVectorImpl<Value> &kernelSizeIntValues,
OpBuilder &b, SmallVectorImpl<Value> &kernelDimSizes,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts) {
Value poolSize;
Expand All @@ -915,19 +916,26 @@ Value PoolSizeCalculator<NumOfDims>::getPoolSize(
// Dim below stands for spatial dimension. Prior to the February 2025
// change, these variables used "height" and "width" (or "h" and "w")
// in these intermediate variables instead of "Dim".

// The average pool properties of kernel size, strides, and padding are
// stored in the reverse order of the input tensor dimensions. The
// following code computes the index of the average pool property that
// corresponds to the current spatial dimension.
int avgPoolPropIdx = NumOfDims - i - 1;

Value IndexODim =
b.create<linalg::IndexOp>(location,
/*value=*/DimSizeFromSumPoolType[i]);
/*value=*/SumPoolTypeDimIndex[i]);
Value ODim = castIndexToInt64(b, location, IndexODim);
Value DDim = b.createOrFold<arith::ConstantOp>(
location, b.getI64IntegerAttr(strideInts[i]));
location, b.getI64IntegerAttr(strideInts[avgPoolPropIdx]));
Value PadDim = b.createOrFold<arith::ConstantOp>(
location, b.getI64IntegerAttr(paddingInts[i]));
location, b.getI64IntegerAttr(paddingInts[avgPoolPropIdx]));
Value ODimDDim = b.createOrFold<arith::MulIOp>(location, ODim, DDim);
Value IDim0 = b.createOrFold<arith::SubIOp>(location, ODimDDim, PadDim);
Value IDim = castIndexToInt64(b, location, InputSpatialDimValues[i]);
Value IDim0KDim =
b.createOrFold<arith::AddIOp>(location, IDim0, kernelSizeIntValues[i]);
Value IDim = castIndexToInt64(b, location, InputSpatialDimSizes[i]);
Value IDim0KDim = b.createOrFold<arith::AddIOp>(
location, IDim0, kernelDimSizes[avgPoolPropIdx]);
Value IDimPadDim = b.createOrFold<arith::AddIOp>(location, IDim, PadDim);
Value IDim1 =
b.createOrFold<arith::MinSIOp>(location, IDim0KDim, IDimPadDim);
Expand All @@ -937,11 +945,15 @@ Value PoolSizeCalculator<NumOfDims>::getPoolSize(
Value IDim1Clamped = b.createOrFold<arith::MinSIOp>(location, IDim1, IDim);
Value IDim1_IDim0_Clamped =
b.createOrFold<arith::SubIOp>(location, IDim1Clamped, IDim0Clamped);

Value poolSizeDim =
!isCountIncludePad
? IDim1_IDim0_Clamped
: b.createOrFold<arith::SubIOp>(location, IDim1, IDim0);
if (i == 0) {
poolSize = IDim1_IDim0_Clamped;
poolSize = poolSizeDim;
} else {
poolSize = b.createOrFold<arith::MulIOp>(location, poolSize,
IDim1_IDim0_Clamped);
poolSize = b.createOrFold<arith::MulIOp>(location, poolSize, poolSizeDim);
}
}
return poolSize;
Expand All @@ -957,26 +969,35 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override;

// Creates the average pooling operation value when the
// count_include_pad parameter is equal to false.
static std::optional<LogicalResult>
createAvgPoolValueCountIncludePadFalseCase(
bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelSizeIntValues,
// If the condition below is true, the divisor total must subtract the
// elements not counted (clamped divisor count). If false, the divisor
// is just the product of kernel dimensions.
static bool
doesAvgPoolDivisorNeedsClamping(bool ceilMode, bool countIncludePad,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts);

// Creates the average pooling operation value with a clamped
// divisor. The clamped divisor is the product of kernel
// dimensions minus the elements not counted; e.g., padding
// and ceiling mode implicit padding.
static LogicalResult createAveragePoolValueWithClampedDivisor(
bool ceilMode, bool countIncludePad, OpTy op,
typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter,
Value self, Value sumPool, Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelDimSizes,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts,
SmallVector<AffineMap> &indexingMapsAvg,
SmallVector<utils::IteratorType> &iteratorTypesAvg);

// Creates the average pooling operation value when the
// count_include_pad parameter is equal to true.
static LogicalResult createAvgPoolValueCountIncludePadTrueCase(
// Creates the average pooling operation value with a
// regular divisor; i.e., the product of kernel dimensions.
static LogicalResult createAveragePoolValueWithRegularDivisor(
OpTy op, typename OpTy::Adaptor &adaptor,
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelSizeIntValues,
SmallVectorImpl<Value> &kernelDimSizes,
SmallVector<AffineMap> &indexingMapsAvg,
SmallVector<utils::IteratorType> &iteratorTypesAvg);
};
Expand Down Expand Up @@ -1040,27 +1061,59 @@ LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::matchAndRewrite(
SmallVector<utils::IteratorType> iteratorTypesAvg(
Dim + 2, utils::IteratorType::parallel);

auto divisorOpResult = createAvgPoolValueCountIncludePadFalseCase(
countIncludePad, op, adaptor, rewriter, self, sumPool, outputTensor,
resultType, kernelSizeIntValues, strideInts, paddingInts, indexingMapsAvg,
iteratorTypesAvg);
if (divisorOpResult)
return *divisorOpResult;
if (doesAvgPoolDivisorNeedsClamping(ceilMode, countIncludePad, strideInts,
paddingInts)) {
return createAveragePoolValueWithClampedDivisor(
ceilMode, countIncludePad, op, adaptor, rewriter, self, sumPool,
outputTensor, resultType, kernelSizeIntValues, strideInts, paddingInts,
indexingMapsAvg, iteratorTypesAvg);
}

return createAvgPoolValueCountIncludePadTrueCase(
return createAveragePoolValueWithRegularDivisor(
op, adaptor, rewriter, self, sumPool, outputTensor, resultType,
kernelSizeIntValues, indexingMapsAvg, iteratorTypesAvg);
}

return success();
template <typename OpTy, typename PoolingOpTy, int Dim>
bool ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
doesAvgPoolDivisorNeedsClamping(bool ceilMode, bool countIncludePad,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts) {
// There are two ways to get the divisor clamped: through padding or
// ceiling mode. For the case when there is padding, the padding elements
// are omitted if count_include_pad == False (divisor is clamped). If
// there is no padding (padding == 0) then the count_include_pad value
// does not take effect.
// The divisor count can be clamped also through the ceil_mode. In this
// case, according to the Hout and Wout formula in this page:
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d,
// the ceil_mode will round up on the stride division. The round up
// will give an extra element that will go out of bounds which PyTorch
// adds zero padding in it. It also does not count the implicit zero
// padding elements in the divisor, and it is not controlled by the
// count_include_pad argument.
// But also note that if all strides are 1 there are not fractions to
// round up, hence there is no ceiling rounding and the window will
// not go out of bounds. For this case the divisor is just the
// product of kernel dimensions.
// Search for torch.nn.AvgPool2d E2E tests for coverage of these
// conditions.

bool hasPadding =
!llvm::all_of(paddingInts, [](int64_t p) { return p == 0; });
bool allStridesUnitary =
llvm::all_of(strideInts, [](int64_t s) { return s == 1; });

return (!countIncludePad && hasPadding) || (ceilMode && !allStridesUnitary);
}

template <typename OpTy, typename PoolingOpTy, int Dim>
std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
createAvgPoolValueCountIncludePadFalseCase(
bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelSizeIntValues,
LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
createAveragePoolValueWithClampedDivisor(
bool ceilMode, bool countIncludePad, OpTy op,
typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter,
Value self, Value sumPool, Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelDimSizes,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts,
SmallVector<AffineMap> &indexingMapsAvg,
Expand All @@ -1069,11 +1122,6 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::

constexpr int avgPoolDims = getAvgPoolNumOfDims<OpTy>();

bool noPadding = llvm::all_of(paddingInts, [](int64_t p) { return p == 0; });
if (countIncludePad || noPadding) {
// These cases are not handled here.
return std::nullopt;
}
if (avgPoolDims < 1) {
return rewriter.notifyMatchFailure(
op, "Unexpected type. Only expected AtenAvgPool1dOp, AtenAvgPool2dOp, "
Expand All @@ -1082,8 +1130,8 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::

Type resultElementType = cast<RankedTensorType>(resultType).getElementType();

PoolSizeCalculator<avgPoolDims> poolSizeCalculator(self, sumPool, rewriter,
loc);
PoolSizeCalculator<avgPoolDims> poolSizeCalculator(
self, sumPool, countIncludePad, rewriter, loc);

// AtenAvgPool2/3dOp has an optional divisor_override
// attribute while AtenAvgPool1dOp does not.
Expand All @@ -1104,7 +1152,7 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
[&](OpBuilder &b, Location loc, ValueRange args) {
if (!poolSize) {
poolSize = poolSizeCalculator.getPoolSize(
b, kernelSizeIntValues, strideInts, paddingInts);
b, kernelDimSizes, strideInts, paddingInts);
}
Value divisor =
convertScalarToDtype(b, loc, poolSize, resultElementType);
Expand All @@ -1122,21 +1170,21 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::

template <typename OpTy, typename PoolingOpTy, int Dim>
LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
createAvgPoolValueCountIncludePadTrueCase(
createAveragePoolValueWithRegularDivisor(
OpTy op, typename OpTy::Adaptor &adaptor,
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelSizeIntValues,
SmallVectorImpl<Value> &kernelDimSizes,
SmallVector<AffineMap> &indexingMapsAvg,
SmallVector<utils::IteratorType> &iteratorTypesAvg) {
Location loc = op->getLoc();

Type resultElementType = cast<RankedTensorType>(resultType).getElementType();

Value divisor = kernelSizeIntValues[0];
for (uint32_t i = 1; i < kernelSizeIntValues.size(); ++i) {
divisor = rewriter.createOrFold<arith::MulIOp>(loc, divisor,
kernelSizeIntValues[i]);
Value divisor = kernelDimSizes[0];
for (uint32_t i = 1; i < kernelDimSizes.size(); ++i) {
divisor =
rewriter.createOrFold<arith::MulIOp>(loc, divisor, kernelDimSizes[i]);
}
// Only average pooling 2D/3D have optional divisor override.
if constexpr (!std::is_same<OpTy, AtenAvgPool1dOp>()) {
Expand Down
21 changes: 21 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,13 @@
"Aten_EmbeddingBagExample_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa_basic",
"AvgPool2dCeilPadNonUnitaryStrides_basic",
"AvgPool2dCeilNoPadStridedIncludePadding_basic",
"AvgPool2dCeilPaddingStridedIncludePadding_basic",
"AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded_basic",
"AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic",
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
"AvgPool2dDivisorOverrideModule_basic",
"BernoulliTensorModule_basic",
"BincountMinlengthModule_basic",
Expand Down Expand Up @@ -2791,6 +2798,10 @@
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
"AvgPool2dSingleIntTupleParamsModule_basic",
"AvgPool2dWithoutPadModule_basic",
"AvgPool1dNoPadCeilPadNotIncluded_basic",
"AvgPool1dPadCeilPadNotIncluded_basic",
"AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded_basic",
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
"BatchMlpLayerModule_basic",
"BincountMinlengthModule_basic",
"BincountModule_basic",
Expand Down Expand Up @@ -3533,6 +3544,13 @@
"AvgPool1dIntModule_basic",
"AvgPool1dStaticModule_basic",
"AvgPool2dCeilModeTrueModule_basic",
"AvgPool1dNoPadCeilPadNotIncluded_basic",
"AvgPool1dPadCeilPadNotIncluded_basic",
"AvgPool2dCeilPaddingStridedIncludePadding_basic",
"AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic",
"AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic",
"AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic",
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
"AvgPool2dDivisorOverrideModule_basic",
"AvgPool2dFloatModule_basic",
"AvgPool2dIntModule_basic",
Expand Down Expand Up @@ -3939,6 +3957,9 @@
"AtenKthvalueFloat64Module_basic",
"AtenKthvalueKeepDimModule_basic",
"AtenKthvalueModule_basic",
"AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa_basic",
"AvgPool2dCeilNoPadUnitaryStrides_basic",
"AvgPool2dCeilPadNonUnitaryStrides_basic",
"AvgPool2dCountIncludePadFalseStaticModule_basic",
"AvgPool3dStaticModule_basic",
"Conv_Transpose1dModule_basic",
Expand Down
Loading
Loading