Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][AMDGPU] Plumb address space 7 through MLIR, add address_space attr. #125594

Merged
merged 3 commits into from
Feb 26, 2025

Conversation

krzysz00
Copy link
Contributor

@krzysz00 krzysz00 commented Feb 3, 2025

This commit adds support for casting memrefs into fat raw buffer pointers to the AMDGPU dialect.

Fat raw buffer pointers - or, in LLVM terms, ptr addrspcae(7), allow encapsulating a buffer descriptor (as produced by the make.buffer.rsrc intrinsic or provided from some API) into a pointer that supports ordinary pointer operations like load or store. This allows people to take advantage of the additional semantics that buffer_load and similar instructions provide without forcing the use of entirely separate amdgpu.raw_buffer_* operations.

Operations on fat raw buffer pointers are translated to the corresponding LLVM intrinsics by the backend.

This commit also goes and and defines a #amdgpu.address_space<> attribute so that AMDGPU-specific memory spaces can be represented. Only #amdgpu.address_space<fat_raw_buffer> will work correctly with the memref dialect, but the other possible address spaces are included for completeness.

@llvmbot
Copy link
Member

llvmbot commented Feb 3, 2025

@llvm/pr-subscribers-mlir-amdgpu

@llvm/pr-subscribers-backend-amdgpu

Author: Krzysztof Drewniak (krzysz00)

Changes

This commit adds support for casting memrefs into fat raw buffer pointers to the AMDGPU dialect.

Fat raw buffer pointers - or, in LLVM terms, ptr addrspcae(7), allow encapsulating a buffer descriptor (as produced by the make.buffer.rsrc intrinsic or provided from some API) into a pointer that supports ordinary pointer operations like load or store. This allows people to take advantage of the additional semantics that buffer_load and similar instructions provide without forcing the use of entirely separate amdgpu.raw_buffer_* operations.

Operations on fat raw buffer pointers are translated to the corresponding LLVM intrinsics by the backend.

This commit also goes and and defines a #amdgpu.address_space<> attribute so that AMDGPU-specific memory spaces can be represented. Only #amdgpu.address_space<fat_raw_buffer> will work correctly with the memref dialect, but the other possible address spaces are included for completeness.


Patch is 37.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125594.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h (+12-4)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+105)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h (+2)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+193-63)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+53)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir (+109-12)
diff --git a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
index e7637a6013e68a..bb4e7bc037a373 100644
--- a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
+++ b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
@@ -16,18 +16,26 @@ namespace mlir {
 
 class LLVMTypeConverter;
 class RewritePatternSet;
+class TypeConverter;
 class Pass;
 
 #define GEN_PASS_DECL_CONVERTAMDGPUTOROCDL
 #include "mlir/Conversion/Passes.h.inc"
 
-/// Note: The ROCDL target does not support the LLVM bfloat type at this time
-/// and so this function will add conversions to change all `bfloat` uses
-/// to `i16`.
-void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter,
+/// Note: This function will also add conversions for the AMDGPU-specific
+/// address spaces, but those can be added separately using
+/// populateAMDGPUMemorySpaceAttributeConversions().
+void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                              RewritePatternSet &patterns,
                                              amdgpu::Chipset chipset);
 
+/// Remap AMDGPU memory spaces to LLVM address spaces
+/// by mapping amdgpu::AddressSpace::fat_raw_buffer to ptr addrspace(7),
+/// amdgpu::AddressSpace::buffer_rsrc to ptr addrspace(8), and
+/// amdgpu::AddressSpace::fat_strided_buffer to ptr addrspace(9).
+void populateAMDGPUMemorySpaceAttributeConversions(
+    TypeConverter &typeConverter);
+
 std::unique_ptr<Pass> createConvertAMDGPUToROCDLPass();
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 69745addfd748e..6c42849fc71f13 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -9,8 +9,11 @@
 #ifndef AMDGPU
 #define AMDGPU
 
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ViewLikeInterface.td"
 include "mlir/IR/EnumAttr.td"
+include "mlir/IR/Properties.td"
 include "mlir/IR/OpBase.td"
 
 def AMDGPU_Dialect : Dialect {
@@ -32,6 +35,45 @@ def AMDGPU_Dialect : Dialect {
   let useDefaultAttributePrinterParser = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// AMDGPU general attribute definitions
+//===----------------------------------------------------------------------===//
+
+def AMDGPU_AddressSpace : I32EnumAttr<"AddressSpace",
+    "AMDGPU-specific address spaces",
+    [
+      I32EnumAttrCase<"FatRawBuffer",        0, "fat_raw_buffer">,
+      I32EnumAttrCase<"BufferRsrc",          1, "buffer_rsrc">,
+      I32EnumAttrCase<"FatStructuredBuffer", 2, "fat_structured_buffer">,
+    ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::amdgpu";
+}
+
+def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
+    "address_space"> {
+  let description = [{
+    AMDGPU-specific memory spaces that may not have exact analogues on other
+    GPU targets or backends.
+
+    - fat_raw_buffer is the memory space used when a memref is stored as
+    as a "buffer fat pointer" - that is, a buffer resource (that is set up to
+    use raw byte-level indexing) along with its offset. The AMDGPU backend
+    implements ptr addrspace(7) to represent these fat pointers so that
+    buffer resources (which allow advanced features like bounds checking or
+    cache swizzling) can be used like ordinary LLVM pointers or memrefs.
+    See also the fat_raw_buffer_cast operation
+    - buffer_rsrc is the memory space for ptr addrspace(8), representing a
+    buffer resource. It should not be used for memrefs, since it does not support
+    indexing
+    - fat_structured_buffer represents ptr addrspace(9), a buffer resource
+    that carries both an index and offset field, which are used for complex
+    structured indexing that is primarily seen in graphics applications. This
+    is also incompatible with the simple indexing model supported by memref.
+  }];
+  let assemblyFormat = [{ `<` $value `>` }];
+}
+
 //===----------------------------------------------------------------------===//
 // AMDGPU Op definitions
 //===----------------------------------------------------------------------===//
@@ -118,6 +160,69 @@ def AMDGPU_PackedStochRoundFp8Op :
   let hasVerifier = 1;
 }
 
+def AMDGPU_FatRawBufferCastOp :
+    AMDGPU_Op<"fat_raw_buffer_cast",
+      [Pure,
+       DeclareOpInterfaceMethods<InferTypeOpInterface>,
+       ViewLikeOpInterface, AttrSizedOperandSegments]>,
+    Arguments<(ins AnyMemRef:$source,
+      Optional<I32>:$validBytes,
+      Optional<I<14>>:$cacheSwizzleStride,
+      DefaultValuedProp<BoolProp, "true">:$boundsCheck,
+      UnitProp:$resetOffset)>,
+    Results<(outs AnyMemRef:$result)> {
+  let summary = "Create a raw buffer fat pointer that matches `memref`";
+  let description = [{
+    Wraps the memory pointed to by `in` as a raw buffer fat pointer, or,
+    in LLVM terms, a ptr addrspace(7), returning a memref that has the same
+    sizes and layout but the `#amdgpu.address_space<fat_raw_buffer>`
+    address space.
+
+    This memref can be used with standard memref operations like `memref.load`,
+    `memref.store`, and `memref.atomicrmw`, which will be lowered to the relevant
+    buffer intrinsics. (`vector.masked_load/store` will work once there's backend
+    support for lowering them, and then this document will be updated)
+
+    If `validBytes` is given, it is the number of bytes that will be valid as
+    an offset to `out`. If it is not provided, this will be inferred from
+    the size of the memref during lowering. This size is
+    max_d (sizes[d] * strides[d]) * sizeof(element type)..
+
+    The flags of the buffer descriptor will be set up to enable raw usage -
+    for example, stride = 0, add_tid = 0, and so on. The `boundsCheck`
+    property determines if bounds checking is enabled or not (on architectures
+    where this can be controlled - that is, on RDNA chips).
+
+    If `cacheSwizzleStride` is provided, L1 cache swizzling will be enabled
+    on architectures that support it. This swizzling, unlike the main swizzling
+    mode (whose usage makes a buffer non-raw) does not affect index calculaton,
+    but does affect cache behavior. Mixing access between cache-swizzled raw
+    buffers and other forms of memory access, like ordinary pointer loads or
+    unswizzled buffer pointers can cause incorrect behavior and must be avoided.
+
+    This operation preserves the sizes, strides, and offset of the input
+    memref - they'll be added in by `memref.load` later. However, if
+    `resetOffset` is set, that offset will be added to the base pointer.
+    If the value of the memref's offset is not independent of the lane/thread ID,
+    this will lead to substantially decreased performance due to the need for
+    a waterfall loop on the base address of the buffer resource.
+  }];
+
+  let extraClassDeclaration = [{
+    Value getViewSource() { return getSource(); }
+  }];
+
+  let assemblyFormat = [{
+    $source oilist (`validBytes` `(` $validBytes `)`
+      | `cacheSwizzleStride` `(` $cacheSwizzleStride `)`
+      | `boundsCheck` `(` $boundsCheck `)`
+      | `resetOffset` $resetOffset )
+    attr-dict `:` type($source) `to` type($result)
+  }];
+
+  let hasVerifier = 1;
+}
+
 /// Raw buffer load
 def AMDGPU_RawBufferLoadOp :
     AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index 0a2e6bb5e9fe49..3de57c923178ad 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -18,7 +18,9 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h.inc"
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 51f5d7a161b903..d41b4d00f7a365 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -19,6 +19,8 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 
+#include "../LLVMCommon/MemRefDescriptor.h"
+
 #include "llvm/ADT/STLExtras.h"
 #include <optional>
 
@@ -30,6 +32,11 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::amdgpu;
 
+// Define commonly used chipsets versions for convenience.
+static constexpr Chipset kGfx908 = Chipset(9, 0, 8);
+static constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
+static constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+
 /// Convert an unsigned number `val` to i32.
 static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
                                   Location loc, Value val) {
@@ -76,11 +83,164 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
   return index ? index : createI32Constant(rewriter, loc, 0);
 }
 
+/// Compute the contents of the `num_records` field for a given memref
+/// descriptor - that is, the number of bytes that's one element past the
+/// greatest possible valid index into the memref.
+static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
+                           MemRefType memrefType,
+                           MemRefDescriptor &memrefDescriptor,
+                           ArrayRef<int64_t> strides,
+                           uint32_t elementByteWidth) {
+  if (memrefType.hasStaticShape() &&
+      !llvm::any_of(strides, ShapedType::isDynamic)) {
+    int64_t size = memrefType.getRank() == 0 ? 1 : 0;
+    ArrayRef<int64_t> shape = memrefType.getShape();
+    for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
+      size = std::max(shape[i] * strides[i], size);
+    size = size * elementByteWidth;
+    assert(size < std::numeric_limits<uint32_t>::max() &&
+           "the memref buffer is too large");
+    return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
+  }
+  Value maxIndex;
+  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
+    Value size = memrefDescriptor.size(rewriter, loc, i);
+    Value stride = memrefDescriptor.stride(rewriter, loc, i);
+    Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
+    maxIndex = maxIndex
+                   ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+                   : maxThisDim;
+  }
+  return rewriter.create<LLVM::MulOp>(
+      loc, convertUnsignedToI32(rewriter, loc, maxIndex),
+      createI32Constant(rewriter, loc, elementByteWidth));
+}
+
+static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
+                            Value basePointer, Value numRecords,
+                            bool boundsCheck, amdgpu::Chipset chipset,
+                            Value cacheSwizzleStride = nullptr) {
+  // The stride value is generally 0. However, on MI-300 and onward, you can
+  // enable a cache swizzling mode by setting bit 14 of the stride field
+  // and setting that stride to a cache stride.
+  Type i16 = rewriter.getI16Type();
+  Value stride;
+  if (chipset.majorVersion == 9 && chipset >= kGfx940 && cacheSwizzleStride) {
+    Value cacheStrideZext =
+        rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
+    Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
+        loc, i16, rewriter.getI16IntegerAttr(1 << 14));
+    stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
+                                         /*isDisjoint=*/true);
+  } else {
+    stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
+                                               rewriter.getI16IntegerAttr(0));
+  }
+  // Get the number of elements.
+  // Flag word:
+  // bits 0-11: dst sel, ignored by these intrinsics
+  // bits 12-14: data format (ignored, must be nonzero, 7=float)
+  // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
+  // bit 19: In nested heap (0 here)
+  // bit 20: Behavior on unmap (0 means  "return 0 / ignore")
+  // bits 21-22: Index stride for swizzles (N/A)
+  // bit 23: Add thread ID (0)
+  // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
+  // bits 25-26: Reserved (0)
+  // bit 27: Buffer is non-volatile (CDNA only)
+  // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
+  //  none, 3 = either swizzles or testing against offset field) RDNA only
+  // bits 30-31: Type (must be 0)
+  uint32_t flags = (7 << 12) | (4 << 15);
+  if (chipset.majorVersion >= 10) {
+    flags |= (1 << 24);
+    uint32_t oob = boundsCheck ? 3 : 2;
+    flags |= (oob << 28);
+  }
+  Value flagsConst = createI32Constant(rewriter, loc, flags);
+  Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
+  Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
+      loc, rsrcType, basePointer, stride, numRecords, flagsConst);
+  return resource;
+}
+
 namespace {
-// Define commonly used chipsets versions for convenience.
-constexpr Chipset kGfx908 = Chipset(9, 0, 8);
-constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
-constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+struct FatRawBufferCastLowering
+    : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
+  FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
+        chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value memRef = adaptor.getSource();
+    Value unconvertedMemref = op.getSource();
+    MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
+    MemRefDescriptor descriptor(memRef);
+
+    DataLayout dataLayout = DataLayout::closest(op);
+    int64_t elementByteWidth =
+        dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
+
+    int64_t unusedOffset = 0;
+    SmallVector<int64_t, 5> strideVals;
+    if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
+      return op.emitOpError("Can't lower non-stride-offset memrefs");
+
+    Value numRecords = adaptor.getValidBytes();
+    if (!numRecords)
+      numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
+                                 strideVals, elementByteWidth);
+
+    Value basePointer;
+    if (adaptor.getResetOffset())
+      basePointer =
+          descriptor.bufferPtr(rewriter, loc, *getTypeConverter(), memrefType);
+    else
+      basePointer = descriptor.alignedPtr(rewriter, loc);
+
+    Value offset;
+    if (adaptor.getResetOffset())
+      offset = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+                                                 rewriter.getIndexAttr(0));
+    else
+      offset = descriptor.offset(rewriter, loc);
+
+    // No need to unpack() and pack() all the individual sizes and strides,
+    // so we'll just extract the arrays.
+    Value sizes = rewriter.create<LLVM::ExtractValueOp>(
+        loc, descriptor, kSizePosInMemRefDescriptor);
+    Value strides = rewriter.create<LLVM::ExtractValueOp>(
+        loc, descriptor, kStridePosInMemRefDescriptor);
+
+    Value rsrc = makeBufferRsrc(rewriter, loc, basePointer, numRecords,
+                                adaptor.getBoundsCheck(), chipset,
+                                adaptor.getCacheSwizzleStride());
+    Value fatPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
+        loc, LLVM::LLVMPointerType::get(op.getContext(), 7), rsrc);
+
+    Value result = MemRefDescriptor::undef(
+        rewriter, loc,
+        getTypeConverter()->convertType(op.getResult().getType()));
+    result = rewriter.create<LLVM::InsertValueOp>(
+        loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
+    result = rewriter.create<LLVM::InsertValueOp>(
+        loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
+    result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
+                                                  kOffsetPosInMemRefDescriptor);
+    result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
+                                                  kSizePosInMemRefDescriptor);
+    result = rewriter.create<LLVM::InsertValueOp>(loc, result, strides,
+                                                  kStridePosInMemRefDescriptor);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
 
 /// Define lowering patterns for raw buffer ops
 template <typename GpuOp, typename Intrinsic>
@@ -122,7 +282,6 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
 
     Type i32 = rewriter.getI32Type();
-    Type i16 = rewriter.getI16Type();
 
     // Get the type size in bytes.
     DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -199,60 +358,10 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
 
     Value ptr = memrefDescriptor.bufferPtr(
         rewriter, loc, *this->getTypeConverter(), memrefType);
-    // The stride value is always 0 for raw buffers. This also disables
-    // swizling.
-    Value stride = rewriter.create<LLVM::ConstantOp>(
-        loc, i16, rewriter.getI16IntegerAttr(0));
-    // Get the number of elements.
-    Value numRecords;
-    if (memrefType.hasStaticShape() &&
-        !llvm::any_of(strides, ShapedType::isDynamic)) {
-      int64_t size = memrefType.getRank() == 0 ? 1 : 0;
-      ArrayRef<int64_t> shape = memrefType.getShape();
-      for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
-        size = std::max(shape[i] * strides[i], size);
-      size = size * elementByteWidth;
-      assert(size < std::numeric_limits<uint32_t>::max() &&
-             "the memref buffer is too large");
-      numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
-    } else {
-      Value maxIndex;
-      for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
-        Value size = memrefDescriptor.size(rewriter, loc, i);
-        Value stride = memrefDescriptor.stride(rewriter, loc, i);
-        Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
-        maxIndex =
-            maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
-                     : maxThisDim;
-      }
-      numRecords = rewriter.create<LLVM::MulOp>(
-          loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
-    }
-
-    // Flag word:
-    // bits 0-11: dst sel, ignored by these intrinsics
-    // bits 12-14: data format (ignored, must be nonzero, 7=float)
-    // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
-    // bit 19: In nested heap (0 here)
-    // bit 20: Behavior on unmap (0 means  "return 0 / ignore")
-    // bits 21-22: Index stride for swizzles (N/A)
-    // bit 23: Add thread ID (0)
-    // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
-    // bits 25-26: Reserved (0)
-    // bit 27: Buffer is non-volatile (CDNA only)
-    // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
-    //  none, 3 = either swizzles or testing against offset field) RDNA only
-    // bits 30-31: Type (must be 0)
-    uint32_t flags = (7 << 12) | (4 << 15);
-    if (chipset.majorVersion >= 10) {
-      flags |= (1 << 24);
-      uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
-      flags |= (oob << 28);
-    }
-    Value flagsConst = createI32Constant(rewriter, loc, flags);
-    Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
-    Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
-        loc, rsrcType, ptr, stride, numRecords, flagsConst);
+    Value numRecords = getNumRecords(
+        rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
+    Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
+                                    adaptor.getBoundsCheck(), chipset);
     args.push_back(resource);
 
     // Indexing (voffset)
...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 3, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Krzysztof Drewniak (krzysz00)

Changes

This commit adds support for casting memrefs into fat raw buffer pointers to the AMDGPU dialect.

Fat raw buffer pointers - or, in LLVM terms, ptr addrspcae(7), allow encapsulating a buffer descriptor (as produced by the make.buffer.rsrc intrinsic or provided from some API) into a pointer that supports ordinary pointer operations like load or store. This allows people to take advantage of the additional semantics that buffer_load and similar instructions provide without forcing the use of entirely separate amdgpu.raw_buffer_* operations.

Operations on fat raw buffer pointers are translated to the corresponding LLVM intrinsics by the backend.

This commit also goes and and defines a #amdgpu.address_space<> attribute so that AMDGPU-specific memory spaces can be represented. Only #amdgpu.address_space<fat_raw_buffer> will work correctly with the memref dialect, but the other possible address spaces are included for completeness.


Patch is 37.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125594.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h (+12-4)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+105)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h (+2)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+193-63)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+53)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir (+109-12)
diff --git a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
index e7637a6013e68ad..bb4e7bc037a373c 100644
--- a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
+++ b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h
@@ -16,18 +16,26 @@ namespace mlir {
 
 class LLVMTypeConverter;
 class RewritePatternSet;
+class TypeConverter;
 class Pass;
 
 #define GEN_PASS_DECL_CONVERTAMDGPUTOROCDL
 #include "mlir/Conversion/Passes.h.inc"
 
-/// Note: The ROCDL target does not support the LLVM bfloat type at this time
-/// and so this function will add conversions to change all `bfloat` uses
-/// to `i16`.
-void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter,
+/// Note: This function will also add conversions for the AMDGPU-specific
+/// address spaces, but those can be added separately using
+/// populateAMDGPUMemorySpaceAttributeConversions().
+void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                              RewritePatternSet &patterns,
                                              amdgpu::Chipset chipset);
 
+/// Remap AMDGPU memory spaces to LLVM address spaces
+/// by mapping amdgpu::AddressSpace::fat_raw_buffer to ptr addrspace(7),
+/// amdgpu::AddressSpace::buffer_rsrc to ptr addrspace(8), and
+/// amdgpu::AddressSpace::fat_strided_buffer to ptr addrspace(9).
+void populateAMDGPUMemorySpaceAttributeConversions(
+    TypeConverter &typeConverter);
+
 std::unique_ptr<Pass> createConvertAMDGPUToROCDLPass();
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 69745addfd748ec..6c42849fc71f134 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -9,8 +9,11 @@
 #ifndef AMDGPU
 #define AMDGPU
 
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ViewLikeInterface.td"
 include "mlir/IR/EnumAttr.td"
+include "mlir/IR/Properties.td"
 include "mlir/IR/OpBase.td"
 
 def AMDGPU_Dialect : Dialect {
@@ -32,6 +35,45 @@ def AMDGPU_Dialect : Dialect {
   let useDefaultAttributePrinterParser = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// AMDGPU general attribute definitions
+//===----------------------------------------------------------------------===//
+
+def AMDGPU_AddressSpace : I32EnumAttr<"AddressSpace",
+    "AMDGPU-specific address spaces",
+    [
+      I32EnumAttrCase<"FatRawBuffer",        0, "fat_raw_buffer">,
+      I32EnumAttrCase<"BufferRsrc",          1, "buffer_rsrc">,
+      I32EnumAttrCase<"FatStructuredBuffer", 2, "fat_structured_buffer">,
+    ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::amdgpu";
+}
+
+def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
+    "address_space"> {
+  let description = [{
+    AMDGPU-specific memory spaces that may not have exact analogues on other
+    GPU targets or backends.
+
+    - fat_raw_buffer is the memory space used when a memref is stored as
+    as a "buffer fat pointer" - that is, a buffer resource (that is set up to
+    use raw byte-level indexing) along with its offset. The AMDGPU backend
+    implements ptr addrspace(7) to represent these fat pointers so that
+    buffer resources (which allow advanced features like bounds checking or
+    cache swizzling) can be used like ordinary LLVM pointers or memrefs.
+    See also the fat_raw_buffer_cast operation
+    - buffer_rsrc is the memory space for ptr addrspace(8), representing a
+    buffer resource. It should not be used for memrefs, since it does not support
+    indexing
+    - fat_structured_buffer represents ptr addrspace(9), a buffer resource
+    that carries both an index and offset field, which are used for complex
+    structured indexing that is primarily seen in graphics applications. This
+    is also incompatible with the simple indexing model supported by memref.
+  }];
+  let assemblyFormat = [{ `<` $value `>` }];
+}
+
 //===----------------------------------------------------------------------===//
 // AMDGPU Op definitions
 //===----------------------------------------------------------------------===//
@@ -118,6 +160,69 @@ def AMDGPU_PackedStochRoundFp8Op :
   let hasVerifier = 1;
 }
 
+def AMDGPU_FatRawBufferCastOp :
+    AMDGPU_Op<"fat_raw_buffer_cast",
+      [Pure,
+       DeclareOpInterfaceMethods<InferTypeOpInterface>,
+       ViewLikeOpInterface, AttrSizedOperandSegments]>,
+    Arguments<(ins AnyMemRef:$source,
+      Optional<I32>:$validBytes,
+      Optional<I<14>>:$cacheSwizzleStride,
+      DefaultValuedProp<BoolProp, "true">:$boundsCheck,
+      UnitProp:$resetOffset)>,
+    Results<(outs AnyMemRef:$result)> {
+  let summary = "Create a raw buffer fat pointer that matches `memref`";
+  let description = [{
+    Wraps the memory pointed to by `in` as a raw buffer fat pointer, or,
+    in LLVM terms, a ptr addrspace(7), returning a memref that has the same
+    sizes and layout but the `#amdgpu.address_space<fat_raw_buffer>`
+    address space.
+
+    This memref can be used with standard memref operations like `memref.load`,
+    `memref.store`, and `memref.atomicrmw`, which will be lowered to the relevant
+    buffer intrinsics. (`vector.masked_load/store` will work once there's backend
+    support for lowering them, and then this document will be updated)
+
+    If `validBytes` is given, it is the number of bytes that will be valid as
+    an offset to `out`. If it is not provided, this will be inferred from
+    the size of the memref during lowering. This size is
+    max_d (sizes[d] * strides[d]) * sizeof(element type)..
+
+    The flags of the buffer descriptor will be set up to enable raw usage -
+    for example, stride = 0, add_tid = 0, and so on. The `boundsCheck`
+    property determines if bounds checking is enabled or not (on architectures
+    where this can be controlled - that is, on RDNA chips).
+
+    If `cacheSwizzleStride` is provided, L1 cache swizzling will be enabled
+    on architectures that support it. This swizzling, unlike the main swizzling
+    mode (whose usage makes a buffer non-raw) does not affect index calculaton,
+    but does affect cache behavior. Mixing access between cache-swizzled raw
+    buffers and other forms of memory access, like ordinary pointer loads or
+    unswizzled buffer pointers can cause incorrect behavior and must be avoided.
+
+    This operation preserves the sizes, strides, and offset of the input
+    memref - they'll be added in by `memref.load` later. However, if
+    `resetOffset` is set, that offset will be added to the base pointer.
+    If the value of the memref's offset is not independent of the lane/thread ID,
+    this will lead to substantially decreased performance due to the need for
+    a waterfall loop on the base address of the buffer resource.
+  }];
+
+  let extraClassDeclaration = [{
+    Value getViewSource() { return getSource(); }
+  }];
+
+  let assemblyFormat = [{
+    $source oilist (`validBytes` `(` $validBytes `)`
+      | `cacheSwizzleStride` `(` $cacheSwizzleStride `)`
+      | `boundsCheck` `(` $boundsCheck `)`
+      | `resetOffset` $resetOffset )
+    attr-dict `:` type($source) `to` type($result)
+  }];
+
+  let hasVerifier = 1;
+}
+
 /// Raw buffer load
 def AMDGPU_RawBufferLoadOp :
     AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>,
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index 0a2e6bb5e9fe494..3de57c923178ad9 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -18,7 +18,9 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h.inc"
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 51f5d7a161b9030..d41b4d00f7a3658 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -19,6 +19,8 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 
+#include "../LLVMCommon/MemRefDescriptor.h"
+
 #include "llvm/ADT/STLExtras.h"
 #include <optional>
 
@@ -30,6 +32,11 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::amdgpu;
 
+// Define commonly used chipsets versions for convenience.
+static constexpr Chipset kGfx908 = Chipset(9, 0, 8);
+static constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
+static constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+
 /// Convert an unsigned number `val` to i32.
 static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
                                   Location loc, Value val) {
@@ -76,11 +83,164 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
   return index ? index : createI32Constant(rewriter, loc, 0);
 }
 
+/// Compute the contents of the `num_records` field for a given memref
+/// descriptor - that is, the number of bytes that's one element past the
+/// greatest possible valid index into the memref.
+static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
+                           MemRefType memrefType,
+                           MemRefDescriptor &memrefDescriptor,
+                           ArrayRef<int64_t> strides,
+                           uint32_t elementByteWidth) {
+  if (memrefType.hasStaticShape() &&
+      !llvm::any_of(strides, ShapedType::isDynamic)) {
+    int64_t size = memrefType.getRank() == 0 ? 1 : 0;
+    ArrayRef<int64_t> shape = memrefType.getShape();
+    for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
+      size = std::max(shape[i] * strides[i], size);
+    size = size * elementByteWidth;
+    assert(size < std::numeric_limits<uint32_t>::max() &&
+           "the memref buffer is too large");
+    return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
+  }
+  Value maxIndex;
+  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
+    Value size = memrefDescriptor.size(rewriter, loc, i);
+    Value stride = memrefDescriptor.stride(rewriter, loc, i);
+    Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
+    maxIndex = maxIndex
+                   ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+                   : maxThisDim;
+  }
+  return rewriter.create<LLVM::MulOp>(
+      loc, convertUnsignedToI32(rewriter, loc, maxIndex),
+      createI32Constant(rewriter, loc, elementByteWidth));
+}
+
+static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
+                            Value basePointer, Value numRecords,
+                            bool boundsCheck, amdgpu::Chipset chipset,
+                            Value cacheSwizzleStride = nullptr) {
+  // The stride value is generally 0. However, on MI-300 and onward, you can
+  // enable a cache swizzling mode by setting bit 14 of the stride field
+  // and setting that stride to a cache stride.
+  Type i16 = rewriter.getI16Type();
+  Value stride;
+  if (chipset.majorVersion == 9 && chipset >= kGfx940 && cacheSwizzleStride) {
+    Value cacheStrideZext =
+        rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
+    Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
+        loc, i16, rewriter.getI16IntegerAttr(1 << 14));
+    stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
+                                         /*isDisjoint=*/true);
+  } else {
+    stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
+                                               rewriter.getI16IntegerAttr(0));
+  }
+  // Get the number of elements.
+  // Flag word:
+  // bits 0-11: dst sel, ignored by these intrinsics
+  // bits 12-14: data format (ignored, must be nonzero, 7=float)
+  // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
+  // bit 19: In nested heap (0 here)
+  // bit 20: Behavior on unmap (0 means  "return 0 / ignore")
+  // bits 21-22: Index stride for swizzles (N/A)
+  // bit 23: Add thread ID (0)
+  // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
+  // bits 25-26: Reserved (0)
+  // bit 27: Buffer is non-volatile (CDNA only)
+  // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
+  //  none, 3 = either swizzles or testing against offset field) RDNA only
+  // bits 30-31: Type (must be 0)
+  uint32_t flags = (7 << 12) | (4 << 15);
+  if (chipset.majorVersion >= 10) {
+    flags |= (1 << 24);
+    uint32_t oob = boundsCheck ? 3 : 2;
+    flags |= (oob << 28);
+  }
+  Value flagsConst = createI32Constant(rewriter, loc, flags);
+  Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
+  Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
+      loc, rsrcType, basePointer, stride, numRecords, flagsConst);
+  return resource;
+}
+
 namespace {
-// Define commonly used chipsets versions for convenience.
-constexpr Chipset kGfx908 = Chipset(9, 0, 8);
-constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
-constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+struct FatRawBufferCastLowering
+    : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
+  FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
+        chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value memRef = adaptor.getSource();
+    Value unconvertedMemref = op.getSource();
+    MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
+    MemRefDescriptor descriptor(memRef);
+
+    DataLayout dataLayout = DataLayout::closest(op);
+    int64_t elementByteWidth =
+        dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
+
+    int64_t unusedOffset = 0;
+    SmallVector<int64_t, 5> strideVals;
+    if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
+      return op.emitOpError("Can't lower non-stride-offset memrefs");
+
+    Value numRecords = adaptor.getValidBytes();
+    if (!numRecords)
+      numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
+                                 strideVals, elementByteWidth);
+
+    Value basePointer;
+    if (adaptor.getResetOffset())
+      basePointer =
+          descriptor.bufferPtr(rewriter, loc, *getTypeConverter(), memrefType);
+    else
+      basePointer = descriptor.alignedPtr(rewriter, loc);
+
+    Value offset;
+    if (adaptor.getResetOffset())
+      offset = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+                                                 rewriter.getIndexAttr(0));
+    else
+      offset = descriptor.offset(rewriter, loc);
+
+    // No need to unpack() and pack() all the individual sizes and strides,
+    // so we'll just extract the arrays.
+    Value sizes = rewriter.create<LLVM::ExtractValueOp>(
+        loc, descriptor, kSizePosInMemRefDescriptor);
+    Value strides = rewriter.create<LLVM::ExtractValueOp>(
+        loc, descriptor, kStridePosInMemRefDescriptor);
+
+    Value rsrc = makeBufferRsrc(rewriter, loc, basePointer, numRecords,
+                                adaptor.getBoundsCheck(), chipset,
+                                adaptor.getCacheSwizzleStride());
+    Value fatPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
+        loc, LLVM::LLVMPointerType::get(op.getContext(), 7), rsrc);
+
+    Value result = MemRefDescriptor::undef(
+        rewriter, loc,
+        getTypeConverter()->convertType(op.getResult().getType()));
+    result = rewriter.create<LLVM::InsertValueOp>(
+        loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
+    result = rewriter.create<LLVM::InsertValueOp>(
+        loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
+    result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
+                                                  kOffsetPosInMemRefDescriptor);
+    result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
+                                                  kSizePosInMemRefDescriptor);
+    result = rewriter.create<LLVM::InsertValueOp>(loc, result, strides,
+                                                  kStridePosInMemRefDescriptor);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
 
 /// Define lowering patterns for raw buffer ops
 template <typename GpuOp, typename Intrinsic>
@@ -122,7 +282,6 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
 
     Type i32 = rewriter.getI32Type();
-    Type i16 = rewriter.getI16Type();
 
     // Get the type size in bytes.
     DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -199,60 +358,10 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
 
     Value ptr = memrefDescriptor.bufferPtr(
         rewriter, loc, *this->getTypeConverter(), memrefType);
-    // The stride value is always 0 for raw buffers. This also disables
-    // swizling.
-    Value stride = rewriter.create<LLVM::ConstantOp>(
-        loc, i16, rewriter.getI16IntegerAttr(0));
-    // Get the number of elements.
-    Value numRecords;
-    if (memrefType.hasStaticShape() &&
-        !llvm::any_of(strides, ShapedType::isDynamic)) {
-      int64_t size = memrefType.getRank() == 0 ? 1 : 0;
-      ArrayRef<int64_t> shape = memrefType.getShape();
-      for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
-        size = std::max(shape[i] * strides[i], size);
-      size = size * elementByteWidth;
-      assert(size < std::numeric_limits<uint32_t>::max() &&
-             "the memref buffer is too large");
-      numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
-    } else {
-      Value maxIndex;
-      for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
-        Value size = memrefDescriptor.size(rewriter, loc, i);
-        Value stride = memrefDescriptor.stride(rewriter, loc, i);
-        Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
-        maxIndex =
-            maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
-                     : maxThisDim;
-      }
-      numRecords = rewriter.create<LLVM::MulOp>(
-          loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
-    }
-
-    // Flag word:
-    // bits 0-11: dst sel, ignored by these intrinsics
-    // bits 12-14: data format (ignored, must be nonzero, 7=float)
-    // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
-    // bit 19: In nested heap (0 here)
-    // bit 20: Behavior on unmap (0 means  "return 0 / ignore")
-    // bits 21-22: Index stride for swizzles (N/A)
-    // bit 23: Add thread ID (0)
-    // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
-    // bits 25-26: Reserved (0)
-    // bit 27: Buffer is non-volatile (CDNA only)
-    // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
-    //  none, 3 = either swizzles or testing against offset field) RDNA only
-    // bits 30-31: Type (must be 0)
-    uint32_t flags = (7 << 12) | (4 << 15);
-    if (chipset.majorVersion >= 10) {
-      flags |= (1 << 24);
-      uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
-      flags |= (oob << 28);
-    }
-    Value flagsConst = createI32Constant(rewriter, loc, flags);
-    Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
-    Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
-        loc, rsrcType, ptr, stride, numRecords, flagsConst);
+    Value numRecords = getNumRecords(
+        rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
+    Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
+                                    adaptor.getBoundsCheck(), chipset);
     args.push_back(resource);
 
     // Indexing (v...
[truncated]

@krzysz00 krzysz00 force-pushed the amdgpu-addrspace-7-mlir branch from 2df9493 to 07a8e93 Compare February 3, 2025 23:12
Copy link

github-actions bot commented Feb 3, 2025

✅ With the latest revision this PR passed the undef deprecator.

@krzysz00 krzysz00 force-pushed the amdgpu-addrspace-7-mlir branch from 07a8e93 to 63039ff Compare February 6, 2025 20:17
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall

Copy link

github-actions bot commented Feb 10, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@krzysz00 krzysz00 requested review from pashu123 and kuhar February 10, 2025 17:58
@kuhar
Copy link
Member

kuhar commented Feb 10, 2025

@krzysz00 seems like there's a failing test on windows

@krzysz00
Copy link
Contributor Author

Note to self: probably an order of evaluation within functions being underspecified issue - just need to find and fix the expression.

(and probably also add some strided metadata patterns)

@krzysz00 krzysz00 force-pushed the amdgpu-addrspace-7-mlir branch from 44d296f to 1ee8405 Compare February 18, 2025 20:42
…attr.

This commit adds support for casting memrefs into fat raw buffer
pointers to the AMDGPU dialect.

Fat raw buffer pointers - or, in LLVM terms, ptr addrspcae(7), allow
encapsulating a buffer descriptor (as produced by the make.buffer.rsrc
intrinsic or provided from some API) into a pointer that supports
ordinary pointer operations like load or store. This allows people to
take advantage of the additional semantics that buffer_load and
similar instructions provide without forcing the use of entirely
separate amdgpu.raw_buffer_* operations.

Operations on fat raw buffer pointers are translated to the
corresponding LLVM intrinsics by the backend.

This commit also goes and and defines a #amdgpu.address_space<>
attribute so that AMDGPU-specific memory spaces can be represented.
Only #amdgpu.address_space<fat_raw_buffer> will work correctly with
the memref dialect, but the other possible address spaces are included
for completeness.

Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Prashant Kumar <[email protected]>
@krzysz00 krzysz00 force-pushed the amdgpu-addrspace-7-mlir branch from 1ee8405 to ceed9d2 Compare February 20, 2025 18:10
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, just a couple of issues

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@krzysz00 krzysz00 merged commit 42526d2 into llvm:main Feb 26, 2025
9 of 10 checks passed
kmpeng pushed a commit to kmpeng/llvm-project that referenced this pull request Feb 26, 2025
…attr. (llvm#125594)

This commit adds support for casting memrefs into fat raw buffer
pointers to the AMDGPU dialect.

Fat raw buffer pointers - or, in LLVM terms, ptr addrspcae(7), allow
encapsulating a buffer descriptor (as produced by the make.buffer.rsrc
intrinsic or provided from some API) into a pointer that supports
ordinary pointer operations like load or store. This allows people to
take advantage of the additional semantics that buffer_load and similar
instructions provide without forcing the use of entirely separate
amdgpu.raw_buffer_* operations.

Operations on fat raw buffer pointers are translated to the
corresponding LLVM intrinsics by the backend.

This commit also goes and and defines a #amdgpu.address_space<>
attribute so that AMDGPU-specific memory spaces can be represented. Only
#amdgpu.address_space<fat_raw_buffer> will work correctly with the
memref dialect, but the other possible address spaces are included for
completeness.

---------

Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Prashant Kumar <[email protected]>
kmpeng pushed a commit to kmpeng/llvm-project that referenced this pull request Feb 26, 2025
…attr. (llvm#125594)

This commit adds support for casting memrefs into fat raw buffer
pointers to the AMDGPU dialect.

Fat raw buffer pointers - or, in LLVM terms, ptr addrspcae(7), allow
encapsulating a buffer descriptor (as produced by the make.buffer.rsrc
intrinsic or provided from some API) into a pointer that supports
ordinary pointer operations like load or store. This allows people to
take advantage of the additional semantics that buffer_load and similar
instructions provide without forcing the use of entirely separate
amdgpu.raw_buffer_* operations.

Operations on fat raw buffer pointers are translated to the
corresponding LLVM intrinsics by the backend.

This commit also goes and and defines a #amdgpu.address_space<>
attribute so that AMDGPU-specific memory spaces can be represented. Only
#amdgpu.address_space<fat_raw_buffer> will work correctly with the
memref dialect, but the other possible address spaces are included for
completeness.

---------

Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Prashant Kumar <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants