Skip to content

Commit

Permalink
[CodeGen] Add a pattern to fold extract_slice consumer into xfer.writ…
Browse files Browse the repository at this point in the history
…e. (iree-org#17067)

The LLVMCPU and SPIRV lit tests are updated because the pass runs
`scf::ForOp` canonicalization patterns. It drops the unused for op
results which makes analysis easier.

This is helpful for bufferization when masking is not involved. Because
we make the chain simpler and the new tensor.empty op can be replaced
with `flow.tensor.load` ops in bufferization pre-processing passes.
  • Loading branch information
hanhanW authored Apr 19, 2024
1 parent f755b42 commit 125f420
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -204,11 +205,99 @@ class FoldInsertSliceIntoTransferWrite final
}
};

/// Fold tensor.extract_slice into vector.transfer_write if
/// 1. The vector.transfer_write op has only one use.
/// 2. All the offests of the tensor.extract_slice op are zeros.
/// 3. The vector.transfer_write op does not have masks.
/// 4. The vector.transfer_write op writes to a tensor.empty op.
///
/// E.g.:
///
/// ```
/// %0 = vector.transfer_write %v, %t[%a, %b, %c]
/// {in_bounds = [true, true, true]}
/// : vector<1x64x128xf16>, tensor<1x64x128xf16>
/// %extracted_slice = tensor.extract_slice %0[0, 0, 0] [1, %3, 128] [1, 1, 1]
/// : tensor<1x64x128xf16> to tensor<1x?x128xf16>
/// ```
/// is rewritten to:
/// ```
/// %1 = vector.transfer_write %v, %t2[%a, %b, %c]
/// {in_bounds = [true, false, true]}
/// : vector<4x5xf32>, tensor<?x?xf32>
/// ```
class FoldExtractSliceIntoTransferWrite final
: public OpRewritePattern<tensor::ExtractSliceOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp,
PatternRewriter &rewriter) const override {
if (extractSliceOp.getDroppedDims().any()) {
return rewriter.notifyMatchFailure(
extractSliceOp,
"expect it is not a rank-reduced tensor.extract_slice op");
}
if (!llvm::all_of(extractSliceOp.getMixedOffsets(), isZeroIndex)) {
return rewriter.notifyMatchFailure(extractSliceOp,
"expect all the offsets are zeros");
}

auto xferOp =
extractSliceOp.getSource().getDefiningOp<vector::TransferWriteOp>();
if (!xferOp) {
return rewriter.notifyMatchFailure(
extractSliceOp, "expect the source is from transfer.vector_write op");
}
if (!xferOp->hasOneUse()) {
return rewriter.notifyMatchFailure(
extractSliceOp,
"expect the transfer.vector_write op has only one use");
}
if (!xferOp.getSource().getDefiningOp<tensor::EmptyOp>()) {
return rewriter.notifyMatchFailure(
extractSliceOp, "expect the transfer.vector_write op to write into a "
"tensor.empty op");
}
if (xferOp.getMask()) {
return failure();
}

Location loc = extractSliceOp.getLoc();
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
auto init = rewriter.create<tensor::EmptyOp>(
loc, mixedSizes, extractSliceOp.getType().getElementType());

SmallVector<bool> inBounds;
inBounds.resize(mixedSizes.size());
for (auto [idx, vecSize, destSize] :
llvm::zip_equal(llvm::seq<int64_t>(0, inBounds.size()),
xferOp.getVectorType().getShape(), mixedSizes)) {
auto maybeCst = getConstantIntValue(destSize);
if (!maybeCst) {
inBounds[idx] = false;
continue;
}
if (*maybeCst >= vecSize) {
inBounds[idx] = false;
} else {
inBounds[idx] = true;
}
}

rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
extractSliceOp, xferOp.getVector(), init, xferOp.getIndices(),
xferOp.getPermutationMap(), inBounds);

return success();
}
};

} // namespace

void mlir::iree_compiler::populateVectorTransferTensorSliceTransforms(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns
.add<FoldExtractSliceIntoTransferRead, FoldInsertSliceIntoTransferWrite>(
patterns.getContext(), benefit);
.add<FoldExtractSliceIntoTransferRead, FoldInsertSliceIntoTransferWrite,
FoldExtractSliceIntoTransferWrite>(patterns.getContext(), benefit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ void OptimizeTensorInsertExtractSlicesPass::runOnOperation() {

RewritePatternSet patterns(context);
populateVectorTransferTensorSliceTransforms(patterns);
scf::ForOp::getCanonicalizationPatterns(patterns, context);
vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ iree_lit_test_suite(
"lower_ukernel_to_calls.mlir",
"materialize_encoding_into_nop.mlir",
"materialize_user_configs.mlir",
"optimize_tensor_insert_extract_slices.mlir",
"pad_dynamic_alloc.mlir",
"polynomial_approximation.mlir",
"reconcile_translation_info.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ iree_lit_test_suite(
"lower_ukernel_to_calls.mlir"
"materialize_encoding_into_nop.mlir"
"materialize_user_configs.mlir"
"optimize_tensor_insert_extract_slices.mlir"
"pad_dynamic_alloc.mlir"
"polynomial_approximation.mlir"
"reconcile_translation_info.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-optimize-tensor-insert-extract-slices))" --split-input-file %s | FileCheck %s

func.func @fold_extract_slice_consumer_into_xfer_write(%arg0: vector<1x64x128xf16>, %arg1: index) -> tensor<1x?x128xf16> {
%c0 = arith.constant 0 : index
%0 = tensor.empty() : tensor<1x64x128xf16>
%1 = vector.transfer_write %arg0, %0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x64x128xf16>, tensor<1x64x128xf16>
%extracted_slice = tensor.extract_slice %1[0, 0, 0] [1, %arg1, 128] [1, 1, 1] : tensor<1x64x128xf16> to tensor<1x?x128xf16>
return %extracted_slice : tensor<1x?x128xf16>
}
// CHECK-LABEL: func.func @fold_extract_slice_consumer_into_xfer_write
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[SZ:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[INIT:.+]] = tensor.empty(%[[SZ]]) : tensor<1x?x128xf16>
// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[VEC]], %[[INIT]]
// CHECK-SAME: [%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, false, true]}
// CHECK-SAME: : vector<1x64x128xf16>, tensor<1x?x128xf16>
// CHECK: return %[[WRITE]]

// -----

#map = affine_map<()[s0] -> (s0 * 64)>
#map1 = affine_map<()[s0] -> (s0 * 128)>
#map2 = affine_map<()[s0] -> (s0 * -64 + 968, 64)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map5 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
func.func @batch_matmul_with_padding_strategy(%arg0: tensor<1x?x1280xf16>, %arg1: tensor<1x1280x128xf16>) {
%cst = arith.constant dense<0.000000e+00> : vector<1x64x128xf16>
%c20 = arith.constant 20 : index
%c1 = arith.constant 1 : index
%cst_0 = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x968x1280xf16>>
%workgroup_id_z = hal.interface.workgroup.id[2] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%1 = affine.apply #map()[%workgroup_id_y]
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%2 = affine.apply #map1()[%workgroup_id_x]
%3 = affine.min #map2()[%workgroup_id_y]
%4 = tensor.empty() : tensor<1x64x128xf16>
%5 = vector.transfer_write %cst, %4[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x64x128xf16>, tensor<1x64x128xf16>
%6 = scf.for %arg2 = %c0 to %c20 step %c1 iter_args(%arg3 = %5) -> (tensor<1x64x128xf16>) {
%7 = affine.delinearize_index %arg2 into (%c20) : index
%8 = affine.apply #map()[%7]
%extracted_slice_1 = tensor.extract_slice %arg1[0, %8, 0] [1, 64, 128] [1, 1, 1] : tensor<1x1280x128xf16> to tensor<1x64x128xf16>
%extracted_slice_2 = tensor.extract_slice %arg0[0, 0, %8] [1, %3, 64] [1, 1, 1] : tensor<1x?x1280xf16> to tensor<1x?x64xf16>
%9 = vector.transfer_read %extracted_slice_2[%c0, %c0, %c0], %cst_0 {in_bounds = [true, false, true]} : tensor<1x?x64xf16>, vector<1x64x64xf16>
%10 = vector.transfer_read %extracted_slice_1[%c0, %c0, %c0], %cst_0 {in_bounds = [true, true, true]} : tensor<1x64x128xf16>, vector<1x64x128xf16>
%11 = vector.transfer_read %arg3[%c0, %c0, %c0], %cst_0 {in_bounds = [true, true, true]} : tensor<1x64x128xf16>, vector<1x64x128xf16>
%12 = vector.contract {indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %9, %10, %11 : vector<1x64x64xf16>, vector<1x64x128xf16> into vector<1x64x128xf16>
%13 = vector.transfer_write %12, %arg3[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x64x128xf16>, tensor<1x64x128xf16>
scf.yield %13 : tensor<1x64x128xf16>
}
%extracted_slice = tensor.extract_slice %6[0, 0, 0] [1, %3, 128] [1, 1, 1] : tensor<1x64x128xf16> to tensor<1x?x128xf16>
flow.dispatch.tensor.store %extracted_slice, %0, offsets = [%workgroup_id_z, %1, %2], sizes = [1, %3, 128], strides = [1, 1, 1] : tensor<1x?x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x968x1280xf16>>
return
}
// CHECK-LABEL: func.func @batch_matmul_with_padding_strategy
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[SCF:.+]] = scf.for {{.+}} -> (vector<1x64x128xf16>) {
// CHECK: }
// CHECK: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor<1x?x128xf16>
// CHECK: %[[WRITE:.+]] = vector.transfer_write %[[SCF]], %[[INIT]]
// CHECK-SAME: [%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, false, true]}
// CHECK-SAME: : vector<1x64x128xf16>, tensor<1x?x128xf16>
// CHECK: flow.dispatch.tensor.store %[[WRITE]]
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@
// CHECK-NEXT: %[[OUT_SLICE:.*]] = tensor.extract_slice %[[OUT_TENSOR_1]]{{.*}} : tensor<1024x1024xf32> to tensor<8x?xf32>
// CHECK-NEXT: %[[OUT_SLICE_1:.*]] = tensor.extract_slice %[[OUT_SLICE]]{{.*}} : tensor<8x?xf32> to tensor<8x?xf32>
// CHECK-NEXT: %[[OUT_VEC:.*]] = vector.transfer_read %[[OUT_TENSOR_1]]{{.*}} : tensor<1024x1024xf32>, vector<8x[16]xf32>
// CHECK-NEXT: %[[INNER_LOOP:.*]]:3 = scf.for {{.*}} iter_args({{.*}}, %[[RES:.*]] = %[[OUT_VEC]]) -> (tensor<8x?xf32>, tensor<8x?xf32>, vector<8x[16]xf32>) {
// CHECK-NEXT: %[[INNER_LOOP:.*]] = scf.for {{.*}} iter_args(%[[RES:.*]] = %[[OUT_VEC]]) -> (vector<8x[16]xf32>) {
// CHECK-NEXT: %[[LHS:.*]] = vector.transfer_read {{.*}} : tensor<1024x1024xf32>, vector<8x1xf32>
// CHECK-NEXT: %[[RHS:.*]] = vector.transfer_read {{.*}} : tensor<1024x1024xf32>, vector<1x[16]xf32>
// CHECK-NEXT: %[[CONTRACT:.*]] = vector.contract {indexing_maps = [#map1, #map2, #map3],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
// CHECK-SAME: %[[LHS]], %[[RHS]], %[[RES]] : vector<8x1xf32>, vector<1x[16]xf32> into vector<8x[16]xf32>
// CHECK-NEXT: scf.yield {{.*}}, %[[CONTRACT]] : tensor<8x?xf32>, tensor<8x?xf32>, vector<8x[16]xf32>
// CHECK-NEXT: scf.yield %[[CONTRACT]] : vector<8x[16]xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[OUT_WRITE:.*]] = vector.transfer_write %[[INNER_LOOP]]#2, %[[INNER_LOOP]]#1{{.*}} {{.*}} : vector<8x[16]xf32>, tensor<8x?xf32>
// CHECK-NEXT: %[[INSERT_SLICE:.*]] = tensor.insert_slice %[[OUT_WRITE]] into %[[INNER_LOOP]]#0{{.*}} : tensor<8x?xf32> into tensor<8x?xf32>
// CHECK-NEXT: %[[OUT_WRITE:.*]] = vector.transfer_write %[[INNER_LOOP]], %[[OUT_SLICE_1]]{{.*}} {{.*}} : vector<8x[16]xf32>, tensor<8x?xf32>
// CHECK-NEXT: %[[INSERT_SLICE:.*]] = tensor.insert_slice %[[OUT_WRITE]] into %[[OUT_SLICE]]{{.*}} : tensor<8x?xf32> into tensor<8x?xf32>
// CHECK-NEXT: tensor.insert_slice %[[INSERT_SLICE]] into %[[OUT_TENSOR_1]]{{.*}} : tensor<8x?xf32> into tensor<1024x1024xf32>

func.func @pipeline() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,31 +177,31 @@ func.func @matmul_2x8x128_fp16(%a: tensor<2x128xf16>, %b: tensor<128x8xf16>, %x:
// CHECK-LABEL: func.func @matmul_2x8x128_fp16
// CHECK-SAME: (%[[LHS:.+]]: tensor<2x128xf16>, %[[RHS:.+]]: tensor<128x8xf16>, %[[X:.+]]: tensor<2x8xf16>, %[[Y:.+]]: tensor<2x8xf16>)
// CHECK: %[[ZERO:.+]] = arith.constant dense<0.000000e+00> : vector<8xf16>
// CHECK: %[[FOR:.+]]:3 = scf.for %arg4 = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%arg5 = %{{.+}}, %arg6 = %[[ZERO]], %arg7 = %[[ZERO]])
// CHECK: %[[FOR:.+]]:2 = scf.for %arg4 = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%arg5 = %[[ZERO]], %arg6 = %[[ZERO]])
// CHECK-COUNT-2: vector.transfer_read %[[LHS]]{{.+}} : tensor<2x128xf16>, vector<8xf16>
// CHECK-COUNT-8: vector.transfer_read %[[RHS]]{{.+}} : tensor<128x8xf16>, vector<8xf16>
// CHECK-COUNT-32: vector.fma {{.+}} : vector<4xf16>
// CHECK: %[[ISS0:.+]] = vector.insert_strided_slice %{{.+}}, %[[ZERO]] {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
// CHECK: %[[ISS1:.+]] = vector.insert_strided_slice %{{.+}}, %[[ISS0]] {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
// CHECK: %[[ISS2:.+]] = vector.insert_strided_slice %{{.+}}, %[[ZERO]] {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
// CHECK: %[[ISS3:.+]] = vector.insert_strided_slice %{{.+}}, %[[ISS2]] {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
// CHECK: scf.yield %arg5, %[[ISS1]], %[[ISS3]] : tensor<2x8xf16>, vector<8xf16>, vector<8xf16>
// CHECK: scf.yield %[[ISS1]], %[[ISS3]] : vector<8xf16>, vector<8xf16>
// CHECK: }
// CHECK: %[[X0:.+]] = vector.transfer_read %[[X]]{{.+}} : tensor<2x8xf16>, vector<8xf16>
// CHECK: %[[X1:.+]] = vector.transfer_read %[[X]]{{.+}} : tensor<2x8xf16>, vector<8xf16>
// CHECK: %[[LHS0:.+]] = vector.extract_strided_slice %[[FOR]]#1 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[LHS0:.+]] = vector.extract_strided_slice %[[FOR]]#0 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[RHS0:.+]] = vector.extract_strided_slice %[[X0]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[DIV0:.+]] = arith.divf %[[LHS0]], %[[RHS0]]
// CHECK: %[[ISS0:.+]] = vector.insert_strided_slice %[[DIV0]], %[[ZERO]] {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
// CHECK: %[[LHS1:.+]] = vector.extract_strided_slice %[[FOR]]#1 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[LHS1:.+]] = vector.extract_strided_slice %[[FOR]]#0 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[RHS1:.+]] = vector.extract_strided_slice %[[X0]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[DIV1:.+]] = arith.divf %[[LHS1]], %[[RHS1]]
// CHECK: %[[ISS1:.+]] = vector.insert_strided_slice %[[DIV1]], %[[ISS0]] {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
// CHECK: %[[LHS2:.+]] = vector.extract_strided_slice %[[FOR]]#2 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[LHS2:.+]] = vector.extract_strided_slice %[[FOR]]#1 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[RHS2:.+]] = vector.extract_strided_slice %[[X1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[DIV2:.+]] = arith.divf %[[LHS2]], %[[RHS2]]
// CHECK: %[[ISS2:.+]] = vector.insert_strided_slice %[[DIV2]], %[[ZERO]] {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
// CHECK: %[[LHS3:.+]] = vector.extract_strided_slice %[[FOR]]#2 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[LHS3:.+]] = vector.extract_strided_slice %[[FOR]]#1 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[RHS3:.+]] = vector.extract_strided_slice %[[X1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[DIV3:.+]] = arith.divf %[[LHS3]], %[[RHS3]]
// CHECK: %[[ISS3:.+]] = vector.insert_strided_slice %[[DIV3]], %[[ISS2]] {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
Expand Down

0 comments on commit 125f420

Please sign in to comment.