Skip to content

[UR][L0] Refactor urDeviceSelectBinary, allow more fallbacks (NFC) #18645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions unified-runtime/source/adapters/level_zero/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <algorithm>
#include <climits>
#include <optional>
#include <vector>

// UR_L0_USE_COPY_ENGINE can be set to an integer value, or
// a pair of integer values of the form "lower_index:upper_index".
Expand Down Expand Up @@ -1456,7 +1457,7 @@ ur_result_t urDevicePartition(

ur_result_t urDeviceSelectBinary(
/// [in] handle of the device to select binary for.
ur_device_handle_t /*Device*/,
[[maybe_unused]] ur_device_handle_t Device,
/// [in] the array of binaries to select from.
const ur_device_binary_t *Binaries,
/// [in] the number of binaries passed in ppBinaries. Must greater than or
Expand Down Expand Up @@ -1486,21 +1487,34 @@ ur_result_t urDeviceSelectBinary(

uint32_t *SelectedBinaryInd = SelectedBinary;

// Find the appropriate device image, fallback to spirv if not found
constexpr uint32_t InvalidInd = (std::numeric_limits<uint32_t>::max)();
uint32_t Spirv = InvalidInd;
// Find the appropriate device image
// The order of elements is important, as it defines the priority:
std::vector<const char *> FallbackTargets = {UR_DEVICE_BINARY_TARGET_SPIRV64};

constexpr uint32_t InvalidInd = std::numeric_limits<uint32_t>::max();
uint32_t FallbackInd = InvalidInd;
uint32_t FallbackPriority = InvalidInd;

for (uint32_t i = 0; i < NumBinaries; ++i) {
if (strcmp(Binaries[i].pDeviceTargetSpec, BinaryTarget) == 0) {
*SelectedBinaryInd = i;
return UR_RESULT_SUCCESS;
}
if (strcmp(Binaries[i].pDeviceTargetSpec,
UR_DEVICE_BINARY_TARGET_SPIRV64) == 0)
Spirv = i;
for (uint32_t j = 0; j < FallbackTargets.size(); ++j) {
// We have a fall-back with the same or higher priority already
// no need to check the rest
if (FallbackPriority <= j)
break;

if (strcmp(Binaries[i].pDeviceTargetSpec, FallbackTargets[j]) == 0) {
FallbackInd = i;
FallbackPriority = j;
break;
}
}
}
// Points to a spirv image, if such indeed was found
if ((*SelectedBinaryInd = Spirv) != InvalidInd)
// We didn't find a primary target, try the highest-priority fall-back
if ((*SelectedBinaryInd = FallbackInd) != InvalidInd)
return UR_RESULT_SUCCESS;

// No image can be loaded for the given device
Expand Down
8 changes: 8 additions & 0 deletions unified-runtime/test/adapters/level_zero/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ function(add_adapter_tests adapter)
)
endif()

add_adapter_test(${adapter}_device_select_binary
FIXTURE DEVICES
SOURCES
urDeviceSelectBinary.cpp
ENVIRONMENT
"UR_ADAPTERS_FORCE_LOAD=\"$<TARGET_FILE:ur_adapter_${adapter}>\""
)

add_adapter_test(${adapter}_mem_buffer_map
FIXTURE DEVICES
SOURCES
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (C) 2025 Intel Corporation
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
// Exceptions. See LICENSE.TXT
//
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "ur_api.h"
#include <uur/fixtures.h>

#include <array>
#include <vector>

using urLevelZeroDeviceSelectBinaryTest = uur::urDeviceTest;
UUR_INSTANTIATE_DEVICE_TEST_SUITE(urLevelZeroDeviceSelectBinaryTest);

static ur_device_binary_t binary_for_tgt(const char *Target) {
return {UR_STRUCTURE_TYPE_DEVICE_BINARY, nullptr, Target};
}

TEST_P(urLevelZeroDeviceSelectBinaryTest, TargetPreference) {
std::vector<ur_device_binary_t> binaries = {
binary_for_tgt(UR_DEVICE_BINARY_TARGET_UNKNOWN),
binary_for_tgt(UR_DEVICE_BINARY_TARGET_SPIRV64),
binary_for_tgt(UR_DEVICE_BINARY_TARGET_SPIRV64_GEN)};

// Gen binary should be preferred over SPIR-V
{
uint32_t selected_binary = binaries.size(); // invalid index
ASSERT_SUCCESS(urDeviceSelectBinary(device, binaries.data(),
binaries.size(), &selected_binary));
ASSERT_EQ(selected_binary, binaries.size() - 1);
}

// Remove the Gen binary,
// SPIR-V should be selected
binaries.pop_back();
{
uint32_t selected_binary = binaries.size(); // invalid index
ASSERT_SUCCESS(urDeviceSelectBinary(device, binaries.data(),
binaries.size(), &selected_binary));
ASSERT_EQ(selected_binary, binaries.size() - 1);
}

// No supported binaries left, should return an error
binaries.pop_back();
{
uint32_t selected_binary = binaries.size(); // invalid index
ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_BINARY,
urDeviceSelectBinary(device, binaries.data(),
binaries.size(), &selected_binary));
}
}

TEST_P(urLevelZeroDeviceSelectBinaryTest, FirstOfSupported) {
std::vector<const char *> SupportedTargets = {
UR_DEVICE_BINARY_TARGET_SPIRV64,
UR_DEVICE_BINARY_TARGET_SPIRV64_GEN,
};
for (const char *Target : SupportedTargets) {
std::array binaries = {
binary_for_tgt(UR_DEVICE_BINARY_TARGET_UNKNOWN),
binary_for_tgt(Target),
binary_for_tgt(UR_DEVICE_BINARY_TARGET_AMDGCN),
binary_for_tgt(Target),
};

uint32_t selected_binary = binaries.size(); // invalid index
ASSERT_SUCCESS(urDeviceSelectBinary(device, binaries.data(),
binaries.size(), &selected_binary));
ASSERT_EQ(selected_binary, 1u);
}
}