Skip to content

Commit d63672b

Browse files
authored
[UR][L0] Refactor urDeviceSelectBinary, allow more fallbacks (NFC) (#18645)
We would like to extend urDeviceSelectBinary downstream to allow for device-specific binary targets. This commit refactors the current handling to make this easier. Additionally L0 specific tests are added for urDeviceSelectBinary to verify that the fallback logic works as expected.
1 parent d877bb9 commit d63672b

File tree

3 files changed

+102
-9
lines changed

3 files changed

+102
-9
lines changed

unified-runtime/source/adapters/level_zero/device.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <algorithm>
1717
#include <climits>
1818
#include <optional>
19+
#include <vector>
1920

2021
// UR_L0_USE_COPY_ENGINE can be set to an integer value, or
2122
// a pair of integer values of the form "lower_index:upper_index".
@@ -1456,7 +1457,7 @@ ur_result_t urDevicePartition(
14561457

14571458
ur_result_t urDeviceSelectBinary(
14581459
/// [in] handle of the device to select binary for.
1459-
ur_device_handle_t /*Device*/,
1460+
[[maybe_unused]] ur_device_handle_t Device,
14601461
/// [in] the array of binaries to select from.
14611462
const ur_device_binary_t *Binaries,
14621463
/// [in] the number of binaries passed in ppBinaries. Must greater than or
@@ -1486,21 +1487,34 @@ ur_result_t urDeviceSelectBinary(
14861487

14871488
uint32_t *SelectedBinaryInd = SelectedBinary;
14881489

1489-
// Find the appropriate device image, fallback to spirv if not found
1490-
constexpr uint32_t InvalidInd = (std::numeric_limits<uint32_t>::max)();
1491-
uint32_t Spirv = InvalidInd;
1490+
// Find the appropriate device image
1491+
// The order of elements is important, as it defines the priority:
1492+
std::vector<const char *> FallbackTargets = {UR_DEVICE_BINARY_TARGET_SPIRV64};
1493+
1494+
constexpr uint32_t InvalidInd = std::numeric_limits<uint32_t>::max();
1495+
uint32_t FallbackInd = InvalidInd;
1496+
uint32_t FallbackPriority = InvalidInd;
14921497

14931498
for (uint32_t i = 0; i < NumBinaries; ++i) {
14941499
if (strcmp(Binaries[i].pDeviceTargetSpec, BinaryTarget) == 0) {
14951500
*SelectedBinaryInd = i;
14961501
return UR_RESULT_SUCCESS;
14971502
}
1498-
if (strcmp(Binaries[i].pDeviceTargetSpec,
1499-
UR_DEVICE_BINARY_TARGET_SPIRV64) == 0)
1500-
Spirv = i;
1503+
for (uint32_t j = 0; j < FallbackTargets.size(); ++j) {
1504+
// We have a fall-back with the same or higher priority already
1505+
// no need to check the rest
1506+
if (FallbackPriority <= j)
1507+
break;
1508+
1509+
if (strcmp(Binaries[i].pDeviceTargetSpec, FallbackTargets[j]) == 0) {
1510+
FallbackInd = i;
1511+
FallbackPriority = j;
1512+
break;
1513+
}
1514+
}
15011515
}
1502-
// Points to a spirv image, if such indeed was found
1503-
if ((*SelectedBinaryInd = Spirv) != InvalidInd)
1516+
// We didn't find a primary target, try the highest-priority fall-back
1517+
if ((*SelectedBinaryInd = FallbackInd) != InvalidInd)
15041518
return UR_RESULT_SUCCESS;
15051519

15061520
// No image can be loaded for the given device

unified-runtime/test/adapters/level_zero/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ function(add_adapter_tests adapter)
9898
)
9999
endif()
100100

101+
add_adapter_test(${adapter}_device_select_binary
102+
FIXTURE DEVICES
103+
SOURCES
104+
urDeviceSelectBinary.cpp
105+
ENVIRONMENT
106+
"UR_ADAPTERS_FORCE_LOAD=\"$<TARGET_FILE:ur_adapter_${adapter}>\""
107+
)
108+
101109
add_adapter_test(${adapter}_mem_buffer_map
102110
FIXTURE DEVICES
103111
SOURCES
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright (C) 2025 Intel Corporation
2+
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
3+
// Exceptions. See LICENSE.TXT
4+
//
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
#include "ur_api.h"
7+
#include <uur/fixtures.h>
8+
9+
#include <array>
10+
#include <vector>
11+
12+
using urLevelZeroDeviceSelectBinaryTest = uur::urDeviceTest;
13+
UUR_INSTANTIATE_DEVICE_TEST_SUITE(urLevelZeroDeviceSelectBinaryTest);
14+
15+
static ur_device_binary_t binary_for_tgt(const char *Target) {
16+
return {UR_STRUCTURE_TYPE_DEVICE_BINARY, nullptr, Target};
17+
}
18+
19+
TEST_P(urLevelZeroDeviceSelectBinaryTest, TargetPreference) {
20+
std::vector<ur_device_binary_t> binaries = {
21+
binary_for_tgt(UR_DEVICE_BINARY_TARGET_UNKNOWN),
22+
binary_for_tgt(UR_DEVICE_BINARY_TARGET_SPIRV64),
23+
binary_for_tgt(UR_DEVICE_BINARY_TARGET_SPIRV64_GEN)};
24+
25+
// Gen binary should be preferred over SPIR-V
26+
{
27+
uint32_t selected_binary = binaries.size(); // invalid index
28+
ASSERT_SUCCESS(urDeviceSelectBinary(device, binaries.data(),
29+
binaries.size(), &selected_binary));
30+
ASSERT_EQ(selected_binary, binaries.size() - 1);
31+
}
32+
33+
// Remove the Gen binary,
34+
// SPIR-V should be selected
35+
binaries.pop_back();
36+
{
37+
uint32_t selected_binary = binaries.size(); // invalid index
38+
ASSERT_SUCCESS(urDeviceSelectBinary(device, binaries.data(),
39+
binaries.size(), &selected_binary));
40+
ASSERT_EQ(selected_binary, binaries.size() - 1);
41+
}
42+
43+
// No supported binaries left, should return an error
44+
binaries.pop_back();
45+
{
46+
uint32_t selected_binary = binaries.size(); // invalid index
47+
ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_BINARY,
48+
urDeviceSelectBinary(device, binaries.data(),
49+
binaries.size(), &selected_binary));
50+
}
51+
}
52+
53+
TEST_P(urLevelZeroDeviceSelectBinaryTest, FirstOfSupported) {
54+
std::vector<const char *> SupportedTargets = {
55+
UR_DEVICE_BINARY_TARGET_SPIRV64,
56+
UR_DEVICE_BINARY_TARGET_SPIRV64_GEN,
57+
};
58+
for (const char *Target : SupportedTargets) {
59+
std::array binaries = {
60+
binary_for_tgt(UR_DEVICE_BINARY_TARGET_UNKNOWN),
61+
binary_for_tgt(Target),
62+
binary_for_tgt(UR_DEVICE_BINARY_TARGET_AMDGCN),
63+
binary_for_tgt(Target),
64+
};
65+
66+
uint32_t selected_binary = binaries.size(); // invalid index
67+
ASSERT_SUCCESS(urDeviceSelectBinary(device, binaries.data(),
68+
binaries.size(), &selected_binary));
69+
ASSERT_EQ(selected_binary, 1u);
70+
}
71+
}

0 commit comments

Comments
 (0)