diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 60e0694bf6bc..9797fef5e7e0 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -5764,6 +5764,35 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { : nhwcToNchw4DTransposeDims); } + void + unsqueezeInputOutputFor2dPool(RankedTensorType inputTy, Value &input, + Type &outputTy, Location loc, + ConversionPatternRewriter &rewriter) const { + // 1d pool AtenOps mapped to TosaOp will already have the data in 4D format, + // here we can have 3D data only if the AtenOp itself is a 2d pool op with + // data in HWC format. + + // Unsqueeze input tensor in HWC format to NHWC format to be + // compatible with tosa::AvgPool2dOp, batch is made explicitly 1. + SmallVector rank4Shape(inputTy.getShape()); + assert(inputTy.getRank() == 3 && + "Expected input to be atleast 3 dimensional."); + rank4Shape.insert(rank4Shape.begin(), 1); + input = rewriter.create( + loc, + RankedTensorType::get(makeShapeTorchCompatible(rank4Shape), + inputTy.getElementType()), + input, tosa::getTosaConstShape(rewriter, loc, rank4Shape)); + + // Unsqueeze output type + auto outRankedTy = cast(outputTy); + assert(outRankedTy.getRank() == 3 && + "Expected output rank to be same as input."); + SmallVector rank4ShapeOut(outRankedTy.getShape()); + rank4ShapeOut.insert(rank4ShapeOut.begin(), 1); + outputTy = outRankedTy.clone(rank4ShapeOut); + } + LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -5778,6 +5807,13 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "Failed to process inputs for pooling"); + // input has already been verified to be RankedTensorType + auto inputTy = cast(input.getType()); + if (inputTy.getRank() != 4) { + unsqueezeInputOutputFor2dPool(inputTy, input, outputTy, op->getLoc(), + rewriter); + } + Value pooledOutput; static_assert(std::is_same::value || std::is_same::value, @@ -5805,14 +5841,14 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { op, rewriter, pooledOutput); Value result = transposedOutput; - auto resultTy = dyn_cast( + auto resultTy = cast(result.getType()); + auto expectedResultTy = dyn_cast( OpConversionPattern::getTypeConverter()->convertType( op.getType())); - if constexpr (std::is_same() || - std::is_same()) { - auto resultShape = resultTy.getShape(); - auto resultElemTy = resultTy.getElementType(); + if (resultTy.getRank() != expectedResultTy.getRank()) { + auto resultShape = expectedResultTy.getShape(); + auto resultElemTy = expectedResultTy.getElementType(); result = rewriter.create( op->getLoc(), @@ -5823,7 +5859,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { makeShapeTorchCompatible(resultShape))); } - rewriter.replaceOpWithNewOp(op, resultTy, result); + rewriter.replaceOpWithNewOp(op, expectedResultTy, result); return success(); } @@ -5851,7 +5887,7 @@ class ConvertAtenAdaptivePoolingOp auto inputElemTy = inputTy.getElementType(); // Rank sanity check. - if (inputTy.getRank() != 4 && inputRank != 3) + if (inputRank != 4 && inputRank != 3) return rewriter.notifyMatchFailure( op, "NCHW->NHWC transpose requires 3D or 4D tensor"); @@ -5944,6 +5980,22 @@ static Type getOutputTypeForNonAdaptivePoolingOp( inputElemTy); } +template +void expandPoolParams(AtenOpT op, SmallVectorImpl ¶ms, + int64_t val) { + // Expand pooling parameter (kernel, stride) to size 2 to be compatible with + // tosa::MaxPool2dOp or tosa::AvgPool2dOp + if constexpr (std::is_same() || + std::is_same()) + params.push_back(val); + + if constexpr (std::is_same() || + std::is_same()) { + if (params.size() == 1) + params.push_back(params[0]); + } +} + // Checks the validity of pooling parameters and stores them in the respective // vector. Also, gets the output type for the pooling op. template @@ -5969,12 +6021,7 @@ static LogicalResult getOutputTypeAndPoolingParameters( m_TorchListOfConstantInts(kernelSizeInts))) return rewriter.notifyMatchFailure( op, "Non-const kernel_size for pooling op unsupported"); - - // Expand kernel size parameter to size 2 to be compatible with - // tosa::MaxPool2dOp or tosa::AvgPool2dOp - if constexpr (std::is_same() || - std::is_same()) - kernelSizeInts.push_back(1); + expandPoolParams(op, kernelSizeInts, 1); if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) return rewriter.notifyMatchFailure( @@ -5986,22 +6033,13 @@ static LogicalResult getOutputTypeAndPoolingParameters( if (strideInts.empty()) { strideInts.assign(kernelSizeInts); } else { - // Expand stride parameter to size 2 to be compatible with - // tosa::MaxPool2dOp or tosa::AvgPool2dOp - if constexpr (std::is_same() || - std::is_same()) - strideInts.push_back(1); + expandPoolParams(op, strideInts, 1); } if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts))) return rewriter.notifyMatchFailure( op, "Non-const padding factor for pooling op unsupported"); - - // Expand padding parameter to size 2 to be compatible with - // tosa::MaxPool2dOp or tosa::AvgPool2dOp - if constexpr (std::is_same() || - std::is_same()) - paddingInts.push_back(0); + expandPoolParams(op, paddingInts, 0); if constexpr (std::is_same() || std::is_same()) { @@ -6033,6 +6071,7 @@ static LogicalResult getOutputTypeAndPoolingParameters( return rewriter.notifyMatchFailure( op, "only support constant bool ceil_mode for pooling op"); + expandPoolParams(op, dilationArray, 1); outputTy = getOutputTypeForNonAdaptivePoolingOp( inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray, ceilMode); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3dd78cc011f6..315226856cc0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -405,6 +405,7 @@ "AtenNonzero1DDynamicModule_basic", # no lowering for torch.aten.sym_constrain_range_for_size "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", + "AvgPool2dCHWModule_basic", "QuantizedReluInt32_basic", "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", @@ -528,6 +529,8 @@ "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", "Aten_TrilinearModuleSumAllDims_basic", "Aten_TrilinearModuleSumdims_basic", + "AvgPool2dSingleIntTupleParamsIncludePadModule_basic", + "AvgPool2dSingleIntTupleParamsModule_basic", } FX_IMPORTER_STABLEHLO_XFAIL_SET = { @@ -952,6 +955,8 @@ } FX_IMPORTER_STABLEHLO_CRASHING_SET = { + "AvgPool2dSingleIntTupleParamsIncludePadModule_basic", + "AvgPool2dSingleIntTupleParamsModule_basic", "BatchNorm1DModule_basic", "BatchNorm2DModule_basic", "BatchNorm3DModule_basic", @@ -2756,6 +2761,9 @@ "AtenTopKModule_basic", "AtenTopKSmallestModule_basic", "Aten_EmbeddingBagExample_basic", + "AvgPool2dCHWModule_basic", + "AvgPool2dSingleIntTupleParamsIncludePadModule_basic", + "AvgPool2dSingleIntTupleParamsModule_basic", "AvgPool2dWithoutPadModule_basic", "BatchMlpLayerModule_basic", "BincountMinlengthModule_basic", @@ -3355,6 +3363,7 @@ "AtenSymConstrainRangeForSize_basic", "AtenSymConstrainRange_basic", "Aten_AssertScalar_basic", + "AvgPool2dSingleIntTupleParamsIncludePadModule_basic", "ScatterAddDynamicModule_basic", "UniformModule_basic", "UniformStaticShapeModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index e2eaa4cfd0fe..6ba79d0ef24f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -1428,6 +1428,84 @@ def AvgPool2dWithoutPadModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0)) +class AvgPool2dCHWModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[6, 8], + stride=[2, 2], + ) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCHWModule()) +def AvgPool2dCHWModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 20, 20, low=0.5, high=1.0)) + + +class AvgPool2dSingleIntTupleParamsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=(6,), + stride=(2,), + padding=(1,), + count_include_pad=False, + ) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dSingleIntTupleParamsModule()) +def AvgPool2dSingleIntTupleParamsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0)) + + +class AvgPool2dSingleIntTupleParamsIncludePadModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=(6,), + stride=(2,), + padding=(1,), + count_include_pad=True, + ) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dSingleIntTupleParamsIncludePadModule() +) +def AvgPool2dSingleIntTupleParamsIncludePadModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0)) + + # ============================================================================== diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 546d7e6f0bba..c9fddb5fa6a5 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2259,6 +2259,45 @@ func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torc return %3 : !torch.vtensor<[1,192,35,35],f32> } +// ----- +// CHECK-LABEL: func.func @avgPool2dCHWInput( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,64,56],f32>) -> !torch.vtensor<[1,59,51],f32> { +// CHECK: %[[TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,64,56],f32> -> tensor<1x64x56xf32> +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: %[[C0:.*]] = torch.constant.int 0 +// CHECK: %[[C1:.*]] = torch.constant.int 1 +// CHECK: %[[C6:.*]] = torch.constant.int 6 +// CHECK: %[[L1:.*]] = torch.prim.ListConstruct %[[C6]], %[[C6]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[L2:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[L3:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[PERMS_IN:.*]] = "tosa.const"() <{value = dense<[1, 2, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[TRANSPOSE_IN:.*]] = tosa.transpose %[[TENSOR]], %[[PERMS_IN]] : (tensor<1x64x56xf32>, tensor<3xi32>) -> tensor<64x56x1xf32> +// CHECK: %[[CONST_SHAPE_IN:.*]] = tosa.const_shape {value = dense<[1, 64, 56, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[RESHAPE_IN:.*]] = tosa.reshape %[[TRANSPOSE_IN]], %[[CONST_SHAPE_IN]] : (tensor<64x56x1xf32>, !tosa.shape<4>) -> tensor<1x64x56x1xf32> +// CHECK: %[[POOL:.*]] = tosa.avg_pool2d %[[RESHAPE_IN]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x64x56x1xf32>) -> tensor<1x59x51x1xf32> +// CHECK: %[[PERMS_OUT:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[TRANSPOSE_OUT:.*]] = tosa.transpose %[[POOL]], %[[PERMS_OUT]] : (tensor<1x59x51x1xf32>, tensor<4xi32>) -> tensor<1x1x59x51xf32> +// CHECK: %[[CONST_SHAPE_OUT:.*]] = tosa.const_shape {value = dense<[1, 59, 51]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[RESHAPE_OUT:.*]] = tosa.reshape %[[TRANSPOSE_OUT]], %[[CONST_SHAPE_OUT]] : (tensor<1x1x59x51xf32>, !tosa.shape<3>) -> tensor<1x59x51xf32> +// CHECK: %[[CAST:.*]] = tensor.cast %[[RESHAPE_OUT]] : tensor<1x59x51xf32> to tensor<1x59x51xf32> +// CHECK: %[[TORCH:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<1x59x51xf32> -> !torch.vtensor<[1,59,51],f32> +// CHECK: return %[[TORCH]] +func.func @avgPool2dCHWInput(%arg0: !torch.vtensor<[1,64,56],f32>) -> !torch.vtensor<[1,59,51],f32> { + %none = torch.constant.none + %false = torch.constant.bool false + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int6 = torch.constant.int 6 + %0 = torch.prim.ListConstruct %int6, %int6 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %true, %false, %none : !torch.vtensor<[1,64,56],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,59,51],f32> + return %3 : !torch.vtensor<[1,59,51],f32> + } + // ----- // CHECK-LABEL: func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> {