Skip to content

Simplify lowering of aten.reflection_pad2d to linalg #2772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 55 additions & 72 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,7 @@ class ConvertAtenReflectionPad1dOp
namespace {

// Lower the aten.reflection.pad_2d operator into a sequence of
// tensor.extract_slice, linalg.generic, and tensor_insert_slice
// operations.
// tensor.extract_slice and tensor_insert_slice operations.

// To understand the lowering, consider this pytorch example:
//
Expand Down Expand Up @@ -282,8 +281,8 @@ namespace {
// center right: [[2,1]]
//
// The lowering uses a tensor.extract_slice operation to create each tile,
// a linalg.generic for the reflection, and a tensor.insert_slice to
// insert the tile in the resulting tensor.
// including the reversal of the order of the elements if necessary,
// and a tensor.insert_slice to insert the tile in the result tensor.
class ConvertAtenReflectionPad2dOp
: public OpConversionPattern<AtenReflectionPad2dOp> {
public:
Expand All @@ -305,20 +304,12 @@ class ConvertAtenReflectionPad2dOp
return rewriter.create<arith::AddIOp>(loc, x, y);
};

auto createAdds = [&](std::initializer_list<Value> values) {
assert(values.size() >= 2);
return std::accumulate(values.begin() + 1, values.end(), data(values)[0],
createAdd);
};

auto createSub = [&](Value x, Value y) {
return rewriter.create<arith::SubIOp>(loc, x, y);
};

auto createSubs = [&](std::initializer_list<Value> values) {
assert(values.size() >= 2);
return std::accumulate(values.begin() + 1, values.end(), data(values)[0],
createSub);
auto getIndexConst = [&](int c) {
return rewriter.create<arith::ConstantIndexOp>(loc, c);
};

// Enums for specifying the coordinates of a tile. An "h" prefix
Expand Down Expand Up @@ -349,7 +340,6 @@ class ConvertAtenReflectionPad2dOp
};

Value input = adaptor.getSelf();
MLIRContext *context = rewriter.getContext();
auto inputType = llvm::cast<RankedTensorType>(input.getType());
auto outputType = llvm::cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
Expand All @@ -372,37 +362,27 @@ class ConvertAtenReflectionPad2dOp
assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] &&
"Bottom padding too large");

Type indexType = rewriter.getIndexType();
Value zero = getConstant(rewriter, loc, 0, indexType);
Value one = getConstant(rewriter, loc, 1, indexType);
Value zero = getIndexConst(0);
Value one = getIndexConst(1);
Value two = getIndexConst(2);
Value minusOne = getIndexConst(-1);

Value tileWidth[3];
tileWidth[HCENTER] = hDimSize;
for (auto h : {LEFT, RIGHT})
tileWidth[h] = getConstant(rewriter, loc, getHPadArgument(h), indexType);
tileWidth[h] = getIndexConst(getHPadArgument(h));

Value tileHeight[3];
tileHeight[VCENTER] = vDimSize;
for (auto v : {TOP, BOTTOM})
tileHeight[v] = getConstant(rewriter, loc, getVPadArgument(v), indexType);

// Helper to reflect/reverse the i-th dimension of an affine map
// without symbols. This only works if applied on a tensor
// for which the corresponding dimension has a statically
// known size which is good enough since we only apply
// it to reflect the padding slices.
auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i,
int64_t size) {
AffineExpr d = map.getResult(i);
return map.replace(d, size - d - 1, numDims, 0);
};
tileHeight[v] = getIndexConst(getVPadArgument(v));

// Create output shape and tensor
SmallVector<Value> resultShape{inputShape};
resultShape[vDim] =
createAdds({resultShape[vDim], tileHeight[TOP], tileHeight[BOTTOM]});
resultShape[hDim] =
createAdds({resultShape[hDim], tileWidth[LEFT], tileWidth[RIGHT]});
resultShape[vDim] = createAdd(createAdd(resultShape[vDim], tileHeight[TOP]),
tileHeight[BOTTOM]);
resultShape[hDim] = createAdd(createAdd(resultShape[hDim], tileWidth[LEFT]),
tileWidth[RIGHT]);

Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape,
inputType.getElementType());
Expand Down Expand Up @@ -444,18 +424,29 @@ class ConvertAtenReflectionPad2dOp

// Setup information about the tiles

// Compute the offsets for extracting the slice from the
// input. We need to skip the row or column through which
// the tile should be reflected, if any (none for the center tile).
// Compute the offsets for extracting the slice from the input. To
// reverse the order of the elements in the non-central tiles,
// extract the slices with negative strides and start from the
// last element of the input that should belong to the slice,
// skipping the "axis" element through which the elements are
// reflected:
//
// - The left tile is obtained by extracting elements
// tileWidth[LEFT] + 1, ..., 2 in this, i.e. reverse order.
//
// - The right tile is obtained by extracting elements
// hDimSize - 1, ..., hDimSize - tileWidth[RIGHT] - 1 in this,
// i.e. reverse order.

Value extractHOffset[3];
extractHOffset[LEFT] = one;
extractHOffset[LEFT] = tileWidth[LEFT];
extractHOffset[HCENTER] = zero;
extractHOffset[RIGHT] = createSubs({hDimSize, tileWidth[RIGHT], one});
extractHOffset[RIGHT] = createSub(hDimSize, two);

Value extractVOffset[3];
extractVOffset[TOP] = one;
extractVOffset[TOP] = tileHeight[TOP];
extractVOffset[VCENTER] = zero;
extractVOffset[BOTTOM] = createSubs({vDimSize, tileHeight[BOTTOM], one});
extractVOffset[BOTTOM] = createSub(vDimSize, two);

// Compute the horizontal and vertical offsets for inserting
// the tiles in the resultTensor.
Expand All @@ -469,55 +460,47 @@ class ConvertAtenReflectionPad2dOp
insertVOffset[VCENTER] = tileHeight[TOP];
insertVOffset[BOTTOM] = createAdd(vDimSize, tileHeight[TOP]);

auto shouldHReflect = [](PadHLoc l) { return l == LEFT || l == RIGHT; };
auto shouldVReflect = [](PadVLoc l) { return l == TOP || l == BOTTOM; };
// Define the strides for the tensor.extract_slice operations.
// Using a negative stride for a dimension reverses the order
// of the extracted elements as necessary for the reflection.
Value extractHStride[3];
extractHStride[LEFT] = minusOne;
extractHStride[HCENTER] = one;
extractHStride[RIGHT] = minusOne;

Value extractVStride[3];
extractVStride[TOP] = minusOne;
extractVStride[VCENTER] = one;
extractVStride[BOTTOM] = minusOne;

SmallVector<utils::IteratorType> iteratorTypes{
numDims, utils::IteratorType::parallel};
auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context);
SmallVector<Value> allOneStrides(numDims, one);

auto createTile = [&](PadVLoc verticalPos, PadHLoc horizontalPos) {
// Create the tile by extracting a slice from the input tenor.
SmallVector<Value> extractShape{inputShape};
extractShape[hDim] = tileWidth[horizontalPos];
extractShape[vDim] = tileHeight[verticalPos];
SmallVector<Value> tileShape{inputShape};
tileShape[hDim] = tileWidth[horizontalPos];
tileShape[vDim] = tileHeight[verticalPos];

// Create the tile by extracting a slice from the input tenor.
SmallVector<Value> extractOffsets(numDims, zero);
extractOffsets[hDim] = extractHOffset[horizontalPos];
extractOffsets[vDim] = extractVOffset[verticalPos];

Value tile = rewriter.create<tensor::ExtractSliceOp>(
loc, input, extractOffsets, extractShape, allOneStrides);
SmallVector<Value> extractStrides(numDims, one);
extractStrides[hDim] = extractHStride[horizontalPos];
extractStrides[vDim] = extractVStride[verticalPos];

// Reverse the tile along the horizontal, vertical, or both
// dimensions.
auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context);
if (shouldHReflect(horizontalPos)) {
inputMap =
reflectDim(inputMap, numDims, hDim, getHPadArgument(horizontalPos));
}
if (shouldVReflect(verticalPos)) {
inputMap =
reflectDim(inputMap, numDims, vDim, getVPadArgument(verticalPos));
}

tile = rewriter
.create<linalg::GenericOp>(
loc, llvm::cast<RankedTensorType>(tile.getType()), tile,
tile, ArrayRef({inputMap, idMap}), iteratorTypes,
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
b.create<linalg::YieldOp>(nestedLoc, args[0]);
})
.getResult(0);
Value tile = rewriter.create<tensor::ExtractSliceOp>(
loc, input, extractOffsets, tileShape, extractStrides);

// Insert the tile in the resultTensor.
SmallVector<Value> insertOffsets(numDims, zero);
insertOffsets[hDim] = insertHOffset[horizontalPos];
insertOffsets[vDim] = insertVOffset[verticalPos];

resultTensor = rewriter.create<tensor::InsertSliceOp>(
loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides);
loc, tile, resultTensor, insertOffsets, tileShape, allOneStrides);
};

for (auto v : {TOP, BOTTOM, VCENTER})
Expand Down