Skip to content

Commit 62095a5

Browse files
committed
[SYCL][UR][L0 v2] get rid of std::function in memory.hpp
There migration logic is always the same, there's no need to pass callback to every getDevicePtr/map/unmap function.
1 parent d8a66b8 commit 62095a5

File tree

8 files changed

+144
-162
lines changed

8 files changed

+144
-162
lines changed

unified-runtime/source/adapters/level_zero/v2/command_buffer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ ur_result_t getMemPtr(ur_mem_handle_t memObj,
3939
urAccessMode =
4040
ur_mem_buffer_t::getDeviceAccessMode(properties->memoryAccess);
4141
}
42-
ptr = ur_cast<char *>(
43-
memBuffer->getDevicePtr(device, urAccessMode, 0, memBuffer->getSize(),
44-
[&](void *, void *, size_t) {}));
42+
wait_list_view emptyWaitList(nullptr, 0);
43+
ptr = ur_cast<char *>(memBuffer->getDevicePtr(
44+
device, urAccessMode, 0, memBuffer->getSize(), nullptr, emptyWaitList));
4545
}
4646
assert(ptrStorage != nullptr);
4747
ptrStorage->push_back(std::make_unique<char *>(ptr));

unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "../ur_interface_loader.hpp"
1515
#include "context.hpp"
1616
#include "kernel.hpp"
17+
#include "memory.hpp"
1718

1819
ur_command_list_manager::ur_command_list_manager(
1920
ur_context_handle_t context, ur_device_handle_t device,
@@ -44,12 +45,7 @@ ur_result_t ur_command_list_manager::appendGenericFillUnlocked(
4445

4546
auto pDst = ur_cast<char *>(dst->getDevicePtr(
4647
device, ur_mem_buffer_t::device_access_mode_t::read_only, offset, size,
47-
[&](void *src, void *dst, size_t size) {
48-
ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy,
49-
(zeCommandList.get(), dst, src, size, nullptr,
50-
waitListView.num, waitListView.handles));
51-
waitListView.clear();
52-
}));
48+
zeCommandList.get(), waitListView));
5349

5450
// PatternSize must be a power of two for zeCommandListAppendMemoryFill.
5551
// When it's not, the fill is emulated with zeCommandListAppendMemoryCopy.
@@ -87,21 +83,11 @@ ur_result_t ur_command_list_manager::appendGenericCopyUnlocked(
8783

8884
auto pSrc = ur_cast<char *>(src->getDevicePtr(
8985
device, ur_mem_buffer_t::device_access_mode_t::read_only, srcOffset, size,
90-
[&](void *src, void *dst, size_t size) {
91-
ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy,
92-
(zeCommandList.get(), dst, src, size, nullptr,
93-
waitListView.num, waitListView.handles));
94-
waitListView.clear();
95-
}));
86+
zeCommandList.get(), waitListView));
9687

9788
auto pDst = ur_cast<char *>(dst->getDevicePtr(
9889
device, ur_mem_buffer_t::device_access_mode_t::write_only, dstOffset,
99-
size, [&](void *src, void *dst, size_t size) {
100-
ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy,
101-
(zeCommandList.get(), dst, src, size, nullptr,
102-
waitListView.num, waitListView.handles));
103-
waitListView.clear();
104-
}));
90+
size, zeCommandList.get(), waitListView));
10591

10692
ZE2UR_CALL(zeCommandListAppendMemoryCopy,
10793
(zeCommandList.get(), pDst, pSrc, size, zeSignalEvent,
@@ -130,20 +116,10 @@ ur_result_t ur_command_list_manager::appendRegionCopyUnlocked(
130116

131117
auto pSrc = ur_cast<char *>(src->getDevicePtr(
132118
device, ur_mem_buffer_t::device_access_mode_t::read_only, 0,
133-
src->getSize(), [&](void *src, void *dst, size_t size) {
134-
ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy,
135-
(zeCommandList.get(), dst, src, size, nullptr,
136-
waitListView.num, waitListView.handles));
137-
waitListView.clear();
138-
}));
119+
src->getSize(), zeCommandList.get(), waitListView));
139120
auto pDst = ur_cast<char *>(dst->getDevicePtr(
140121
device, ur_mem_buffer_t::device_access_mode_t::write_only, 0,
141-
dst->getSize(), [&](void *src, void *dst, size_t size) {
142-
ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy,
143-
(zeCommandList.get(), dst, src, size, nullptr,
144-
waitListView.num, waitListView.handles));
145-
waitListView.clear();
146-
}));
122+
dst->getSize(), zeCommandList.get(), waitListView));
147123

148124
ZE2UR_CALL(zeCommandListAppendMemoryCopyRegion,
149125
(zeCommandList.get(), pDst, &zeParams.dstRegion, zeParams.dstPitch,
@@ -213,13 +189,6 @@ ur_result_t ur_command_list_manager::appendKernelLaunch(
213189

214190
auto waitListView = getWaitListView(phEventWaitList, numEventsInWaitList);
215191

216-
auto memoryMigrate = [&](void *src, void *dst, size_t size) {
217-
ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy,
218-
(zeCommandList.get(), dst, src, size, nullptr,
219-
waitListView.num, waitListView.handles));
220-
waitListView.clear();
221-
};
222-
223192
// If the offset is {0, 0, 0}, pass NULL instead.
224193
// This allows us to skip setting the offset.
225194
bool hasOffset = false;
@@ -232,7 +201,7 @@ ur_result_t ur_command_list_manager::appendKernelLaunch(
232201

233202
UR_CALL(hKernel->prepareForSubmission(context, device, pGlobalWorkOffset,
234203
workDim, WG[0], WG[1], WG[2],
235-
memoryMigrate));
204+
zeCommandList.get(), waitListView));
236205

237206
TRACK_SCOPE_LATENCY(
238207
"ur_command_list_manager::zeCommandListAppendLaunchKernel");

unified-runtime/source/adapters/level_zero/v2/command_list_manager.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
#include "common.hpp"
1414
#include "context.hpp"
1515
#include "event_pool_cache.hpp"
16-
#include "memory.hpp"
1716
#include "queue_api.hpp"
1817
#include <ze_api.h>
1918

19+
struct ur_mem_buffer_t;
20+
2021
struct wait_list_view {
2122
ze_event_handle_t *handles;
2223
uint32_t num;

unified-runtime/source/adapters/level_zero/v2/kernel.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ ur_result_t ur_kernel_handle_t_::prepareForSubmission(
277277
ur_context_handle_t hContext, ur_device_handle_t hDevice,
278278
const size_t *pGlobalWorkOffset, uint32_t workDim, uint32_t groupSizeX,
279279
uint32_t groupSizeY, uint32_t groupSizeZ,
280-
std::function<void(void *, void *, size_t)> migrate) {
280+
ze_command_list_handle_t commandList, wait_list_view &waitListView) {
281281
auto hZeKernel = getZeHandle(hDevice);
282282

283283
if (pGlobalWorkOffset != NULL) {
@@ -293,8 +293,9 @@ ur_result_t ur_kernel_handle_t_::prepareForSubmission(
293293
if (pending.hMem) {
294294
if (!pending.hMem->isImage()) {
295295
auto hBuffer = pending.hMem->getBuffer();
296-
zePtr = hBuffer->getDevicePtr(hDevice, pending.mode, 0,
297-
hBuffer->getSize(), migrate);
296+
zePtr =
297+
hBuffer->getDevicePtr(hDevice, pending.mode, 0, hBuffer->getSize(),
298+
commandList, waitListView);
298299
} else {
299300
auto hImage = static_cast<ur_mem_image_t *>(pending.hMem->getImage());
300301
zePtr = reinterpret_cast<void *>(hImage->getZeImage());

unified-runtime/source/adapters/level_zero/v2/kernel.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,13 @@ struct ur_kernel_handle_t_ : ur_object {
8383

8484
// Set all required values for the kernel before submission (including pending
8585
// memory allocations).
86-
ur_result_t
87-
prepareForSubmission(ur_context_handle_t hContext, ur_device_handle_t hDevice,
88-
const size_t *pGlobalWorkOffset, uint32_t workDim,
89-
uint32_t groupSizeX, uint32_t groupSizeY,
90-
uint32_t groupSizeZ,
91-
std::function<void(void *, void *, size_t)> migrate);
86+
ur_result_t prepareForSubmission(ur_context_handle_t hContext,
87+
ur_device_handle_t hDevice,
88+
const size_t *pGlobalWorkOffset,
89+
uint32_t workDim, uint32_t groupSizeX,
90+
uint32_t groupSizeY, uint32_t groupSizeZ,
91+
ze_command_list_handle_t cmdList,
92+
wait_list_view &waitListView);
9293

9394
private:
9495
// Keep the program of the kernel.

unified-runtime/source/adapters/level_zero/v2/memory.cpp

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,24 @@ ur_usm_handle_t::ur_usm_handle_t(ur_context_handle_t hContext, size_t size,
3232
: ur_mem_buffer_t(hContext, size, device_access_mode_t::read_write),
3333
ptr(const_cast<void *>(ptr)) {}
3434

35-
void *ur_usm_handle_t::getDevicePtr(
36-
ur_device_handle_t /*hDevice*/, device_access_mode_t /*access*/,
37-
size_t offset, size_t /*size*/,
38-
std::function<void(void *src, void *dst, size_t)> /*migrate*/) {
35+
void *ur_usm_handle_t::getDevicePtr(ur_device_handle_t /*hDevice*/,
36+
device_access_mode_t /*access*/,
37+
size_t offset, size_t /*size*/,
38+
ze_command_list_handle_t /*cmdList*/,
39+
wait_list_view & /*waitListView*/) {
3940
return ur_cast<char *>(ptr) + offset;
4041
}
4142

42-
void *
43-
ur_usm_handle_t::mapHostPtr(ur_map_flags_t /*flags*/, size_t offset,
44-
size_t /*size*/,
45-
std::function<void(void *src, void *dst, size_t)>) {
43+
void *ur_usm_handle_t::mapHostPtr(ur_map_flags_t /*flags*/, size_t offset,
44+
size_t /*size*/,
45+
ze_command_list_handle_t /*cmdList*/,
46+
wait_list_view & /*waitListView*/) {
4647
return ur_cast<char *>(ptr) + offset;
4748
}
4849

49-
void ur_usm_handle_t::unmapHostPtr(
50-
void * /*pMappedPtr*/, std::function<void(void *src, void *dst, size_t)>) {
50+
void ur_usm_handle_t::unmapHostPtr(void * /*pMappedPtr*/,
51+
ze_command_list_handle_t cmdList,
52+
wait_list_view & /*waitListView*/) {
5153
/* nop */
5254
}
5355

@@ -106,14 +108,14 @@ ur_integrated_buffer_handle_t::~ur_integrated_buffer_handle_t() {
106108

107109
void *ur_integrated_buffer_handle_t::getDevicePtr(
108110
ur_device_handle_t /*hDevice*/, device_access_mode_t /*access*/,
109-
size_t offset, size_t /*size*/,
110-
std::function<void(void *src, void *dst, size_t)> /*migrate*/) {
111+
size_t offset, size_t /*size*/, ze_command_list_handle_t /*cmdList*/,
112+
wait_list_view & /*waitListView*/) {
111113
return ur_cast<char *>(ptr.get()) + offset;
112114
}
113115

114116
void *ur_integrated_buffer_handle_t::mapHostPtr(
115117
ur_map_flags_t /*flags*/, size_t offset, size_t /*size*/,
116-
std::function<void(void *src, void *dst, size_t)> /*migrate*/) {
118+
ze_command_list_handle_t /*cmdList*/, wait_list_view & /*waitListView*/) {
117119
// TODO: if writeBackPtr is set, we should map to that pointer
118120
// because that's what SYCL expects, SYCL will attempt to call free
119121
// on the resulting pointer leading to double free with the current
@@ -122,7 +124,8 @@ void *ur_integrated_buffer_handle_t::mapHostPtr(
122124
}
123125

124126
void ur_integrated_buffer_handle_t::unmapHostPtr(
125-
void * /*pMappedPtr*/, std::function<void(void *src, void *dst, size_t)>) {
127+
void * /*pMappedPtr*/, ze_command_list_handle_t /*cmdList*/,
128+
wait_list_view & /*waitListView*/) {
126129
// TODO: if writeBackPtr is set, we should copy the data back
127130
/* nop */
128131
}
@@ -250,8 +253,8 @@ void *ur_discrete_buffer_handle_t::getActiveDeviceAlloc(size_t offset) {
250253

251254
void *ur_discrete_buffer_handle_t::getDevicePtr(
252255
ur_device_handle_t hDevice, device_access_mode_t /*access*/, size_t offset,
253-
size_t /*size*/,
254-
std::function<void(void *src, void *dst, size_t)> /*migrate*/) {
256+
size_t /*size*/, ze_command_list_handle_t /*cmdList*/,
257+
wait_list_view & /*waitListView*/) {
255258
TRACK_SCOPE_LATENCY("ur_discrete_buffer_handle_t::getDevicePtr");
256259

257260
if (!activeAllocationDevice) {
@@ -283,9 +286,22 @@ void *ur_discrete_buffer_handle_t::getDevicePtr(
283286
return getActiveDeviceAlloc(offset);
284287
}
285288

286-
void *ur_discrete_buffer_handle_t::mapHostPtr(
287-
ur_map_flags_t flags, size_t offset, size_t size,
288-
std::function<void(void *src, void *dst, size_t)> migrate) {
289+
static void migrateMemory(ze_command_list_handle_t cmdList, void *src,
290+
void *dst, size_t size,
291+
wait_list_view &waitListView) {
292+
if (!cmdList) {
293+
throw UR_RESULT_ERROR_INVALID_NULL_HANDLE;
294+
}
295+
ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy,
296+
(cmdList, dst, src, size, nullptr, waitListView.num,
297+
waitListView.handles));
298+
waitListView.clear();
299+
}
300+
301+
void *ur_discrete_buffer_handle_t::mapHostPtr(ur_map_flags_t flags,
302+
size_t offset, size_t size,
303+
ze_command_list_handle_t cmdList,
304+
wait_list_view &waitListView) {
289305
TRACK_SCOPE_LATENCY("ur_discrete_buffer_handle_t::mapHostPtr");
290306
// TODO: use async alloc?
291307

@@ -309,15 +325,16 @@ void *ur_discrete_buffer_handle_t::mapHostPtr(
309325

310326
if (activeAllocationDevice && (flags & UR_MAP_FLAG_READ)) {
311327
auto srcPtr = getActiveDeviceAlloc(offset);
312-
migrate(srcPtr, hostAllocations.back().ptr.get(), size);
328+
migrateMemory(cmdList, srcPtr, hostAllocations.back().ptr.get(), size,
329+
waitListView);
313330
}
314331

315332
return hostAllocations.back().ptr.get();
316333
}
317334

318-
void ur_discrete_buffer_handle_t::unmapHostPtr(
319-
void *pMappedPtr,
320-
std::function<void(void *src, void *dst, size_t)> migrate) {
335+
void ur_discrete_buffer_handle_t::unmapHostPtr(void *pMappedPtr,
336+
ze_command_list_handle_t cmdList,
337+
wait_list_view &waitListView) {
321338
TRACK_SCOPE_LATENCY("ur_discrete_buffer_handle_t::unmapHostPtr");
322339

323340
auto hostAlloc =
@@ -341,8 +358,9 @@ void ur_discrete_buffer_handle_t::unmapHostPtr(
341358
// UR_MAP_FLAG_WRITE_INVALIDATE_REGION when there is an active device
342359
// allocation. is this correct?
343360
if (activeAllocationDevice) {
344-
migrate(hostAlloc->ptr.get(), getActiveDeviceAlloc(hostAlloc->offset),
345-
hostAlloc->size);
361+
migrateMemory(cmdList, hostAlloc->ptr.get(),
362+
getActiveDeviceAlloc(hostAlloc->offset), hostAlloc->size,
363+
waitListView);
346364
}
347365

348366
hostAllocations.erase(hostAlloc);
@@ -361,18 +379,20 @@ ur_shared_buffer_handle_t::ur_shared_buffer_handle_t(
361379

362380
void *ur_shared_buffer_handle_t::getDevicePtr(
363381
ur_device_handle_t, device_access_mode_t, size_t offset, size_t,
364-
std::function<void(void *src, void *dst, size_t)>) {
382+
ze_command_list_handle_t /*cmdList*/, wait_list_view & /*waitListView*/) {
365383
return reinterpret_cast<char *>(ptr.get()) + offset;
366384
}
367385

368-
void *ur_shared_buffer_handle_t::mapHostPtr(
369-
ur_map_flags_t, size_t offset, size_t,
370-
std::function<void(void *src, void *dst, size_t)>) {
386+
void *
387+
ur_shared_buffer_handle_t::mapHostPtr(ur_map_flags_t, size_t offset, size_t,
388+
ze_command_list_handle_t /*cmdList*/,
389+
wait_list_view & /*waitListView*/) {
371390
return reinterpret_cast<char *>(ptr.get()) + offset;
372391
}
373392

374393
void ur_shared_buffer_handle_t::unmapHostPtr(
375-
void *, std::function<void(void *src, void *dst, size_t)>) {
394+
void *, ze_command_list_handle_t /*cmdList*/,
395+
wait_list_view & /*waitListView*/) {
376396
// nop
377397
}
378398

@@ -403,24 +423,27 @@ ur_mem_sub_buffer_t::~ur_mem_sub_buffer_t() {
403423
ur::level_zero::urMemRelease(hParent);
404424
}
405425

406-
void *ur_mem_sub_buffer_t::getDevicePtr(
407-
ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
408-
size_t size, std::function<void(void *src, void *dst, size_t)> migrate) {
426+
void *ur_mem_sub_buffer_t::getDevicePtr(ur_device_handle_t hDevice,
427+
device_access_mode_t access,
428+
size_t offset, size_t size,
429+
ze_command_list_handle_t cmdList,
430+
wait_list_view &waitListView) {
409431
return hParent->getBuffer()->getDevicePtr(
410-
hDevice, access, offset + this->offset, size, std::move(migrate));
432+
hDevice, access, offset + this->offset, size, cmdList, waitListView);
411433
}
412434

413-
void *ur_mem_sub_buffer_t::mapHostPtr(
414-
ur_map_flags_t flags, size_t offset, size_t size,
415-
std::function<void(void *src, void *dst, size_t)> migrate) {
435+
void *ur_mem_sub_buffer_t::mapHostPtr(ur_map_flags_t flags, size_t offset,
436+
size_t size,
437+
ze_command_list_handle_t cmdList,
438+
wait_list_view &waitListView) {
416439
return hParent->getBuffer()->mapHostPtr(flags, offset + this->offset, size,
417-
std::move(migrate));
440+
cmdList, waitListView);
418441
}
419442

420-
void ur_mem_sub_buffer_t::unmapHostPtr(
421-
void *pMappedPtr,
422-
std::function<void(void *src, void *dst, size_t)> migrate) {
423-
return hParent->getBuffer()->unmapHostPtr(pMappedPtr, std::move(migrate));
443+
void ur_mem_sub_buffer_t::unmapHostPtr(void *pMappedPtr,
444+
ze_command_list_handle_t cmdList,
445+
wait_list_view &waitListView) {
446+
return hParent->getBuffer()->unmapHostPtr(pMappedPtr, cmdList, waitListView);
424447
}
425448

426449
ur_shared_mutex &ur_mem_sub_buffer_t::getMutex() {
@@ -690,9 +713,10 @@ ur_result_t urMemGetNativeHandle(ur_mem_handle_t hMem,
690713

691714
std::scoped_lock<ur_shared_mutex> lock(hBuffer->getMutex());
692715

716+
wait_list_view emptyWaitListView(nullptr, 0);
693717
auto ptr = hBuffer->getDevicePtr(
694718
hDevice, ur_mem_buffer_t::device_access_mode_t::read_write, 0,
695-
hBuffer->getSize(), nullptr);
719+
hBuffer->getSize(), nullptr, emptyWaitListView);
696720
*phNativeMem = reinterpret_cast<ur_native_handle_t>(ptr);
697721
return UR_RESULT_SUCCESS;
698722
} catch (...) {

0 commit comments

Comments
 (0)