From 2b37e1d75d75b80b22046f80f86e39860f48ca26 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Mon, 28 Apr 2025 17:55:04 -0400 Subject: [PATCH 1/2] Average pooling clamped divisor should be done on all conditions where the kernel can go out of bounds --- lib/Conversion/TorchToLinalg/Pooling.cpp | 227 +++++--- projects/pt1/e2e_testing/xfail_sets.py | 19 + .../torch_mlir_e2e_test/test_suite/pooling.py | 491 ++++++++++++++++++ test/Conversion/TorchToLinalg/pooling.mlir | 34 +- 4 files changed, 704 insertions(+), 67 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 45268452a992..c4199a22f61e 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -856,8 +856,9 @@ namespace { // used in the divisor of the average pooling operator. template class PoolSizeCalculator { public: - PoolSizeCalculator(Value self, Value sumPool, - ConversionPatternRewriter &rewriter, Location loc); + PoolSizeCalculator(Value self, Value sumPool, bool countIncludePad, + bool ceilMode, ConversionPatternRewriter &rewriter, + Location loc); // The algorithm for computing the divisor with // count_include_pad equal is mainly based on pytorch @@ -871,18 +872,20 @@ template class PoolSizeCalculator { SmallVectorImpl &paddingInts); private: - int64_t DimSizeFromSumPoolType[NumOfDims]; - Value InputSpatialDimValues[NumOfDims]; + int64_t SumPoolTypeDimIndex[NumOfDims]; + Value InputSpatialDimSizes[NumOfDims]; Location location; + bool isCountIncludePad; + bool isCeilMode; }; } // namespace template PoolSizeCalculator::PoolSizeCalculator( - Value self, Value sumPool, ConversionPatternRewriter &rewriter, - Location loc) - : location(loc) { + Value self, Value sumPool, bool countIncludePad, bool ceilMode, + ConversionPatternRewriter &rewriter, Location loc) + : location(loc), isCountIncludePad(countIncludePad), isCeilMode(ceilMode) { auto selfType = cast(self.getType()); const int64_t selfRank = selfType.getRank(); RankedTensorType sumPoolType = cast(sumPool.getType()); @@ -891,57 +894,124 @@ PoolSizeCalculator::PoolSizeCalculator( // Store dimensions in this order: // 0 => width, 1 => height, 2 => depth for (int i = 0; i < NumOfDims; ++i) { - int64_t DimSizeFromSelfType = toPositiveDim(-(i + 1), selfRank); - InputSpatialDimValues[i] = - getDimOp(rewriter, location, self, DimSizeFromSelfType); - DimSizeFromSumPoolType[i] = toPositiveDim(-(i + 1), rank); + int64_t inputSpatialDimIndex = toPositiveDim(-(i + 1), selfRank); + InputSpatialDimSizes[i] = + getDimOp(rewriter, location, self, inputSpatialDimIndex); + SumPoolTypeDimIndex[i] = toPositiveDim(-(i + 1), rank); } } template Value PoolSizeCalculator::getPoolSize( - OpBuilder &b, SmallVectorImpl &kernelSizeIntValues, + OpBuilder &b, SmallVectorImpl &kernelDimSizes, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts) { Value poolSize; Value cstZero = b.createOrFold(location, b.getI64IntegerAttr(0)); + Value cstOne = + b.createOrFold(location, b.getI64IntegerAttr(1)); + Value cstTwo = + b.createOrFold(location, b.getI64IntegerAttr(2)); for (int i = 0; i < NumOfDims; ++i) { - // See the link below for the PyTorch implementation where this is - // derived from: - // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 - // Dim below stands for spatial dimension. Prior to the February 2025 - // change, these variables used "height" and "width" (or "h" and "w") - // in these intermediate variables instead of "Dim". - Value IndexODim = + // The following code computes the clamped kernel size used to compute + // the divisor of the average pooling operator. Here is the formula that + // it represents: + // + // indexStartOffset = ceil((kernelSize - 1)/2) - padding + // + // clampedKernelSize = + // min(outIntIndex * stride + indexStartOffset + floor((kernelSize - 1)/2) + // + 1, + // InputSpatialDimSize + padding) - + // max(outIntIndex * stride + indexStartOffset - ceil((kernelSize - 1)/2), + // -padding) + // + // The outIntIndex is the current iteration value coming from the + // linalg.generic op and it represents the center of the kernel window. + // The padding above becomes zero if count_include_pad is false. + // The kernelSize - 1 is used to subtract the center element of the kernel + // from the kernel size before dividing by two. Note that PyTorch even + // kernel dimensions are biased to the lower side of the dimension. Hence + // the lower length uses ceiling. While the upper length uses floor. + // + // If count_include_pad is true, in most cases the divisor is just the + // product of kernel dimensions. But we still need this logic for the + // case in which the ceiling mode is true since the kernel window + // center can go into the padding outside of the input tensor. This + // introduces an implicit padding that is not controlled by the + // count_include_pad parameter. See the + // AvgPool2dCeilPaddingStridedIncludePadding E2E test for details. + + // The average pool properties of kernel size, strides, and padding are + // stored in the reverse order of the input tensor dimensions. The + // following code computes the index of the average pool property that + // corresponds to the current spatial dimension. + int avgPoolPropIdx = NumOfDims - i - 1; + + Value padding = b.createOrFold( + location, b.getI64IntegerAttr(paddingInts[avgPoolPropIdx])); + Value InputSpatialDimSize = + castIndexToInt64(b, location, InputSpatialDimSizes[i]); + // Subtract center element from kernel size before division by two. + Value kernelSizeMinusOne = b.createOrFold( + location, kernelDimSizes[avgPoolPropIdx], cstOne); + // PyTorch even kernel dimensions are biased to the lower side of the + // dimension. Hence the lower lenght uses ceiling. + Value kernelLowerLength = b.createOrFold( + location, kernelSizeMinusOne, cstTwo); + // While the upper length uses floor. + Value kernelUpperLength = b.createOrFold( + location, kernelSizeMinusOne, cstTwo); + + // The more padding the closest we can read from the lower bound of + // the input tensor. + Value indexStartOffset = + b.createOrFold(location, kernelLowerLength, padding); + + Value outIndex = b.create(location, - /*value=*/DimSizeFromSumPoolType[i]); - Value ODim = castIndexToInt64(b, location, IndexODim); - Value DDim = b.createOrFold( - location, b.getI64IntegerAttr(strideInts[i])); - Value PadDim = b.createOrFold( - location, b.getI64IntegerAttr(paddingInts[i])); - Value ODimDDim = b.createOrFold(location, ODim, DDim); - Value IDim0 = b.createOrFold(location, ODimDDim, PadDim); - Value IDim = castIndexToInt64(b, location, InputSpatialDimValues[i]); - Value IDim0KDim = - b.createOrFold(location, IDim0, kernelSizeIntValues[i]); - Value IDimPadDim = b.createOrFold(location, IDim, PadDim); - Value IDim1 = - b.createOrFold(location, IDim0KDim, IDimPadDim); - - Value IDim0Clamped = - b.createOrFold(location, IDim0, cstZero); - Value IDim1Clamped = b.createOrFold(location, IDim1, IDim); - Value IDim1_IDim0_Clamped = - b.createOrFold(location, IDim1Clamped, IDim0Clamped); + /*value=*/SumPoolTypeDimIndex[i]); + Value outIntIndex = castIndexToInt64(b, location, outIndex); + + Value stride = b.createOrFold( + location, b.getI64IntegerAttr(strideInts[avgPoolPropIdx])); + + Value indexStrided = b.createOrFold( + location, b.createOrFold(location, outIntIndex, stride), + indexStartOffset); + + Value inputUpperBound = isCountIncludePad + ? b.createOrFold( + location, InputSpatialDimSize, padding) + : InputSpatialDimSize; + + Value inputLowerBound = + isCountIncludePad + ? b.createOrFold(location, cstZero, padding) + : cstZero; + + Value upperBoundMinusOne = b.createOrFold( + location, indexStrided, kernelUpperLength); + Value upperBound = + b.createOrFold(location, upperBoundMinusOne, cstOne); + Value upperBoundClamped = + b.createOrFold(location, upperBound, inputUpperBound); + + Value lowerBound = b.createOrFold(location, indexStrided, + kernelLowerLength); + Value lowerBoundClamped = + b.createOrFold(location, lowerBound, inputLowerBound); + Value clampedKernelSize = b.createOrFold( + location, upperBoundClamped, lowerBoundClamped); + if (i == 0) { - poolSize = IDim1_IDim0_Clamped; + poolSize = clampedKernelSize; } else { - poolSize = b.createOrFold(location, poolSize, - IDim1_IDim0_Clamped); + poolSize = + b.createOrFold(location, poolSize, clampedKernelSize); } } return poolSize; @@ -961,10 +1031,10 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { // count_include_pad parameter is equal to false. static std::optional createAvgPoolValueCountIncludePadFalseCase( - bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter, Value self, Value sumPool, - Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, + bool ceilMode, bool countIncludePad, OpTy op, + typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter, + Value self, Value sumPool, Value outputTensor, Type resultType, + SmallVectorImpl &kernelDimSizes, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, SmallVector &indexingMapsAvg, @@ -976,7 +1046,7 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { OpTy op, typename OpTy::Adaptor &adaptor, ConversionPatternRewriter &rewriter, Value self, Value sumPool, Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &kernelDimSizes, SmallVector &indexingMapsAvg, SmallVector &iteratorTypesAvg); }; @@ -1041,9 +1111,9 @@ LogicalResult ConvertAtenAvgPoolOp::matchAndRewrite( Dim + 2, utils::IteratorType::parallel); auto divisorOpResult = createAvgPoolValueCountIncludePadFalseCase( - countIncludePad, op, adaptor, rewriter, self, sumPool, outputTensor, - resultType, kernelSizeIntValues, strideInts, paddingInts, indexingMapsAvg, - iteratorTypesAvg); + ceilMode, countIncludePad, op, adaptor, rewriter, self, sumPool, + outputTensor, resultType, kernelSizeIntValues, strideInts, paddingInts, + indexingMapsAvg, iteratorTypesAvg); if (divisorOpResult) return *divisorOpResult; @@ -1057,10 +1127,10 @@ LogicalResult ConvertAtenAvgPoolOp::matchAndRewrite( template std::optional ConvertAtenAvgPoolOp:: createAvgPoolValueCountIncludePadFalseCase( - bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter, Value self, Value sumPool, - Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, + bool ceilMode, bool countIncludePad, OpTy op, + typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter, + Value self, Value sumPool, Value outputTensor, Type resultType, + SmallVectorImpl &kernelDimSizes, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, SmallVector &indexingMapsAvg, @@ -1069,8 +1139,37 @@ std::optional ConvertAtenAvgPoolOp:: constexpr int avgPoolDims = getAvgPoolNumOfDims(); - bool noPadding = llvm::all_of(paddingInts, [](int64_t p) { return p == 0; }); - if (countIncludePad || noPadding) { + bool hasPadding = + !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; }); + bool allStridesUnitary = + llvm::all_of(strideInts, [](int64_t s) { return s == 1; }); + + // If the condition below is true, the divisor total must subtract the + // elements not counted (clamped divisor count). If false, the divisor + // is just the product of kernel dimensions. + bool divisorIsClamped = + (!countIncludePad && hasPadding) || (ceilMode && !allStridesUnitary); + // There are two ways to get the divisor clamped: through padding or + // ceiling mode. For the case when there is padding, the padding elements + // are omitted if count_include_pad == False (divisor is clamped). If + // there is no padding (padding == 0) then the count_include_pad value + // does not take effect. + // The divisor count can be clamped also through the ceil_mode. In this + // case, according to the Hout and Wout formula in this page: + // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d, + // the ceil_mode will round up on the stride division. The round up + // will give an extra element that will go out of bounds which PyTorch + // adds zero padding in it. It also does not count the implicit zero + // padding elements in the divisor, and it is not controlled by the + // count_include_pad argument. + // But also note that if all strides are 1 there is not fractions to + // round up, hence there is no ceiling rounding and the window will + // not go out of bounds. For this case the divisor is just the + // product of kernel dimensions. + // Search for torch.nn.AvgPool2d E2E tests for coverage of these + // conditions. + + if (!divisorIsClamped) { // These cases are not handled here. return std::nullopt; } @@ -1082,8 +1181,8 @@ std::optional ConvertAtenAvgPoolOp:: Type resultElementType = cast(resultType).getElementType(); - PoolSizeCalculator poolSizeCalculator(self, sumPool, rewriter, - loc); + PoolSizeCalculator poolSizeCalculator( + self, sumPool, countIncludePad, ceilMode, rewriter, loc); // AtenAvgPool2/3dOp has an optional divisor_override // attribute while AtenAvgPool1dOp does not. @@ -1104,7 +1203,7 @@ std::optional ConvertAtenAvgPoolOp:: [&](OpBuilder &b, Location loc, ValueRange args) { if (!poolSize) { poolSize = poolSizeCalculator.getPoolSize( - b, kernelSizeIntValues, strideInts, paddingInts); + b, kernelDimSizes, strideInts, paddingInts); } Value divisor = convertScalarToDtype(b, loc, poolSize, resultElementType); @@ -1126,17 +1225,17 @@ LogicalResult ConvertAtenAvgPoolOp:: OpTy op, typename OpTy::Adaptor &adaptor, ConversionPatternRewriter &rewriter, Value self, Value sumPool, Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &kernelDimSizes, SmallVector &indexingMapsAvg, SmallVector &iteratorTypesAvg) { Location loc = op->getLoc(); Type resultElementType = cast(resultType).getElementType(); - Value divisor = kernelSizeIntValues[0]; - for (uint32_t i = 1; i < kernelSizeIntValues.size(); ++i) { - divisor = rewriter.createOrFold(loc, divisor, - kernelSizeIntValues[i]); + Value divisor = kernelDimSizes[0]; + for (uint32_t i = 1; i < kernelDimSizes.size(); ++i) { + divisor = + rewriter.createOrFold(loc, divisor, kernelDimSizes[i]); } // Only average pooling 2D/3D have optional divisor override. if constexpr (!std::is_same()) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 42d7e01f9468..20b9a4749062 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -650,6 +650,13 @@ "Aten_EmbeddingBagExample_basic", "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", + "AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa_basic", + "AvgPool2dCeilPadNonUnitaryStrides_basic", + "AvgPool2dCeilNoPadStridedIncludePadding_basic", + "AvgPool2dCeilPaddingStridedIncludePadding_basic", + "AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded_basic", + "AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic", + "AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic", "AvgPool2dDivisorOverrideModule_basic", "BernoulliTensorModule_basic", "BincountMinlengthModule_basic", @@ -2791,6 +2798,10 @@ "AvgPool2dSingleIntTupleParamsIncludePadModule_basic", "AvgPool2dSingleIntTupleParamsModule_basic", "AvgPool2dWithoutPadModule_basic", + "AvgPool1dNoPadCeilPadNotIncluded_basic", + "AvgPool1dPadCeilPadNotIncluded_basic", + "AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded_basic", + "AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic", "BatchMlpLayerModule_basic", "BincountMinlengthModule_basic", "BincountModule_basic", @@ -3533,6 +3544,11 @@ "AvgPool1dIntModule_basic", "AvgPool1dStaticModule_basic", "AvgPool2dCeilModeTrueModule_basic", + "AvgPool2dCeilPaddingStridedIncludePadding_basic", + "AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic", + "AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic", + "AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic", + "AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic", "AvgPool2dDivisorOverrideModule_basic", "AvgPool2dFloatModule_basic", "AvgPool2dIntModule_basic", @@ -3939,6 +3955,9 @@ "AtenKthvalueFloat64Module_basic", "AtenKthvalueKeepDimModule_basic", "AtenKthvalueModule_basic", + "AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa_basic", + "AvgPool2dCeilNoPadUnitaryStrides_basic", + "AvgPool2dCeilPadNonUnitaryStrides_basic", "AvgPool2dCountIncludePadFalseStaticModule_basic", "AvgPool3dStaticModule_basic", "Conv_Transpose1dModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 4a43b99033c1..ed8ec0faefa4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -2514,3 +2514,494 @@ def MaxUnpool3dModulePad0_basic(module, tu: TestUtils): output, indices = pool(input) module.forward(output, indices) + + +class AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa(torch.nn.Module): + # This test captures the torch-mlir issue reported here: + # https://github.com/llvm/torch-mlir/issues/4079 + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[2, 2], + padding=[0, 0], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa()) +def AvgPool2dCeilNoPadNonUnitaryStridesIreeSwa_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilNoPadUnitaryStrides(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[0, 0], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilNoPadUnitaryStrides()) +def AvgPool2dCeilNoPadUnitaryStrides_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilPadNonUnitaryStrides(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[2, 2], + padding=[1, 1], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilPadNonUnitaryStrides()) +def AvgPool2dCeilPadNonUnitaryStrides_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilNoPadStridedIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[2, 2], + padding=[0, 0], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilNoPadStridedIncludePadding()) +def AvgPool2dCeilNoPadStridedIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilNoPadUnitaryStrideIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[0, 0], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dCeilNoPadUnitaryStrideIncludePadding() +) +def AvgPool2dCeilNoPadUnitaryStrideIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilPaddingUnitaryStrideIncludePaddingFalse(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dCeilPaddingUnitaryStrideIncludePaddingFalse() +) +def AvgPool2dCeilPaddingUnitaryStrideIncludePaddingFalse_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dFloorNoPadUnitaryStrideIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[0, 0], + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dFloorNoPadUnitaryStrideIncludePadding() +) +def AvgPool2dFloorNoPadUnitaryStrideIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dFloorPaddingUnitaryStrideIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dFloorPaddingUnitaryStrideIncludePadding() +) +def AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilPaddingUnitaryStrideIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dCeilPaddingUnitaryStrideIncludePadding() +) +def AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilPaddingStridedIncludePadding(torch.nn.Module): + # Note that in this case the kernel window center will go into the padding. + # When this happens the padding elements are counted in the divisor, but + # the out of bound elements from the ceiling are not counted + # (i.e., clamped from the divisor count). + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[2, 2], + padding=[1, 1], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilPaddingStridedIncludePadding()) +def AvgPool2dCeilPaddingStridedIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded(torch.nn.Module): + # Different sizes used for each kernel and stride.dimensions. No padding. + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 2], + stride=[2, 3], + padding=[0, 0], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 3, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded() +) +def AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 4, low=-1)) + + +class AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded(torch.nn.Module): + # Different sizes used for each kernel, stride, and padding.dimensions. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 4], + stride=[2, 3], + padding=[1, 2], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 3, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded() +) +def AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 4, low=-1)) + + +class AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded(torch.nn.Module): + # 3D version of AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool3d( + kernel_size=[3, 2, 4], + stride=[3, 2, 5], + padding=[0, 0, 0], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 5, 7], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded() +) +def AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 5, 7, low=-1)) + + +class AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded(torch.nn.Module): + # 3-D version of AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool3d( + kernel_size=[3, 4, 7], + stride=[2, 3, 4], + padding=[1, 2, 3], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 3, 4, 7], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded() +) +def AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 4, 7, low=-1)) + + +class AvgPool1dNoPadCeilPadNotIncluded(torch.nn.Module): + # 1D version of AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool1d( + kernel_size=[2], + stride=[2], + padding=[1], + ceil_mode=True, + count_include_pad=False, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 5], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool1dNoPadCeilPadNotIncluded()) +def AvgPool1dNoPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 5, low=-1)) + + +class AvgPool1dPadCeilPadNotIncluded(torch.nn.Module): + # 1-D version of AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool1d( + kernel_size=[2], + stride=[2], + padding=[1], + ceil_mode=True, + count_include_pad=False, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 3], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool1dPadCeilPadNotIncluded()) +def AvgPool1dPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, low=-1)) diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index c065e624efa9..91043b83728a 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -126,7 +126,7 @@ func.func @forward_avg_pool2d_countincludepad_false(%arg0: !torch.vtensor<[1,3,6 // CHECK: linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<[1, 2]> : vector<2xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x3x64x58xf32>, tensor<4x5xf32>) outs(%[[OUT1:.*]] : tensor<1x3x61x27xf32>) -> tensor<1x3x61x27xf32> // CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x61x27xf32>) outs(%[[OUT2:.*]] : tensor<1x3x61x27xf32>) // CHECK-NEXT: ^bb0(%[[BIIN1:.*]]: f32, %[[BOUT1:.*]]: f32): - // CHECK-COUNT-4: arith.minsi + // CHECK-COUNT-1: arith.minsi // CHECK-COUNT-1: arith.divf // CHECK: linalg.yield %[[TMP1:.*]] : f32 // CHECK-NEXT: } -> tensor<1x3x61x27xf32> @@ -179,7 +179,7 @@ func.func @forward_avg_pool3dd_countincludepad_false(%arg0: !torch.vtensor<[1,3, // CHECK: linalg.pooling_ndhwc_sum {dilations = dense<1> : vector<3xi64>, strides = dense<[1, 2, 1]> : vector<3xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x7x66x58x3xf32>, tensor<4x5x5xf32>) outs(%[[OUT1:.*]] : tensor<1x4x31x54x3xf32>) -> tensor<1x4x31x54x3xf32> // CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x4x31x54xf32>) outs(%[[OUT2:.*]] : tensor<1x3x4x31x54xf32>) // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): - // CHECK-COUNT-6: arith.minsi + // CHECK-COUNT-3: arith.minsi // CHECK-COUNT-1: arith.divf // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 // CHECK-NEXT: } -> tensor<1x3x4x31x54xf32> @@ -221,7 +221,7 @@ func.func @forward_avg_pool1d_countincludepad_false(%arg0: !torch.vtensor<[1,512 // CHECK: linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%[[IN1:.*]], %[[IN2:.*]] : tensor<1x512x12xf32>, tensor<1xf32>) outs(%[[OUT1:.*]] : tensor<1x512x12xf32>) -> tensor<1x512x12xf32> // CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[IN3:.*]] : tensor<1x512x12xf32>) outs(%[[OUT2:.*]] : tensor<1x512x12xf32> // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): - // CHECK-COUNT-2: arith.minsi + // CHECK-COUNT-1: arith.minsi // CHECK-COUNT-1: arith.divf // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 // CHECK-NEXT: } -> tensor<1x512x12xf32> @@ -233,3 +233,31 @@ func.func @forward_avg_pool1d_countincludepad_false(%arg0: !torch.vtensor<[1,512 %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %false : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,12],f32> return %3 : !torch.vtensor<[1,512,12],f32> } + +// CHECK-LABEL: func @forward_avgpool_2d_ceil +func.func @forward_avgpool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> { + // CHECK: %[[POOL_OUT:.*]] = linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins(%[[PADDED_IN:.*]], %[[KERNEL_IN:.*]] : tensor<1x1x6x6xf32>, tensor<3x3xf32>) outs(%[[OUT1:.*]] : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32> + // CHECK: linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[POOL_OUT]] : tensor<1x1x2x2xf32>) outs(%[[GEN_OUT:.*]] : tensor<1x1x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-COUNT-3: arith.muli + // CHECK-COUNT-1: arith.sitofp + // CHECK-COUNT-1: arith.divf + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x1x2x2xf32> + %int3 = torch.constant.int 3 + %int3_0 = torch.constant.int 3 + %int0 = torch.constant.int 0 + %int0_1 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int2_2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int1_3 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int3, %int3_0 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int0, %int0_1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int2, %int2_2, %int1, %int1_3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %true = torch.constant.bool true + %false = torch.constant.bool false + %none = torch.constant.none + %3 = torch.aten.avg_pool2d %arg0, %0, %2, %1, %true, %false, %none : !torch.vtensor<[1,1,4,4],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,2,2],f32> + return %3 : !torch.vtensor<[1,1,2,2],f32> +} From 3a0c98c66019d15ea452d8639393926241a3376a Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Mon, 28 Apr 2025 18:18:21 -0400 Subject: [PATCH 2/2] Fix xfail_sets.y --- projects/pt1/e2e_testing/xfail_sets.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 20b9a4749062..871e73ee9fd7 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3544,6 +3544,8 @@ "AvgPool1dIntModule_basic", "AvgPool1dStaticModule_basic", "AvgPool2dCeilModeTrueModule_basic", + "AvgPool1dNoPadCeilPadNotIncluded_basic", + "AvgPool1dPadCeilPadNotIncluded_basic", "AvgPool2dCeilPaddingStridedIncludePadding_basic", "AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic", "AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic",