Skip to content

[UR][HIP] Enable usm pools #17972

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions unified-runtime/source/adapters/hip/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ else()
message(FATAL_ERROR "Unspecified UR HIP platform please set UR_HIP_PLATFORM to 'AMD' or 'NVIDIA'")
endif()

if(UMF_ENABLE_POOL_TRACKING)
target_compile_definitions(${TARGET_NAME} PRIVATE UMF_ENABLE_POOL_TRACKING)
else()
message(WARNING "HIP adapter USM pools are disabled, set UMF_ENABLE_POOL_TRACKING to enable them")
endif()

target_include_directories(${TARGET_NAME} PRIVATE
"${CMAKE_CURRENT_SOURCE_DIR}/../../"
)
51 changes: 34 additions & 17 deletions unified-runtime/source/adapters/hip/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ urUSMHostAlloc(ur_context_handle_t hContext, const ur_usm_desc_t *pUSMDesc,
return USMHostAllocImpl(ppMem, hContext, /* flags */ 0, size, alignment);
}

return umfPoolMallocHelper(hPool, ppMem, size, alignment);
auto UMFPool = hPool->HostMemPool.get();
*ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment);
if (*ppMem == nullptr) {
auto umfErr = umfPoolGetLastAllocationError(UMFPool);
return umf::umf2urResult(umfErr);
}
return UR_RESULT_SUCCESS;
}

/// USM: Implements USM device allocations using a normal HIP device pointer
Expand All @@ -54,7 +60,13 @@ urUSMDeviceAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice,
alignment);
}

return umfPoolMallocHelper(hPool, ppMem, size, alignment);
auto UMFPool = hPool->DeviceMemPool.get();
*ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment);
if (*ppMem == nullptr) {
auto umfErr = umfPoolGetLastAllocationError(UMFPool);
return umf::umf2urResult(umfErr);
}
return UR_RESULT_SUCCESS;
}

/// USM: Implements USM Shared allocations using HIP Managed Memory
Expand All @@ -71,7 +83,13 @@ urUSMSharedAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice,
/*device flags*/ 0, size, alignment);
}

return umfPoolMallocHelper(hPool, ppMem, size, alignment);
auto UMFPool = hPool->SharedMemPool.get();
*ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment);
if (*ppMem == nullptr) {
auto umfErr = umfPoolGetLastAllocationError(UMFPool);
return umf::umf2urResult(umfErr);
}
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
Expand Down Expand Up @@ -330,15 +348,25 @@ ur_result_t USMHostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size,
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
ur_usm_pool_desc_t *PoolDesc)
: Context(Context) {
if (PoolDesc) {
if (auto *Limits = find_stype_node<ur_usm_pool_limits_desc_t>(PoolDesc)) {

const void *pNext = PoolDesc->pNext;
while (pNext != nullptr) {
const ur_base_desc_t *BaseDesc = static_cast<const ur_base_desc_t *>(pNext);
switch (BaseDesc->stype) {
case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: {
const ur_usm_pool_limits_desc_t *Limits =
reinterpret_cast<const ur_usm_pool_limits_desc_t *>(BaseDesc);
for (auto &config : DisjointPoolConfigs.Configs) {
config.MaxPoolableSize = Limits->maxPoolableSize;
config.SlabMinSize = Limits->minDriverAllocSize;
}
} else {
break;
}
default: {
throw UsmAllocationException(UR_RESULT_ERROR_INVALID_ARGUMENT);
}
}
pNext = BaseDesc->pNext;
}

auto MemProvider =
Expand Down Expand Up @@ -468,17 +496,6 @@ bool checkUSMImplAlignment(uint32_t Alignment, void **ResultPtr) {
reinterpret_cast<std::uintptr_t>(*ResultPtr) % Alignment == 0;
}

ur_result_t umfPoolMallocHelper(ur_usm_pool_handle_t hPool, void **ppMem,
size_t size, uint32_t alignment) {
auto UMFPool = hPool->DeviceMemPool.get();
*ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment);
if (*ppMem == nullptr) {
auto umfErr = umfPoolGetLastAllocationError(UMFPool);
return umf::umf2urResult(umfErr);
}
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreateExp(ur_context_handle_t,
ur_device_handle_t,
ur_usm_pool_desc_t *,
Expand Down
3 changes: 0 additions & 3 deletions unified-runtime/source/adapters/hip/usm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,3 @@ ur_result_t USMHostAllocImpl(void **ResultPtr, ur_context_handle_t Context,
bool checkUSMAlignment(uint32_t &alignment, const ur_usm_desc_t *pUSMDesc);

bool checkUSMImplAlignment(uint32_t Alignment, void **ResultPtr);

ur_result_t umfPoolMallocHelper(ur_usm_pool_handle_t hPool, void **ppMem,
size_t size, uint32_t alignment);
2 changes: 1 addition & 1 deletion unified-runtime/test/conformance/usm/urUSMPoolCreate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ TEST_P(urUSMPoolCreateTest, Success) {
}

TEST_P(urUSMPoolCreateTest, SuccessWithFlag) {
UUR_KNOWN_FAILURE_ON(uur::CUDA{});
UUR_KNOWN_FAILURE_ON(uur::CUDA{}, uur::HIP{});

ur_usm_pool_desc_t pool_desc{UR_STRUCTURE_TYPE_USM_POOL_DESC, nullptr,
UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK};
Expand Down