@@ -32,22 +32,24 @@ ur_usm_handle_t::ur_usm_handle_t(ur_context_handle_t hContext, size_t size,
32
32
: ur_mem_buffer_t (hContext, size, device_access_mode_t ::read_write),
33
33
ptr (const_cast <void *>(ptr)) {}
34
34
35
- void *ur_usm_handle_t ::getDevicePtr(
36
- ur_device_handle_t /* hDevice*/ , device_access_mode_t /* access*/ ,
37
- size_t offset, size_t /* size*/ ,
38
- std::function<void (void *src, void *dst, size_t )> /* migrate*/ ) {
35
+ void *ur_usm_handle_t ::getDevicePtr(ur_device_handle_t /* hDevice*/ ,
36
+ device_access_mode_t /* access*/ ,
37
+ size_t offset, size_t /* size*/ ,
38
+ ze_command_list_handle_t /* cmdList*/ ,
39
+ wait_list_view & /* waitListView*/ ) {
39
40
return ur_cast<char *>(ptr) + offset;
40
41
}
41
42
42
- void *
43
- ur_usm_handle_t ::mapHostPtr( ur_map_flags_t /* flags */ , size_t offset ,
44
- size_t /* size */ ,
45
- std::function< void ( void *src, void *dst, size_t )> ) {
43
+ void *ur_usm_handle_t ::mapHostPtr( ur_map_flags_t /* flags */ , size_t offset,
44
+ size_t /* size */ ,
45
+ ze_command_list_handle_t /* cmdList */ ,
46
+ wait_list_view & /* waitListView */ ) {
46
47
return ur_cast<char *>(ptr) + offset;
47
48
}
48
49
49
- void ur_usm_handle_t::unmapHostPtr (
50
- void * /* pMappedPtr*/ , std::function<void (void *src, void *dst, size_t )>) {
50
+ void ur_usm_handle_t::unmapHostPtr (void * /* pMappedPtr*/ ,
51
+ ze_command_list_handle_t cmdList,
52
+ wait_list_view & /* waitListView*/ ) {
51
53
/* nop */
52
54
}
53
55
@@ -106,14 +108,14 @@ ur_integrated_buffer_handle_t::~ur_integrated_buffer_handle_t() {
106
108
107
109
void *ur_integrated_buffer_handle_t ::getDevicePtr(
108
110
ur_device_handle_t /* hDevice*/ , device_access_mode_t /* access*/ ,
109
- size_t offset, size_t /* size*/ ,
110
- std::function< void ( void *src, void *dst, size_t )> /* migrate */ ) {
111
+ size_t offset, size_t /* size*/ , ze_command_list_handle_t /* cmdList */ ,
112
+ wait_list_view & /* waitListView */ ) {
111
113
return ur_cast<char *>(ptr.get ()) + offset;
112
114
}
113
115
114
116
void *ur_integrated_buffer_handle_t ::mapHostPtr(
115
117
ur_map_flags_t /* flags*/ , size_t offset, size_t /* size*/ ,
116
- std::function< void ( void *src, void *dst, size_t )> /* migrate */ ) {
118
+ ze_command_list_handle_t /* cmdList */ , wait_list_view & /* waitListView */ ) {
117
119
// TODO: if writeBackPtr is set, we should map to that pointer
118
120
// because that's what SYCL expects, SYCL will attempt to call free
119
121
// on the resulting pointer leading to double free with the current
@@ -122,7 +124,8 @@ void *ur_integrated_buffer_handle_t::mapHostPtr(
122
124
}
123
125
124
126
void ur_integrated_buffer_handle_t::unmapHostPtr (
125
- void * /* pMappedPtr*/ , std::function<void (void *src, void *dst, size_t )>) {
127
+ void * /* pMappedPtr*/ , ze_command_list_handle_t /* cmdList*/ ,
128
+ wait_list_view & /* waitListView*/ ) {
126
129
// TODO: if writeBackPtr is set, we should copy the data back
127
130
/* nop */
128
131
}
@@ -250,8 +253,8 @@ void *ur_discrete_buffer_handle_t::getActiveDeviceAlloc(size_t offset) {
250
253
251
254
void *ur_discrete_buffer_handle_t ::getDevicePtr(
252
255
ur_device_handle_t hDevice, device_access_mode_t /* access*/ , size_t offset,
253
- size_t /* size*/ ,
254
- std::function< void ( void *src, void *dst, size_t )> /* migrate */ ) {
256
+ size_t /* size*/ , ze_command_list_handle_t /* cmdList */ ,
257
+ wait_list_view & /* waitListView */ ) {
255
258
TRACK_SCOPE_LATENCY (" ur_discrete_buffer_handle_t::getDevicePtr" );
256
259
257
260
if (!activeAllocationDevice) {
@@ -283,9 +286,22 @@ void *ur_discrete_buffer_handle_t::getDevicePtr(
283
286
return getActiveDeviceAlloc (offset);
284
287
}
285
288
286
- void *ur_discrete_buffer_handle_t ::mapHostPtr(
287
- ur_map_flags_t flags, size_t offset, size_t size,
288
- std::function<void (void *src, void *dst, size_t )> migrate) {
289
+ static void migrateMemory (ze_command_list_handle_t cmdList, void *src,
290
+ void *dst, size_t size,
291
+ wait_list_view &waitListView) {
292
+ if (!cmdList) {
293
+ throw UR_RESULT_ERROR_INVALID_NULL_HANDLE;
294
+ }
295
+ ZE2UR_CALL_THROWS (zeCommandListAppendMemoryCopy,
296
+ (cmdList, dst, src, size, nullptr , waitListView.num ,
297
+ waitListView.handles ));
298
+ waitListView.clear ();
299
+ }
300
+
301
+ void *ur_discrete_buffer_handle_t ::mapHostPtr(ur_map_flags_t flags,
302
+ size_t offset, size_t size,
303
+ ze_command_list_handle_t cmdList,
304
+ wait_list_view &waitListView) {
289
305
TRACK_SCOPE_LATENCY (" ur_discrete_buffer_handle_t::mapHostPtr" );
290
306
// TODO: use async alloc?
291
307
@@ -309,15 +325,16 @@ void *ur_discrete_buffer_handle_t::mapHostPtr(
309
325
310
326
if (activeAllocationDevice && (flags & UR_MAP_FLAG_READ)) {
311
327
auto srcPtr = getActiveDeviceAlloc (offset);
312
- migrate (srcPtr, hostAllocations.back ().ptr .get (), size);
328
+ migrateMemory (cmdList, srcPtr, hostAllocations.back ().ptr .get (), size,
329
+ waitListView);
313
330
}
314
331
315
332
return hostAllocations.back ().ptr .get ();
316
333
}
317
334
318
- void ur_discrete_buffer_handle_t::unmapHostPtr (
319
- void *pMappedPtr ,
320
- std::function< void ( void *src, void *dst, size_t )> migrate ) {
335
+ void ur_discrete_buffer_handle_t::unmapHostPtr (void *pMappedPtr,
336
+ ze_command_list_handle_t cmdList ,
337
+ wait_list_view &waitListView ) {
321
338
TRACK_SCOPE_LATENCY (" ur_discrete_buffer_handle_t::unmapHostPtr" );
322
339
323
340
auto hostAlloc =
@@ -341,8 +358,9 @@ void ur_discrete_buffer_handle_t::unmapHostPtr(
341
358
// UR_MAP_FLAG_WRITE_INVALIDATE_REGION when there is an active device
342
359
// allocation. is this correct?
343
360
if (activeAllocationDevice) {
344
- migrate (hostAlloc->ptr .get (), getActiveDeviceAlloc (hostAlloc->offset ),
345
- hostAlloc->size );
361
+ migrateMemory (cmdList, hostAlloc->ptr .get (),
362
+ getActiveDeviceAlloc (hostAlloc->offset ), hostAlloc->size ,
363
+ waitListView);
346
364
}
347
365
348
366
hostAllocations.erase (hostAlloc);
@@ -361,18 +379,20 @@ ur_shared_buffer_handle_t::ur_shared_buffer_handle_t(
361
379
362
380
void *ur_shared_buffer_handle_t ::getDevicePtr(
363
381
ur_device_handle_t , device_access_mode_t , size_t offset, size_t ,
364
- std::function< void ( void *src, void *dst, size_t )> ) {
382
+ ze_command_list_handle_t /* cmdList */ , wait_list_view & /* waitListView */ ) {
365
383
return reinterpret_cast <char *>(ptr.get ()) + offset;
366
384
}
367
385
368
- void *ur_shared_buffer_handle_t ::mapHostPtr(
369
- ur_map_flags_t , size_t offset, size_t ,
370
- std::function<void (void *src, void *dst, size_t )>) {
386
+ void *
387
+ ur_shared_buffer_handle_t ::mapHostPtr(ur_map_flags_t , size_t offset, size_t ,
388
+ ze_command_list_handle_t /* cmdList*/ ,
389
+ wait_list_view & /* waitListView*/ ) {
371
390
return reinterpret_cast <char *>(ptr.get ()) + offset;
372
391
}
373
392
374
393
void ur_shared_buffer_handle_t::unmapHostPtr (
375
- void *, std::function<void (void *src, void *dst, size_t )>) {
394
+ void *, ze_command_list_handle_t /* cmdList*/ ,
395
+ wait_list_view & /* waitListView*/ ) {
376
396
// nop
377
397
}
378
398
@@ -403,24 +423,27 @@ ur_mem_sub_buffer_t::~ur_mem_sub_buffer_t() {
403
423
ur::level_zero::urMemRelease (hParent);
404
424
}
405
425
406
- void *ur_mem_sub_buffer_t ::getDevicePtr(
407
- ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
408
- size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
426
+ void *ur_mem_sub_buffer_t ::getDevicePtr(ur_device_handle_t hDevice,
427
+ device_access_mode_t access,
428
+ size_t offset, size_t size,
429
+ ze_command_list_handle_t cmdList,
430
+ wait_list_view &waitListView) {
409
431
return hParent->getBuffer ()->getDevicePtr (
410
- hDevice, access , offset + this ->offset , size, std::move (migrate) );
432
+ hDevice, access , offset + this ->offset , size, cmdList, waitListView );
411
433
}
412
434
413
- void *ur_mem_sub_buffer_t ::mapHostPtr(
414
- ur_map_flags_t flags, size_t offset, size_t size,
415
- std::function<void (void *src, void *dst, size_t )> migrate) {
435
+ void *ur_mem_sub_buffer_t ::mapHostPtr(ur_map_flags_t flags, size_t offset,
436
+ size_t size,
437
+ ze_command_list_handle_t cmdList,
438
+ wait_list_view &waitListView) {
416
439
return hParent->getBuffer ()->mapHostPtr (flags, offset + this ->offset , size,
417
- std::move (migrate) );
440
+ cmdList, waitListView );
418
441
}
419
442
420
- void ur_mem_sub_buffer_t::unmapHostPtr (
421
- void *pMappedPtr ,
422
- std::function< void ( void *src, void *dst, size_t )> migrate ) {
423
- return hParent->getBuffer ()->unmapHostPtr (pMappedPtr, std::move (migrate) );
443
+ void ur_mem_sub_buffer_t::unmapHostPtr (void *pMappedPtr,
444
+ ze_command_list_handle_t cmdList ,
445
+ wait_list_view &waitListView ) {
446
+ return hParent->getBuffer ()->unmapHostPtr (pMappedPtr, cmdList, waitListView );
424
447
}
425
448
426
449
ur_shared_mutex &ur_mem_sub_buffer_t ::getMutex() {
@@ -690,9 +713,10 @@ ur_result_t urMemGetNativeHandle(ur_mem_handle_t hMem,
690
713
691
714
std::scoped_lock<ur_shared_mutex> lock (hBuffer->getMutex ());
692
715
716
+ wait_list_view emptyWaitListView (nullptr , 0 );
693
717
auto ptr = hBuffer->getDevicePtr (
694
718
hDevice, ur_mem_buffer_t ::device_access_mode_t ::read_write, 0 ,
695
- hBuffer->getSize (), nullptr );
719
+ hBuffer->getSize (), nullptr , emptyWaitListView );
696
720
*phNativeMem = reinterpret_cast <ur_native_handle_t >(ptr);
697
721
return UR_RESULT_SUCCESS;
698
722
} catch (...) {
0 commit comments