From a98a5bfa4a25d21d0a9c99fabb6354d0358fb7d3 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Fri, 14 Feb 2025 12:16:35 -0500 Subject: [PATCH] Addressing feedback from Sayan (sahas3). --- lib/Conversion/TorchToLinalg/Pooling.cpp | 115 ++++++++---------- .../torch_mlir_e2e_test/test_suite/pooling.py | 8 +- 2 files changed, 54 insertions(+), 69 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 6f8ca853e6fc..0dcd4b6317dd 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -831,7 +831,7 @@ class ConvertAtenMaxUnpool3dOp final } // namespace namespace { -// The following structures and the adsfdasf method +// The following structures and the getNumOfDims method // are used to get the number of dimensions from the // average pooling type at compile time. template struct AtenAvgPoolTypeNumOfDims { @@ -851,31 +851,6 @@ template constexpr int getAvgPoolNumOfDims() { } } // namespace -namespace { -// This structure, used solely in PoolSizeCalculator, provides -// the intermediate values for each dimension to compute the -// divisor of the average pooling operator. -struct PoolSizeValues { - int64_t SpatialDimsInt64; - int64_t DimSpatialInt; - Value InputSpatialDimValues; - Value IndexODim; - Value ODim; - Value DDim; - Value PadDim; - Value ODimDDim; - Value IDim0; - Value IDim; - Value IDim0KDim; - Value IDimPadDim; - Value IDim1; - Value IDim1IDims0; - Value IDim0Clamped; - Value IDim1Clamped; - Value IDim1_IDim0; -}; -} // namespace - namespace { // This is a helper class to create the pooling size value // used in the divisor of the average pooling operator. @@ -885,7 +860,7 @@ template class PoolSizeCalculator { ConversionPatternRewriter &rewriter, Location loc); // The algorithm for computing the divisor with - // count_include_pad is manily based on pytorch + // count_include_pad equal is mainly based on pytorch // implementation. The following code is comment // with pytorch code. // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 @@ -896,8 +871,8 @@ template class PoolSizeCalculator { SmallVectorImpl &paddingInts); private: - PoolSizeValues dims[NumOfDims]; - ConversionPatternRewriter &rewriterHandle; + int64_t DimSizeFromSumPoolType[NumOfDims]; + Value InputSpatialDimValues[NumOfDims]; Location location; }; @@ -907,7 +882,7 @@ template PoolSizeCalculator::PoolSizeCalculator( Value self, Value sumPool, ConversionPatternRewriter &rewriter, Location loc) - : rewriterHandle(rewriter), location(loc) { + : location(loc) { auto selfType = cast(self.getType()); const int64_t selfRank = selfType.getRank(); RankedTensorType sumPoolType = cast(sumPool.getType()); @@ -916,10 +891,10 @@ PoolSizeCalculator::PoolSizeCalculator( // Store dimensions in this order: // 0 => width, 1 => height, 2 => depth for (int i = 0; i < NumOfDims; ++i) { - dims[i].SpatialDimsInt64 = toPositiveDim(-(i + 1), selfRank); - dims[i].InputSpatialDimValues = - getDimOp(rewriterHandle, location, self, dims[i].SpatialDimsInt64); - dims[i].DimSpatialInt = toPositiveDim(-(i + 1), rank); + int64_t DimSizeFromSelfType = toPositiveDim(-(i + 1), selfRank); + InputSpatialDimValues[i] = + getDimOp(rewriter, location, self, DimSizeFromSelfType); + DimSizeFromSumPoolType[i] = toPositiveDim(-(i + 1), rank); } } @@ -930,42 +905,52 @@ Value PoolSizeCalculator::getPoolSize( SmallVectorImpl &paddingInts) { Value poolSize; - Value cstZero = rewriterHandle.create( - location, rewriterHandle.getI64IntegerAttr(0)); + Value cstZero = b.create(location, b.getI64IntegerAttr(0)); for (int i = 0; i < NumOfDims; ++i) { - dims[i].IndexODim = - b.create(location, /*value=*/dims[i].DimSpatialInt); - dims[i].ODim = castIndexToInt64(b, location, dims[i].IndexODim); - dims[i].DDim = rewriterHandle.create( - location, rewriterHandle.getI64IntegerAttr(strideInts[i])); - dims[i].PadDim = rewriterHandle.create( - location, rewriterHandle.getI64IntegerAttr(paddingInts[i])); - dims[i].ODimDDim = - b.create(location, dims[i].ODim, dims[i].DDim); - dims[i].IDim0 = - b.create(location, dims[i].ODimDDim, dims[i].PadDim); - dims[i].IDim = castIndexToInt64(b, location, dims[i].InputSpatialDimValues); - dims[i].IDim0KDim = b.create(location, dims[i].IDim0, - kernelSizeIntValues[i]); - dims[i].IDimPadDim = - b.create(location, dims[i].IDim, dims[i].PadDim); - dims[i].IDim1 = b.create(location, dims[i].IDim0KDim, - dims[i].IDimPadDim); - dims[i].IDim1IDims0 = - b.create(location, dims[i].IDim1, dims[i].IDim0); - - dims[i].IDim0Clamped = - b.create(location, dims[i].IDim0, cstZero); - dims[i].IDim1Clamped = - b.create(location, dims[i].IDim1, dims[i].IDim); - dims[i].IDim1_IDim0 = b.create( - location, dims[i].IDim1Clamped, dims[i].IDim0Clamped); + // See the link below for the PyTorch implementation where + // this is derived from: + // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 + // Dim below stands for spatial dimension. Prior to the February + // 2025 change, these variables used "height" and "width" (or + // "h" and "w") in these intermediate variables instead of "Dim". + Value IndexODim; + Value ODim; + Value DDim; + Value PadDim; + Value ODimDDim; + Value IDim0; + Value IDim; + Value IDim0KDim; + Value IDimPadDim; + Value IDim1; + Value IDim0Clamped; + Value IDim1Clamped; + Value IDim1_IDim0_Clamped; + IndexODim = b.create(location, + /*value=*/DimSizeFromSumPoolType[i]); + ODim = castIndexToInt64(b, location, IndexODim); + DDim = b.create(location, + b.getI64IntegerAttr(strideInts[i])); + PadDim = b.create(location, + b.getI64IntegerAttr(paddingInts[i])); + ODimDDim = b.create(location, ODim, DDim); + IDim0 = b.create(location, ODimDDim, PadDim); + IDim = castIndexToInt64(b, location, InputSpatialDimValues[i]); + IDim0KDim = + b.create(location, IDim0, kernelSizeIntValues[i]); + IDimPadDim = b.create(location, IDim, PadDim); + IDim1 = b.create(location, IDim0KDim, IDimPadDim); + + IDim0Clamped = b.create(location, IDim0, cstZero); + IDim1Clamped = b.create(location, IDim1, IDim); + IDim1_IDim0_Clamped = + b.create(location, IDim1Clamped, IDim0Clamped); if (i == 0) { - poolSize = dims[0].IDim1_IDim0; + poolSize = IDim1_IDim0_Clamped; } else { poolSize = - b.create(location, poolSize, dims[i].IDim1_IDim0); + b.create(location, poolSize, IDim1_IDim0_Clamped); } } return poolSize; diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index e308e38bc1e3..ce7d2e2bc42e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -1464,7 +1464,7 @@ class AvgPool3dCountIncludePadFalse(torch.nn.Module): def __init__(self): super().__init__() - self.ap2d = torch.nn.AvgPool3d( + self.ap3d = torch.nn.AvgPool3d( kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[1, 1, 1], @@ -1481,7 +1481,7 @@ def __init__(self): ] ) def forward(self, x): - return self.ap2d(x) + return self.ap3d(x) @register_test_case(module_factory=lambda: AvgPool3dCountIncludePadFalse()) @@ -1493,7 +1493,7 @@ class AvgPool3dCountIncludePadFalseWithoutPadding(torch.nn.Module): def __init__(self): super().__init__() - self.ap2d = torch.nn.AvgPool3d( + self.ap3d = torch.nn.AvgPool3d( kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[0, 0, 0], @@ -1510,7 +1510,7 @@ def __init__(self): ] ) def forward(self, x): - return self.ap2d(x) + return self.ap3d(x) @register_test_case(