Skip to content

Commit 010c310

Browse files
authored
[UR][HIP] Enable usm pools (#17972)
This patch fixes up and enable memory pools for the HIP adapter, it is based on oneapi-src/unified-runtime#1689 and on the CUDA adapter implementation. The initial patch had segmentation faults in the CI that we couldn't reproduce locally. That happened as well in this patch and I couldn't reproduce the segfaults locally either. However I noticed that it failed in `urUSMHostAlloc`, and that entry point was different from the CUDA adapter version, in that the HIP adapter was using a "helper" function. It turns out that the helper function was using a device pool instead of a host pool to do the allocation, which seemed obviously wrong. Replacing the helper by similar code used in the CUDA adapter fixes the crash in the CI.
1 parent b1051b6 commit 010c310

File tree

4 files changed

+41
-21
lines changed

4 files changed

+41
-21
lines changed

unified-runtime/source/adapters/hip/CMakeLists.txt

+6
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,12 @@ else()
203203
message(FATAL_ERROR "Unspecified UR HIP platform please set UR_HIP_PLATFORM to 'AMD' or 'NVIDIA'")
204204
endif()
205205

206+
if(UMF_ENABLE_POOL_TRACKING)
207+
target_compile_definitions(${TARGET_NAME} PRIVATE UMF_ENABLE_POOL_TRACKING)
208+
else()
209+
message(WARNING "HIP adapter USM pools are disabled, set UMF_ENABLE_POOL_TRACKING to enable them")
210+
endif()
211+
206212
target_include_directories(${TARGET_NAME} PRIVATE
207213
"${CMAKE_CURRENT_SOURCE_DIR}/../../"
208214
)

unified-runtime/source/adapters/hip/usm.cpp

+34-17
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,13 @@ urUSMHostAlloc(ur_context_handle_t hContext, const ur_usm_desc_t *pUSMDesc,
3737
return USMHostAllocImpl(ppMem, hContext, /* flags */ 0, size, alignment);
3838
}
3939

40-
return umfPoolMallocHelper(hPool, ppMem, size, alignment);
40+
auto UMFPool = hPool->HostMemPool.get();
41+
*ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment);
42+
if (*ppMem == nullptr) {
43+
auto umfErr = umfPoolGetLastAllocationError(UMFPool);
44+
return umf::umf2urResult(umfErr);
45+
}
46+
return UR_RESULT_SUCCESS;
4147
}
4248

4349
/// USM: Implements USM device allocations using a normal HIP device pointer
@@ -54,7 +60,13 @@ urUSMDeviceAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice,
5460
alignment);
5561
}
5662

57-
return umfPoolMallocHelper(hPool, ppMem, size, alignment);
63+
auto UMFPool = hPool->DeviceMemPool.get();
64+
*ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment);
65+
if (*ppMem == nullptr) {
66+
auto umfErr = umfPoolGetLastAllocationError(UMFPool);
67+
return umf::umf2urResult(umfErr);
68+
}
69+
return UR_RESULT_SUCCESS;
5870
}
5971

6072
/// USM: Implements USM Shared allocations using HIP Managed Memory
@@ -71,7 +83,13 @@ urUSMSharedAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice,
7183
/*device flags*/ 0, size, alignment);
7284
}
7385

74-
return umfPoolMallocHelper(hPool, ppMem, size, alignment);
86+
auto UMFPool = hPool->SharedMemPool.get();
87+
*ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment);
88+
if (*ppMem == nullptr) {
89+
auto umfErr = umfPoolGetLastAllocationError(UMFPool);
90+
return umf::umf2urResult(umfErr);
91+
}
92+
return UR_RESULT_SUCCESS;
7593
}
7694

7795
UR_APIEXPORT ur_result_t UR_APICALL
@@ -330,15 +348,25 @@ ur_result_t USMHostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size,
330348
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
331349
ur_usm_pool_desc_t *PoolDesc)
332350
: Context(Context) {
333-
if (PoolDesc) {
334-
if (auto *Limits = find_stype_node<ur_usm_pool_limits_desc_t>(PoolDesc)) {
351+
352+
const void *pNext = PoolDesc->pNext;
353+
while (pNext != nullptr) {
354+
const ur_base_desc_t *BaseDesc = static_cast<const ur_base_desc_t *>(pNext);
355+
switch (BaseDesc->stype) {
356+
case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: {
357+
const ur_usm_pool_limits_desc_t *Limits =
358+
reinterpret_cast<const ur_usm_pool_limits_desc_t *>(BaseDesc);
335359
for (auto &config : DisjointPoolConfigs.Configs) {
336360
config.MaxPoolableSize = Limits->maxPoolableSize;
337361
config.SlabMinSize = Limits->minDriverAllocSize;
338362
}
339-
} else {
363+
break;
364+
}
365+
default: {
340366
throw UsmAllocationException(UR_RESULT_ERROR_INVALID_ARGUMENT);
341367
}
368+
}
369+
pNext = BaseDesc->pNext;
342370
}
343371

344372
auto MemProvider =
@@ -468,17 +496,6 @@ bool checkUSMImplAlignment(uint32_t Alignment, void **ResultPtr) {
468496
reinterpret_cast<std::uintptr_t>(*ResultPtr) % Alignment == 0;
469497
}
470498

471-
ur_result_t umfPoolMallocHelper(ur_usm_pool_handle_t hPool, void **ppMem,
472-
size_t size, uint32_t alignment) {
473-
auto UMFPool = hPool->DeviceMemPool.get();
474-
*ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment);
475-
if (*ppMem == nullptr) {
476-
auto umfErr = umfPoolGetLastAllocationError(UMFPool);
477-
return umf::umf2urResult(umfErr);
478-
}
479-
return UR_RESULT_SUCCESS;
480-
}
481-
482499
UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreateExp(ur_context_handle_t,
483500
ur_device_handle_t,
484501
ur_usm_pool_desc_t *,

unified-runtime/source/adapters/hip/usm.hpp

-3
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,3 @@ ur_result_t USMHostAllocImpl(void **ResultPtr, ur_context_handle_t Context,
140140
bool checkUSMAlignment(uint32_t &alignment, const ur_usm_desc_t *pUSMDesc);
141141

142142
bool checkUSMImplAlignment(uint32_t Alignment, void **ResultPtr);
143-
144-
ur_result_t umfPoolMallocHelper(ur_usm_pool_handle_t hPool, void **ppMem,
145-
size_t size, uint32_t alignment);

unified-runtime/test/conformance/usm/urUSMPoolCreate.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ TEST_P(urUSMPoolCreateTest, Success) {
2929
}
3030

3131
TEST_P(urUSMPoolCreateTest, SuccessWithFlag) {
32-
UUR_KNOWN_FAILURE_ON(uur::CUDA{});
32+
UUR_KNOWN_FAILURE_ON(uur::CUDA{}, uur::HIP{});
3333

3434
ur_usm_pool_desc_t pool_desc{UR_STRUCTURE_TYPE_USM_POOL_DESC, nullptr,
3535
UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK};

0 commit comments

Comments
 (0)