Skip to content

Commit

Permalink
[LLVMGPU] Fit mma schedules inside shared memory limits (iree-org#16927)
Browse files Browse the repository at this point in the history
This patch adds support to check if a matmul schedule would cause
promotion to create allocations which do not fit shared memory size, and
shrink the MMA schedule if so. The patch also updates the
check-resource-usage pass in LLVMGPU pass pipeline to query shared
memory limit from the target.

---------

Co-authored-by: Quinn Dawkins <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
  • Loading branch information
3 people authored Apr 11, 2024
1 parent 4437c43 commit 94971b4
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 25 deletions.
95 changes: 89 additions & 6 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h"

#include <cstdint>

#include "llvm/ADT/APInt.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
Expand All @@ -17,10 +20,87 @@ using llvm::APIntOps::GreatestCommonDivisor;

namespace mlir::iree_compiler {

std::optional<GPUMMASchedule>
static int64_t calculateSharedMemoryUsedInBytes(const GPUMMASchedule &schedule,
int64_t lhsBitwidth,
int64_t rhsBitwidth) {
int64_t tileM = schedule.mSize * schedule.mTileCount * schedule.mWarpCount;
int64_t tileN = schedule.nSize * schedule.nTileCount * schedule.nWarpCount;
int64_t tileK = schedule.kSize * schedule.kTileCount;
return (tileM * tileK * lhsBitwidth + tileN * tileK * rhsBitwidth) / 8;
}

bool isValidSchedule(const GPUMatmulShapeType &problem,
const GPUMMASchedule &schedule) {
bool isValidM = (problem.mSize % (schedule.mSize * schedule.mTileCount *
schedule.mWarpCount)) == 0;
bool isValidN = (problem.nSize % (schedule.nSize * schedule.nTileCount *
schedule.nWarpCount)) == 0;
bool isValidK = (problem.kSize % (schedule.kSize * schedule.kTileCount)) == 0;
return isValidN && isValidM && isValidK;
}

FailureOr<GPUMMASchedule> fitScheduleInSharedMemory(
const GPUMatmulShapeType &problem, ArrayRef<GPUMatmulShapeType> intrinsics,
GPUMMASchedule schedule, int64_t sharedMemLimitInBytes) {
int64_t lhsBitwidth =
intrinsics[schedule.index].aType.getIntOrFloatBitWidth();
int64_t rhsBitwidth =
intrinsics[schedule.index].bType.getIntOrFloatBitWidth();

while (!isValidSchedule(problem, schedule) ||
calculateSharedMemoryUsedInBytes(schedule, lhsBitwidth, rhsBitwidth) >
sharedMemLimitInBytes) {
LLVM_DEBUG({
llvm::dbgs() << "Shrinking schedule\n";
llvm::dbgs() << "mSize: " << schedule.mSize << "\n";
llvm::dbgs() << "nSize: " << schedule.nSize << "\n";
llvm::dbgs() << "kSize: " << schedule.kSize << "\n";
llvm::dbgs() << "mTileCount: " << schedule.mTileCount << "\n";
llvm::dbgs() << "nTileCount: " << schedule.nTileCount << "\n";
llvm::dbgs() << "kTileCount: " << schedule.kTileCount << "\n";
llvm::dbgs() << "mWarpCount: " << schedule.mWarpCount << "\n";
llvm::dbgs() << "nWarpCount: " << schedule.nWarpCount << "\n";
});

auto decrementIfPossible = [](int64_t &c) -> LogicalResult {
if (c <= 1) {
return failure();
}
--c;
return success();
};

// Attempt to shrink the schedule along one of the dimensions.
// TODO: A better solution should probably factor problem.mSize /
// (mWarpCount * mTileCount * mSize) and then pop off the smallest factors
// one at a time, preferably trying to keep the tile "generally square."
if (succeeded(decrementIfPossible(schedule.mTileCount))) {
continue;
}
if (succeeded(decrementIfPossible(schedule.nTileCount))) {
continue;
}
if (succeeded(decrementIfPossible(schedule.kTileCount))) {
continue;
}
if (succeeded(decrementIfPossible(schedule.mWarpCount))) {
continue;
}
if (succeeded(decrementIfPossible(schedule.nWarpCount))) {
continue;
}

// If no dimension can be shrunk, give up.
return failure();
}
return schedule;
}

FailureOr<GPUMMASchedule>
deduceMMASchedule(const GPUMatmulShapeType &problem,
ArrayRef<GPUMatmulShapeType> intrinsics,
const GPUMMAHeuristicSeeds &seeds, bool canUpcastAcc) {
const GPUMMAHeuristicSeeds &seeds,
int64_t sharedMemLimitInBytes, bool canUpcastAcc) {
for (auto [index, intrinsic] : llvm::enumerate(intrinsics)) {
if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType) {
continue; // Cannot use this intrinsic for mismatched types
Expand Down Expand Up @@ -112,11 +192,14 @@ deduceMMASchedule(const GPUMatmulShapeType &problem,
llvm::dbgs() << " subgroup tile count (M, N, K) = (" << mTileCount
<< ", " << nTileCount << ", " << kTileCount << ")\n";
});
return GPUMMASchedule{index, intrinsic.mSize, intrinsic.nSize,
intrinsic.kSize, mWarpCount, nWarpCount,
mTileCount, nTileCount, kTileCount};
return fitScheduleInSharedMemory(
problem, intrinsics,
GPUMMASchedule{index, intrinsic.mSize, intrinsic.nSize, intrinsic.kSize,
mWarpCount, nWarpCount, mTileCount, nTileCount,
kTileCount},
sharedMemLimitInBytes);
}
return std::nullopt;
return failure();
}

} // namespace mlir::iree_compiler
5 changes: 3 additions & 2 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ struct GPUMMASchedule {

/// Returns a schedule for using one of the given MMA |intrinsics| to target the
/// input |problem|. Returns std::nullopt if we cannot find such a schedule.
std::optional<GPUMMASchedule>
FailureOr<GPUMMASchedule>
deduceMMASchedule(const GPUMatmulShapeType &problem,
ArrayRef<GPUMatmulShapeType> intrinsics,
const GPUMMAHeuristicSeeds &seeds, bool canUpcastAcc = false);
const GPUMMAHeuristicSeeds &seeds,
int64_t sharedMemLimitInBytes, bool canUpcastAcc = false);

} // namespace mlir::iree_compiler
28 changes: 20 additions & 8 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ struct TargetInfo {
bool hasMmaSync = false;
// These are listed in the order of preference, not necessarily monotonically.
SmallVector<int64_t, 2> supportedSubgroupSizes = {32};
int64_t sharedMemoryLimitInBytes = 65536;
};

struct TileWorkgroupSizePair {
Expand Down Expand Up @@ -196,6 +197,7 @@ static TargetInfo getCudaTargetInfo(mlir::FunctionOpInterface entryPoint) {
// All the cuda target are assumed to have warp support.
info.hasWarpShuffle = true;
info.supportedSubgroupSizes = {32};
info.sharedMemoryLimitInBytes = 163 * 1024;
StringRef targetName = getTargetArch(entryPoint);
// If no target name is set assume all the features are off.
if (targetName == "")
Expand All @@ -221,6 +223,7 @@ static TargetInfo getCudaTargetInfo(mlir::FunctionOpInterface entryPoint) {
static TargetInfo getRocmTargetInfo(mlir::FunctionOpInterface entryPoint) {
TargetInfo info;
StringRef targetName = getTargetArch(entryPoint);
info.sharedMemoryLimitInBytes = 65536;
// If no target name is set assume all the features are off.
if (targetName.empty())
return info;
Expand Down Expand Up @@ -409,15 +412,18 @@ setConvolutionVectorDistributionConfig(mlir::FunctionOpInterface entryPoint,
/*bestMNTileCountPerSubgroup=*/8,
/*bestKTileCountPerSubgroup=*/2};

int64_t sharedMemoryLimitInBytes = targetInfo.sharedMemoryLimitInBytes;

// First try to find a schedule with an exactly matching intrinsic.
std::optional<GPUMMASchedule> schedule =
deduceMMASchedule(problem, intrinsics, seeds);
if (!schedule) {
FailureOr<GPUMMASchedule> schedule =
deduceMMASchedule(problem, intrinsics, seeds, sharedMemoryLimitInBytes);
if (failed(schedule)) {
// Then try again by allowing upcasting accumulator.
schedule =
deduceMMASchedule(problem, intrinsics, seeds, /*canUpcastAcc=*/true);
deduceMMASchedule(problem, intrinsics, seeds, sharedMemoryLimitInBytes,
/*canUpcastAcc=*/true);
}
if (!schedule) {
if (failed(schedule)) {
return failure();
}

Expand Down Expand Up @@ -557,13 +563,16 @@ setMatmulVectorDistributionConfig(mlir::FunctionOpInterface entryPoint,
/*bestKTileCountPerSubgroup=*/4};
}

int64_t sharedMemoryLimitInBytes = targetInfo.sharedMemoryLimitInBytes;

// First try to find a schedule with an exactly matching intrinsic.
std::optional<GPUMMASchedule> schedule =
deduceMMASchedule(problem, intrinsics, seeds);
deduceMMASchedule(problem, intrinsics, seeds, sharedMemoryLimitInBytes);
if (!schedule) {
// Then try again by allowing upcasting accumulator.
schedule =
deduceMMASchedule(problem, intrinsics, seeds, /*canUpcastAcc=*/true);
deduceMMASchedule(problem, intrinsics, seeds, sharedMemoryLimitInBytes,
/*canUpcastAcc=*/true);
}
if (!schedule) {
return failure();
Expand Down Expand Up @@ -1482,7 +1491,6 @@ static LogicalResult setTransposeConfig(mlir::FunctionOpInterface entryPoint,
static LogicalResult
setArgmaxUkernelConfig(mlir::FunctionOpInterface entryPoint,
linalg::GenericOp op, const TargetInfo &targetInfo) {

// Checks if UKernels are enabled.
if (auto variantOp =
entryPoint->getParentOfType<IREE::HAL::ExecutableVariantOp>()) {
Expand Down Expand Up @@ -1801,6 +1809,10 @@ static void propagateLoweringConfig(Operation *rootOperation,
}
}

int64_t getTargetSharedMemoryLimitInBytes(FunctionOpInterface entryPoint) {
return getTargetInfo(entryPoint).sharedMemoryLimitInBytes;
}

//===----------------------------------------------------------------------===//
// Entry Point
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@
#define IREE_COMPILER_CODEGEN_LLVMGPU_KERNELCONFIG_H_

#include "mlir/IR/BuiltinOps.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

namespace mlir::iree_compiler {

// TODO: Ideally, we should be setting the resource limits as an attribute in
// the lowering configuration.
int64_t getTargetSharedMemoryLimitInBytes(FunctionOpInterface entryPoint);

LogicalResult initGPULaunchConfig(ModuleOp moduleOp);

} // namespace mlir::iree_compiler
Expand Down
12 changes: 6 additions & 6 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/LLVMGPU/KernelConfig.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
Expand Down Expand Up @@ -823,14 +824,13 @@ static void addLowerToLLVMGPUPasses(OpPassManager &pm, bool forROCDL) {
addLowerAndOptimizeAddressComputationPasses(pm);

// Run checks on shared memory usage.
// TODO: query this from the target.
int64_t limit = clLLVMGPUSharedMemoryLimit;
auto getSharedMemoryLimit = [limit](mlir::FunctionOpInterface) {
return limit;
auto getSharedMemoryLimitInBytes = [](mlir::FunctionOpInterface entryPoint) {
return getTargetSharedMemoryLimitInBytes(entryPoint);
};
// TODO: query this from the target.
auto getIndexBitwidth = [](mlir::FunctionOpInterface) { return 64; };
pm.addPass(
createGPUCheckResourceUsagePass(getSharedMemoryLimit, getIndexBitwidth));
pm.addPass(createGPUCheckResourceUsagePass(getSharedMemoryLimitInBytes,
getIndexBitwidth));

// SCF -> CF
pm.addNestedPass<func::FuncOp>(createConvertSCFToCFPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ iree_lit_test_suite(
"ukernel_pipeline_transform.mlir",
"vector_distribute_conversion.mlir",
"vector_distribute_layout.mlir",
"vector_distribution_pipeline_test.mlir",
"vector_lowering.mlir",
"vector_to_gpu.mlir",
"workgroup_specialization_pipeline_test.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ iree_lit_test_suite(
"ukernel_pipeline_transform.mlir"
"vector_distribute_conversion.mlir"
"vector_distribute_layout.mlir"
"vector_distribution_pipeline_test.mlir"
"vector_lowering.mlir"
"vector_to_gpu.mlir"
"workgroup_specialization_pipeline_test.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: iree-opt --split-input-file --iree-codegen-llvmgpu-use-vector-distribution \
// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-select-lowering-strategy, iree-llvmgpu-lower-executable-target, canonicalize)))' \
// RUN: %s | FileCheck %s

hal.executable @fit_shared_memory_schedule {
hal.executable.variant public @rocm_hsaco_fb
target(<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>],
target_arch = "gfx942", ukernels = "none"}>) {
hal.executable.export public @fit_shared_memory_schedule ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]} {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @fit_shared_memory_schedule() {
%cst = arith.constant 0.000000e+00 : f32
%c129181184 = arith.constant 129181184 : index
%c18112 = arith.constant 18112 : index
%c100980224 = arith.constant 100980224 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c129181184) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x80x1280xf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c18112) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x1280x1280xf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c100980224) : !flow.dispatch.tensor<writeonly:tensor<64x80x1280xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [64, 80, 1280], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x80x1280xf16>> -> tensor<64x80x1280xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [64, 1280, 1280], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x1280x1280xf16>> -> tensor<64x1280x1280xf16>
%5 = tensor.empty() : tensor<64x80x1280xf32>
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<64x80x1280xf32>) -> tensor<64x80x1280xf32>
%7 = linalg.batch_matmul ins(%3, %4 : tensor<64x80x1280xf16>, tensor<64x1280x1280xf16>) outs(%6 : tensor<64x80x1280xf32>) -> tensor<64x80x1280xf32>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0], sizes = [64, 80, 1280], strides = [1, 1, 1] : tensor<64x80x1280xf32> -> !flow.dispatch.tensor<writeonly:tensor<64x80x1280xf32>>
return
}
}
}}

// CHECK-LABEL: .executable.export public @fit_shared_memory_schedule
9 changes: 6 additions & 3 deletions compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -909,9 +909,12 @@ LogicalResult setCooperativeMatrixConfig(
GPUMMAHeuristicSeeds seeds{numSubgroupsPerWorkgroup, numMNTilesPerSubgroup,
numKTilesPerSubgroup};

std::optional<GPUMMASchedule> schedule =
deduceMMASchedule(problem, intrinsics, seeds);
if (!schedule)
int64_t sharedMemoryLimitInBytes =
targetEnv.getResourceLimits().getMaxComputeSharedMemorySize();

FailureOr<GPUMMASchedule> schedule =
deduceMMASchedule(problem, intrinsics, seeds, sharedMemoryLimitInBytes);
if (failed(schedule))
return failure();

auto pipeline = CodeGenPipeline::SPIRVCooperativeMatrixVectorize;
Expand Down

0 comments on commit 94971b4

Please sign in to comment.