Skip to content

Commit c53c6fd

Browse files
committed
[UR][HIP] Enable usm pools
1 parent e620747 commit c53c6fd

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

unified-runtime/source/adapters/hip/CMakeLists.txt

+6
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,12 @@ else()
203203
message(FATAL_ERROR "Unspecified UR HIP platform please set UR_HIP_PLATFORM to 'AMD' or 'NVIDIA'")
204204
endif()
205205

206+
if(UMF_ENABLE_POOL_TRACKING)
207+
target_compile_definitions(${TARGET_NAME} PRIVATE UMF_ENABLE_POOL_TRACKING)
208+
else()
209+
message(WARNING "HIP adapter USM pools are disabled, set UMF_ENABLE_POOL_TRACKING to enable them")
210+
endif()
211+
206212
target_include_directories(${TARGET_NAME} PRIVATE
207213
"${CMAKE_CURRENT_SOURCE_DIR}/../../"
208214
)

unified-runtime/source/adapters/hip/usm.cpp

+13-3
Original file line numberDiff line numberDiff line change
@@ -330,15 +330,25 @@ ur_result_t USMHostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size,
330330
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
331331
ur_usm_pool_desc_t *PoolDesc)
332332
: Context(Context) {
333-
if (PoolDesc) {
334-
if (auto *Limits = find_stype_node<ur_usm_pool_limits_desc_t>(PoolDesc)) {
333+
334+
const void *pNext = PoolDesc->pNext;
335+
while (pNext != nullptr) {
336+
const ur_base_desc_t *BaseDesc = static_cast<const ur_base_desc_t *>(pNext);
337+
switch (BaseDesc->stype) {
338+
case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: {
339+
const ur_usm_pool_limits_desc_t *Limits =
340+
reinterpret_cast<const ur_usm_pool_limits_desc_t *>(BaseDesc);
335341
for (auto &config : DisjointPoolConfigs.Configs) {
336342
config.MaxPoolableSize = Limits->maxPoolableSize;
337343
config.SlabMinSize = Limits->minDriverAllocSize;
338344
}
339-
} else {
345+
break;
346+
}
347+
default: {
340348
throw UsmAllocationException(UR_RESULT_ERROR_INVALID_ARGUMENT);
341349
}
350+
}
351+
pNext = BaseDesc->pNext;
342352
}
343353

344354
auto MemProvider =

unified-runtime/test/conformance/usm/urUSMPoolCreate.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ TEST_P(urUSMPoolCreateTest, Success) {
2929
}
3030

3131
TEST_P(urUSMPoolCreateTest, SuccessWithFlag) {
32-
UUR_KNOWN_FAILURE_ON(uur::CUDA{});
32+
UUR_KNOWN_FAILURE_ON(uur::CUDA{}, uur::HIP{});
3333

3434
ur_usm_pool_desc_t pool_desc{UR_STRUCTURE_TYPE_USM_POOL_DESC, nullptr,
3535
UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK};

0 commit comments

Comments
 (0)