diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 49f5f0ec3321..79ed790254c1 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -247,8 +247,7 @@ class ConvertAtenReflectionPad1dOp namespace { // Lower the aten.reflection.pad_2d operator into a sequence of -// tensor.extract_slice, linalg.generic, and tensor_insert_slice -// operations. +// tensor.extract_slice and tensor_insert_slice operations. // To understand the lowering, consider this pytorch example: // @@ -282,8 +281,8 @@ namespace { // center right: [[2,1]] // // The lowering uses a tensor.extract_slice operation to create each tile, -// a linalg.generic for the reflection, and a tensor.insert_slice to -// insert the tile in the resulting tensor. +// including the reversal of the order of the elements if necessary, +// and a tensor.insert_slice to insert the tile in the result tensor. class ConvertAtenReflectionPad2dOp : public OpConversionPattern { public: @@ -305,20 +304,12 @@ class ConvertAtenReflectionPad2dOp return rewriter.create(loc, x, y); }; - auto createAdds = [&](std::initializer_list values) { - assert(values.size() >= 2); - return std::accumulate(values.begin() + 1, values.end(), data(values)[0], - createAdd); - }; - auto createSub = [&](Value x, Value y) { return rewriter.create(loc, x, y); }; - auto createSubs = [&](std::initializer_list values) { - assert(values.size() >= 2); - return std::accumulate(values.begin() + 1, values.end(), data(values)[0], - createSub); + auto getIndexConst = [&](int c) { + return rewriter.create(loc, c); }; // Enums for specifying the coordinates of a tile. An "h" prefix @@ -349,7 +340,6 @@ class ConvertAtenReflectionPad2dOp }; Value input = adaptor.getSelf(); - MLIRContext *context = rewriter.getContext(); auto inputType = llvm::cast(input.getType()); auto outputType = llvm::cast( getTypeConverter()->convertType(op->getResult(0).getType())); @@ -372,37 +362,27 @@ class ConvertAtenReflectionPad2dOp assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] && "Bottom padding too large"); - Type indexType = rewriter.getIndexType(); - Value zero = getConstant(rewriter, loc, 0, indexType); - Value one = getConstant(rewriter, loc, 1, indexType); + Value zero = getIndexConst(0); + Value one = getIndexConst(1); + Value two = getIndexConst(2); + Value minusOne = getIndexConst(-1); Value tileWidth[3]; tileWidth[HCENTER] = hDimSize; for (auto h : {LEFT, RIGHT}) - tileWidth[h] = getConstant(rewriter, loc, getHPadArgument(h), indexType); + tileWidth[h] = getIndexConst(getHPadArgument(h)); Value tileHeight[3]; tileHeight[VCENTER] = vDimSize; for (auto v : {TOP, BOTTOM}) - tileHeight[v] = getConstant(rewriter, loc, getVPadArgument(v), indexType); - - // Helper to reflect/reverse the i-th dimension of an affine map - // without symbols. This only works if applied on a tensor - // for which the corresponding dimension has a statically - // known size which is good enough since we only apply - // it to reflect the padding slices. - auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, - int64_t size) { - AffineExpr d = map.getResult(i); - return map.replace(d, size - d - 1, numDims, 0); - }; + tileHeight[v] = getIndexConst(getVPadArgument(v)); // Create output shape and tensor SmallVector resultShape{inputShape}; - resultShape[vDim] = - createAdds({resultShape[vDim], tileHeight[TOP], tileHeight[BOTTOM]}); - resultShape[hDim] = - createAdds({resultShape[hDim], tileWidth[LEFT], tileWidth[RIGHT]}); + resultShape[vDim] = createAdd(createAdd(resultShape[vDim], tileHeight[TOP]), + tileHeight[BOTTOM]); + resultShape[hDim] = createAdd(createAdd(resultShape[hDim], tileWidth[LEFT]), + tileWidth[RIGHT]); Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, inputType.getElementType()); @@ -444,18 +424,29 @@ class ConvertAtenReflectionPad2dOp // Setup information about the tiles - // Compute the offsets for extracting the slice from the - // input. We need to skip the row or column through which - // the tile should be reflected, if any (none for the center tile). + // Compute the offsets for extracting the slice from the input. To + // reverse the order of the elements in the non-central tiles, + // extract the slices with negative strides and start from the + // last element of the input that should belong to the slice, + // skipping the "axis" element through which the elements are + // reflected: + // + // - The left tile is obtained by extracting elements + // tileWidth[LEFT] + 1, ..., 2 in this, i.e. reverse order. + // + // - The right tile is obtained by extracting elements + // hDimSize - 1, ..., hDimSize - tileWidth[RIGHT] - 1 in this, + // i.e. reverse order. + Value extractHOffset[3]; - extractHOffset[LEFT] = one; + extractHOffset[LEFT] = tileWidth[LEFT]; extractHOffset[HCENTER] = zero; - extractHOffset[RIGHT] = createSubs({hDimSize, tileWidth[RIGHT], one}); + extractHOffset[RIGHT] = createSub(hDimSize, two); Value extractVOffset[3]; - extractVOffset[TOP] = one; + extractVOffset[TOP] = tileHeight[TOP]; extractVOffset[VCENTER] = zero; - extractVOffset[BOTTOM] = createSubs({vDimSize, tileHeight[BOTTOM], one}); + extractVOffset[BOTTOM] = createSub(vDimSize, two); // Compute the horizontal and vertical offsets for inserting // the tiles in the resultTensor. @@ -469,47 +460,39 @@ class ConvertAtenReflectionPad2dOp insertVOffset[VCENTER] = tileHeight[TOP]; insertVOffset[BOTTOM] = createAdd(vDimSize, tileHeight[TOP]); - auto shouldHReflect = [](PadHLoc l) { return l == LEFT || l == RIGHT; }; - auto shouldVReflect = [](PadVLoc l) { return l == TOP || l == BOTTOM; }; + // Define the strides for the tensor.extract_slice operations. + // Using a negative stride for a dimension reverses the order + // of the extracted elements as necessary for the reflection. + Value extractHStride[3]; + extractHStride[LEFT] = minusOne; + extractHStride[HCENTER] = one; + extractHStride[RIGHT] = minusOne; + + Value extractVStride[3]; + extractVStride[TOP] = minusOne; + extractVStride[VCENTER] = one; + extractVStride[BOTTOM] = minusOne; SmallVector iteratorTypes{ numDims, utils::IteratorType::parallel}; - auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context); SmallVector allOneStrides(numDims, one); auto createTile = [&](PadVLoc verticalPos, PadHLoc horizontalPos) { - // Create the tile by extracting a slice from the input tenor. - SmallVector extractShape{inputShape}; - extractShape[hDim] = tileWidth[horizontalPos]; - extractShape[vDim] = tileHeight[verticalPos]; + SmallVector tileShape{inputShape}; + tileShape[hDim] = tileWidth[horizontalPos]; + tileShape[vDim] = tileHeight[verticalPos]; + // Create the tile by extracting a slice from the input tenor. SmallVector extractOffsets(numDims, zero); extractOffsets[hDim] = extractHOffset[horizontalPos]; extractOffsets[vDim] = extractVOffset[verticalPos]; - Value tile = rewriter.create( - loc, input, extractOffsets, extractShape, allOneStrides); + SmallVector extractStrides(numDims, one); + extractStrides[hDim] = extractHStride[horizontalPos]; + extractStrides[vDim] = extractVStride[verticalPos]; - // Reverse the tile along the horizontal, vertical, or both - // dimensions. - auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context); - if (shouldHReflect(horizontalPos)) { - inputMap = - reflectDim(inputMap, numDims, hDim, getHPadArgument(horizontalPos)); - } - if (shouldVReflect(verticalPos)) { - inputMap = - reflectDim(inputMap, numDims, vDim, getVPadArgument(verticalPos)); - } - - tile = rewriter - .create( - loc, llvm::cast(tile.getType()), tile, - tile, ArrayRef({inputMap, idMap}), iteratorTypes, - [](OpBuilder &b, Location nestedLoc, ValueRange args) { - b.create(nestedLoc, args[0]); - }) - .getResult(0); + Value tile = rewriter.create( + loc, input, extractOffsets, tileShape, extractStrides); // Insert the tile in the resultTensor. SmallVector insertOffsets(numDims, zero); @@ -517,7 +500,7 @@ class ConvertAtenReflectionPad2dOp insertOffsets[vDim] = insertVOffset[verticalPos]; resultTensor = rewriter.create( - loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); + loc, tile, resultTensor, insertOffsets, tileShape, allOneStrides); }; for (auto v : {TOP, BOTTOM, VCENTER})