Skip to content

Commit 0500f4c

Browse files
committed
Add non command-buffer test
1 parent 762348c commit 0500f4c

File tree

1 file changed

+200
-0
lines changed

1 file changed

+200
-0
lines changed

test/conformance/kernel/urKernelSetArgLocal.cpp

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See LICENSE.TXT
44
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
55

6+
#include <cstring>
67
#include <uur/fixtures.h>
78

89
struct urKernelSetArgLocalTest : uur::urKernelTest {
@@ -32,3 +33,202 @@ TEST_P(urKernelSetArgLocalTest, InvalidKernelArgumentIndex) {
3233
urKernelSetArgLocal(kernel, num_kernel_args + 1,
3334
local_mem_size, nullptr));
3435
}
36+
37+
// Test launching kernels with multiple local arguments return the expected
38+
// outputs
39+
struct urKernelSetArgLocalMultiTest : uur::urKernelExecutionTest {
40+
void SetUp() override {
41+
program_name = "saxpy_usm_local_mem";
42+
UUR_RETURN_ON_FATAL_FAILURE(urKernelExecutionTest::SetUp());
43+
44+
ASSERT_SUCCESS(urPlatformGetInfo(platform, UR_PLATFORM_INFO_BACKEND,
45+
sizeof(backend), &backend, nullptr));
46+
47+
// HIP has extra args for local memory so we define an offset for arg indices here for updating
48+
hip_arg_offset = backend == UR_PLATFORM_BACKEND_HIP ? 3 : 0;
49+
ur_device_usm_access_capability_flags_t shared_usm_flags;
50+
ASSERT_SUCCESS(
51+
uur::GetDeviceUSMSingleSharedSupport(device, shared_usm_flags));
52+
if (!(shared_usm_flags & UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS)) {
53+
GTEST_SKIP() << "Shared USM is not supported.";
54+
}
55+
56+
const size_t allocation_size =
57+
sizeof(uint32_t) * global_size * local_size;
58+
for (auto &shared_ptr : shared_ptrs) {
59+
ASSERT_SUCCESS(urUSMSharedAlloc(context, device, nullptr, nullptr,
60+
allocation_size, &shared_ptr));
61+
ASSERT_NE(shared_ptr, nullptr);
62+
63+
std::vector<uint8_t> pattern(allocation_size);
64+
uur::generateMemFillPattern(pattern);
65+
std::memcpy(shared_ptr, pattern.data(), allocation_size);
66+
}
67+
size_t current_index = 0;
68+
// Index 0 is local_mem_a arg
69+
ASSERT_SUCCESS(urKernelSetArgLocal(kernel, current_index++,
70+
local_mem_a_size, nullptr));
71+
72+
// Hip has extra args for local mem at index 1-3
73+
if (backend == UR_PLATFORM_BACKEND_HIP) {
74+
ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++,
75+
sizeof(local_size), nullptr,
76+
&local_size));
77+
ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++,
78+
sizeof(local_size), nullptr,
79+
&local_size));
80+
ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++,
81+
sizeof(local_size), nullptr,
82+
&local_size));
83+
}
84+
85+
// Index 1 is local_mem_b arg
86+
ASSERT_SUCCESS(urKernelSetArgLocal(kernel, current_index++,
87+
local_mem_b_size, nullptr));
88+
if (backend == UR_PLATFORM_BACKEND_HIP) {
89+
ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++,
90+
sizeof(local_size), nullptr,
91+
&local_size));
92+
ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++,
93+
sizeof(local_size), nullptr,
94+
&local_size));
95+
ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++,
96+
sizeof(local_size), nullptr,
97+
&local_size));
98+
}
99+
100+
// Index 2 is output
101+
ASSERT_SUCCESS(urKernelSetArgPointer(kernel, current_index++, nullptr,
102+
shared_ptrs[0]));
103+
// Index 3 is A
104+
ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++, sizeof(A),
105+
nullptr, &A));
106+
// Index 4 is X
107+
ASSERT_SUCCESS(urKernelSetArgPointer(kernel, current_index++, nullptr,
108+
shared_ptrs[1]));
109+
// Index 5 is Y
110+
ASSERT_SUCCESS(urKernelSetArgPointer(kernel, current_index++, nullptr,
111+
shared_ptrs[2]));
112+
}
113+
114+
void Validate(uint32_t *output, uint32_t *X, uint32_t *Y, uint32_t A,
115+
size_t length, size_t local_size) {
116+
for (size_t i = 0; i < length; i++) {
117+
uint32_t result = A * X[i] + Y[i] + local_size;
118+
ASSERT_EQ(result, output[i]);
119+
}
120+
}
121+
122+
virtual void TearDown() override {
123+
for (auto &shared_ptr : shared_ptrs) {
124+
if (shared_ptr) {
125+
EXPECT_SUCCESS(urUSMFree(context, shared_ptr));
126+
}
127+
}
128+
129+
UUR_RETURN_ON_FATAL_FAILURE(urKernelExecutionTest::TearDown());
130+
}
131+
132+
static constexpr size_t local_size = 4;
133+
static constexpr size_t local_mem_a_size = local_size * sizeof(uint32_t);
134+
static constexpr size_t local_mem_b_size = local_mem_a_size * 2;
135+
static constexpr size_t global_size = 16;
136+
static constexpr size_t global_offset = 0;
137+
static constexpr size_t n_dimensions = 1;
138+
static constexpr uint32_t A = 42;
139+
std::array<void *, 5> shared_ptrs = {nullptr, nullptr, nullptr, nullptr,
140+
nullptr};
141+
142+
uint32_t hip_arg_offset = 0;
143+
ur_platform_backend_t backend{};
144+
};
145+
UUR_INSTANTIATE_KERNEL_TEST_SUITE_P(urKernelSetArgLocalMultiTest);
146+
147+
TEST_P(urKernelSetArgLocalMultiTest, Basic) {
148+
ASSERT_SUCCESS(urEnqueueKernelLaunch(queue, kernel, n_dimensions,
149+
&global_offset, &global_size,
150+
&local_size, 0, nullptr, nullptr));
151+
ASSERT_SUCCESS(urQueueFinish(queue));
152+
153+
uint32_t *output = (uint32_t *)shared_ptrs[0];
154+
uint32_t *X = (uint32_t *)shared_ptrs[1];
155+
uint32_t *Y = (uint32_t *)shared_ptrs[2];
156+
Validate(output, X, Y, A, global_size, local_size);
157+
}
158+
159+
TEST_P(urKernelSetArgLocalMultiTest, ReLaunch) {
160+
ASSERT_SUCCESS(urEnqueueKernelLaunch(queue, kernel, n_dimensions,
161+
&global_offset, &global_size,
162+
&local_size, 0, nullptr, nullptr));
163+
ASSERT_SUCCESS(urQueueFinish(queue));
164+
165+
uint32_t *output = (uint32_t *)shared_ptrs[0];
166+
uint32_t *X = (uint32_t *)shared_ptrs[1];
167+
uint32_t *Y = (uint32_t *)shared_ptrs[2];
168+
Validate(output, X, Y, A, global_size, local_size);
169+
170+
// Relaunch with new arguments
171+
ASSERT_SUCCESS(urEnqueueKernelLaunch(queue, kernel, n_dimensions,
172+
&global_offset, &global_size,
173+
&local_size, 0, nullptr, nullptr));
174+
ASSERT_SUCCESS(urQueueFinish(queue));
175+
uint32_t *new_output = (uint32_t *)shared_ptrs[0];
176+
uint32_t *new_X = (uint32_t *)shared_ptrs[3];
177+
uint32_t *new_Y = (uint32_t *)shared_ptrs[4];
178+
Validate(new_output, new_X, new_Y, A, global_size, local_size);
179+
}
180+
181+
// Overwrite local args to a larger value, then reset back to original
182+
TEST_P(urKernelSetArgLocalMultiTest, Overwrite) {
183+
ASSERT_SUCCESS(urEnqueueKernelLaunch(queue, kernel, n_dimensions,
184+
&global_offset, &global_size,
185+
&local_size, 0, nullptr, nullptr));
186+
ASSERT_SUCCESS(urQueueFinish(queue));
187+
188+
uint32_t *output = (uint32_t *)shared_ptrs[0];
189+
uint32_t *X = (uint32_t *)shared_ptrs[1];
190+
uint32_t *Y = (uint32_t *)shared_ptrs[2];
191+
Validate(output, X, Y, A, global_size, local_size);
192+
193+
size_t new_local_size = 2;
194+
size_t new_local_mem_a_size = new_local_size * sizeof(uint32_t);
195+
size_t new_local_mem_b_size = new_local_size * sizeof(uint32_t) * 2;
196+
size_t current_index = 0;
197+
ASSERT_SUCCESS(urKernelSetArgLocal(kernel, current_index++,
198+
new_local_mem_a_size, nullptr));
199+
200+
// Hip has extra args for local mem at index 1-3
201+
if (backend == UR_PLATFORM_BACKEND_HIP) {
202+
ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++,
203+
sizeof(new_local_size), nullptr,
204+
&new_local_size));
205+
ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++,
206+
sizeof(new_local_size), nullptr,
207+
&new_local_size));
208+
ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++,
209+
sizeof(new_local_size), nullptr,
210+
&new_local_size));
211+
}
212+
213+
// Index 1 is local_mem_b arg
214+
ASSERT_SUCCESS(urKernelSetArgLocal(kernel, current_index++,
215+
new_local_mem_b_size, nullptr));
216+
if (backend == UR_PLATFORM_BACKEND_HIP) {
217+
ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++,
218+
sizeof(new_local_size), nullptr,
219+
&new_local_size));
220+
ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++,
221+
sizeof(new_local_size), nullptr,
222+
&new_local_size));
223+
ASSERT_SUCCESS(urKernelSetArgValue(kernel, current_index++,
224+
sizeof(new_local_size), nullptr,
225+
&new_local_size));
226+
}
227+
228+
ASSERT_SUCCESS(urEnqueueKernelLaunch(queue, kernel, n_dimensions,
229+
&global_offset, &global_size,
230+
&new_local_size, 0, nullptr, nullptr));
231+
ASSERT_SUCCESS(urQueueFinish(queue));
232+
233+
Validate(output, X, Y, A, global_size, new_local_size);
234+
}

0 commit comments

Comments
 (0)