-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][AMDGPU] Plumb address space 7 through MLIR, add address_space attr. #125594
Conversation
@llvm/pr-subscribers-mlir-amdgpu @llvm/pr-subscribers-backend-amdgpu Author: Krzysztof Drewniak (krzysz00) ChangesThis commit adds support for casting memrefs into fat raw buffer pointers to the AMDGPU dialect. Fat raw buffer pointers - or, in LLVM terms, ptr addrspcae(7), allow encapsulating a buffer descriptor (as produced by the make.buffer.rsrc intrinsic or provided from some API) into a pointer that supports ordinary pointer operations like load or store. This allows people to take advantage of the additional semantics that buffer_load and similar instructions provide without forcing the use of entirely separate amdgpu.raw_buffer_* operations. Operations on fat raw buffer pointers are translated to the corresponding LLVM intrinsics by the backend. This commit also goes and and defines a #amdgpu.address_space<> attribute so that AMDGPU-specific memory spaces can be represented. Only #amdgpu.address_space<fat_raw_buffer> will work correctly with the memref dialect, but the other possible address spaces are included for completeness. Patch is 37.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125594.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
index e7637a6013e68a..bb4e7bc037a373 100644
--- a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
+++ b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
@@ -16,18 +16,26 @@ namespace mlir {
class LLVMTypeConverter;
class RewritePatternSet;
+class TypeConverter;
class Pass;
#define GEN_PASS_DECL_CONVERTAMDGPUTOROCDL
#include "mlir/Conversion/Passes.h.inc"
-/// Note: The ROCDL target does not support the LLVM bfloat type at this time
-/// and so this function will add conversions to change all `bfloat` uses
-/// to `i16`.
-void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter,
+/// Note: This function will also add conversions for the AMDGPU-specific
+/// address spaces, but those can be added separately using
+/// populateAMDGPUMemorySpaceAttributeConversions().
+void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
amdgpu::Chipset chipset);
+/// Remap AMDGPU memory spaces to LLVM address spaces
+/// by mapping amdgpu::AddressSpace::fat_raw_buffer to ptr addrspace(7),
+/// amdgpu::AddressSpace::buffer_rsrc to ptr addrspace(8), and
+/// amdgpu::AddressSpace::fat_strided_buffer to ptr addrspace(9).
+void populateAMDGPUMemorySpaceAttributeConversions(
+ TypeConverter &typeConverter);
+
std::unique_ptr<Pass> createConvertAMDGPUToROCDLPass();
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 69745addfd748e..6c42849fc71f13 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -9,8 +9,11 @@
#ifndef AMDGPU
#define AMDGPU
+include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/EnumAttr.td"
+include "mlir/IR/Properties.td"
include "mlir/IR/OpBase.td"
def AMDGPU_Dialect : Dialect {
@@ -32,6 +35,45 @@ def AMDGPU_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
}
+//===----------------------------------------------------------------------===//
+// AMDGPU general attribute definitions
+//===----------------------------------------------------------------------===//
+
+def AMDGPU_AddressSpace : I32EnumAttr<"AddressSpace",
+ "AMDGPU-specific address spaces",
+ [
+ I32EnumAttrCase<"FatRawBuffer", 0, "fat_raw_buffer">,
+ I32EnumAttrCase<"BufferRsrc", 1, "buffer_rsrc">,
+ I32EnumAttrCase<"FatStructuredBuffer", 2, "fat_structured_buffer">,
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::amdgpu";
+}
+
+def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
+ "address_space"> {
+ let description = [{
+ AMDGPU-specific memory spaces that may not have exact analogues on other
+ GPU targets or backends.
+
+ - fat_raw_buffer is the memory space used when a memref is stored as
+ as a "buffer fat pointer" - that is, a buffer resource (that is set up to
+ use raw byte-level indexing) along with its offset. The AMDGPU backend
+ implements ptr addrspace(7) to represent these fat pointers so that
+ buffer resources (which allow advanced features like bounds checking or
+ cache swizzling) can be used like ordinary LLVM pointers or memrefs.
+ See also the fat_raw_buffer_cast operation
+ - buffer_rsrc is the memory space for ptr addrspace(8), representing a
+ buffer resource. It should not be used for memrefs, since it does not support
+ indexing
+ - fat_structured_buffer represents ptr addrspace(9), a buffer resource
+ that carries both an index and offset field, which are used for complex
+ structured indexing that is primarily seen in graphics applications. This
+ is also incompatible with the simple indexing model supported by memref.
+ }];
+ let assemblyFormat = [{ `<` $value `>` }];
+}
+
//===----------------------------------------------------------------------===//
// AMDGPU Op definitions
//===----------------------------------------------------------------------===//
@@ -118,6 +160,69 @@ def AMDGPU_PackedStochRoundFp8Op :
let hasVerifier = 1;
}
+def AMDGPU_FatRawBufferCastOp :
+ AMDGPU_Op<"fat_raw_buffer_cast",
+ [Pure,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>,
+ ViewLikeOpInterface, AttrSizedOperandSegments]>,
+ Arguments<(ins AnyMemRef:$source,
+ Optional<I32>:$validBytes,
+ Optional<I<14>>:$cacheSwizzleStride,
+ DefaultValuedProp<BoolProp, "true">:$boundsCheck,
+ UnitProp:$resetOffset)>,
+ Results<(outs AnyMemRef:$result)> {
+ let summary = "Create a raw buffer fat pointer that matches `memref`";
+ let description = [{
+ Wraps the memory pointed to by `in` as a raw buffer fat pointer, or,
+ in LLVM terms, a ptr addrspace(7), returning a memref that has the same
+ sizes and layout but the `#amdgpu.address_space<fat_raw_buffer>`
+ address space.
+
+ This memref can be used with standard memref operations like `memref.load`,
+ `memref.store`, and `memref.atomicrmw`, which will be lowered to the relevant
+ buffer intrinsics. (`vector.masked_load/store` will work once there's backend
+ support for lowering them, and then this document will be updated)
+
+ If `validBytes` is given, it is the number of bytes that will be valid as
+ an offset to `out`. If it is not provided, this will be inferred from
+ the size of the memref during lowering. This size is
+ max_d (sizes[d] * strides[d]) * sizeof(element type)..
+
+ The flags of the buffer descriptor will be set up to enable raw usage -
+ for example, stride = 0, add_tid = 0, and so on. The `boundsCheck`
+ property determines if bounds checking is enabled or not (on architectures
+ where this can be controlled - that is, on RDNA chips).
+
+ If `cacheSwizzleStride` is provided, L1 cache swizzling will be enabled
+ on architectures that support it. This swizzling, unlike the main swizzling
+ mode (whose usage makes a buffer non-raw) does not affect index calculaton,
+ but does affect cache behavior. Mixing access between cache-swizzled raw
+ buffers and other forms of memory access, like ordinary pointer loads or
+ unswizzled buffer pointers can cause incorrect behavior and must be avoided.
+
+ This operation preserves the sizes, strides, and offset of the input
+ memref - they'll be added in by `memref.load` later. However, if
+ `resetOffset` is set, that offset will be added to the base pointer.
+ If the value of the memref's offset is not independent of the lane/thread ID,
+ this will lead to substantially decreased performance due to the need for
+ a waterfall loop on the base address of the buffer resource.
+ }];
+
+ let extraClassDeclaration = [{
+ Value getViewSource() { return getSource(); }
+ }];
+
+ let assemblyFormat = [{
+ $source oilist (`validBytes` `(` $validBytes `)`
+ | `cacheSwizzleStride` `(` $cacheSwizzleStride `)`
+ | `boundsCheck` `(` $boundsCheck `)`
+ | `resetOffset` $resetOffset )
+ attr-dict `:` type($source) `to` type($result)
+ }];
+
+ let hasVerifier = 1;
+}
+
/// Raw buffer load
def AMDGPU_RawBufferLoadOp :
AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index 0a2e6bb5e9fe49..3de57c923178ad 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -18,7 +18,9 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h.inc"
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 51f5d7a161b903..d41b4d00f7a365 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -19,6 +19,8 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
+#include "../LLVMCommon/MemRefDescriptor.h"
+
#include "llvm/ADT/STLExtras.h"
#include <optional>
@@ -30,6 +32,11 @@ namespace mlir {
using namespace mlir;
using namespace mlir::amdgpu;
+// Define commonly used chipsets versions for convenience.
+static constexpr Chipset kGfx908 = Chipset(9, 0, 8);
+static constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
+static constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+
/// Convert an unsigned number `val` to i32.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
Location loc, Value val) {
@@ -76,11 +83,164 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
return index ? index : createI32Constant(rewriter, loc, 0);
}
+/// Compute the contents of the `num_records` field for a given memref
+/// descriptor - that is, the number of bytes that's one element past the
+/// greatest possible valid index into the memref.
+static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
+ MemRefType memrefType,
+ MemRefDescriptor &memrefDescriptor,
+ ArrayRef<int64_t> strides,
+ uint32_t elementByteWidth) {
+ if (memrefType.hasStaticShape() &&
+ !llvm::any_of(strides, ShapedType::isDynamic)) {
+ int64_t size = memrefType.getRank() == 0 ? 1 : 0;
+ ArrayRef<int64_t> shape = memrefType.getShape();
+ for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
+ size = std::max(shape[i] * strides[i], size);
+ size = size * elementByteWidth;
+ assert(size < std::numeric_limits<uint32_t>::max() &&
+ "the memref buffer is too large");
+ return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
+ }
+ Value maxIndex;
+ for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
+ Value size = memrefDescriptor.size(rewriter, loc, i);
+ Value stride = memrefDescriptor.stride(rewriter, loc, i);
+ Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
+ maxIndex = maxIndex
+ ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+ : maxThisDim;
+ }
+ return rewriter.create<LLVM::MulOp>(
+ loc, convertUnsignedToI32(rewriter, loc, maxIndex),
+ createI32Constant(rewriter, loc, elementByteWidth));
+}
+
+static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
+ Value basePointer, Value numRecords,
+ bool boundsCheck, amdgpu::Chipset chipset,
+ Value cacheSwizzleStride = nullptr) {
+ // The stride value is generally 0. However, on MI-300 and onward, you can
+ // enable a cache swizzling mode by setting bit 14 of the stride field
+ // and setting that stride to a cache stride.
+ Type i16 = rewriter.getI16Type();
+ Value stride;
+ if (chipset.majorVersion == 9 && chipset >= kGfx940 && cacheSwizzleStride) {
+ Value cacheStrideZext =
+ rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
+ Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
+ loc, i16, rewriter.getI16IntegerAttr(1 << 14));
+ stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
+ /*isDisjoint=*/true);
+ } else {
+ stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
+ rewriter.getI16IntegerAttr(0));
+ }
+ // Get the number of elements.
+ // Flag word:
+ // bits 0-11: dst sel, ignored by these intrinsics
+ // bits 12-14: data format (ignored, must be nonzero, 7=float)
+ // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
+ // bit 19: In nested heap (0 here)
+ // bit 20: Behavior on unmap (0 means "return 0 / ignore")
+ // bits 21-22: Index stride for swizzles (N/A)
+ // bit 23: Add thread ID (0)
+ // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
+ // bits 25-26: Reserved (0)
+ // bit 27: Buffer is non-volatile (CDNA only)
+ // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
+ // none, 3 = either swizzles or testing against offset field) RDNA only
+ // bits 30-31: Type (must be 0)
+ uint32_t flags = (7 << 12) | (4 << 15);
+ if (chipset.majorVersion >= 10) {
+ flags |= (1 << 24);
+ uint32_t oob = boundsCheck ? 3 : 2;
+ flags |= (oob << 28);
+ }
+ Value flagsConst = createI32Constant(rewriter, loc, flags);
+ Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
+ Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
+ loc, rsrcType, basePointer, stride, numRecords, flagsConst);
+ return resource;
+}
+
namespace {
-// Define commonly used chipsets versions for convenience.
-constexpr Chipset kGfx908 = Chipset(9, 0, 8);
-constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
-constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+struct FatRawBufferCastLowering
+ : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
+ FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
+ chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value memRef = adaptor.getSource();
+ Value unconvertedMemref = op.getSource();
+ MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
+ MemRefDescriptor descriptor(memRef);
+
+ DataLayout dataLayout = DataLayout::closest(op);
+ int64_t elementByteWidth =
+ dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
+
+ int64_t unusedOffset = 0;
+ SmallVector<int64_t, 5> strideVals;
+ if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
+ return op.emitOpError("Can't lower non-stride-offset memrefs");
+
+ Value numRecords = adaptor.getValidBytes();
+ if (!numRecords)
+ numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
+ strideVals, elementByteWidth);
+
+ Value basePointer;
+ if (adaptor.getResetOffset())
+ basePointer =
+ descriptor.bufferPtr(rewriter, loc, *getTypeConverter(), memrefType);
+ else
+ basePointer = descriptor.alignedPtr(rewriter, loc);
+
+ Value offset;
+ if (adaptor.getResetOffset())
+ offset = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+ rewriter.getIndexAttr(0));
+ else
+ offset = descriptor.offset(rewriter, loc);
+
+ // No need to unpack() and pack() all the individual sizes and strides,
+ // so we'll just extract the arrays.
+ Value sizes = rewriter.create<LLVM::ExtractValueOp>(
+ loc, descriptor, kSizePosInMemRefDescriptor);
+ Value strides = rewriter.create<LLVM::ExtractValueOp>(
+ loc, descriptor, kStridePosInMemRefDescriptor);
+
+ Value rsrc = makeBufferRsrc(rewriter, loc, basePointer, numRecords,
+ adaptor.getBoundsCheck(), chipset,
+ adaptor.getCacheSwizzleStride());
+ Value fatPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
+ loc, LLVM::LLVMPointerType::get(op.getContext(), 7), rsrc);
+
+ Value result = MemRefDescriptor::undef(
+ rewriter, loc,
+ getTypeConverter()->convertType(op.getResult().getType()));
+ result = rewriter.create<LLVM::InsertValueOp>(
+ loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
+ result = rewriter.create<LLVM::InsertValueOp>(
+ loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
+ result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
+ kOffsetPosInMemRefDescriptor);
+ result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
+ kSizePosInMemRefDescriptor);
+ result = rewriter.create<LLVM::InsertValueOp>(loc, result, strides,
+ kStridePosInMemRefDescriptor);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
/// Define lowering patterns for raw buffer ops
template <typename GpuOp, typename Intrinsic>
@@ -122,7 +282,6 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
Type i32 = rewriter.getI32Type();
- Type i16 = rewriter.getI16Type();
// Get the type size in bytes.
DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -199,60 +358,10 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Value ptr = memrefDescriptor.bufferPtr(
rewriter, loc, *this->getTypeConverter(), memrefType);
- // The stride value is always 0 for raw buffers. This also disables
- // swizling.
- Value stride = rewriter.create<LLVM::ConstantOp>(
- loc, i16, rewriter.getI16IntegerAttr(0));
- // Get the number of elements.
- Value numRecords;
- if (memrefType.hasStaticShape() &&
- !llvm::any_of(strides, ShapedType::isDynamic)) {
- int64_t size = memrefType.getRank() == 0 ? 1 : 0;
- ArrayRef<int64_t> shape = memrefType.getShape();
- for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
- size = std::max(shape[i] * strides[i], size);
- size = size * elementByteWidth;
- assert(size < std::numeric_limits<uint32_t>::max() &&
- "the memref buffer is too large");
- numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
- } else {
- Value maxIndex;
- for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
- Value size = memrefDescriptor.size(rewriter, loc, i);
- Value stride = memrefDescriptor.stride(rewriter, loc, i);
- Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
- maxIndex =
- maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
- : maxThisDim;
- }
- numRecords = rewriter.create<LLVM::MulOp>(
- loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
- }
-
- // Flag word:
- // bits 0-11: dst sel, ignored by these intrinsics
- // bits 12-14: data format (ignored, must be nonzero, 7=float)
- // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
- // bit 19: In nested heap (0 here)
- // bit 20: Behavior on unmap (0 means "return 0 / ignore")
- // bits 21-22: Index stride for swizzles (N/A)
- // bit 23: Add thread ID (0)
- // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
- // bits 25-26: Reserved (0)
- // bit 27: Buffer is non-volatile (CDNA only)
- // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
- // none, 3 = either swizzles or testing against offset field) RDNA only
- // bits 30-31: Type (must be 0)
- uint32_t flags = (7 << 12) | (4 << 15);
- if (chipset.majorVersion >= 10) {
- flags |= (1 << 24);
- uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
- flags |= (oob << 28);
- }
- Value flagsConst = createI32Constant(rewriter, loc, flags);
- Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
- Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
- loc, rsrcType, ptr, stride, numRecords, flagsConst);
+ Value numRecords = getNumRecords(
+ rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
+ Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
+ adaptor.getBoundsCheck(), chipset);
args.push_back(resource);
// Indexing (voffset)
...
[truncated]
|
@llvm/pr-subscribers-mlir-gpu Author: Krzysztof Drewniak (krzysz00) ChangesThis commit adds support for casting memrefs into fat raw buffer pointers to the AMDGPU dialect. Fat raw buffer pointers - or, in LLVM terms, ptr addrspcae(7), allow encapsulating a buffer descriptor (as produced by the make.buffer.rsrc intrinsic or provided from some API) into a pointer that supports ordinary pointer operations like load or store. This allows people to take advantage of the additional semantics that buffer_load and similar instructions provide without forcing the use of entirely separate amdgpu.raw_buffer_* operations. Operations on fat raw buffer pointers are translated to the corresponding LLVM intrinsics by the backend. This commit also goes and and defines a #amdgpu.address_space<> attribute so that AMDGPU-specific memory spaces can be represented. Only #amdgpu.address_space<fat_raw_buffer> will work correctly with the memref dialect, but the other possible address spaces are included for completeness. Patch is 37.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125594.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
index e7637a6013e68ad..bb4e7bc037a373c 100644
--- a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
+++ b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
@@ -16,18 +16,26 @@ namespace mlir {
class LLVMTypeConverter;
class RewritePatternSet;
+class TypeConverter;
class Pass;
#define GEN_PASS_DECL_CONVERTAMDGPUTOROCDL
#include "mlir/Conversion/Passes.h.inc"
-/// Note: The ROCDL target does not support the LLVM bfloat type at this time
-/// and so this function will add conversions to change all `bfloat` uses
-/// to `i16`.
-void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter,
+/// Note: This function will also add conversions for the AMDGPU-specific
+/// address spaces, but those can be added separately using
+/// populateAMDGPUMemorySpaceAttributeConversions().
+void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
amdgpu::Chipset chipset);
+/// Remap AMDGPU memory spaces to LLVM address spaces
+/// by mapping amdgpu::AddressSpace::fat_raw_buffer to ptr addrspace(7),
+/// amdgpu::AddressSpace::buffer_rsrc to ptr addrspace(8), and
+/// amdgpu::AddressSpace::fat_strided_buffer to ptr addrspace(9).
+void populateAMDGPUMemorySpaceAttributeConversions(
+ TypeConverter &typeConverter);
+
std::unique_ptr<Pass> createConvertAMDGPUToROCDLPass();
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 69745addfd748ec..6c42849fc71f134 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -9,8 +9,11 @@
#ifndef AMDGPU
#define AMDGPU
+include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/EnumAttr.td"
+include "mlir/IR/Properties.td"
include "mlir/IR/OpBase.td"
def AMDGPU_Dialect : Dialect {
@@ -32,6 +35,45 @@ def AMDGPU_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
}
+//===----------------------------------------------------------------------===//
+// AMDGPU general attribute definitions
+//===----------------------------------------------------------------------===//
+
+def AMDGPU_AddressSpace : I32EnumAttr<"AddressSpace",
+ "AMDGPU-specific address spaces",
+ [
+ I32EnumAttrCase<"FatRawBuffer", 0, "fat_raw_buffer">,
+ I32EnumAttrCase<"BufferRsrc", 1, "buffer_rsrc">,
+ I32EnumAttrCase<"FatStructuredBuffer", 2, "fat_structured_buffer">,
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::amdgpu";
+}
+
+def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
+ "address_space"> {
+ let description = [{
+ AMDGPU-specific memory spaces that may not have exact analogues on other
+ GPU targets or backends.
+
+ - fat_raw_buffer is the memory space used when a memref is stored as
+ as a "buffer fat pointer" - that is, a buffer resource (that is set up to
+ use raw byte-level indexing) along with its offset. The AMDGPU backend
+ implements ptr addrspace(7) to represent these fat pointers so that
+ buffer resources (which allow advanced features like bounds checking or
+ cache swizzling) can be used like ordinary LLVM pointers or memrefs.
+ See also the fat_raw_buffer_cast operation
+ - buffer_rsrc is the memory space for ptr addrspace(8), representing a
+ buffer resource. It should not be used for memrefs, since it does not support
+ indexing
+ - fat_structured_buffer represents ptr addrspace(9), a buffer resource
+ that carries both an index and offset field, which are used for complex
+ structured indexing that is primarily seen in graphics applications. This
+ is also incompatible with the simple indexing model supported by memref.
+ }];
+ let assemblyFormat = [{ `<` $value `>` }];
+}
+
//===----------------------------------------------------------------------===//
// AMDGPU Op definitions
//===----------------------------------------------------------------------===//
@@ -118,6 +160,69 @@ def AMDGPU_PackedStochRoundFp8Op :
let hasVerifier = 1;
}
+def AMDGPU_FatRawBufferCastOp :
+ AMDGPU_Op<"fat_raw_buffer_cast",
+ [Pure,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>,
+ ViewLikeOpInterface, AttrSizedOperandSegments]>,
+ Arguments<(ins AnyMemRef:$source,
+ Optional<I32>:$validBytes,
+ Optional<I<14>>:$cacheSwizzleStride,
+ DefaultValuedProp<BoolProp, "true">:$boundsCheck,
+ UnitProp:$resetOffset)>,
+ Results<(outs AnyMemRef:$result)> {
+ let summary = "Create a raw buffer fat pointer that matches `memref`";
+ let description = [{
+ Wraps the memory pointed to by `in` as a raw buffer fat pointer, or,
+ in LLVM terms, a ptr addrspace(7), returning a memref that has the same
+ sizes and layout but the `#amdgpu.address_space<fat_raw_buffer>`
+ address space.
+
+ This memref can be used with standard memref operations like `memref.load`,
+ `memref.store`, and `memref.atomicrmw`, which will be lowered to the relevant
+ buffer intrinsics. (`vector.masked_load/store` will work once there's backend
+ support for lowering them, and then this document will be updated)
+
+ If `validBytes` is given, it is the number of bytes that will be valid as
+ an offset to `out`. If it is not provided, this will be inferred from
+ the size of the memref during lowering. This size is
+ max_d (sizes[d] * strides[d]) * sizeof(element type)..
+
+ The flags of the buffer descriptor will be set up to enable raw usage -
+ for example, stride = 0, add_tid = 0, and so on. The `boundsCheck`
+ property determines if bounds checking is enabled or not (on architectures
+ where this can be controlled - that is, on RDNA chips).
+
+ If `cacheSwizzleStride` is provided, L1 cache swizzling will be enabled
+ on architectures that support it. This swizzling, unlike the main swizzling
+ mode (whose usage makes a buffer non-raw) does not affect index calculaton,
+ but does affect cache behavior. Mixing access between cache-swizzled raw
+ buffers and other forms of memory access, like ordinary pointer loads or
+ unswizzled buffer pointers can cause incorrect behavior and must be avoided.
+
+ This operation preserves the sizes, strides, and offset of the input
+ memref - they'll be added in by `memref.load` later. However, if
+ `resetOffset` is set, that offset will be added to the base pointer.
+ If the value of the memref's offset is not independent of the lane/thread ID,
+ this will lead to substantially decreased performance due to the need for
+ a waterfall loop on the base address of the buffer resource.
+ }];
+
+ let extraClassDeclaration = [{
+ Value getViewSource() { return getSource(); }
+ }];
+
+ let assemblyFormat = [{
+ $source oilist (`validBytes` `(` $validBytes `)`
+ | `cacheSwizzleStride` `(` $cacheSwizzleStride `)`
+ | `boundsCheck` `(` $boundsCheck `)`
+ | `resetOffset` $resetOffset )
+ attr-dict `:` type($source) `to` type($result)
+ }];
+
+ let hasVerifier = 1;
+}
+
/// Raw buffer load
def AMDGPU_RawBufferLoadOp :
AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index 0a2e6bb5e9fe494..3de57c923178ad9 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -18,7 +18,9 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h.inc"
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 51f5d7a161b9030..d41b4d00f7a3658 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -19,6 +19,8 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
+#include "../LLVMCommon/MemRefDescriptor.h"
+
#include "llvm/ADT/STLExtras.h"
#include <optional>
@@ -30,6 +32,11 @@ namespace mlir {
using namespace mlir;
using namespace mlir::amdgpu;
+// Define commonly used chipsets versions for convenience.
+static constexpr Chipset kGfx908 = Chipset(9, 0, 8);
+static constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
+static constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+
/// Convert an unsigned number `val` to i32.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
Location loc, Value val) {
@@ -76,11 +83,164 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
return index ? index : createI32Constant(rewriter, loc, 0);
}
+/// Compute the contents of the `num_records` field for a given memref
+/// descriptor - that is, the number of bytes that's one element past the
+/// greatest possible valid index into the memref.
+static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
+ MemRefType memrefType,
+ MemRefDescriptor &memrefDescriptor,
+ ArrayRef<int64_t> strides,
+ uint32_t elementByteWidth) {
+ if (memrefType.hasStaticShape() &&
+ !llvm::any_of(strides, ShapedType::isDynamic)) {
+ int64_t size = memrefType.getRank() == 0 ? 1 : 0;
+ ArrayRef<int64_t> shape = memrefType.getShape();
+ for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
+ size = std::max(shape[i] * strides[i], size);
+ size = size * elementByteWidth;
+ assert(size < std::numeric_limits<uint32_t>::max() &&
+ "the memref buffer is too large");
+ return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
+ }
+ Value maxIndex;
+ for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
+ Value size = memrefDescriptor.size(rewriter, loc, i);
+ Value stride = memrefDescriptor.stride(rewriter, loc, i);
+ Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
+ maxIndex = maxIndex
+ ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+ : maxThisDim;
+ }
+ return rewriter.create<LLVM::MulOp>(
+ loc, convertUnsignedToI32(rewriter, loc, maxIndex),
+ createI32Constant(rewriter, loc, elementByteWidth));
+}
+
+static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
+ Value basePointer, Value numRecords,
+ bool boundsCheck, amdgpu::Chipset chipset,
+ Value cacheSwizzleStride = nullptr) {
+ // The stride value is generally 0. However, on MI-300 and onward, you can
+ // enable a cache swizzling mode by setting bit 14 of the stride field
+ // and setting that stride to a cache stride.
+ Type i16 = rewriter.getI16Type();
+ Value stride;
+ if (chipset.majorVersion == 9 && chipset >= kGfx940 && cacheSwizzleStride) {
+ Value cacheStrideZext =
+ rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
+ Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
+ loc, i16, rewriter.getI16IntegerAttr(1 << 14));
+ stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
+ /*isDisjoint=*/true);
+ } else {
+ stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
+ rewriter.getI16IntegerAttr(0));
+ }
+ // Get the number of elements.
+ // Flag word:
+ // bits 0-11: dst sel, ignored by these intrinsics
+ // bits 12-14: data format (ignored, must be nonzero, 7=float)
+ // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
+ // bit 19: In nested heap (0 here)
+ // bit 20: Behavior on unmap (0 means "return 0 / ignore")
+ // bits 21-22: Index stride for swizzles (N/A)
+ // bit 23: Add thread ID (0)
+ // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
+ // bits 25-26: Reserved (0)
+ // bit 27: Buffer is non-volatile (CDNA only)
+ // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
+ // none, 3 = either swizzles or testing against offset field) RDNA only
+ // bits 30-31: Type (must be 0)
+ uint32_t flags = (7 << 12) | (4 << 15);
+ if (chipset.majorVersion >= 10) {
+ flags |= (1 << 24);
+ uint32_t oob = boundsCheck ? 3 : 2;
+ flags |= (oob << 28);
+ }
+ Value flagsConst = createI32Constant(rewriter, loc, flags);
+ Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
+ Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
+ loc, rsrcType, basePointer, stride, numRecords, flagsConst);
+ return resource;
+}
+
namespace {
-// Define commonly used chipsets versions for convenience.
-constexpr Chipset kGfx908 = Chipset(9, 0, 8);
-constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
-constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+struct FatRawBufferCastLowering
+ : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
+ FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
+ chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value memRef = adaptor.getSource();
+ Value unconvertedMemref = op.getSource();
+ MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
+ MemRefDescriptor descriptor(memRef);
+
+ DataLayout dataLayout = DataLayout::closest(op);
+ int64_t elementByteWidth =
+ dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
+
+ int64_t unusedOffset = 0;
+ SmallVector<int64_t, 5> strideVals;
+ if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
+ return op.emitOpError("Can't lower non-stride-offset memrefs");
+
+ Value numRecords = adaptor.getValidBytes();
+ if (!numRecords)
+ numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
+ strideVals, elementByteWidth);
+
+ Value basePointer;
+ if (adaptor.getResetOffset())
+ basePointer =
+ descriptor.bufferPtr(rewriter, loc, *getTypeConverter(), memrefType);
+ else
+ basePointer = descriptor.alignedPtr(rewriter, loc);
+
+ Value offset;
+ if (adaptor.getResetOffset())
+ offset = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+ rewriter.getIndexAttr(0));
+ else
+ offset = descriptor.offset(rewriter, loc);
+
+ // No need to unpack() and pack() all the individual sizes and strides,
+ // so we'll just extract the arrays.
+ Value sizes = rewriter.create<LLVM::ExtractValueOp>(
+ loc, descriptor, kSizePosInMemRefDescriptor);
+ Value strides = rewriter.create<LLVM::ExtractValueOp>(
+ loc, descriptor, kStridePosInMemRefDescriptor);
+
+ Value rsrc = makeBufferRsrc(rewriter, loc, basePointer, numRecords,
+ adaptor.getBoundsCheck(), chipset,
+ adaptor.getCacheSwizzleStride());
+ Value fatPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
+ loc, LLVM::LLVMPointerType::get(op.getContext(), 7), rsrc);
+
+ Value result = MemRefDescriptor::undef(
+ rewriter, loc,
+ getTypeConverter()->convertType(op.getResult().getType()));
+ result = rewriter.create<LLVM::InsertValueOp>(
+ loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
+ result = rewriter.create<LLVM::InsertValueOp>(
+ loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
+ result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
+ kOffsetPosInMemRefDescriptor);
+ result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
+ kSizePosInMemRefDescriptor);
+ result = rewriter.create<LLVM::InsertValueOp>(loc, result, strides,
+ kStridePosInMemRefDescriptor);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
/// Define lowering patterns for raw buffer ops
template <typename GpuOp, typename Intrinsic>
@@ -122,7 +282,6 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
Type i32 = rewriter.getI32Type();
- Type i16 = rewriter.getI16Type();
// Get the type size in bytes.
DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -199,60 +358,10 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Value ptr = memrefDescriptor.bufferPtr(
rewriter, loc, *this->getTypeConverter(), memrefType);
- // The stride value is always 0 for raw buffers. This also disables
- // swizling.
- Value stride = rewriter.create<LLVM::ConstantOp>(
- loc, i16, rewriter.getI16IntegerAttr(0));
- // Get the number of elements.
- Value numRecords;
- if (memrefType.hasStaticShape() &&
- !llvm::any_of(strides, ShapedType::isDynamic)) {
- int64_t size = memrefType.getRank() == 0 ? 1 : 0;
- ArrayRef<int64_t> shape = memrefType.getShape();
- for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
- size = std::max(shape[i] * strides[i], size);
- size = size * elementByteWidth;
- assert(size < std::numeric_limits<uint32_t>::max() &&
- "the memref buffer is too large");
- numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
- } else {
- Value maxIndex;
- for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
- Value size = memrefDescriptor.size(rewriter, loc, i);
- Value stride = memrefDescriptor.stride(rewriter, loc, i);
- Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
- maxIndex =
- maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
- : maxThisDim;
- }
- numRecords = rewriter.create<LLVM::MulOp>(
- loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
- }
-
- // Flag word:
- // bits 0-11: dst sel, ignored by these intrinsics
- // bits 12-14: data format (ignored, must be nonzero, 7=float)
- // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
- // bit 19: In nested heap (0 here)
- // bit 20: Behavior on unmap (0 means "return 0 / ignore")
- // bits 21-22: Index stride for swizzles (N/A)
- // bit 23: Add thread ID (0)
- // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
- // bits 25-26: Reserved (0)
- // bit 27: Buffer is non-volatile (CDNA only)
- // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
- // none, 3 = either swizzles or testing against offset field) RDNA only
- // bits 30-31: Type (must be 0)
- uint32_t flags = (7 << 12) | (4 << 15);
- if (chipset.majorVersion >= 10) {
- flags |= (1 << 24);
- uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
- flags |= (oob << 28);
- }
- Value flagsConst = createI32Constant(rewriter, loc, flags);
- Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
- Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
- loc, rsrcType, ptr, stride, numRecords, flagsConst);
+ Value numRecords = getNumRecords(
+ rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
+ Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
+ adaptor.getBoundsCheck(), chipset);
args.push_back(resource);
// Indexing (v...
[truncated]
|
2df9493
to
07a8e93
Compare
✅ With the latest revision this PR passed the undef deprecator. |
07a8e93
to
63039ff
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall
✅ With the latest revision this PR passed the C/C++ code formatter. |
@krzysz00 seems like there's a failing test on windows |
Note to self: probably an order of evaluation within functions being underspecified issue - just need to find and fix the expression. (and probably also add some strided metadata patterns) |
44d296f
to
1ee8405
Compare
…attr. This commit adds support for casting memrefs into fat raw buffer pointers to the AMDGPU dialect. Fat raw buffer pointers - or, in LLVM terms, ptr addrspcae(7), allow encapsulating a buffer descriptor (as produced by the make.buffer.rsrc intrinsic or provided from some API) into a pointer that supports ordinary pointer operations like load or store. This allows people to take advantage of the additional semantics that buffer_load and similar instructions provide without forcing the use of entirely separate amdgpu.raw_buffer_* operations. Operations on fat raw buffer pointers are translated to the corresponding LLVM intrinsics by the backend. This commit also goes and and defines a #amdgpu.address_space<> attribute so that AMDGPU-specific memory spaces can be represented. Only #amdgpu.address_space<fat_raw_buffer> will work correctly with the memref dialect, but the other possible address spaces are included for completeness. Co-authored-by: Jakub Kuderski <[email protected]> Co-authored-by: Prashant Kumar <[email protected]>
1ee8405
to
ceed9d2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall, just a couple of issues
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…attr. (llvm#125594) This commit adds support for casting memrefs into fat raw buffer pointers to the AMDGPU dialect. Fat raw buffer pointers - or, in LLVM terms, ptr addrspcae(7), allow encapsulating a buffer descriptor (as produced by the make.buffer.rsrc intrinsic or provided from some API) into a pointer that supports ordinary pointer operations like load or store. This allows people to take advantage of the additional semantics that buffer_load and similar instructions provide without forcing the use of entirely separate amdgpu.raw_buffer_* operations. Operations on fat raw buffer pointers are translated to the corresponding LLVM intrinsics by the backend. This commit also goes and and defines a #amdgpu.address_space<> attribute so that AMDGPU-specific memory spaces can be represented. Only #amdgpu.address_space<fat_raw_buffer> will work correctly with the memref dialect, but the other possible address spaces are included for completeness. --------- Co-authored-by: Jakub Kuderski <[email protected]> Co-authored-by: Prashant Kumar <[email protected]>
…attr. (llvm#125594) This commit adds support for casting memrefs into fat raw buffer pointers to the AMDGPU dialect. Fat raw buffer pointers - or, in LLVM terms, ptr addrspcae(7), allow encapsulating a buffer descriptor (as produced by the make.buffer.rsrc intrinsic or provided from some API) into a pointer that supports ordinary pointer operations like load or store. This allows people to take advantage of the additional semantics that buffer_load and similar instructions provide without forcing the use of entirely separate amdgpu.raw_buffer_* operations. Operations on fat raw buffer pointers are translated to the corresponding LLVM intrinsics by the backend. This commit also goes and and defines a #amdgpu.address_space<> attribute so that AMDGPU-specific memory spaces can be represented. Only #amdgpu.address_space<fat_raw_buffer> will work correctly with the memref dialect, but the other possible address spaces are included for completeness. --------- Co-authored-by: Jakub Kuderski <[email protected]> Co-authored-by: Prashant Kumar <[email protected]>
This commit adds support for casting memrefs into fat raw buffer pointers to the AMDGPU dialect.
Fat raw buffer pointers - or, in LLVM terms, ptr addrspcae(7), allow encapsulating a buffer descriptor (as produced by the make.buffer.rsrc intrinsic or provided from some API) into a pointer that supports ordinary pointer operations like load or store. This allows people to take advantage of the additional semantics that buffer_load and similar instructions provide without forcing the use of entirely separate amdgpu.raw_buffer_* operations.
Operations on fat raw buffer pointers are translated to the corresponding LLVM intrinsics by the backend.
This commit also goes and and defines a #amdgpu.address_space<> attribute so that AMDGPU-specific memory spaces can be represented. Only #amdgpu.address_space<fat_raw_buffer> will work correctly with the memref dialect, but the other possible address spaces are included for completeness.