diff --git a/unified-runtime/source/adapters/offload/device.cpp b/unified-runtime/source/adapters/offload/device.cpp index ebe0405b8917e..8ff5f3d53ddcd 100644 --- a/unified-runtime/source/adapters/offload/device.cpp +++ b/unified-runtime/source/adapters/offload/device.cpp @@ -43,6 +43,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet); ol_device_info_t olInfo; + bool isVec3{false}; switch (propName) { case UR_DEVICE_INFO_NAME: olInfo = OL_DEVICE_INFO_NAME; @@ -76,6 +77,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, return ReturnValue(uint32_t{1}); case UR_DEVICE_INFO_MAX_WORK_ITEM_DIMENSIONS: return ReturnValue(uint32_t{3}); + case UR_DEVICE_INFO_MAX_WORK_ITEM_SIZES: + olInfo = OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE; + isVec3 = true; + break; // Unimplemented features case UR_DEVICE_INFO_PROGRAM_SET_SPECIALIZATION_CONSTANTS: case UR_DEVICE_INFO_GLOBAL_VARIABLE_SUPPORT: @@ -93,6 +98,26 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; } + // OL dimensions are uint32_t while UR is size_t, so they need to be mapped + if (isVec3) { + if (pPropSizeRet) { + *pPropSizeRet = sizeof(size_t) * 3; + } + + if (pPropValue) { + ol_dimensions_t olVec; + size_t *urVec = reinterpret_cast(pPropValue); + OL_RETURN_ON_ERR(olGetDeviceInfo(hDevice->OffloadDevice, olInfo, + sizeof(olVec), &olVec)); + + urVec[0] = olVec.x; + urVec[1] = olVec.y; + urVec[2] = olVec.z; + } + + return UR_RESULT_SUCCESS; + } + if (pPropSizeRet) { OL_RETURN_ON_ERR( olGetDeviceInfoSize(hDevice->OffloadDevice, olInfo, pPropSizeRet));