Skip to content

Commit

Permalink
[Codegen][ROCDL] Replace custom generalization pass with upstream one (
Browse files Browse the repository at this point in the history
…iree-org#16662)

After tiling in LLVMGPUVectorDistribute, the tiling configuration
attributes are no longer necessary. Additionally, the generalization of
named ops before vectorization is to make it easier to fold away unit
extent dims before vectorizing. At this point, it is best to generalize
all named ops to allow the unit dim folding patterns to apply more
easily, so we can switch to the upstream pass for that and drop the
local one that only applied to convolutions and contractions.
  • Loading branch information
qedawkins authored Mar 5, 2024
1 parent 77758bd commit 7171014
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,35 +43,6 @@ generalizeCandidates(MLIRContext *context,
return success();
}

namespace {
struct GPUGeneralizeNamedConvolutionAndContractionOpsPass
: public GPUGeneralizeNamedConvolutionAndContractionOpsBase<
GPUGeneralizeNamedConvolutionAndContractionOpsPass> {

void runOnOperation() override;
};
} // namespace

void GPUGeneralizeNamedConvolutionAndContractionOpsPass::runOnOperation() {
auto funcOp = getOperation();
SmallVector<linalg::LinalgOp> namedOpCandidates;
funcOp.walk([&](linalg::LinalgOp linalgOp) {
if (isa<linalg::ConvolutionOpInterface>(*linalgOp))
namedOpCandidates.push_back(linalgOp);
if (isa<linalg::ContractionOpInterface>(*linalgOp))
namedOpCandidates.push_back(linalgOp);
});

if (failed(generalizeCandidates(&getContext(), namedOpCandidates))) {
return signalPassFailure();
}
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createGPUGeneralizeNamedConvolutionAndContractionOpsPass() {
return std::make_unique<GPUGeneralizeNamedConvolutionAndContractionOpsPass>();
}

namespace {
struct GPUGeneralizeNamedOpsPass
: public GPUGeneralizeNamedOpsBase<GPUGeneralizeNamedOpsPass> {
Expand Down
5 changes: 0 additions & 5 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,6 @@ createWorkgroupSpecializationPass();
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createWorkGroupSwizzle(unsigned swizzleLogTile = 0);

// This pass generalizes named Linalg convolution and contraction ops to allow
// for better folding of unit dimensions.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createGPUGeneralizeNamedConvolutionAndContractionOpsPass();

// This pass generalizes named Linalg ops that are better off as generics.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createGPUGeneralizeNamedOpsPass();
Expand Down
6 changes: 0 additions & 6 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ def GPUDistributeSharedMemoryCopy :
let constructor = "mlir::iree_compiler::createGPUDistributeSharedMemoryCopy()";
}

def GPUGeneralizeNamedConvolutionAndContractionOps :
InterfacePass<"iree-codegen-gpu-generalize-named-convolution-and-contraction-ops", "mlir::FunctionOpInterface"> {
let summary = "Convert named Linalg convolution and contraction ops to linalg.generic ops";
let constructor = "mlir::iree_compiler::createGPUGeneralizeNamedConvolutionAndContractionOpsPass()";
}

def GPUGeneralizeNamedOps :
InterfacePass<"iree-codegen-gpu-generalize-named-ops", "mlir::FunctionOpInterface"> {
let summary = "Convert named Linalg ops to linalg.generic ops";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ iree_lit_test_suite(
"gpu_check_resource_usage.mlir",
"gpu_distribute.mlir",
"gpu_distribute_shared_memory.mlir",
"gpu_generalize_named_convolution_and_contraction_ops.mlir",
"gpu_generalize_named_ops.mlir",
"gpu_lower_to_ukernels.mlir",
"gpu_nested_layout_vector_distribution.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ iree_lit_test_suite(
"gpu_check_resource_usage.mlir"
"gpu_distribute.mlir"
"gpu_distribute_shared_memory.mlir"
"gpu_generalize_named_convolution_and_contraction_ops.mlir"
"gpu_generalize_named_ops.mlir"
"gpu_lower_to_ukernels.mlir"
"gpu_nested_layout_contract_amdgpu.mlir"
Expand Down

This file was deleted.

12 changes: 6 additions & 6 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,13 +527,13 @@ void addGPUVectorDistributePassPipeline(OpPassManager &pm) {
nestedModulePM.addNestedPass<func::FuncOp>(
createGPUTensorTileToSerialLoops());

// Generalize convolutions and contraction ops so that we can fold away unit
// extent dims. All convolutions are expected to have the kernel dimensions
// tiled to 1 by this point, so folding unit dims like this directly maps it
// to a matrix multiplication. After vectorization we expect to get a pure
// matmul (or a transposed variant) as a `vector.contract`.
// Generalize all named ops so that we can fold away unit extent dims. By this
// point, all tiling is finished so the tiling configurations on those ops can
// be safely dropped. This additionally allows vectorization of convolution to
// `vector.contract` as filter dimensions are expected to be tiled to 1 by
// this point.
nestedModulePM.addNestedPass<func::FuncOp>(
createGPUGeneralizeNamedConvolutionAndContractionOpsPass());
createLinalgGeneralizeNamedOpsPass());
LinalgFoldUnitExtentDimsPassOptions options;
options.useRankReducingSlices = true;
nestedModulePM.addNestedPass<func::FuncOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,85 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {
// CHECK: scf.for {{.*}} = %c0 to %c768 step %c32 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x4x1x1x1x4xf32>)
// CHECK-COUNT-16: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK-COUNT-8: vector.transfer_write {{.+}} : vector<4x1xf32>, memref<2x256x512x256xf32, #hal.descriptor_type<storage_buffer>>

// -----

#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {
mma_intrinsics = [#iree_gpu.mfma_layout<F16_16x16x16_F32>, #iree_gpu.mfma_layout<F16_32x32x8_F32>],
target_arch = "gfx942",
ukernels = "none"
}>
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
#pipeline_layout = #hal.pipeline.layout<
push_constants = 2,
sets = [
<0, bindings = [
<0, storage_buffer, ReadOnly>,
<1, storage_buffer, ReadOnly>,
<2, storage_buffer>
]>
]>
hal.executable public @main_dispatch_expanded_matmul {
hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) {
hal.executable.export public @generic_2x1024x20x64x1280_f16 ordinal(0) layout(#pipeline_layout) attributes {
hal.interface.bindings = [
#hal.interface.binding<0, 0>,
#hal.interface.binding<0, 1>,
#hal.interface.binding<0, 2>
]} {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @generic_2x1024x20x64x1280_f16() {
%cst = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = arith.index_castui %0 : i32 to index
%3 = arith.index_castui %1 : i32 to index
%4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x1024x1280xf16>>
%5 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x64x1280xf16>>
%6 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%3) : !flow.dispatch.tensor<writeonly:tensor<2x1024x20x64xf16>>
%7 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0], sizes = [2, 1024, 1280], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x1024x1280xf16>> -> tensor<2x1024x1280xf16>
%8 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0], sizes = [20, 64, 1280], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x64x1280xf16>> -> tensor<20x64x1280xf16>
%9 = tensor.empty() : tensor<2x1024x20x64xf16>
%10 = linalg.fill ins(%cst : f16) outs(%9 : tensor<2x1024x20x64xf16>) -> tensor<2x1024x20x64xf16>
%11 = linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]
} ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>)
outs(%10 : tensor<2x1024x20x64xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%12 = arith.mulf %in, %in_0 : f16
%13 = arith.addf %out, %12 : f16
linalg.yield %13 : f16
} -> tensor<2x1024x20x64xf16>
flow.dispatch.tensor.store %11, %6, offsets = [0, 0, 0, 0], sizes = [2, 1024, 20, 64], strides = [1, 1, 1, 1] : tensor<2x1024x20x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x1024x20x64xf16>>
return
}
}
}
}


// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute
// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mfma_layout<F16_16x16x16_F32>,
// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2, subgroup_m_tile_count = 2, subgroup_n_tile_count = 2, subgroup_k_tile_count = 8>

// CHECK-LABEL: hal.executable.export public @generic_2x1024x20x64x1280_f16
// CHECK-SAME: subgroup_size = 64
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [128 : index, 2 : index, 1 : index]

// CHECK-LABEL: func.func @generic_2x1024x20x64x1280_f16
// CHECK-NOT: vector.transfer_read
// CHECK: scf.for {{.*}} = %c0 to %c1280 step %c128 iter_args({{.*}}) -> (vector<2x2x1x1x1x4xf16>)
// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times
// along the K dimension. So in total 32 mfma ops.
// CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: scf.yield %{{.+}} : vector<2x2x1x1x1x4xf16>
// CHECK-COUNT-4: vector.transfer_write {{.+}} : vector<4x1xf16>

0 comments on commit 7171014

Please sign in to comment.