diff --git a/scripts/core/CUDA.rst b/scripts/core/CUDA.rst index 9771693113..08b61bf9dc 100644 --- a/scripts/core/CUDA.rst +++ b/scripts/core/CUDA.rst @@ -148,6 +148,39 @@ take the extra global offset argument. Use of the global offset is not recommended for non SYCL compiler toolchains. This parameter can be ignored if the user does not wish to use the global offset. +Local Memory Arguments +---------------------- + +In UR local memory is a region of memory shared by all the work-items in +a work-group. A kernel function signature can include local memory address +space pointer arguments, which are set by the user with +``urKernelSetArgLocal`` with the number of bytes of local memory to allocate +and make available from the pointer argument. + +The CUDA adapter implements local memory in a kernel as a single ``__shared__`` +memory allocation, and each individual local memory argument is a ``u32`` byte +offset kernel parameter which is combined inside the kernel with the +``__shared__`` memory allocation. Therefore for ``N`` local arguments that need +set on a kernel with ``urKernelSetArgLocal``, the total aligned size across the +``N`` calls to ``urKernelSetArgLocal`` is calculated for the ``__shared__`` +memory allocation by the CUDA adapter and passed as the ``sharedMemBytes`` +argument to ``cuLaunchKernel`` (or variants like ``cuLaunchCooperativeKernel`` +or ``cuGraphAddKernelNode``). + +For each kernel ``u32`` local memory offset parameter, aligned offsets into the +single memory location are calculated and passed at runtime by the adapter via +``kernelParams`` when launching the kernel (or adding the kernel as a graph +node). When a user calls ``urKernelSetArgLocal`` with an argument index that +has already been set on the kernel, the adapter recalculates the size of the +``__shared__`` memory allocation and offset for the index, as well as the +offsets of any local memory arguments at following indices. + +.. warning:: + + The CUDA UR adapter implementation of local memory assumes the kernel created + has been created by DPC++, instrumenting the device code so that local memory + arguments are offsets rather than pointers. + Other Notes =========== @@ -164,4 +197,5 @@ Contributors ------------ * Hugh Delaney `hugh.delaney@codeplay.com `_ +* Ewan Crawford `ewan@codeplay.com `_ diff --git a/scripts/core/HIP.rst b/scripts/core/HIP.rst index 3ded0138ff..920a5f5a3e 100644 --- a/scripts/core/HIP.rst +++ b/scripts/core/HIP.rst @@ -91,6 +91,46 @@ take the extra global offset argument. Use of the global offset is not recommended for non SYCL compiler toolchains. This parameter can be ignored if the user does not wish to use the global offset. +Local Memory Arguments +---------------------- + +In UR local memory is a region of memory shared by all the work-items in +a work-group. A kernel function signature can include local memory address +space pointer arguments, which are set by the user with +``urKernelSetArgLocal`` with the number of bytes of local memory to allocate +and make available from the pointer argument. + +The HIP adapter implements local memory in a kernel as a single ``__shared__`` +memory allocation, and each individual local memory argument is a ``u32`` byte +offset kernel parameter which is combined inside the kernel with the +``__shared__`` memory allocation. Therefore for ``N`` local arguments that need +set on a kernel with ``urKernelSetArgLocal``, the total aligned size across the +``N`` calls to ``urKernelSetArgLocal`` is calculated for the ``__shared__`` +memory allocation by the HIP adapter and passed as the ``sharedMemBytes`` +argument to ``hipModuleLaunchKernel`` or ``hipGraphAddKernelNode``. + +For each kernel ``u32`` local memory offset parameter, aligned offsets into the +single memory location are calculated and passed at runtime by the adapter via +``kernelParams`` when launching the kernel (or adding the kernel as a graph +node). When a user calls ``urKernelSetArgLocal`` with an argument index that +has already been set on the kernel, the adapter recalculates the size of the +``__shared__`` memory allocation and offset for the index, as well as the +offsets of any local memory arguments at following indices. + +.. warning:: + + The HIP UR adapter implementation of local memory assumes the kernel created + has been created by DPC++, instrumenting the device code so that local memory + arguments are offsets rather than pointers. + + +HIP kernels that are generated for DPC++ kernels with SYCL local accessors +contain extra value arguments on top of the local memory argument for the +local accessor. For each ``urKernelSetArgLocal`` argument, a user needs +to make 3 calls to ``urKernelSetArgValue`` with each of the next 3 consecutive +argument indexes. This represents a 3 dimensional offset into the local +accessor. + Other Notes =========== @@ -100,4 +140,5 @@ Contributors ------------ * Hugh Delaney `hugh.delaney@codeplay.com `_ +* Ewan Crawford `ewan@codeplay.com `_ diff --git a/source/adapters/cuda/command_buffer.cpp b/source/adapters/cuda/command_buffer.cpp index 527c339783..4b4b2cffe5 100644 --- a/source/adapters/cuda/command_buffer.cpp +++ b/source/adapters/cuda/command_buffer.cpp @@ -522,9 +522,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( DepsList.data(), DepsList.size(), &NodeParams)); - if (LocalSize != 0) - hKernel->clearLocalSize(); - // Add signal node if external return event is used. CUgraphNode SignalNode = nullptr; if (phEvent) { diff --git a/source/adapters/cuda/enqueue.cpp b/source/adapters/cuda/enqueue.cpp index fc3d0220e8..54a0f778fb 100644 --- a/source/adapters/cuda/enqueue.cpp +++ b/source/adapters/cuda/enqueue.cpp @@ -493,9 +493,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( ThreadsPerBlock[0], ThreadsPerBlock[1], ThreadsPerBlock[2], LocalSize, CuStream, const_cast(ArgIndices.data()), nullptr)); - if (LocalSize != 0) - hKernel->clearLocalSize(); - if (phEvent) { UR_CHECK_ERROR(RetImplEvent->record()); *phEvent = RetImplEvent.release(); @@ -673,9 +670,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp( const_cast(ArgIndices.data()), nullptr)); - if (LocalSize != 0) - hKernel->clearLocalSize(); - if (phEvent) { UR_CHECK_ERROR(RetImplEvent->record()); *phEvent = RetImplEvent.release(); diff --git a/source/adapters/cuda/kernel.hpp b/source/adapters/cuda/kernel.hpp index 7ad20a4f0e..2b04dfba43 100644 --- a/source/adapters/cuda/kernel.hpp +++ b/source/adapters/cuda/kernel.hpp @@ -61,10 +61,22 @@ struct ur_kernel_handle_t_ { using args_t = std::array; using args_size_t = std::vector; using args_index_t = std::vector; + /// Storage shared by all args which is mem copied into when adding a new + /// argument. args_t Storage; + /// Aligned size of each parameter, including padding. args_size_t ParamSizes; + /// Byte offset into /p Storage allocation for each parameter. args_index_t Indices; - args_size_t OffsetPerIndex; + /// Aligned size in bytes for each local memory parameter after padding has + /// been added. Zero if the argument at the index isn't a local memory + /// argument. + args_size_t AlignedLocalMemSize; + /// Original size in bytes for each local memory parameter, prior to being + /// padded to appropriate alignment. Zero if the argument at the index + /// isn't a local memory argument. + args_size_t OriginalLocalMemSize; + // A struct to keep track of memargs so that we can do dependency analysis // at urEnqueueKernelLaunch struct mem_obj_arg { @@ -93,7 +105,8 @@ struct ur_kernel_handle_t_ { Indices.resize(Index + 2, Indices.back()); // Ensure enough space for the new argument ParamSizes.resize(Index + 1); - OffsetPerIndex.resize(Index + 1); + AlignedLocalMemSize.resize(Index + 1); + OriginalLocalMemSize.resize(Index + 1); } ParamSizes[Index] = Size; // calculate the insertion point on the array @@ -102,28 +115,81 @@ struct ur_kernel_handle_t_ { // Update the stored value for the argument std::memcpy(&Storage[InsertPos], Arg, Size); Indices[Index] = &Storage[InsertPos]; - OffsetPerIndex[Index] = LocalSize; + AlignedLocalMemSize[Index] = LocalSize; } - void addLocalArg(size_t Index, size_t Size) { - size_t LocalOffset = this->getLocalSize(); + /// Returns the padded size and offset of a local memory argument. + /// Local memory arguments need to be padded if the alignment for the size + /// doesn't match the current offset into the kernel local data. + /// @param Index Kernel arg index. + /// @param Size User passed size of local parameter. + /// @return Tuple of (Aligned size, Aligned offset into local data). + std::pair calcAlignedLocalArgument(size_t Index, + size_t Size) { + // Store the unpadded size of the local argument + if (Index + 2 > Indices.size()) { + AlignedLocalMemSize.resize(Index + 1); + OriginalLocalMemSize.resize(Index + 1); + } + OriginalLocalMemSize[Index] = Size; + + // Calculate the current starting offset into local data + const size_t LocalOffset = std::accumulate( + std::begin(AlignedLocalMemSize), + std::next(std::begin(AlignedLocalMemSize), Index), size_t{0}); - // maximum required alignment is the size of the largest vector type + // Maximum required alignment is the size of the largest vector type const size_t MaxAlignment = sizeof(double) * 16; - // for arguments smaller than the maximum alignment simply align to the + // For arguments smaller than the maximum alignment simply align to the // size of the argument const size_t Alignment = std::min(MaxAlignment, Size); - // align the argument + // Align the argument size_t AlignedLocalOffset = LocalOffset; - size_t Pad = LocalOffset % Alignment; + const size_t Pad = LocalOffset % Alignment; if (Pad != 0) { AlignedLocalOffset += Alignment - Pad; } + const size_t AlignedLocalSize = Size + (AlignedLocalOffset - LocalOffset); + return std::make_pair(AlignedLocalSize, AlignedLocalOffset); + } + + void addLocalArg(size_t Index, size_t Size) { + // Get the aligned argument size and offset into local data + auto [AlignedLocalSize, AlignedLocalOffset] = + calcAlignedLocalArgument(Index, Size); + + // Store argument details addArg(Index, sizeof(size_t), (const void *)&(AlignedLocalOffset), - Size + (AlignedLocalOffset - LocalOffset)); + AlignedLocalSize); + + // For every existing local argument which follows at later argument + // indices, update the offset and pointer into the kernel local memory. + // Required as padding will need to be recalculated. + const size_t NumArgs = Indices.size() - 1; // Accounts for implicit arg + for (auto SuccIndex = Index + 1; SuccIndex < NumArgs; SuccIndex++) { + const size_t OriginalLocalSize = OriginalLocalMemSize[SuccIndex]; + if (OriginalLocalSize == 0) { + // Skip if successor argument isn't a local memory arg + continue; + } + + // Recalculate alignment + auto [SuccAlignedLocalSize, SuccAlignedLocalOffset] = + calcAlignedLocalArgument(SuccIndex, OriginalLocalSize); + + // Store new local memory size + AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize; + + // Store new offset into local data + const size_t InsertPos = + std::accumulate(std::begin(ParamSizes), + std::begin(ParamSizes) + SuccIndex, size_t{0}); + std::memcpy(&Storage[InsertPos], &SuccAlignedLocalOffset, + sizeof(size_t)); + } } void addMemObjArg(int Index, ur_mem_handle_t hMem, ur_mem_flags_t Flags) { @@ -145,15 +211,11 @@ struct ur_kernel_handle_t_ { std::memcpy(ImplicitOffsetArgs, ImplicitOffset, Size); } - void clearLocalSize() { - std::fill(std::begin(OffsetPerIndex), std::end(OffsetPerIndex), 0); - } - const args_index_t &getIndices() const noexcept { return Indices; } uint32_t getLocalSize() const { - return std::accumulate(std::begin(OffsetPerIndex), - std::end(OffsetPerIndex), 0); + return std::accumulate(std::begin(AlignedLocalMemSize), + std::end(AlignedLocalMemSize), 0); } } Args; @@ -240,7 +302,5 @@ struct ur_kernel_handle_t_ { uint32_t getLocalSize() const noexcept { return Args.getLocalSize(); } - void clearLocalSize() { Args.clearLocalSize(); } - size_t getRegsPerThread() const noexcept { return RegsPerThread; }; }; diff --git a/source/adapters/hip/command_buffer.cpp b/source/adapters/hip/command_buffer.cpp index 9fed5db2f8..538c2ff85a 100644 --- a/source/adapters/hip/command_buffer.cpp +++ b/source/adapters/hip/command_buffer.cpp @@ -396,9 +396,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( DepsList.data(), DepsList.size(), &NodeParams)); - if (LocalSize != 0) - hKernel->clearLocalSize(); - // Get sync point and register the node with it. auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode); if (pSyncPoint) { diff --git a/source/adapters/hip/enqueue.cpp b/source/adapters/hip/enqueue.cpp index 025a3f41f4..b9aa097848 100644 --- a/source/adapters/hip/enqueue.cpp +++ b/source/adapters/hip/enqueue.cpp @@ -324,8 +324,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( ThreadsPerBlock[0], ThreadsPerBlock[1], ThreadsPerBlock[2], hKernel->getLocalSize(), HIPStream, ArgIndices.data(), nullptr)); - hKernel->clearLocalSize(); - if (phEvent) { UR_CHECK_ERROR(RetImplEvent->record()); *phEvent = RetImplEvent.release(); diff --git a/source/adapters/hip/kernel.hpp b/source/adapters/hip/kernel.hpp index afea69832b..c6d30e81ad 100644 --- a/source/adapters/hip/kernel.hpp +++ b/source/adapters/hip/kernel.hpp @@ -56,10 +56,22 @@ struct ur_kernel_handle_t_ { using args_t = std::array; using args_size_t = std::vector; using args_index_t = std::vector; + /// Storage shared by all args which is mem copied into when adding a new + /// argument. args_t Storage; + /// Aligned size of each parameter, including padding. args_size_t ParamSizes; + /// Byte offset into /p Storage allocation for each parameter. args_index_t Indices; - args_size_t OffsetPerIndex; + /// Aligned size in bytes for each local memory parameter after padding has + /// been added. Zero if the argument at the index isn't a local memory + /// argument. + args_size_t AlignedLocalMemSize; + /// Original size in bytes for each local memory parameter, prior to being + /// padded to appropriate alignment. Zero if the argument at the index + /// isn't a local memory argument. + args_size_t OriginalLocalMemSize; + // A struct to keep track of memargs so that we can do dependency analysis // at urEnqueueKernelLaunch struct mem_obj_arg { @@ -88,7 +100,8 @@ struct ur_kernel_handle_t_ { Indices.resize(Index + 2, Indices.back()); // Ensure enough space for the new argument ParamSizes.resize(Index + 1); - OffsetPerIndex.resize(Index + 1); + AlignedLocalMemSize.resize(Index + 1); + OriginalLocalMemSize.resize(Index + 1); } ParamSizes[Index] = Size; // calculate the insertion point on the array @@ -97,28 +110,81 @@ struct ur_kernel_handle_t_ { // Update the stored value for the argument std::memcpy(&Storage[InsertPos], Arg, Size); Indices[Index] = &Storage[InsertPos]; - OffsetPerIndex[Index] = LocalSize; + AlignedLocalMemSize[Index] = LocalSize; } - void addLocalArg(size_t Index, size_t Size) { - size_t LocalOffset = this->getLocalSize(); + /// Returns the padded size and offset of a local memory argument. + /// Local memory arguments need to be padded if the alignment for the size + /// doesn't match the current offset into the kernel local data. + /// @param Index Kernel arg index. + /// @param Size User passed size of local parameter. + /// @return Tuple of (Aligned size, Aligned offset into local data). + std::pair calcAlignedLocalArgument(size_t Index, + size_t Size) { + // Store the unpadded size of the local argument + if (Index + 2 > Indices.size()) { + AlignedLocalMemSize.resize(Index + 1); + OriginalLocalMemSize.resize(Index + 1); + } + OriginalLocalMemSize[Index] = Size; - // maximum required alignment is the size of the largest vector type + // Calculate the current starting offset into local data + const size_t LocalOffset = std::accumulate( + std::begin(AlignedLocalMemSize), + std::next(std::begin(AlignedLocalMemSize), Index), size_t{0}); + + // Maximum required alignment is the size of the largest vector type const size_t MaxAlignment = sizeof(double) * 16; - // for arguments smaller than the maximum alignment simply align to the + // For arguments smaller than the maximum alignment simply align to the // size of the argument const size_t Alignment = std::min(MaxAlignment, Size); - // align the argument + // Align the argument size_t AlignedLocalOffset = LocalOffset; - size_t Pad = LocalOffset % Alignment; + const size_t Pad = LocalOffset % Alignment; if (Pad != 0) { AlignedLocalOffset += Alignment - Pad; } - addArg(Index, sizeof(size_t), (const void *)&AlignedLocalOffset, - Size + AlignedLocalOffset - LocalOffset); + const size_t AlignedLocalSize = Size + (AlignedLocalOffset - LocalOffset); + return std::make_pair(AlignedLocalSize, AlignedLocalOffset); + } + + void addLocalArg(size_t Index, size_t Size) { + // Get the aligned argument size and offset into local data + auto [AlignedLocalSize, AlignedLocalOffset] = + calcAlignedLocalArgument(Index, Size); + + // Store argument details + addArg(Index, sizeof(size_t), (const void *)&(AlignedLocalOffset), + AlignedLocalSize); + + // For every existing local argument which follows at later argument + // indices, update the offset and pointer into the kernel local memory. + // Required as padding will need to be recalculated. + const size_t NumArgs = Indices.size() - 1; // Accounts for implicit arg + for (auto SuccIndex = Index + 1; SuccIndex < NumArgs; SuccIndex++) { + const size_t OriginalLocalSize = OriginalLocalMemSize[SuccIndex]; + if (OriginalLocalSize == 0) { + // Skip if successor argument isn't a local memory arg + continue; + } + + // Recalculate alignment + auto [SuccAlignedLocalSize, SuccAlignedLocalOffset] = + calcAlignedLocalArgument(SuccIndex, OriginalLocalSize); + + // Store new local memory size + AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize; + + // Store new offset into local data + const size_t InsertPos = + std::accumulate(std::begin(ParamSizes), + std::begin(ParamSizes) + SuccIndex, size_t{0}); + std::memcpy(&Storage[InsertPos], &SuccAlignedLocalOffset, + sizeof(size_t)); + } } void addMemObjArg(int Index, ur_mem_handle_t hMem, ur_mem_flags_t Flags) { @@ -140,15 +206,11 @@ struct ur_kernel_handle_t_ { std::memcpy(ImplicitOffsetArgs, ImplicitOffset, Size); } - void clearLocalSize() { - std::fill(std::begin(OffsetPerIndex), std::end(OffsetPerIndex), 0); - } - const args_index_t &getIndices() const noexcept { return Indices; } uint32_t getLocalSize() const { - return std::accumulate(std::begin(OffsetPerIndex), - std::end(OffsetPerIndex), 0); + return std::accumulate(std::begin(AlignedLocalMemSize), + std::end(AlignedLocalMemSize), 0); } } Args; @@ -220,6 +282,4 @@ struct ur_kernel_handle_t_ { } uint32_t getLocalSize() const noexcept { return Args.getLocalSize(); } - - void clearLocalSize() { Args.clearLocalSize(); } }; diff --git a/test/conformance/device_code/saxpy_usm_local_mem.cpp b/test/conformance/device_code/saxpy_usm_local_mem.cpp index 7ef17e59b5..c2bc3adc5e 100644 --- a/test/conformance/device_code/saxpy_usm_local_mem.cpp +++ b/test/conformance/device_code/saxpy_usm_local_mem.cpp @@ -15,15 +15,22 @@ int main() { uint32_t A = 42; sycl_queue.submit([&](sycl::handler &cgh) { - sycl::local_accessor local_mem(local_size, cgh); + sycl::local_accessor local_mem_A(local_size, cgh); + sycl::local_accessor local_mem_B(local_size * 2, cgh); + cgh.parallel_for( sycl::nd_range<1>{{array_size}, {local_size}}, [=](sycl::nd_item<1> itemId) { auto i = itemId.get_global_linear_id(); auto local_id = itemId.get_local_linear_id(); - local_mem[local_id] = i; - Z[i] = A * X[i] + Y[i] + local_mem[local_id] + - itemId.get_local_range(0); + + local_mem_A[local_id] = i; + local_mem_B[local_id * 2] = -i; + local_mem_B[(local_id * 2) + 1] = itemId.get_local_range(0); + + Z[i] = A * X[i] + Y[i] + local_mem_A[local_id] + + local_mem_B[local_id * 2] + + local_mem_B[(local_id * 2) + 1]; }); }); return 0; diff --git a/test/conformance/exp_command_buffer/exp_command_buffer_adapter_native_cpu.match b/test/conformance/exp_command_buffer/exp_command_buffer_adapter_native_cpu.match index c6fe7ad962..3588eaea82 100644 --- a/test/conformance/exp_command_buffer/exp_command_buffer_adapter_native_cpu.match +++ b/test/conformance/exp_command_buffer/exp_command_buffer_adapter_native_cpu.match @@ -37,7 +37,11 @@ {{OPT}}KernelCommandEventSyncUpdateTest.TwoWaitEvents/* {{OPT}}KernelCommandEventSyncUpdateTest.InvalidWaitUpdate/* {{OPT}}KernelCommandEventSyncUpdateTest.InvalidSignalUpdate/* -{{OPT}}LocalMemoryUpdateTest.UpdateParameters/* -{{OPT}}LocalMemoryUpdateTest.UpdateParametersAndLocalSize/* +{{OPT}}LocalMemoryUpdateTest.UpdateParametersSameLocalSize/* +{{OPT}}LocalMemoryUpdateTest.UpdateLocalOnly/* +{{OPT}}LocalMemoryUpdateTest.UpdateParametersEmptyLocalSize/* +{{OPT}}LocalMemoryUpdateTest.UpdateParametersSmallerLocalSize/* +{{OPT}}LocalMemoryUpdateTest.UpdateParametersLargerLocalSize/* +{{OPT}}LocalMemoryUpdateTest.UpdateParametersPartialLocalSize/* {{OPT}}LocalMemoryMultiUpdateTest.UpdateParameters/* {{OPT}}LocalMemoryMultiUpdateTest.UpdateWithoutBlocking/* diff --git a/test/conformance/exp_command_buffer/update/local_memory_update.cpp b/test/conformance/exp_command_buffer/update/local_memory_update.cpp index c295556fdb..c467c9783a 100644 --- a/test/conformance/exp_command_buffer/update/local_memory_update.cpp +++ b/test/conformance/exp_command_buffer/update/local_memory_update.cpp @@ -8,8 +8,7 @@ #include // Test that updating a command-buffer with a single kernel command -// taking a local memory argument works correctly. - +// taking local memory arguments works correctly. struct LocalMemoryUpdateTestBase : uur::command_buffer::urUpdatableCommandBufferExpExecutionTest { virtual void SetUp() override { @@ -17,7 +16,13 @@ struct LocalMemoryUpdateTestBase UUR_RETURN_ON_FATAL_FAILURE( urUpdatableCommandBufferExpExecutionTest::SetUp()); - // HIP has extra args for local memory so we define an offset for arg indices here for updating + if (backend == UR_PLATFORM_BACKEND_LEVEL_ZERO) { + GTEST_SKIP() + << "Local memory argument update not supported on Level Zero."; + } + + // HIP has extra args for local memory so we define an offset for arg + // indices here for updating hip_arg_offset = backend == UR_PLATFORM_BACKEND_HIP ? 3 : 0; ur_device_usm_access_capability_flags_t shared_usm_flags; ASSERT_SUCCESS( @@ -38,33 +43,48 @@ struct LocalMemoryUpdateTestBase std::memcpy(shared_ptr, pattern.data(), allocation_size); } size_t current_index = 0; - // Index 0 is local_mem arg + // Index 0 is local_mem_a arg ASSERT_SUCCESS(urKernelSetArgLocal(kernel, current_index++, - local_mem_size, nullptr)); + local_mem_a_size, nullptr)); + + // Hip has extra args for local mem at index 1-3 + if (backend == UR_PLATFORM_BACKEND_HIP) { + ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, + sizeof(hip_local_offset), + nullptr, &hip_local_offset)); + ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, + sizeof(hip_local_offset), + nullptr, &hip_local_offset)); + ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, + sizeof(hip_local_offset), + nullptr, &hip_local_offset)); + } - //Hip has extr args for local mem at index 1-3 + // Index 1 is local_mem_b arg + ASSERT_SUCCESS(urKernelSetArgLocal(kernel, current_index++, + local_mem_b_size, nullptr)); if (backend == UR_PLATFORM_BACKEND_HIP) { ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, - sizeof(local_size), nullptr, - &local_size)); + sizeof(hip_local_offset), + nullptr, &hip_local_offset)); ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, - sizeof(local_size), nullptr, - &local_size)); + sizeof(hip_local_offset), + nullptr, &hip_local_offset)); ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, - sizeof(local_size), nullptr, - &local_size)); + sizeof(hip_local_offset), + nullptr, &hip_local_offset)); } - // Index 1 is output + // Index 2 is output ASSERT_SUCCESS(urKernelSetArgPointer(kernel, current_index++, nullptr, shared_ptrs[0])); - // Index 2 is A + // Index 3 is A ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, sizeof(A), nullptr, &A)); - // Index 3 is X + // Index 4 is X ASSERT_SUCCESS(urKernelSetArgPointer(kernel, current_index++, nullptr, shared_ptrs[1])); - // Index 4 is Y + // Index 5 is Y ASSERT_SUCCESS(urKernelSetArgPointer(kernel, current_index++, nullptr, shared_ptrs[2])); } @@ -72,7 +92,7 @@ struct LocalMemoryUpdateTestBase void Validate(uint32_t *output, uint32_t *X, uint32_t *Y, uint32_t A, size_t length, size_t local_size) { for (size_t i = 0; i < length; i++) { - uint32_t result = A * X[i] + Y[i] + i + local_size; + uint32_t result = A * X[i] + Y[i] + local_size; ASSERT_EQ(result, output[i]); } } @@ -89,7 +109,8 @@ struct LocalMemoryUpdateTestBase } static constexpr size_t local_size = 4; - static constexpr size_t local_mem_size = local_size * sizeof(uint32_t); + static constexpr size_t local_mem_a_size = local_size * sizeof(uint32_t); + static constexpr size_t local_mem_b_size = local_mem_a_size * 2; static constexpr size_t global_size = 16; static constexpr size_t global_offset = 0; static constexpr size_t n_dimensions = 1; @@ -98,6 +119,7 @@ struct LocalMemoryUpdateTestBase nullptr}; uint32_t hip_arg_offset = 0; + static constexpr uint64_t hip_local_offset = 0; }; struct LocalMemoryUpdateTest : LocalMemoryUpdateTestBase { @@ -127,7 +149,9 @@ struct LocalMemoryUpdateTest : LocalMemoryUpdateTestBase { UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(LocalMemoryUpdateTest); -TEST_P(LocalMemoryUpdateTest, UpdateParameters) { +// Test updating A,X,Y parameters to new values and local memory parameters +// to original values. +TEST_P(LocalMemoryUpdateTest, UpdateParametersSameLocalSize) { // Run command-buffer prior to update an verify output ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); @@ -139,63 +163,218 @@ TEST_P(LocalMemoryUpdateTest, UpdateParameters) { Validate(output, X, Y, A, global_size, local_size); // Update inputs - ur_exp_command_buffer_update_pointer_arg_desc_t new_input_descs[2]; - ur_exp_command_buffer_update_value_arg_desc_t new_value_descs[2]; + std::array + new_input_descs; + std::array + new_value_descs; - // New local_mem at index 0 + // New local_mem_a at index 0 new_value_descs[0] = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype nullptr, // pNext 0, // argIndex - local_mem_size, // argSize + local_mem_a_size, // argSize + nullptr, // pProperties + nullptr, // hArgValue + }; + + // New local_mem_b at index 1 + new_value_descs[1] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 1 + hip_arg_offset, // argIndex + local_mem_b_size, // argSize nullptr, // pProperties nullptr, // hArgValue }; - // New A at index 2 + // New A at index 3 uint32_t new_A = 33; + new_value_descs[2] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 3 + (2 * hip_arg_offset), // argIndex + sizeof(new_A), // argSize + nullptr, // pProperties + &new_A, // hArgValue + }; + + // New X at index 4 + new_input_descs[0] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype + nullptr, // pNext + 4 + (2 * hip_arg_offset), // argIndex + nullptr, // pProperties + &shared_ptrs[3], // pArgValue + }; + + // New Y at index 5 + new_input_descs[1] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype + nullptr, // pNext + 5 + (2 * hip_arg_offset), // argIndex + nullptr, // pProperties + &shared_ptrs[4], // pArgValue + }; + + // Update kernel inputs + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + kernel, // hNewKernel + 0, // numNewMemObjArgs + new_input_descs.size(), // numNewPointerArgs + new_value_descs.size(), // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + new_input_descs.data(), // pNewPointerArgList + new_value_descs.data(), // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize + }; + + // Update kernel and enqueue command-buffer again + ASSERT_SUCCESS( + urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + // Verify that update occurred correctly + uint32_t *new_output = (uint32_t *)shared_ptrs[0]; + uint32_t *new_X = (uint32_t *)shared_ptrs[3]; + uint32_t *new_Y = (uint32_t *)shared_ptrs[4]; + Validate(new_output, new_X, new_Y, new_A, global_size, local_size); +} + +// Test only passing local memory parameters to update with the original values. +TEST_P(LocalMemoryUpdateTest, UpdateLocalOnly) { + // Run command-buffer prior to update an verify output + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + uint32_t *output = (uint32_t *)shared_ptrs[0]; + uint32_t *X = (uint32_t *)shared_ptrs[1]; + uint32_t *Y = (uint32_t *)shared_ptrs[2]; + Validate(output, X, Y, A, global_size, local_size); + + // Update inputs + std::array + new_value_descs; + + // New local_mem_a at index 0 + new_value_descs[0] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 0, // argIndex + local_mem_a_size, // argSize + nullptr, // pProperties + nullptr, // hArgValue + }; + + // New local_mem_b at index 1 new_value_descs[1] = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype nullptr, // pNext - 2 + hip_arg_offset, // argIndex + 1 + hip_arg_offset, // argIndex + local_mem_b_size, // argSize + nullptr, // pProperties + nullptr, // hArgValue + }; + + // Update kernel inputs + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + kernel, // hNewKernel + 0, // numNewMemObjArgs + 0, // numNewPointerArgs + new_value_descs.size(), // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + nullptr, // pNewPointerArgList + new_value_descs.data(), // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize + }; + + // Update kernel and enqueue command-buffer again + ASSERT_SUCCESS( + urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + // Verify that update occurred correctly + Validate(output, X, Y, A, global_size, local_size); +} + +// Test updating A,X,Y parameters to new values and omitting local memory parameters +// from the update. +TEST_P(LocalMemoryUpdateTest, UpdateParametersEmptyLocalSize) { + // Run command-buffer prior to update and verify output + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + uint32_t *output = (uint32_t *)shared_ptrs[0]; + uint32_t *X = (uint32_t *)shared_ptrs[1]; + uint32_t *Y = (uint32_t *)shared_ptrs[2]; + Validate(output, X, Y, A, global_size, local_size); + + // Update inputs + std::array + new_input_descs; + std::array + new_value_descs; + + // New A at index 3 + uint32_t new_A = 33; + new_value_descs[0] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 3 + (2 * hip_arg_offset), // argIndex sizeof(new_A), // argSize nullptr, // pProperties &new_A, // hArgValue }; - // New X at index 3 + // New X at index 4 new_input_descs[0] = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype nullptr, // pNext - 3 + hip_arg_offset, // argIndex - nullptr, // pProperties - &shared_ptrs[3], // pArgValue + 4 + (2 * hip_arg_offset), // argIndex + nullptr, // pProperties + &shared_ptrs[3], // pArgValue }; - // New Y at index 4 + // New Y at index 5 new_input_descs[1] = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype nullptr, // pNext - 4 + hip_arg_offset, // argIndex - nullptr, // pProperties - &shared_ptrs[4], // pArgValue + 5 + (2 * hip_arg_offset), // argIndex + nullptr, // pProperties + &shared_ptrs[4], // pArgValue }; // Update kernel inputs ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext - kernel, // hNewKernel - 0, // numNewMemObjArgs - 2, // numNewPointerArgs - 2, // numNewValueArgs - n_dimensions, // newWorkDim - nullptr, // pNewMemObjArgList - new_input_descs, // pNewPointerArgList - new_value_descs, // pNewValueArgList - nullptr, // pNewGlobalWorkOffset - nullptr, // pNewGlobalWorkSize - nullptr, // pNewLocalWorkSize + kernel, // hNewKernel + 0, // numNewMemObjArgs + new_input_descs.size(), // numNewPointerArgs + new_value_descs.size(), // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + new_input_descs.data(), // pNewPointerArgList + new_value_descs.data(), // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize }; // Update kernel and enqueue command-buffer again @@ -212,7 +391,9 @@ TEST_P(LocalMemoryUpdateTest, UpdateParameters) { Validate(new_output, new_X, new_Y, new_A, global_size, local_size); } -TEST_P(LocalMemoryUpdateTest, UpdateParametersAndLocalSize) { +// Test updating A,X,Y parameters to new values and local memory parameters +// to new smaller values. +TEST_P(LocalMemoryUpdateTest, UpdateParametersSmallerLocalSize) { // Run command-buffer prior to update an verify output ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); @@ -228,14 +409,14 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersAndLocalSize) { std::vector new_value_descs{}; - size_t new_local_size = local_size * 2; - size_t new_local_mem_size = new_local_size * sizeof(uint32_t); - // New local_mem at index 0 + size_t new_local_size = 2; + size_t new_local_mem_a_size = new_local_size * sizeof(uint32_t); + // New local_mem_a at index 0 new_value_descs.push_back({ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype nullptr, // pNext 0, // argIndex - new_local_mem_size, // argSize + new_local_mem_a_size, // argSize nullptr, // pProperties nullptr, // hArgValue }); @@ -244,56 +425,94 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersAndLocalSize) { new_value_descs.push_back({ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype nullptr, // pNext - 1, // argIndex - sizeof(new_local_size), // argSize - nullptr, // pProperties - &new_local_size, // hArgValue + 1, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue }); new_value_descs.push_back({ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype nullptr, // pNext - 2, // argIndex - sizeof(new_local_size), // argSize - nullptr, // pProperties - &new_local_size, // hArgValue + 2, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue }); new_value_descs.push_back({ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype nullptr, // pNext - 3, // argIndex - sizeof(new_local_size), // argSize - nullptr, // pProperties - &new_local_size, // hArgValue + 3, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue }); } - // New A at index 2 + // New local_mem_b at index 1 + size_t new_local_mem_b_size = new_local_size * sizeof(uint32_t) * 2; + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 1 + hip_arg_offset, // argIndex + new_local_mem_b_size, // argSize + nullptr, // pProperties + nullptr, // hArgValue + }); + + if (backend == UR_PLATFORM_BACKEND_HIP) { + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 5, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue + }); + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 6, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue + }); + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 7, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue + }); + } + + // New A at index 3 uint32_t new_A = 33; new_value_descs.push_back({ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype nullptr, // pNext - 2 + hip_arg_offset, // argIndex + 3 + (2 * hip_arg_offset), // argIndex sizeof(new_A), // argSize nullptr, // pProperties &new_A, // hArgValue }); - // New X at index 3 + // New X at index 4 new_input_descs[0] = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype nullptr, // pNext - 3 + hip_arg_offset, // argIndex - nullptr, // pProperties - &shared_ptrs[3], // pArgValue + 4 + (2 * hip_arg_offset), // argIndex + nullptr, // pProperties + &shared_ptrs[3], // pArgValue }; - // New Y at index 4 + // New Y at index 5 new_input_descs[1] = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype nullptr, // pNext - 4 + hip_arg_offset, // argIndex - nullptr, // pProperties - &shared_ptrs[4], // pArgValue + 5 + (2 * hip_arg_offset), // argIndex + nullptr, // pProperties + &shared_ptrs[4], // pArgValue }; // Update kernel inputs @@ -327,16 +546,345 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersAndLocalSize) { Validate(new_output, new_X, new_Y, new_A, global_size, new_local_size); } +// Test updating A,X,Y parameters to new values and local memory parameters +// to new larger values. +TEST_P(LocalMemoryUpdateTest, UpdateParametersLargerLocalSize) { + // Run command-buffer prior to update and verify output + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + uint32_t *output = (uint32_t *)shared_ptrs[0]; + uint32_t *X = (uint32_t *)shared_ptrs[1]; + uint32_t *Y = (uint32_t *)shared_ptrs[2]; + Validate(output, X, Y, A, global_size, local_size); + + // Update inputs + ur_exp_command_buffer_update_pointer_arg_desc_t new_input_descs[2]; + std::vector + new_value_descs{}; + + size_t new_local_size = local_size * 4; + size_t new_local_mem_a_size = new_local_size * sizeof(uint32_t); + // New local_mem_a at index 0 + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 0, // argIndex + new_local_mem_a_size, // argSize + nullptr, // pProperties + nullptr, // hArgValue + }); + + if (backend == UR_PLATFORM_BACKEND_HIP) { + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 1, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue + }); + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 2, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue + }); + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 3, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue + }); + } + + // New local_mem_b at index 1 + size_t new_local_mem_b_size = new_local_size * sizeof(uint32_t) * 2; + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 1 + hip_arg_offset, // argIndex + new_local_mem_b_size, // argSize + nullptr, // pProperties + nullptr, // hArgValue + }); + + if (backend == UR_PLATFORM_BACKEND_HIP) { + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 5, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue + }); + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 6, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue + }); + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 7, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue + }); + } + + // New A at index 3 + uint32_t new_A = 33; + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 3 + (2 * hip_arg_offset), // argIndex + sizeof(new_A), // argSize + nullptr, // pProperties + &new_A, // hArgValue + }); + + // New X at index 4 + new_input_descs[0] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype + nullptr, // pNext + 4 + (2 * hip_arg_offset), // argIndex + nullptr, // pProperties + &shared_ptrs[3], // pArgValue + }; + + // New Y at index 5 + new_input_descs[1] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype + nullptr, // pNext + 5 + (2 * hip_arg_offset), // argIndex + nullptr, // pProperties + &shared_ptrs[4], // pArgValue + }; + + // Update kernel inputs + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + kernel, // hNewKernel + 0, // numNewMemObjArgs + 2, // numNewPointerArgs + static_cast(new_value_descs.size()), // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + new_input_descs, // pNewPointerArgList + new_value_descs.data(), // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + &new_local_size, // pNewLocalWorkSize + }; + + // Update kernel and enqueue command-buffer again + ASSERT_SUCCESS( + urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + // Verify that update occurred correctly + uint32_t *new_output = (uint32_t *)shared_ptrs[0]; + uint32_t *new_X = (uint32_t *)shared_ptrs[3]; + uint32_t *new_Y = (uint32_t *)shared_ptrs[4]; + Validate(new_output, new_X, new_Y, new_A, global_size, new_local_size); +} + +// Test updating A,X,Y parameters to new values and only one of the local memory +// parameters, which is set to a new value. Then a separate update call for +// the other local memory argument. +TEST_P(LocalMemoryUpdateTest, UpdateParametersPartialLocalSize) { + // Run command-buffer prior to update and verify output + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + uint32_t *output = (uint32_t *)shared_ptrs[0]; + uint32_t *X = (uint32_t *)shared_ptrs[1]; + uint32_t *Y = (uint32_t *)shared_ptrs[2]; + Validate(output, X, Y, A, global_size, local_size); + + // Update inputs + ur_exp_command_buffer_update_pointer_arg_desc_t new_input_descs[2]; + std::vector + new_value_descs{}; + + size_t new_local_size = local_size * 4; + size_t new_local_mem_a_size = new_local_size * sizeof(uint32_t); + // New local_mem_a at index 0 + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 0, // argIndex + new_local_mem_a_size, // argSize + nullptr, // pProperties + nullptr, // hArgValue + }); + + if (backend == UR_PLATFORM_BACKEND_HIP) { + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 1, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue + }); + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 2, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue + }); + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 3, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue + }); + } + + // New A at index 3 + uint32_t new_A = 33; + new_value_descs.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 3 + (2 * hip_arg_offset), // argIndex + sizeof(new_A), // argSize + nullptr, // pProperties + &new_A, // hArgValue + }); + + // New X at index 4 + new_input_descs[0] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype + nullptr, // pNext + 4 + (2 * hip_arg_offset), // argIndex + nullptr, // pProperties + &shared_ptrs[3], // pArgValue + }; + + // New Y at index 5 + new_input_descs[1] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype + nullptr, // pNext + 5 + (2 * hip_arg_offset), // argIndex + nullptr, // pProperties + &shared_ptrs[4], // pArgValue + }; + + // Update kernel inputs + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + kernel, // hNewKernel + 0, // numNewMemObjArgs + 2, // numNewPointerArgs + static_cast(new_value_descs.size()), // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + new_input_descs, // pNewPointerArgList + new_value_descs.data(), // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + &new_local_size, // pNewLocalWorkSize + }; + + // Update kernel and enqueue command-buffer again + ASSERT_SUCCESS( + urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + + std::vector + second_update_value_args{}; + + size_t new_local_mem_b_size = new_local_size * sizeof(uint32_t) * 2; + // New local_mem_b at index 1 + second_update_value_args.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 1 + hip_arg_offset, // argIndex + new_local_mem_b_size, // argSize + nullptr, // pProperties + nullptr, // hArgValue + }); + + if (backend == UR_PLATFORM_BACKEND_HIP) { + second_update_value_args.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 5, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue + }); + second_update_value_args.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 6, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue + }); + second_update_value_args.push_back({ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 7, // argIndex + sizeof(hip_local_offset), // argSize + nullptr, // pProperties + &hip_local_offset, // hArgValue + }); + } + + ur_exp_command_buffer_update_kernel_launch_desc_t second_update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + kernel, // hNewKernel + 0, // numNewMemObjArgs + 0, // numNewPointerArgs + static_cast( + second_update_value_args.size()), // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + nullptr, // pNewPointerArgList + second_update_value_args.data(), // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize + }; + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(command_handle, + &second_update_desc)); + + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + // Verify that update occurred correctly + uint32_t *new_output = (uint32_t *)shared_ptrs[0]; + uint32_t *new_X = (uint32_t *)shared_ptrs[3]; + uint32_t *new_Y = (uint32_t *)shared_ptrs[4]; + Validate(new_output, new_X, new_Y, new_A, global_size, new_local_size); +} + struct LocalMemoryMultiUpdateTest : LocalMemoryUpdateTestBase { void SetUp() override { UUR_RETURN_ON_FATAL_FAILURE(LocalMemoryUpdateTestBase::SetUp()); - // Append kernel command to command-buffer and close command-buffer for (unsigned node = 0; node < nodes; node++) { - // We need to set the local memory arg each time because it is - // cleared in the kernel handle after being used. - ASSERT_SUCCESS( - urKernelSetArgLocal(kernel, 0, local_mem_size, nullptr)); ASSERT_SUCCESS(urCommandBufferAppendKernelLaunchExp( updatable_cmd_buf_handle, kernel, n_dimensions, &global_offset, &global_size, &local_size, 0, nullptr, 0, nullptr, 0, nullptr, @@ -363,6 +911,8 @@ struct LocalMemoryMultiUpdateTest : LocalMemoryUpdateTestBase { UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(LocalMemoryMultiUpdateTest); +// Test updating A,X,Y parameters to new values and local memory parameters +// to original values. TEST_P(LocalMemoryMultiUpdateTest, UpdateParameters) { // Run command-buffer prior to update an verify output ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, @@ -375,63 +925,75 @@ TEST_P(LocalMemoryMultiUpdateTest, UpdateParameters) { Validate(output, X, Y, A, global_size, local_size); // Update inputs - ur_exp_command_buffer_update_pointer_arg_desc_t new_input_descs[2]; - ur_exp_command_buffer_update_value_arg_desc_t new_value_descs[2]; + std::array + new_input_descs; + std::array + new_value_descs; - // New local_mem at index 0 + // New local_mem_a at index 0 new_value_descs[0] = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype nullptr, // pNext 0, // argIndex - local_mem_size, // argSize + local_mem_a_size, // argSize nullptr, // pProperties nullptr, // hArgValue }; - // New A at index 2 - uint32_t new_A = 33; + // New local_mem_b at index 1 new_value_descs[1] = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype nullptr, // pNext - 2 + hip_arg_offset, // argIndex + 1 + hip_arg_offset, // argIndex + local_mem_b_size, // argSize + nullptr, // pProperties + nullptr, // hArgValue + }; + + // New A at index 3 + uint32_t new_A = 33; + new_value_descs[2] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 3 + (2 * hip_arg_offset), // argIndex sizeof(new_A), // argSize nullptr, // pProperties &new_A, // hArgValue }; - // New X at index 3 + // New X at index 4 new_input_descs[0] = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype nullptr, // pNext - 3 + hip_arg_offset, // argIndex - nullptr, // pProperties - &shared_ptrs[3], // pArgValue + 4 + (2 * hip_arg_offset), // argIndex + nullptr, // pProperties + &shared_ptrs[3], // pArgValue }; - // New Y at index 4 + // New Y at index 5 new_input_descs[1] = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype nullptr, // pNext - 4 + hip_arg_offset, // argIndex - nullptr, // pProperties - &shared_ptrs[4], // pArgValue + 5 + (2 * hip_arg_offset), // argIndex + nullptr, // pProperties + &shared_ptrs[4], // pArgValue }; // Update kernel inputs ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext - kernel, // hNewKernel - 0, // numNewMemObjArgs - 2, // numNewPointerArgs - 2, // numNewValueArgs - n_dimensions, // newWorkDim - nullptr, // pNewMemObjArgList - new_input_descs, // pNewPointerArgList - new_value_descs, // pNewValueArgList - nullptr, // pNewGlobalWorkOffset - nullptr, // pNewGlobalWorkSize - nullptr, // pNewLocalWorkSize + kernel, // hNewKernel + 0, // numNewMemObjArgs + new_input_descs.size(), // numNewPointerArgs + new_value_descs.size(), // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + new_input_descs.data(), // pNewPointerArgList + new_value_descs.data(), // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize }; // Update kernel and enqueue command-buffer again @@ -450,65 +1012,79 @@ TEST_P(LocalMemoryMultiUpdateTest, UpdateParameters) { Validate(new_output, new_X, new_Y, new_A, global_size, local_size); } +// Test updating A,X,Y parameters to new values and local memory parameters +// to original values, but without doing a blocking wait. TEST_P(LocalMemoryMultiUpdateTest, UpdateWithoutBlocking) { // Update inputs - ur_exp_command_buffer_update_pointer_arg_desc_t new_input_descs[2]; - ur_exp_command_buffer_update_value_arg_desc_t new_value_descs[2]; + std::array + new_input_descs; + std::array + new_value_descs; - // New local_mem at index 0 + // New local_mem_a at index 0 new_value_descs[0] = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype nullptr, // pNext 0, // argIndex - local_mem_size, // argSize + local_mem_a_size, // argSize nullptr, // pProperties nullptr, // hArgValue }; - // New A at index 2 - uint32_t new_A = 33; + // New local_mem_a at index 1 new_value_descs[1] = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype nullptr, // pNext - 2 + hip_arg_offset, // argIndex + 1 + hip_arg_offset, // argIndex + local_mem_b_size, // argSize + nullptr, // pProperties + nullptr, // hArgValue + }; + + // New A at index 3 + uint32_t new_A = 33; + new_value_descs[2] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 3 + (2 * hip_arg_offset), // argIndex sizeof(new_A), // argSize nullptr, // pProperties &new_A, // hArgValue }; - // New X at index 3 + // New X at index 4 new_input_descs[0] = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype nullptr, // pNext - 3 + hip_arg_offset, // argIndex - nullptr, // pProperties - &shared_ptrs[3], // pArgValue + 4 + (2 * hip_arg_offset), // argIndex + nullptr, // pProperties + &shared_ptrs[3], // pArgValue }; - // New Y at index 4 + // New Y at index 5 new_input_descs[1] = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype nullptr, // pNext - 4 + hip_arg_offset, // argIndex - nullptr, // pProperties - &shared_ptrs[4], // pArgValue + 5 + (2 * hip_arg_offset), // argIndex + nullptr, // pProperties + &shared_ptrs[4], // pArgValue }; // Update kernel inputs ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext - kernel, // hNewKernel - 0, // numNewMemObjArgs - 2, // numNewPointerArgs - 2, // numNewValueArgs - n_dimensions, // newWorkDim - nullptr, // pNewMemObjArgList - new_input_descs, // pNewPointerArgList - new_value_descs, // pNewValueArgList - nullptr, // pNewGlobalWorkOffset - nullptr, // pNewGlobalWorkSize - nullptr, // pNewLocalWorkSize + kernel, // hNewKernel + 0, // numNewMemObjArgs + new_input_descs.size(), // numNewPointerArgs + new_value_descs.size(), // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + new_input_descs.data(), // pNewPointerArgList + new_value_descs.data(), // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize }; // Enqueue without calling urQueueFinish after ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, diff --git a/test/conformance/kernel/kernel_adapter_native_cpu.match b/test/conformance/kernel/kernel_adapter_native_cpu.match index 7ca10ec3d2..bd5333c609 100644 --- a/test/conformance/kernel/kernel_adapter_native_cpu.match +++ b/test/conformance/kernel/kernel_adapter_native_cpu.match @@ -38,6 +38,9 @@ urKernelRetainTest.InvalidNullHandleKernel/* urKernelSetArgLocalTest.Success/* urKernelSetArgLocalTest.InvalidNullHandleKernel/* urKernelSetArgLocalTest.InvalidKernelArgumentIndex/* +urKernelSetArgLocalMultiTest.Basic/* +urKernelSetArgLocalMultiTest.ReLaunch/* +urKernelSetArgLocalMultiTest.Overwrite/* urKernelSetArgMemObjTest.Success/* urKernelSetArgMemObjTest.InvalidNullHandleKernel/* urKernelSetArgMemObjTest.InvalidKernelArgumentIndex/* diff --git a/test/conformance/kernel/urKernelSetArgLocal.cpp b/test/conformance/kernel/urKernelSetArgLocal.cpp index 1d3789bf3a..380085bd16 100644 --- a/test/conformance/kernel/urKernelSetArgLocal.cpp +++ b/test/conformance/kernel/urKernelSetArgLocal.cpp @@ -3,6 +3,7 @@ // See LICENSE.TXT // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include #include struct urKernelSetArgLocalTest : uur::urKernelTest { @@ -32,3 +33,203 @@ TEST_P(urKernelSetArgLocalTest, InvalidKernelArgumentIndex) { urKernelSetArgLocal(kernel, num_kernel_args + 1, local_mem_size, nullptr)); } + +// Test launching kernels with multiple local arguments return the expected +// outputs +struct urKernelSetArgLocalMultiTest : uur::urKernelExecutionTest { + void SetUp() override { + program_name = "saxpy_usm_local_mem"; + UUR_RETURN_ON_FATAL_FAILURE(urKernelExecutionTest::SetUp()); + + ASSERT_SUCCESS(urPlatformGetInfo(platform, UR_PLATFORM_INFO_BACKEND, + sizeof(backend), &backend, nullptr)); + + // HIP has extra args for local memory so we define an offset for arg indices here for updating + hip_arg_offset = backend == UR_PLATFORM_BACKEND_HIP ? 3 : 0; + ur_device_usm_access_capability_flags_t shared_usm_flags; + ASSERT_SUCCESS( + uur::GetDeviceUSMSingleSharedSupport(device, shared_usm_flags)); + if (!(shared_usm_flags & UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS)) { + GTEST_SKIP() << "Shared USM is not supported."; + } + + const size_t allocation_size = + sizeof(uint32_t) * global_size * local_size; + for (auto &shared_ptr : shared_ptrs) { + ASSERT_SUCCESS(urUSMSharedAlloc(context, device, nullptr, nullptr, + allocation_size, &shared_ptr)); + ASSERT_NE(shared_ptr, nullptr); + + std::vector pattern(allocation_size); + uur::generateMemFillPattern(pattern); + std::memcpy(shared_ptr, pattern.data(), allocation_size); + } + size_t current_index = 0; + // Index 0 is local_mem_a arg + ASSERT_SUCCESS(urKernelSetArgLocal(kernel, current_index++, + local_mem_a_size, nullptr)); + + // Hip has extra args for local mem at index 1-3 + if (backend == UR_PLATFORM_BACKEND_HIP) { + ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, + sizeof(hip_local_offset), + nullptr, &hip_local_offset)); + ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, + sizeof(hip_local_offset), + nullptr, &hip_local_offset)); + ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, + sizeof(hip_local_offset), + nullptr, &hip_local_offset)); + } + + // Index 1 is local_mem_b arg + ASSERT_SUCCESS(urKernelSetArgLocal(kernel, current_index++, + local_mem_b_size, nullptr)); + if (backend == UR_PLATFORM_BACKEND_HIP) { + ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, + sizeof(hip_local_offset), + nullptr, &hip_local_offset)); + ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, + sizeof(hip_local_offset), + nullptr, &hip_local_offset)); + ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, + sizeof(hip_local_offset), + nullptr, &hip_local_offset)); + } + + // Index 2 is output + ASSERT_SUCCESS(urKernelSetArgPointer(kernel, current_index++, nullptr, + shared_ptrs[0])); + // Index 3 is A + ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, sizeof(A), + nullptr, &A)); + // Index 4 is X + ASSERT_SUCCESS(urKernelSetArgPointer(kernel, current_index++, nullptr, + shared_ptrs[1])); + // Index 5 is Y + ASSERT_SUCCESS(urKernelSetArgPointer(kernel, current_index++, nullptr, + shared_ptrs[2])); + } + + void Validate(uint32_t *output, uint32_t *X, uint32_t *Y, uint32_t A, + size_t length, size_t local_size) { + for (size_t i = 0; i < length; i++) { + uint32_t result = A * X[i] + Y[i] + local_size; + ASSERT_EQ(result, output[i]); + } + } + + virtual void TearDown() override { + for (auto &shared_ptr : shared_ptrs) { + if (shared_ptr) { + EXPECT_SUCCESS(urUSMFree(context, shared_ptr)); + } + } + + UUR_RETURN_ON_FATAL_FAILURE(urKernelExecutionTest::TearDown()); + } + + static constexpr size_t local_size = 4; + static constexpr size_t local_mem_a_size = local_size * sizeof(uint32_t); + static constexpr size_t local_mem_b_size = local_mem_a_size * 2; + static constexpr size_t global_size = 16; + static constexpr size_t global_offset = 0; + static constexpr size_t n_dimensions = 1; + static constexpr uint32_t A = 42; + std::array shared_ptrs = {nullptr, nullptr, nullptr, nullptr, + nullptr}; + + uint32_t hip_arg_offset = 0; + static constexpr uint64_t hip_local_offset = 0; + ur_platform_backend_t backend{}; +}; +UUR_INSTANTIATE_KERNEL_TEST_SUITE_P(urKernelSetArgLocalMultiTest); + +TEST_P(urKernelSetArgLocalMultiTest, Basic) { + ASSERT_SUCCESS(urEnqueueKernelLaunch(queue, kernel, n_dimensions, + &global_offset, &global_size, + &local_size, 0, nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + uint32_t *output = (uint32_t *)shared_ptrs[0]; + uint32_t *X = (uint32_t *)shared_ptrs[1]; + uint32_t *Y = (uint32_t *)shared_ptrs[2]; + Validate(output, X, Y, A, global_size, local_size); +} + +TEST_P(urKernelSetArgLocalMultiTest, ReLaunch) { + ASSERT_SUCCESS(urEnqueueKernelLaunch(queue, kernel, n_dimensions, + &global_offset, &global_size, + &local_size, 0, nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + uint32_t *output = (uint32_t *)shared_ptrs[0]; + uint32_t *X = (uint32_t *)shared_ptrs[1]; + uint32_t *Y = (uint32_t *)shared_ptrs[2]; + Validate(output, X, Y, A, global_size, local_size); + + // Relaunch with new arguments + ASSERT_SUCCESS(urEnqueueKernelLaunch(queue, kernel, n_dimensions, + &global_offset, &global_size, + &local_size, 0, nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + uint32_t *new_output = (uint32_t *)shared_ptrs[0]; + uint32_t *new_X = (uint32_t *)shared_ptrs[3]; + uint32_t *new_Y = (uint32_t *)shared_ptrs[4]; + Validate(new_output, new_X, new_Y, A, global_size, local_size); +} + +// Overwrite local args to a larger value, then reset back to original +TEST_P(urKernelSetArgLocalMultiTest, Overwrite) { + ASSERT_SUCCESS(urEnqueueKernelLaunch(queue, kernel, n_dimensions, + &global_offset, &global_size, + &local_size, 0, nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + uint32_t *output = (uint32_t *)shared_ptrs[0]; + uint32_t *X = (uint32_t *)shared_ptrs[1]; + uint32_t *Y = (uint32_t *)shared_ptrs[2]; + Validate(output, X, Y, A, global_size, local_size); + + size_t new_local_size = 2; + size_t new_local_mem_a_size = new_local_size * sizeof(uint32_t); + size_t new_local_mem_b_size = new_local_size * sizeof(uint32_t) * 2; + size_t current_index = 0; + ASSERT_SUCCESS(urKernelSetArgLocal(kernel, current_index++, + new_local_mem_a_size, nullptr)); + + // Hip has extra args for local mem at index 1-3 + if (backend == UR_PLATFORM_BACKEND_HIP) { + ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, + sizeof(hip_local_offset), nullptr, + &hip_local_offset)); + ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, + sizeof(hip_local_offset), nullptr, + &hip_local_offset)); + ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, + sizeof(hip_local_offset), nullptr, + &hip_local_offset)); + } + + // Index 1 is local_mem_b arg + ASSERT_SUCCESS(urKernelSetArgLocal(kernel, current_index++, + new_local_mem_b_size, nullptr)); + if (backend == UR_PLATFORM_BACKEND_HIP) { + ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, + sizeof(hip_local_offset), nullptr, + &hip_local_offset)); + ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, + sizeof(hip_local_offset), nullptr, + &hip_local_offset)); + ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, + sizeof(hip_local_offset), nullptr, + &hip_local_offset)); + } + + ASSERT_SUCCESS(urEnqueueKernelLaunch(queue, kernel, n_dimensions, + &global_offset, &global_size, + &new_local_size, 0, nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + Validate(output, X, Y, A, global_size, new_local_size); +}