Skip to content

Commit

Permalink
Move Cuda device-specific resource limit checking logic into the adap…
Browse files Browse the repository at this point in the history
…ter backend from the sycl runtime

This change is required in order to implement per-device semantics for the
urKernelSuggestMaxCooperativeGroupCountExp query.
  • Loading branch information
GeorgeWeb committed Sep 2, 2024
1 parent 77da3fa commit 9dcdc62
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions source/adapters/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,36 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
ur_device_handle_t Device = hKernel->getProgram()->getDevice();
ScopedContext Active(Device);
try {
// We need to calculate max num of work-groups using per-device semantics.

int MaxNumActiveGroupsPerCU{0};
UR_CHECK_ERROR(cuOccupancyMaxActiveBlocksPerMultiprocessor(
&MaxNumActiveGroupsPerCU, hKernel->get(), localWorkSize,
dynamicSharedMemorySize));
detail::ur::assertion(MaxNumActiveGroupsPerCU >= 0);

// Multiply by the number of SMs (CUs = compute units) on the device in
// order to retreive the total number of groups/blocks that can be launched.
*pGroupCountRet = Device->getNumComputeUnits() * MaxNumActiveGroupsPerCU;
// Handle the case where we can't have all SMs active with at least 1 group
// per SM. In that case, the device is still able to run 1 work-group, hence
// we will manually check if it is possible with the available HW resources.
if (MaxNumActiveGroupsPerCU == 0) {
size_t MaxWorkGroupSize{};
urKernelGetGroupInfo(
hKernel, Device, UR_KERNEL_GROUP_INFO_WORK_GROUP_SIZE,
sizeof(MaxWorkGroupSize), &MaxWorkGroupSize, nullptr);
size_t MaxLocalSizeBytes{};
urDeviceGetInfo(Device, UR_DEVICE_INFO_LOCAL_MEM_SIZE,
sizeof(MaxLocalSizeBytes), &MaxLocalSizeBytes, nullptr);
if (localWorkSize > MaxWorkGroupSize ||
dynamicSharedMemorySize > MaxLocalSizeBytes ||
hasExceededMaxRegistersPerBlock(Device, hKernel, localWorkSize))
*pGroupCountRet = 0;
else
*pGroupCountRet = 1;
} else {
// Multiply by the number of SMs (CUs = compute units) on the device in
// order to retreive the total number of groups/blocks that can be
// launched.
*pGroupCountRet = Device->getNumComputeUnits() * MaxNumActiveGroupsPerCU;
}
} catch (ur_result_t Err) {
return Err;
}
Expand Down

0 comments on commit 9dcdc62

Please sign in to comment.