diff --git a/unified-runtime/source/adapters/level_zero/device.cpp b/unified-runtime/source/adapters/level_zero/device.cpp index e2c7a5e4d6aa3..c372443d07ec2 100644 --- a/unified-runtime/source/adapters/level_zero/device.cpp +++ b/unified-runtime/source/adapters/level_zero/device.cpp @@ -16,6 +16,7 @@ #include #include #include +#include // UR_L0_USE_COPY_ENGINE can be set to an integer value, or // a pair of integer values of the form "lower_index:upper_index". @@ -1456,7 +1457,7 @@ ur_result_t urDevicePartition( ur_result_t urDeviceSelectBinary( /// [in] handle of the device to select binary for. - ur_device_handle_t /*Device*/, + [[maybe_unused]] ur_device_handle_t Device, /// [in] the array of binaries to select from. const ur_device_binary_t *Binaries, /// [in] the number of binaries passed in ppBinaries. Must greater than or @@ -1486,21 +1487,34 @@ ur_result_t urDeviceSelectBinary( uint32_t *SelectedBinaryInd = SelectedBinary; - // Find the appropriate device image, fallback to spirv if not found - constexpr uint32_t InvalidInd = (std::numeric_limits::max)(); - uint32_t Spirv = InvalidInd; + // Find the appropriate device image + // The order of elements is important, as it defines the priority: + std::vector FallbackTargets = {UR_DEVICE_BINARY_TARGET_SPIRV64}; + + constexpr uint32_t InvalidInd = std::numeric_limits::max(); + uint32_t FallbackInd = InvalidInd; + uint32_t FallbackPriority = InvalidInd; for (uint32_t i = 0; i < NumBinaries; ++i) { if (strcmp(Binaries[i].pDeviceTargetSpec, BinaryTarget) == 0) { *SelectedBinaryInd = i; return UR_RESULT_SUCCESS; } - if (strcmp(Binaries[i].pDeviceTargetSpec, - UR_DEVICE_BINARY_TARGET_SPIRV64) == 0) - Spirv = i; + for (uint32_t j = 0; j < FallbackTargets.size(); ++j) { + // We have a fall-back with the same or higher priority already + // no need to check the rest + if (FallbackPriority <= j) + break; + + if (strcmp(Binaries[i].pDeviceTargetSpec, FallbackTargets[j]) == 0) { + FallbackInd = i; + FallbackPriority = j; + break; + } + } } - // Points to a spirv image, if such indeed was found - if ((*SelectedBinaryInd = Spirv) != InvalidInd) + // We didn't find a primary target, try the highest-priority fall-back + if ((*SelectedBinaryInd = FallbackInd) != InvalidInd) return UR_RESULT_SUCCESS; // No image can be loaded for the given device diff --git a/unified-runtime/test/adapters/level_zero/CMakeLists.txt b/unified-runtime/test/adapters/level_zero/CMakeLists.txt index a834300bbc528..7c6b7ca1b2fa8 100644 --- a/unified-runtime/test/adapters/level_zero/CMakeLists.txt +++ b/unified-runtime/test/adapters/level_zero/CMakeLists.txt @@ -98,6 +98,14 @@ function(add_adapter_tests adapter) ) endif() + add_adapter_test(${adapter}_device_select_binary + FIXTURE DEVICES + SOURCES + urDeviceSelectBinary.cpp + ENVIRONMENT + "UR_ADAPTERS_FORCE_LOAD=\"$\"" + ) + add_adapter_test(${adapter}_mem_buffer_map FIXTURE DEVICES SOURCES diff --git a/unified-runtime/test/adapters/level_zero/urDeviceSelectBinary.cpp b/unified-runtime/test/adapters/level_zero/urDeviceSelectBinary.cpp new file mode 100644 index 0000000000000..5a4495f078343 --- /dev/null +++ b/unified-runtime/test/adapters/level_zero/urDeviceSelectBinary.cpp @@ -0,0 +1,71 @@ +// Copyright (C) 2025 Intel Corporation +// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM +// Exceptions. See LICENSE.TXT +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include "ur_api.h" +#include + +#include +#include + +using urLevelZeroDeviceSelectBinaryTest = uur::urDeviceTest; +UUR_INSTANTIATE_DEVICE_TEST_SUITE(urLevelZeroDeviceSelectBinaryTest); + +static ur_device_binary_t binary_for_tgt(const char *Target) { + return {UR_STRUCTURE_TYPE_DEVICE_BINARY, nullptr, Target}; +} + +TEST_P(urLevelZeroDeviceSelectBinaryTest, TargetPreference) { + std::vector binaries = { + binary_for_tgt(UR_DEVICE_BINARY_TARGET_UNKNOWN), + binary_for_tgt(UR_DEVICE_BINARY_TARGET_SPIRV64), + binary_for_tgt(UR_DEVICE_BINARY_TARGET_SPIRV64_GEN)}; + + // Gen binary should be preferred over SPIR-V + { + uint32_t selected_binary = binaries.size(); // invalid index + ASSERT_SUCCESS(urDeviceSelectBinary(device, binaries.data(), + binaries.size(), &selected_binary)); + ASSERT_EQ(selected_binary, binaries.size() - 1); + } + + // Remove the Gen binary, + // SPIR-V should be selected + binaries.pop_back(); + { + uint32_t selected_binary = binaries.size(); // invalid index + ASSERT_SUCCESS(urDeviceSelectBinary(device, binaries.data(), + binaries.size(), &selected_binary)); + ASSERT_EQ(selected_binary, binaries.size() - 1); + } + + // No supported binaries left, should return an error + binaries.pop_back(); + { + uint32_t selected_binary = binaries.size(); // invalid index + ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_BINARY, + urDeviceSelectBinary(device, binaries.data(), + binaries.size(), &selected_binary)); + } +} + +TEST_P(urLevelZeroDeviceSelectBinaryTest, FirstOfSupported) { + std::vector SupportedTargets = { + UR_DEVICE_BINARY_TARGET_SPIRV64, + UR_DEVICE_BINARY_TARGET_SPIRV64_GEN, + }; + for (const char *Target : SupportedTargets) { + std::array binaries = { + binary_for_tgt(UR_DEVICE_BINARY_TARGET_UNKNOWN), + binary_for_tgt(Target), + binary_for_tgt(UR_DEVICE_BINARY_TARGET_AMDGCN), + binary_for_tgt(Target), + }; + + uint32_t selected_binary = binaries.size(); // invalid index + ASSERT_SUCCESS(urDeviceSelectBinary(device, binaries.data(), + binaries.size(), &selected_binary)); + ASSERT_EQ(selected_binary, 1u); + } +}