Skip to content

Commit 1cb6d33

Browse files
committed
[UR][CUDA] Remove vtable from ur_exp_command_buffer_handle_t_
With the handle removal, we will start requiring handle types to not have vtables. This changes the `ur_exp_command_buffer_handle_t_` type so that the type is encoded as an enum rather than a specific type. Three of the specific handle types require extra data, which is now stored in another field as a std::variant.
1 parent dbd2815 commit 1cb6d33

File tree

2 files changed

+110
-232
lines changed

2 files changed

+110
-232
lines changed

unified-runtime/source/adapters/cuda/command_buffer.cpp

+76-65
Original file line numberDiff line numberDiff line change
@@ -100,16 +100,12 @@ ur_result_t ur_exp_command_buffer_handle_t_::addWaitNodes(
100100
return Err;
101101
}
102102

103-
kernel_command_handle::kernel_command_handle(
104-
ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
105-
CUgraphNode Node, CUDA_KERNEL_NODE_PARAMS Params, uint32_t WorkDim,
103+
kernel_command_data::kernel_command_data(
104+
ur_kernel_handle_t Kernel, CUDA_KERNEL_NODE_PARAMS Params, uint32_t WorkDim,
106105
const size_t *GlobalWorkOffsetPtr, const size_t *GlobalWorkSizePtr,
107106
const size_t *LocalWorkSizePtr, uint32_t NumKernelAlternatives,
108-
ur_kernel_handle_t *KernelAlternatives, CUgraphNode SignalNode,
109-
const std::vector<CUgraphNode> &WaitNodes)
110-
: ur_exp_command_buffer_command_handle_t_(CommandBuffer, Node, SignalNode,
111-
WaitNodes),
112-
Kernel(Kernel), Params(Params), WorkDim(WorkDim) {
107+
ur_kernel_handle_t *KernelAlternatives)
108+
: Kernel(Kernel), Params(Params), WorkDim(WorkDim) {
113109
const size_t CopySize = sizeof(size_t) * WorkDim;
114110
std::memcpy(GlobalWorkOffset, GlobalWorkOffsetPtr, CopySize);
115111
std::memcpy(GlobalWorkSize, GlobalWorkSizePtr, CopySize);
@@ -191,8 +187,8 @@ static void setCopyParams(const void *SrcPtr, const CUmemorytype_enum SrcType,
191187
}
192188

193189
// Helper function for enqueuing memory fills. Templated on the CommandType
194-
// enum class for the type of fill being created.
195-
template <class T>
190+
// variant for the type of fill being created.
191+
template <CommandType CT>
196192
static ur_result_t enqueueCommandBufferFillHelper(
197193
ur_exp_command_buffer_handle_t CommandBuffer, void *DstDevice,
198194
const CUmemorytype_enum DstType, const void *Pattern, size_t PatternSize,
@@ -331,8 +327,9 @@ static ur_result_t enqueueCommandBufferFillHelper(
331327

332328
std::vector<CUgraphNode> WaitNodes =
333329
NumEventsInWaitList ? std::move(DepsList) : std::vector<CUgraphNode>();
334-
auto NewCommand = std::make_unique<T>(CommandBuffer, GraphNode, SignalNode,
335-
WaitNodes, std::move(DecomposedNodes));
330+
auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
331+
CT, CommandBuffer, GraphNode, SignalNode, WaitNodes,
332+
fill_command_data{std::move(DecomposedNodes)});
336333
if (RetCommand) {
337334
*RetCommand = NewCommand.get();
338335
}
@@ -528,10 +525,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
528525

529526
std::vector<CUgraphNode> WaitNodes =
530527
numEventsInWaitList ? std::move(DepsList) : std::vector<CUgraphNode>();
531-
auto NewCommand = std::make_unique<kernel_command_handle>(
532-
hCommandBuffer, hKernel, GraphNode, NodeParams, workDim,
533-
pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
534-
numKernelAlternatives, phKernelAlternatives, SignalNode, WaitNodes);
528+
auto KernelData = kernel_command_data{hKernel,
529+
NodeParams,
530+
workDim,
531+
pGlobalWorkOffset,
532+
pGlobalWorkSize,
533+
pLocalWorkSize,
534+
numKernelAlternatives,
535+
phKernelAlternatives};
536+
auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
537+
CommandType::Kernel, hCommandBuffer, GraphNode, SignalNode, WaitNodes,
538+
KernelData);
535539

536540
if (phCommand) {
537541
*phCommand = NewCommand.get();
@@ -585,8 +589,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp(
585589

586590
std::vector<CUgraphNode> WaitNodes =
587591
numEventsInWaitList ? std::move(DepsList) : std::vector<CUgraphNode>();
588-
auto NewCommand = std::make_unique<usm_memcpy_command_handle>(
589-
hCommandBuffer, GraphNode, SignalNode, WaitNodes);
592+
auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
593+
CommandType::USMMemcpy, hCommandBuffer, GraphNode, SignalNode, WaitNodes);
590594
if (phCommand) {
591595
*phCommand = NewCommand.get();
592596
}
@@ -650,8 +654,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
650654

651655
std::vector<CUgraphNode> WaitNodes =
652656
numEventsInWaitList ? std::move(DepsList) : std::vector<CUgraphNode>();
653-
auto NewCommand = std::make_unique<buffer_copy_command_handle>(
654-
hCommandBuffer, GraphNode, SignalNode, WaitNodes);
657+
auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
658+
CommandType::MemBufferCopy, hCommandBuffer, GraphNode, SignalNode,
659+
WaitNodes);
655660

656661
if (phCommand) {
657662
*phCommand = NewCommand.get();
@@ -713,8 +718,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
713718

714719
std::vector<CUgraphNode> WaitNodes =
715720
numEventsInWaitList ? std::move(DepsList) : std::vector<CUgraphNode>();
716-
auto NewCommand = std::make_unique<buffer_copy_rect_command_handle>(
717-
hCommandBuffer, GraphNode, SignalNode, WaitNodes);
721+
auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
722+
CommandType::MemBufferCopyRect, hCommandBuffer, GraphNode, SignalNode,
723+
WaitNodes);
718724

719725
if (phCommand) {
720726
*phCommand = NewCommand.get();
@@ -772,8 +778,9 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp(
772778

773779
std::vector<CUgraphNode> WaitNodes =
774780
numEventsInWaitList ? std::move(DepsList) : std::vector<CUgraphNode>();
775-
auto NewCommand = std::make_unique<buffer_write_command_handle>(
776-
hCommandBuffer, GraphNode, SignalNode, WaitNodes);
781+
auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
782+
CommandType::MemBufferWrite, hCommandBuffer, GraphNode, SignalNode,
783+
WaitNodes);
777784
if (phCommand) {
778785
*phCommand = NewCommand.get();
779786
}
@@ -829,8 +836,9 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp(
829836

830837
std::vector<CUgraphNode> WaitNodes =
831838
numEventsInWaitList ? std::move(DepsList) : std::vector<CUgraphNode>();
832-
auto NewCommand = std::make_unique<buffer_read_command_handle>(
833-
hCommandBuffer, GraphNode, SignalNode, WaitNodes);
839+
auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
840+
CommandType::MemBufferRead, hCommandBuffer, GraphNode, SignalNode,
841+
WaitNodes);
834842
if (phCommand) {
835843
*phCommand = NewCommand.get();
836844
}
@@ -890,8 +898,9 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp(
890898

891899
std::vector<CUgraphNode> WaitNodes =
892900
numEventsInWaitList ? std::move(DepsList) : std::vector<CUgraphNode>();
893-
auto NewCommand = std::make_unique<buffer_write_rect_command_handle>(
894-
hCommandBuffer, GraphNode, SignalNode, WaitNodes);
901+
auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
902+
CommandType::MemBufferWriteRect, hCommandBuffer, GraphNode, SignalNode,
903+
WaitNodes);
895904

896905
if (phCommand) {
897906
*phCommand = NewCommand.get();
@@ -952,8 +961,9 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
952961

953962
std::vector<CUgraphNode> WaitNodes =
954963
numEventsInWaitList ? std::move(DepsList) : std::vector<CUgraphNode>();
955-
auto NewCommand = std::make_unique<buffer_read_rect_command_handle>(
956-
hCommandBuffer, GraphNode, SignalNode, WaitNodes);
964+
auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
965+
CommandType::MemBufferReadRect, hCommandBuffer, GraphNode, SignalNode,
966+
WaitNodes);
957967

958968
if (phCommand) {
959969
*phCommand = NewCommand.get();
@@ -1006,8 +1016,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
10061016

10071017
std::vector<CUgraphNode> WaitNodes =
10081018
numEventsInWaitList ? std::move(DepsList) : std::vector<CUgraphNode>();
1009-
auto NewCommand = std::make_unique<usm_prefetch_command_handle>(
1010-
hCommandBuffer, GraphNode, SignalNode, WaitNodes);
1019+
auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
1020+
CommandType::USMPrefetch, hCommandBuffer, GraphNode, SignalNode,
1021+
WaitNodes);
10111022

10121023
if (phCommand) {
10131024
*phCommand = NewCommand.get();
@@ -1060,8 +1071,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
10601071

10611072
std::vector<CUgraphNode> WaitNodes =
10621073
numEventsInWaitList ? std::move(DepsList) : std::vector<CUgraphNode>();
1063-
auto NewCommand = std::make_unique<usm_advise_command_handle>(
1064-
hCommandBuffer, GraphNode, SignalNode, WaitNodes);
1074+
auto NewCommand = std::make_unique<ur_exp_command_buffer_command_handle_t_>(
1075+
CommandType::USMAdvise, hCommandBuffer, GraphNode, SignalNode, WaitNodes);
10651076

10661077
if (phCommand) {
10671078
*phCommand = NewCommand.get();
@@ -1096,7 +1107,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
10961107
auto DstDevice = std::get<BufferMem>(hBuffer->Mem)
10971108
.getPtrWithOffset(hCommandBuffer->Device, offset);
10981109

1099-
return enqueueCommandBufferFillHelper<buffer_fill_command_handle>(
1110+
return enqueueCommandBufferFillHelper<CommandType::MemBufferFill>(
11001111
hCommandBuffer, &DstDevice, CU_MEMORYTYPE_DEVICE, pPattern, patternSize,
11011112
size, numSyncPointsInWaitList, pSyncPointWaitList, numEventsInWaitList,
11021113
phEventWaitList, pSyncPoint, phEvent, phCommand);
@@ -1116,7 +1127,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp(
11161127
(patternSize > 0); // is a positive power of two
11171128

11181129
UR_ASSERT(PatternIsValid && PatternSizeIsValid, UR_RESULT_ERROR_INVALID_SIZE);
1119-
return enqueueCommandBufferFillHelper<usm_fill_command_handle>(
1130+
return enqueueCommandBufferFillHelper<CommandType::USMFill>(
11201131
hCommandBuffer, pPtr, CU_MEMORYTYPE_UNIFIED, pPattern, patternSize, size,
11211132
numSyncPointsInWaitList, pSyncPointWaitList, numEventsInWaitList,
11221133
phEventWaitList, pSyncPoint, phEvent, phCommand);
@@ -1165,12 +1176,12 @@ ur_result_t
11651176
validateCommandDesc(ur_exp_command_buffer_handle_t CommandBuffer,
11661177
const ur_exp_command_buffer_update_kernel_launch_desc_t
11671178
&UpdateCommandDesc) {
1168-
if (UpdateCommandDesc.hCommand->getCommandType() != CommandType::Kernel) {
1179+
if (UpdateCommandDesc.hCommand->Type != CommandType::Kernel) {
11691180
return UR_RESULT_ERROR_INVALID_VALUE;
11701181
}
11711182

1172-
auto Command =
1173-
static_cast<kernel_command_handle *>(UpdateCommandDesc.hCommand);
1183+
auto *Command = UpdateCommandDesc.hCommand;
1184+
auto &KernelData = std::get<kernel_command_data>(Command->CommandData);
11741185
if (CommandBuffer != Command->CommandBuffer) {
11751186
return UR_RESULT_ERROR_INVALID_COMMAND_BUFFER_COMMAND_HANDLE_EXP;
11761187
}
@@ -1180,14 +1191,14 @@ validateCommandDesc(ur_exp_command_buffer_handle_t CommandBuffer,
11801191
return UR_RESULT_ERROR_INVALID_OPERATION;
11811192
}
11821193

1183-
if (UpdateCommandDesc.newWorkDim != Command->WorkDim &&
1194+
if (UpdateCommandDesc.newWorkDim != KernelData.WorkDim &&
11841195
(!UpdateCommandDesc.pNewGlobalWorkOffset ||
11851196
!UpdateCommandDesc.pNewGlobalWorkSize)) {
11861197
return UR_RESULT_ERROR_INVALID_VALUE;
11871198
}
11881199

11891200
if (UpdateCommandDesc.hNewKernel &&
1190-
!Command->ValidKernelHandles.count(UpdateCommandDesc.hNewKernel)) {
1201+
!KernelData.ValidKernelHandles.count(UpdateCommandDesc.hNewKernel)) {
11911202
return UR_RESULT_ERROR_INVALID_VALUE;
11921203
}
11931204
return UR_RESULT_SUCCESS;
@@ -1202,9 +1213,9 @@ validateCommandDesc(ur_exp_command_buffer_handle_t CommandBuffer,
12021213
ur_result_t
12031214
updateKernelArguments(const ur_exp_command_buffer_update_kernel_launch_desc_t
12041215
&UpdateCommandDesc) {
1205-
auto Command =
1206-
static_cast<kernel_command_handle *>(UpdateCommandDesc.hCommand);
1207-
ur_kernel_handle_t Kernel = Command->Kernel;
1216+
auto *Command = UpdateCommandDesc.hCommand;
1217+
auto &KernelData = std::get<kernel_command_data>(Command->CommandData);
1218+
ur_kernel_handle_t Kernel = KernelData.Kernel;
12081219
ur_device_handle_t Device = Command->CommandBuffer->Device;
12091220

12101221
// Update pointer arguments to the kernel
@@ -1284,29 +1295,29 @@ updateKernelArguments(const ur_exp_command_buffer_update_kernel_launch_desc_t
12841295
ur_result_t
12851296
updateCommand(const ur_exp_command_buffer_update_kernel_launch_desc_t
12861297
&UpdateCommandDesc) {
1287-
auto Command =
1288-
static_cast<kernel_command_handle *>(UpdateCommandDesc.hCommand);
1298+
auto *Command = UpdateCommandDesc.hCommand;
1299+
auto &KernelData = std::get<kernel_command_data>(Command->CommandData);
12891300
if (UpdateCommandDesc.hNewKernel) {
1290-
Command->Kernel = UpdateCommandDesc.hNewKernel;
1301+
KernelData.Kernel = UpdateCommandDesc.hNewKernel;
12911302
}
12921303

12931304
if (UpdateCommandDesc.newWorkDim) {
1294-
Command->WorkDim = UpdateCommandDesc.newWorkDim;
1305+
KernelData.WorkDim = UpdateCommandDesc.newWorkDim;
12951306
}
12961307

12971308
if (UpdateCommandDesc.pNewGlobalWorkOffset) {
1298-
Command->setGlobalOffset(UpdateCommandDesc.pNewGlobalWorkOffset);
1309+
KernelData.setGlobalOffset(UpdateCommandDesc.pNewGlobalWorkOffset);
12991310
}
13001311

13011312
if (UpdateCommandDesc.pNewGlobalWorkSize) {
1302-
Command->setGlobalSize(UpdateCommandDesc.pNewGlobalWorkSize);
1313+
KernelData.setGlobalSize(UpdateCommandDesc.pNewGlobalWorkSize);
13031314
if (!UpdateCommandDesc.pNewLocalWorkSize) {
1304-
Command->setNullLocalSize();
1315+
KernelData.setNullLocalSize();
13051316
}
13061317
}
13071318

13081319
if (UpdateCommandDesc.pNewLocalWorkSize) {
1309-
Command->setLocalSize(UpdateCommandDesc.pNewLocalWorkSize);
1320+
KernelData.setLocalSize(UpdateCommandDesc.pNewLocalWorkSize);
13101321
}
13111322

13121323
return UR_RESULT_SUCCESS;
@@ -1334,27 +1345,27 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
13341345

13351346
// If no work-size is provided make sure we pass nullptr to setKernelParams
13361347
// so it can guess the local work size.
1337-
auto KernelCommandHandle =
1338-
static_cast<kernel_command_handle *>(UpdateCommandDesc.hCommand);
1339-
const bool ProvidedLocalSize = !KernelCommandHandle->isNullLocalSize();
1348+
auto *KernelCommandHandle = UpdateCommandDesc.hCommand;
1349+
auto &KernelData =
1350+
std::get<kernel_command_data>(KernelCommandHandle->CommandData);
1351+
const bool ProvidedLocalSize = !KernelData.isNullLocalSize();
13401352
size_t *LocalWorkSize =
1341-
ProvidedLocalSize ? KernelCommandHandle->LocalWorkSize : nullptr;
1353+
ProvidedLocalSize ? KernelData.LocalWorkSize : nullptr;
13421354

13431355
// Set the number of threads per block to the number of threads per warp
13441356
// by default unless user has provided a better number.
13451357
size_t ThreadsPerBlock[3] = {32u, 1u, 1u};
13461358
size_t BlocksPerGrid[3] = {1u, 1u, 1u};
1347-
CUfunction CuFunc = KernelCommandHandle->Kernel->get();
1359+
CUfunction CuFunc = KernelData.Kernel->get();
13481360
auto Result = setKernelParams(
1349-
hCommandBuffer->Context, hCommandBuffer->Device,
1350-
KernelCommandHandle->WorkDim, KernelCommandHandle->GlobalWorkOffset,
1351-
KernelCommandHandle->GlobalWorkSize, LocalWorkSize,
1352-
KernelCommandHandle->Kernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
1361+
hCommandBuffer->Context, hCommandBuffer->Device, KernelData.WorkDim,
1362+
KernelData.GlobalWorkOffset, KernelData.GlobalWorkSize, LocalWorkSize,
1363+
KernelData.Kernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
13531364
if (Result != UR_RESULT_SUCCESS) {
13541365
return Result;
13551366
}
13561367

1357-
CUDA_KERNEL_NODE_PARAMS &Params = KernelCommandHandle->Params;
1368+
CUDA_KERNEL_NODE_PARAMS &Params = KernelData.Params;
13581369

13591370
Params.func = CuFunc;
13601371
Params.gridDimX = BlocksPerGrid[0];
@@ -1363,9 +1374,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
13631374
Params.blockDimX = ThreadsPerBlock[0];
13641375
Params.blockDimY = ThreadsPerBlock[1];
13651376
Params.blockDimZ = ThreadsPerBlock[2];
1366-
Params.sharedMemBytes = KernelCommandHandle->Kernel->getLocalSize();
1367-
Params.kernelParams = const_cast<void **>(
1368-
KernelCommandHandle->Kernel->getArgPointers().data());
1377+
Params.sharedMemBytes = KernelData.Kernel->getLocalSize();
1378+
Params.kernelParams =
1379+
const_cast<void **>(KernelData.Kernel->getArgPointers().data());
13691380

13701381
CUgraphNode Node = KernelCommandHandle->Node;
13711382
CUgraphExec CudaGraphExec = hCommandBuffer->CudaGraphExec;

0 commit comments

Comments
 (0)