@@ -100,16 +100,12 @@ ur_result_t ur_exp_command_buffer_handle_t_::addWaitNodes(
100
100
return Err;
101
101
}
102
102
103
- kernel_command_handle::kernel_command_handle (
104
- ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
105
- CUgraphNode Node, CUDA_KERNEL_NODE_PARAMS Params, uint32_t WorkDim,
103
+ kernel_command_data::kernel_command_data (
104
+ ur_kernel_handle_t Kernel, CUDA_KERNEL_NODE_PARAMS Params, uint32_t WorkDim,
106
105
const size_t *GlobalWorkOffsetPtr, const size_t *GlobalWorkSizePtr,
107
106
const size_t *LocalWorkSizePtr, uint32_t NumKernelAlternatives,
108
- ur_kernel_handle_t *KernelAlternatives, CUgraphNode SignalNode,
109
- const std::vector<CUgraphNode> &WaitNodes)
110
- : ur_exp_command_buffer_command_handle_t_(CommandBuffer, Node, SignalNode,
111
- WaitNodes),
112
- Kernel(Kernel), Params(Params), WorkDim(WorkDim) {
107
+ ur_kernel_handle_t *KernelAlternatives)
108
+ : Kernel(Kernel), Params(Params), WorkDim(WorkDim) {
113
109
const size_t CopySize = sizeof (size_t ) * WorkDim;
114
110
std::memcpy (GlobalWorkOffset, GlobalWorkOffsetPtr, CopySize);
115
111
std::memcpy (GlobalWorkSize, GlobalWorkSizePtr, CopySize);
@@ -191,8 +187,8 @@ static void setCopyParams(const void *SrcPtr, const CUmemorytype_enum SrcType,
191
187
}
192
188
193
189
// Helper function for enqueuing memory fills. Templated on the CommandType
194
- // enum class for the type of fill being created.
195
- template <class T >
190
+ // variant for the type of fill being created.
191
+ template <CommandType CT >
196
192
static ur_result_t enqueueCommandBufferFillHelper (
197
193
ur_exp_command_buffer_handle_t CommandBuffer, void *DstDevice,
198
194
const CUmemorytype_enum DstType, const void *Pattern, size_t PatternSize,
@@ -331,8 +327,9 @@ static ur_result_t enqueueCommandBufferFillHelper(
331
327
332
328
std::vector<CUgraphNode> WaitNodes =
333
329
NumEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
334
- auto NewCommand = std::make_unique<T>(CommandBuffer, GraphNode, SignalNode,
335
- WaitNodes, std::move (DecomposedNodes));
330
+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
331
+ CT, CommandBuffer, GraphNode, SignalNode, WaitNodes,
332
+ fill_command_data{std::move (DecomposedNodes)});
336
333
if (RetCommand) {
337
334
*RetCommand = NewCommand.get ();
338
335
}
@@ -528,10 +525,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
528
525
529
526
std::vector<CUgraphNode> WaitNodes =
530
527
numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
531
- auto NewCommand = std::make_unique<kernel_command_handle>(
532
- hCommandBuffer, hKernel, GraphNode, NodeParams, workDim,
533
- pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
534
- numKernelAlternatives, phKernelAlternatives, SignalNode, WaitNodes);
528
+ auto KernelData = kernel_command_data{hKernel,
529
+ NodeParams,
530
+ workDim,
531
+ pGlobalWorkOffset,
532
+ pGlobalWorkSize,
533
+ pLocalWorkSize,
534
+ numKernelAlternatives,
535
+ phKernelAlternatives};
536
+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
537
+ CommandType::Kernel, hCommandBuffer, GraphNode, SignalNode, WaitNodes,
538
+ KernelData);
535
539
536
540
if (phCommand) {
537
541
*phCommand = NewCommand.get ();
@@ -585,8 +589,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp(
585
589
586
590
std::vector<CUgraphNode> WaitNodes =
587
591
numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
588
- auto NewCommand = std::make_unique<usm_memcpy_command_handle >(
589
- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
592
+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_ >(
593
+ CommandType::USMMemcpy, hCommandBuffer, GraphNode, SignalNode, WaitNodes);
590
594
if (phCommand) {
591
595
*phCommand = NewCommand.get ();
592
596
}
@@ -650,8 +654,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
650
654
651
655
std::vector<CUgraphNode> WaitNodes =
652
656
numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
653
- auto NewCommand = std::make_unique<buffer_copy_command_handle>(
654
- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
657
+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
658
+ CommandType::MemBufferCopy, hCommandBuffer, GraphNode, SignalNode,
659
+ WaitNodes);
655
660
656
661
if (phCommand) {
657
662
*phCommand = NewCommand.get ();
@@ -713,8 +718,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
713
718
714
719
std::vector<CUgraphNode> WaitNodes =
715
720
numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
716
- auto NewCommand = std::make_unique<buffer_copy_rect_command_handle>(
717
- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
721
+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
722
+ CommandType::MemBufferCopyRect, hCommandBuffer, GraphNode, SignalNode,
723
+ WaitNodes);
718
724
719
725
if (phCommand) {
720
726
*phCommand = NewCommand.get ();
@@ -772,8 +778,9 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp(
772
778
773
779
std::vector<CUgraphNode> WaitNodes =
774
780
numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
775
- auto NewCommand = std::make_unique<buffer_write_command_handle>(
776
- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
781
+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
782
+ CommandType::MemBufferWrite, hCommandBuffer, GraphNode, SignalNode,
783
+ WaitNodes);
777
784
if (phCommand) {
778
785
*phCommand = NewCommand.get ();
779
786
}
@@ -829,8 +836,9 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp(
829
836
830
837
std::vector<CUgraphNode> WaitNodes =
831
838
numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
832
- auto NewCommand = std::make_unique<buffer_read_command_handle>(
833
- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
839
+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
840
+ CommandType::MemBufferRead, hCommandBuffer, GraphNode, SignalNode,
841
+ WaitNodes);
834
842
if (phCommand) {
835
843
*phCommand = NewCommand.get ();
836
844
}
@@ -890,8 +898,9 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp(
890
898
891
899
std::vector<CUgraphNode> WaitNodes =
892
900
numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
893
- auto NewCommand = std::make_unique<buffer_write_rect_command_handle>(
894
- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
901
+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
902
+ CommandType::MemBufferWriteRect, hCommandBuffer, GraphNode, SignalNode,
903
+ WaitNodes);
895
904
896
905
if (phCommand) {
897
906
*phCommand = NewCommand.get ();
@@ -952,8 +961,9 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
952
961
953
962
std::vector<CUgraphNode> WaitNodes =
954
963
numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
955
- auto NewCommand = std::make_unique<buffer_read_rect_command_handle>(
956
- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
964
+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
965
+ CommandType::MemBufferReadRect, hCommandBuffer, GraphNode, SignalNode,
966
+ WaitNodes);
957
967
958
968
if (phCommand) {
959
969
*phCommand = NewCommand.get ();
@@ -1006,8 +1016,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
1006
1016
1007
1017
std::vector<CUgraphNode> WaitNodes =
1008
1018
numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
1009
- auto NewCommand = std::make_unique<usm_prefetch_command_handle>(
1010
- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
1019
+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
1020
+ CommandType::USMPrefetch, hCommandBuffer, GraphNode, SignalNode,
1021
+ WaitNodes);
1011
1022
1012
1023
if (phCommand) {
1013
1024
*phCommand = NewCommand.get ();
@@ -1060,8 +1071,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
1060
1071
1061
1072
std::vector<CUgraphNode> WaitNodes =
1062
1073
numEventsInWaitList ? std::move (DepsList) : std::vector<CUgraphNode>();
1063
- auto NewCommand = std::make_unique<usm_advise_command_handle >(
1064
- hCommandBuffer, GraphNode, SignalNode, WaitNodes);
1074
+ auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_ >(
1075
+ CommandType::USMAdvise, hCommandBuffer, GraphNode, SignalNode, WaitNodes);
1065
1076
1066
1077
if (phCommand) {
1067
1078
*phCommand = NewCommand.get ();
@@ -1096,7 +1107,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
1096
1107
auto DstDevice = std::get<BufferMem>(hBuffer->Mem )
1097
1108
.getPtrWithOffset (hCommandBuffer->Device , offset);
1098
1109
1099
- return enqueueCommandBufferFillHelper<buffer_fill_command_handle >(
1110
+ return enqueueCommandBufferFillHelper<CommandType::MemBufferFill >(
1100
1111
hCommandBuffer, &DstDevice, CU_MEMORYTYPE_DEVICE, pPattern, patternSize,
1101
1112
size, numSyncPointsInWaitList, pSyncPointWaitList, numEventsInWaitList,
1102
1113
phEventWaitList, pSyncPoint, phEvent, phCommand);
@@ -1116,7 +1127,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp(
1116
1127
(patternSize > 0 ); // is a positive power of two
1117
1128
1118
1129
UR_ASSERT (PatternIsValid && PatternSizeIsValid, UR_RESULT_ERROR_INVALID_SIZE);
1119
- return enqueueCommandBufferFillHelper<usm_fill_command_handle >(
1130
+ return enqueueCommandBufferFillHelper<CommandType::USMFill >(
1120
1131
hCommandBuffer, pPtr, CU_MEMORYTYPE_UNIFIED, pPattern, patternSize, size,
1121
1132
numSyncPointsInWaitList, pSyncPointWaitList, numEventsInWaitList,
1122
1133
phEventWaitList, pSyncPoint, phEvent, phCommand);
@@ -1165,12 +1176,12 @@ ur_result_t
1165
1176
validateCommandDesc (ur_exp_command_buffer_handle_t CommandBuffer,
1166
1177
const ur_exp_command_buffer_update_kernel_launch_desc_t
1167
1178
&UpdateCommandDesc) {
1168
- if (UpdateCommandDesc.hCommand ->getCommandType () != CommandType::Kernel) {
1179
+ if (UpdateCommandDesc.hCommand ->Type != CommandType::Kernel) {
1169
1180
return UR_RESULT_ERROR_INVALID_VALUE;
1170
1181
}
1171
1182
1172
- auto Command =
1173
- static_cast <kernel_command_handle *>(UpdateCommandDesc. hCommand );
1183
+ auto * Command = UpdateCommandDesc. hCommand ;
1184
+ auto &KernelData = std::get<kernel_command_data>(Command-> CommandData );
1174
1185
if (CommandBuffer != Command->CommandBuffer ) {
1175
1186
return UR_RESULT_ERROR_INVALID_COMMAND_BUFFER_COMMAND_HANDLE_EXP;
1176
1187
}
@@ -1180,14 +1191,14 @@ validateCommandDesc(ur_exp_command_buffer_handle_t CommandBuffer,
1180
1191
return UR_RESULT_ERROR_INVALID_OPERATION;
1181
1192
}
1182
1193
1183
- if (UpdateCommandDesc.newWorkDim != Command-> WorkDim &&
1194
+ if (UpdateCommandDesc.newWorkDim != KernelData. WorkDim &&
1184
1195
(!UpdateCommandDesc.pNewGlobalWorkOffset ||
1185
1196
!UpdateCommandDesc.pNewGlobalWorkSize )) {
1186
1197
return UR_RESULT_ERROR_INVALID_VALUE;
1187
1198
}
1188
1199
1189
1200
if (UpdateCommandDesc.hNewKernel &&
1190
- !Command-> ValidKernelHandles .count (UpdateCommandDesc.hNewKernel )) {
1201
+ !KernelData. ValidKernelHandles .count (UpdateCommandDesc.hNewKernel )) {
1191
1202
return UR_RESULT_ERROR_INVALID_VALUE;
1192
1203
}
1193
1204
return UR_RESULT_SUCCESS;
@@ -1202,9 +1213,9 @@ validateCommandDesc(ur_exp_command_buffer_handle_t CommandBuffer,
1202
1213
ur_result_t
1203
1214
updateKernelArguments (const ur_exp_command_buffer_update_kernel_launch_desc_t
1204
1215
&UpdateCommandDesc) {
1205
- auto Command =
1206
- static_cast <kernel_command_handle *>(UpdateCommandDesc. hCommand );
1207
- ur_kernel_handle_t Kernel = Command-> Kernel ;
1216
+ auto * Command = UpdateCommandDesc. hCommand ;
1217
+ auto &KernelData = std::get<kernel_command_data>(Command-> CommandData );
1218
+ ur_kernel_handle_t Kernel = KernelData. Kernel ;
1208
1219
ur_device_handle_t Device = Command->CommandBuffer ->Device ;
1209
1220
1210
1221
// Update pointer arguments to the kernel
@@ -1284,29 +1295,29 @@ updateKernelArguments(const ur_exp_command_buffer_update_kernel_launch_desc_t
1284
1295
ur_result_t
1285
1296
updateCommand (const ur_exp_command_buffer_update_kernel_launch_desc_t
1286
1297
&UpdateCommandDesc) {
1287
- auto Command =
1288
- static_cast <kernel_command_handle *>(UpdateCommandDesc. hCommand );
1298
+ auto * Command = UpdateCommandDesc. hCommand ;
1299
+ auto &KernelData = std::get<kernel_command_data>(Command-> CommandData );
1289
1300
if (UpdateCommandDesc.hNewKernel ) {
1290
- Command-> Kernel = UpdateCommandDesc.hNewKernel ;
1301
+ KernelData. Kernel = UpdateCommandDesc.hNewKernel ;
1291
1302
}
1292
1303
1293
1304
if (UpdateCommandDesc.newWorkDim ) {
1294
- Command-> WorkDim = UpdateCommandDesc.newWorkDim ;
1305
+ KernelData. WorkDim = UpdateCommandDesc.newWorkDim ;
1295
1306
}
1296
1307
1297
1308
if (UpdateCommandDesc.pNewGlobalWorkOffset ) {
1298
- Command-> setGlobalOffset (UpdateCommandDesc.pNewGlobalWorkOffset );
1309
+ KernelData. setGlobalOffset (UpdateCommandDesc.pNewGlobalWorkOffset );
1299
1310
}
1300
1311
1301
1312
if (UpdateCommandDesc.pNewGlobalWorkSize ) {
1302
- Command-> setGlobalSize (UpdateCommandDesc.pNewGlobalWorkSize );
1313
+ KernelData. setGlobalSize (UpdateCommandDesc.pNewGlobalWorkSize );
1303
1314
if (!UpdateCommandDesc.pNewLocalWorkSize ) {
1304
- Command-> setNullLocalSize ();
1315
+ KernelData. setNullLocalSize ();
1305
1316
}
1306
1317
}
1307
1318
1308
1319
if (UpdateCommandDesc.pNewLocalWorkSize ) {
1309
- Command-> setLocalSize (UpdateCommandDesc.pNewLocalWorkSize );
1320
+ KernelData. setLocalSize (UpdateCommandDesc.pNewLocalWorkSize );
1310
1321
}
1311
1322
1312
1323
return UR_RESULT_SUCCESS;
@@ -1334,27 +1345,27 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1334
1345
1335
1346
// If no work-size is provided make sure we pass nullptr to setKernelParams
1336
1347
// so it can guess the local work size.
1337
- auto KernelCommandHandle =
1338
- static_cast <kernel_command_handle *>(UpdateCommandDesc.hCommand );
1339
- const bool ProvidedLocalSize = !KernelCommandHandle->isNullLocalSize ();
1348
+ auto *KernelCommandHandle = UpdateCommandDesc.hCommand ;
1349
+ auto &KernelData =
1350
+ std::get<kernel_command_data>(KernelCommandHandle->CommandData );
1351
+ const bool ProvidedLocalSize = !KernelData.isNullLocalSize ();
1340
1352
size_t *LocalWorkSize =
1341
- ProvidedLocalSize ? KernelCommandHandle-> LocalWorkSize : nullptr ;
1353
+ ProvidedLocalSize ? KernelData. LocalWorkSize : nullptr ;
1342
1354
1343
1355
// Set the number of threads per block to the number of threads per warp
1344
1356
// by default unless user has provided a better number.
1345
1357
size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
1346
1358
size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
1347
- CUfunction CuFunc = KernelCommandHandle-> Kernel ->get ();
1359
+ CUfunction CuFunc = KernelData. Kernel ->get ();
1348
1360
auto Result = setKernelParams (
1349
- hCommandBuffer->Context , hCommandBuffer->Device ,
1350
- KernelCommandHandle->WorkDim , KernelCommandHandle->GlobalWorkOffset ,
1351
- KernelCommandHandle->GlobalWorkSize , LocalWorkSize,
1352
- KernelCommandHandle->Kernel , CuFunc, ThreadsPerBlock, BlocksPerGrid);
1361
+ hCommandBuffer->Context , hCommandBuffer->Device , KernelData.WorkDim ,
1362
+ KernelData.GlobalWorkOffset , KernelData.GlobalWorkSize , LocalWorkSize,
1363
+ KernelData.Kernel , CuFunc, ThreadsPerBlock, BlocksPerGrid);
1353
1364
if (Result != UR_RESULT_SUCCESS) {
1354
1365
return Result;
1355
1366
}
1356
1367
1357
- CUDA_KERNEL_NODE_PARAMS &Params = KernelCommandHandle-> Params ;
1368
+ CUDA_KERNEL_NODE_PARAMS &Params = KernelData. Params ;
1358
1369
1359
1370
Params.func = CuFunc;
1360
1371
Params.gridDimX = BlocksPerGrid[0 ];
@@ -1363,9 +1374,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1363
1374
Params.blockDimX = ThreadsPerBlock[0 ];
1364
1375
Params.blockDimY = ThreadsPerBlock[1 ];
1365
1376
Params.blockDimZ = ThreadsPerBlock[2 ];
1366
- Params.sharedMemBytes = KernelCommandHandle-> Kernel ->getLocalSize ();
1367
- Params.kernelParams = const_cast < void **>(
1368
- KernelCommandHandle-> Kernel ->getArgPointers ().data ());
1377
+ Params.sharedMemBytes = KernelData. Kernel ->getLocalSize ();
1378
+ Params.kernelParams =
1379
+ const_cast < void **>(KernelData. Kernel ->getArgPointers ().data ());
1369
1380
1370
1381
CUgraphNode Node = KernelCommandHandle->Node ;
1371
1382
CUgraphExec CudaGraphExec = hCommandBuffer->CudaGraphExec ;
0 commit comments