Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Task] : Handle CHW input for avgpool2d. #4042

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 63 additions & 24 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5764,6 +5764,35 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
: nhwcToNchw4DTransposeDims);
}

void
unsqueezeInputOutputFor2dPool(RankedTensorType inputTy, Value &input,
Type &outputTy, Location loc,
ConversionPatternRewriter &rewriter) const {
// 1d pool AtenOps mapped to TosaOp will already have the data in 4D format,
// here we can have 3D data only if the AtenOp itself is a 2d pool op with
// data in HWC format.

// Unsqueeze input tensor in HWC format to NHWC format to be
// compatible with tosa::AvgPool2dOp, batch is made explicitly 1.
SmallVector<int64_t> rank4Shape(inputTy.getShape());
assert(inputTy.getRank() == 3 &&
"Expected input to be atleast 3 dimensional.");
rank4Shape.insert(rank4Shape.begin(), 1);
input = rewriter.create<tosa::ReshapeOp>(
loc,
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
inputTy.getElementType()),
input, tosa::getTosaConstShape(rewriter, loc, rank4Shape));

// Unsqueeze output type
auto outRankedTy = cast<RankedTensorType>(outputTy);
assert(outRankedTy.getRank() == 3 &&
"Expected output rank to be same as input.");
SmallVector<int64_t> rank4ShapeOut(outRankedTy.getShape());
rank4ShapeOut.insert(rank4ShapeOut.begin(), 1);
outputTy = outRankedTy.clone(rank4ShapeOut);
}

LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand All @@ -5778,6 +5807,13 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
return rewriter.notifyMatchFailure(
op, "Failed to process inputs for pooling");

// input has already been verified to be RankedTensorType
auto inputTy = cast<RankedTensorType>(input.getType());
if (inputTy.getRank() != 4) {
unsqueezeInputOutputFor2dPool(inputTy, input, outputTy, op->getLoc(),
rewriter);
}

Value pooledOutput;
static_assert(std::is_same<TosaOpT, tosa::MaxPool2dOp>::value ||
std::is_same<TosaOpT, tosa::AvgPool2dOp>::value,
Expand Down Expand Up @@ -5805,14 +5841,14 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
op, rewriter, pooledOutput);

Value result = transposedOutput;
auto resultTy = dyn_cast<TensorType>(
auto resultTy = cast<TensorType>(result.getType());
auto expectedResultTy = dyn_cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));

if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool1dOp>()) {
auto resultShape = resultTy.getShape();
auto resultElemTy = resultTy.getElementType();
if (resultTy.getRank() != expectedResultTy.getRank()) {
auto resultShape = expectedResultTy.getShape();
auto resultElemTy = expectedResultTy.getElementType();

result = rewriter.create<tosa::ReshapeOp>(
op->getLoc(),
Expand All @@ -5823,7 +5859,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
makeShapeTorchCompatible(resultShape)));
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, result);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, expectedResultTy, result);

return success();
}
Expand Down Expand Up @@ -5851,7 +5887,7 @@ class ConvertAtenAdaptivePoolingOp
auto inputElemTy = inputTy.getElementType();

// Rank sanity check.
if (inputTy.getRank() != 4 && inputRank != 3)
if (inputRank != 4 && inputRank != 3)
return rewriter.notifyMatchFailure(
op, "NCHW->NHWC transpose requires 3D or 4D tensor");

Expand Down Expand Up @@ -5944,6 +5980,22 @@ static Type getOutputTypeForNonAdaptivePoolingOp(
inputElemTy);
}

template <typename AtenOpT>
void expandPoolParams(AtenOpT op, SmallVectorImpl<int64_t> &params,
int64_t val) {
// Expand pooling parameter (kernel, stride) to size 2 to be compatible with
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool1dOp>())
params.push_back(val);

if constexpr (std::is_same<AtenOpT, AtenMaxPool2dOp>() ||
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
if (params.size() == 1)
params.push_back(params[0]);
}
}

// Checks the validity of pooling parameters and stores them in the respective
// vector. Also, gets the output type for the pooling op.
template <typename AtenOpT, typename tosaOp>
Expand All @@ -5969,12 +6021,7 @@ static LogicalResult getOutputTypeAndPoolingParameters(
m_TorchListOfConstantInts(kernelSizeInts)))
return rewriter.notifyMatchFailure(
op, "Non-const kernel_size for pooling op unsupported");

// Expand kernel size parameter to size 2 to be compatible with
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool1dOp>())
kernelSizeInts.push_back(1);
expandPoolParams(op, kernelSizeInts, 1);

if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
return rewriter.notifyMatchFailure(
Expand All @@ -5986,22 +6033,13 @@ static LogicalResult getOutputTypeAndPoolingParameters(
if (strideInts.empty()) {
strideInts.assign(kernelSizeInts);
} else {
// Expand stride parameter to size 2 to be compatible with
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool1dOp>())
strideInts.push_back(1);
expandPoolParams(op, strideInts, 1);
}

if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts)))
return rewriter.notifyMatchFailure(
op, "Non-const padding factor for pooling op unsupported");

// Expand padding parameter to size 2 to be compatible with
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool1dOp>())
paddingInts.push_back(0);
expandPoolParams(op, paddingInts, 0);

if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
Expand Down Expand Up @@ -6033,6 +6071,7 @@ static LogicalResult getOutputTypeAndPoolingParameters(
return rewriter.notifyMatchFailure(
op, "only support constant bool ceil_mode for pooling op");

expandPoolParams(op, dilationArray, 1);
outputTy = getOutputTypeForNonAdaptivePoolingOp<AtenOpT, tosaOp>(
inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray,
ceilMode);
Expand Down
9 changes: 9 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@
"AtenNonzero1DDynamicModule_basic", # no lowering for torch.aten.sym_constrain_range_for_size
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"AvgPool2dCHWModule_basic",
"QuantizedReluInt32_basic",
"QuantizedReluInt8_basic",
"QuantizedReluUint8_basic",
Expand Down Expand Up @@ -528,6 +529,8 @@
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleSumdims_basic",
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
"AvgPool2dSingleIntTupleParamsModule_basic",
}

FX_IMPORTER_STABLEHLO_XFAIL_SET = {
Expand Down Expand Up @@ -952,6 +955,8 @@
}

FX_IMPORTER_STABLEHLO_CRASHING_SET = {
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
"AvgPool2dSingleIntTupleParamsModule_basic",
"BatchNorm1DModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
Expand Down Expand Up @@ -2756,6 +2761,9 @@
"AtenTopKModule_basic",
"AtenTopKSmallestModule_basic",
"Aten_EmbeddingBagExample_basic",
"AvgPool2dCHWModule_basic",
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
"AvgPool2dSingleIntTupleParamsModule_basic",
"AvgPool2dWithoutPadModule_basic",
"BatchMlpLayerModule_basic",
"BincountMinlengthModule_basic",
Expand Down Expand Up @@ -3355,6 +3363,7 @@
"AtenSymConstrainRangeForSize_basic",
"AtenSymConstrainRange_basic",
"Aten_AssertScalar_basic",
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
"ScatterAddDynamicModule_basic",
"UniformModule_basic",
"UniformStaticShapeModule_basic",
Expand Down
78 changes: 78 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,6 +1428,84 @@ def AvgPool2dWithoutPadModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))


class AvgPool2dCHWModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(
kernel_size=[6, 8],
stride=[2, 2],
)

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, x):
return self.ap2d(x)


@register_test_case(module_factory=lambda: AvgPool2dCHWModule())
def AvgPool2dCHWModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 20, 20, low=0.5, high=1.0))


class AvgPool2dSingleIntTupleParamsModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(
kernel_size=(6,),
stride=(2,),
padding=(1,),
count_include_pad=False,
)

@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.float32, True),
]
)
def forward(self, x):
return self.ap2d(x)


@register_test_case(module_factory=lambda: AvgPool2dSingleIntTupleParamsModule())
def AvgPool2dSingleIntTupleParamsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))


class AvgPool2dSingleIntTupleParamsIncludePadModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(
kernel_size=(6,),
stride=(2,),
padding=(1,),
count_include_pad=True,
)

@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.float32, True),
]
)
def forward(self, x):
return self.ap2d(x)


@register_test_case(
module_factory=lambda: AvgPool2dSingleIntTupleParamsIncludePadModule()
)
def AvgPool2dSingleIntTupleParamsIncludePadModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))


# ==============================================================================


Expand Down
39 changes: 39 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2259,6 +2259,45 @@ func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torc
return %3 : !torch.vtensor<[1,192,35,35],f32>
}

// -----
// CHECK-LABEL: func.func @avgPool2dCHWInput(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,64,56],f32>) -> !torch.vtensor<[1,59,51],f32> {
// CHECK: %[[TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,64,56],f32> -> tensor<1x64x56xf32>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C6:.*]] = torch.constant.int 6
// CHECK: %[[L1:.*]] = torch.prim.ListConstruct %[[C6]], %[[C6]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[L2:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[L3:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PERMS_IN:.*]] = "tosa.const"() <{value = dense<[1, 2, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK: %[[TRANSPOSE_IN:.*]] = tosa.transpose %[[TENSOR]], %[[PERMS_IN]] : (tensor<1x64x56xf32>, tensor<3xi32>) -> tensor<64x56x1xf32>
// CHECK: %[[CONST_SHAPE_IN:.*]] = tosa.const_shape {value = dense<[1, 64, 56, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: %[[RESHAPE_IN:.*]] = tosa.reshape %[[TRANSPOSE_IN]], %[[CONST_SHAPE_IN]] : (tensor<64x56x1xf32>, !tosa.shape<4>) -> tensor<1x64x56x1xf32>
// CHECK: %[[POOL:.*]] = tosa.avg_pool2d %[[RESHAPE_IN]] {acc_type = f32, kernel = array<i64: 6, 6>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x64x56x1xf32>) -> tensor<1x59x51x1xf32>
// CHECK: %[[PERMS_OUT:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[TRANSPOSE_OUT:.*]] = tosa.transpose %[[POOL]], %[[PERMS_OUT]] : (tensor<1x59x51x1xf32>, tensor<4xi32>) -> tensor<1x1x59x51xf32>
// CHECK: %[[CONST_SHAPE_OUT:.*]] = tosa.const_shape {value = dense<[1, 59, 51]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[RESHAPE_OUT:.*]] = tosa.reshape %[[TRANSPOSE_OUT]], %[[CONST_SHAPE_OUT]] : (tensor<1x1x59x51xf32>, !tosa.shape<3>) -> tensor<1x59x51xf32>
// CHECK: %[[CAST:.*]] = tensor.cast %[[RESHAPE_OUT]] : tensor<1x59x51xf32> to tensor<1x59x51xf32>
// CHECK: %[[TORCH:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<1x59x51xf32> -> !torch.vtensor<[1,59,51],f32>
// CHECK: return %[[TORCH]]
func.func @avgPool2dCHWInput(%arg0: !torch.vtensor<[1,64,56],f32>) -> !torch.vtensor<[1,59,51],f32> {
%none = torch.constant.none
%false = torch.constant.bool false
%true = torch.constant.bool true
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int6 = torch.constant.int 6
%0 = torch.prim.ListConstruct %int6, %int6 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %true, %false, %none : !torch.vtensor<[1,64,56],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,59,51],f32>
return %3 : !torch.vtensor<[1,59,51],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> {
Expand Down
Loading