Skip to content

Commit

Permalink
Addressing feedback from Sayan (sahas3).
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan Garcia committed Feb 14, 2025
1 parent cfae304 commit a98a5bf
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 69 deletions.
115 changes: 50 additions & 65 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ class ConvertAtenMaxUnpool3dOp final
} // namespace

namespace {
// The following structures and the adsfdasf method
// The following structures and the getNumOfDims method
// are used to get the number of dimensions from the
// average pooling type at compile time.
template <typename OpTy> struct AtenAvgPoolTypeNumOfDims {
Expand All @@ -851,31 +851,6 @@ template <typename OpTy> constexpr int getAvgPoolNumOfDims() {
}
} // namespace

namespace {
// This structure, used solely in PoolSizeCalculator, provides
// the intermediate values for each dimension to compute the
// divisor of the average pooling operator.
struct PoolSizeValues {
int64_t SpatialDimsInt64;
int64_t DimSpatialInt;
Value InputSpatialDimValues;
Value IndexODim;
Value ODim;
Value DDim;
Value PadDim;
Value ODimDDim;
Value IDim0;
Value IDim;
Value IDim0KDim;
Value IDimPadDim;
Value IDim1;
Value IDim1IDims0;
Value IDim0Clamped;
Value IDim1Clamped;
Value IDim1_IDim0;
};
} // namespace

namespace {
// This is a helper class to create the pooling size value
// used in the divisor of the average pooling operator.
Expand All @@ -885,7 +860,7 @@ template <int NumOfDims> class PoolSizeCalculator {
ConversionPatternRewriter &rewriter, Location loc);

// The algorithm for computing the divisor with
// count_include_pad is manily based on pytorch
// count_include_pad equal is mainly based on pytorch
// implementation. The following code is comment
// with pytorch code.
// https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78
Expand All @@ -896,8 +871,8 @@ template <int NumOfDims> class PoolSizeCalculator {
SmallVectorImpl<int64_t> &paddingInts);

private:
PoolSizeValues dims[NumOfDims];
ConversionPatternRewriter &rewriterHandle;
int64_t DimSizeFromSumPoolType[NumOfDims];
Value InputSpatialDimValues[NumOfDims];
Location location;
};

Expand All @@ -907,7 +882,7 @@ template <int NumOfDims>
PoolSizeCalculator<NumOfDims>::PoolSizeCalculator(
Value self, Value sumPool, ConversionPatternRewriter &rewriter,
Location loc)
: rewriterHandle(rewriter), location(loc) {
: location(loc) {
auto selfType = cast<RankedTensorType>(self.getType());
const int64_t selfRank = selfType.getRank();
RankedTensorType sumPoolType = cast<RankedTensorType>(sumPool.getType());
Expand All @@ -916,10 +891,10 @@ PoolSizeCalculator<NumOfDims>::PoolSizeCalculator(
// Store dimensions in this order:
// 0 => width, 1 => height, 2 => depth
for (int i = 0; i < NumOfDims; ++i) {
dims[i].SpatialDimsInt64 = toPositiveDim(-(i + 1), selfRank);
dims[i].InputSpatialDimValues =
getDimOp(rewriterHandle, location, self, dims[i].SpatialDimsInt64);
dims[i].DimSpatialInt = toPositiveDim(-(i + 1), rank);
int64_t DimSizeFromSelfType = toPositiveDim(-(i + 1), selfRank);
InputSpatialDimValues[i] =
getDimOp(rewriter, location, self, DimSizeFromSelfType);
DimSizeFromSumPoolType[i] = toPositiveDim(-(i + 1), rank);
}
}

Expand All @@ -930,42 +905,52 @@ Value PoolSizeCalculator<NumOfDims>::getPoolSize(
SmallVectorImpl<int64_t> &paddingInts) {
Value poolSize;

Value cstZero = rewriterHandle.create<arith::ConstantOp>(
location, rewriterHandle.getI64IntegerAttr(0));
Value cstZero = b.create<arith::ConstantOp>(location, b.getI64IntegerAttr(0));

for (int i = 0; i < NumOfDims; ++i) {
dims[i].IndexODim =
b.create<linalg::IndexOp>(location, /*value=*/dims[i].DimSpatialInt);
dims[i].ODim = castIndexToInt64(b, location, dims[i].IndexODim);
dims[i].DDim = rewriterHandle.create<arith::ConstantOp>(
location, rewriterHandle.getI64IntegerAttr(strideInts[i]));
dims[i].PadDim = rewriterHandle.create<arith::ConstantOp>(
location, rewriterHandle.getI64IntegerAttr(paddingInts[i]));
dims[i].ODimDDim =
b.create<arith::MulIOp>(location, dims[i].ODim, dims[i].DDim);
dims[i].IDim0 =
b.create<arith::SubIOp>(location, dims[i].ODimDDim, dims[i].PadDim);
dims[i].IDim = castIndexToInt64(b, location, dims[i].InputSpatialDimValues);
dims[i].IDim0KDim = b.create<arith::AddIOp>(location, dims[i].IDim0,
kernelSizeIntValues[i]);
dims[i].IDimPadDim =
b.create<arith::AddIOp>(location, dims[i].IDim, dims[i].PadDim);
dims[i].IDim1 = b.create<arith::MinSIOp>(location, dims[i].IDim0KDim,
dims[i].IDimPadDim);
dims[i].IDim1IDims0 =
b.create<arith::SubIOp>(location, dims[i].IDim1, dims[i].IDim0);

dims[i].IDim0Clamped =
b.create<arith::MaxSIOp>(location, dims[i].IDim0, cstZero);
dims[i].IDim1Clamped =
b.create<arith::MinSIOp>(location, dims[i].IDim1, dims[i].IDim);
dims[i].IDim1_IDim0 = b.create<arith::SubIOp>(
location, dims[i].IDim1Clamped, dims[i].IDim0Clamped);
// See the link below for the PyTorch implementation where
// this is derived from:
// https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78
// Dim below stands for spatial dimension. Prior to the February
// 2025 change, these variables used "height" and "width" (or
// "h" and "w") in these intermediate variables instead of "Dim".
Value IndexODim;
Value ODim;
Value DDim;
Value PadDim;
Value ODimDDim;
Value IDim0;
Value IDim;
Value IDim0KDim;
Value IDimPadDim;
Value IDim1;
Value IDim0Clamped;
Value IDim1Clamped;
Value IDim1_IDim0_Clamped;
IndexODim = b.create<linalg::IndexOp>(location,
/*value=*/DimSizeFromSumPoolType[i]);
ODim = castIndexToInt64(b, location, IndexODim);
DDim = b.create<arith::ConstantOp>(location,
b.getI64IntegerAttr(strideInts[i]));
PadDim = b.create<arith::ConstantOp>(location,
b.getI64IntegerAttr(paddingInts[i]));
ODimDDim = b.create<arith::MulIOp>(location, ODim, DDim);
IDim0 = b.create<arith::SubIOp>(location, ODimDDim, PadDim);
IDim = castIndexToInt64(b, location, InputSpatialDimValues[i]);
IDim0KDim =
b.create<arith::AddIOp>(location, IDim0, kernelSizeIntValues[i]);
IDimPadDim = b.create<arith::AddIOp>(location, IDim, PadDim);
IDim1 = b.create<arith::MinSIOp>(location, IDim0KDim, IDimPadDim);

IDim0Clamped = b.create<arith::MaxSIOp>(location, IDim0, cstZero);
IDim1Clamped = b.create<arith::MinSIOp>(location, IDim1, IDim);
IDim1_IDim0_Clamped =
b.create<arith::SubIOp>(location, IDim1Clamped, IDim0Clamped);
if (i == 0) {
poolSize = dims[0].IDim1_IDim0;
poolSize = IDim1_IDim0_Clamped;
} else {
poolSize =
b.create<arith::MulIOp>(location, poolSize, dims[i].IDim1_IDim0);
b.create<arith::MulIOp>(location, poolSize, IDim1_IDim0_Clamped);
}
}
return poolSize;
Expand Down
8 changes: 4 additions & 4 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,7 +1464,7 @@ class AvgPool3dCountIncludePadFalse(torch.nn.Module):

def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool3d(
self.ap3d = torch.nn.AvgPool3d(
kernel_size=[3, 3, 3],
stride=[1, 1, 1],
padding=[1, 1, 1],
Expand All @@ -1481,7 +1481,7 @@ def __init__(self):
]
)
def forward(self, x):
return self.ap2d(x)
return self.ap3d(x)


@register_test_case(module_factory=lambda: AvgPool3dCountIncludePadFalse())
Expand All @@ -1493,7 +1493,7 @@ class AvgPool3dCountIncludePadFalseWithoutPadding(torch.nn.Module):

def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool3d(
self.ap3d = torch.nn.AvgPool3d(
kernel_size=[3, 3, 3],
stride=[1, 1, 1],
padding=[0, 0, 0],
Expand All @@ -1510,7 +1510,7 @@ def __init__(self):
]
)
def forward(self, x):
return self.ap2d(x)
return self.ap3d(x)


@register_test_case(
Expand Down

0 comments on commit a98a5bf

Please sign in to comment.