diff --git a/unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp b/unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp index dc5e5ab77034d..439afeae99a26 100644 --- a/unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp +++ b/unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp @@ -51,6 +51,48 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices, return UR_RESULT_SUCCESS; } +ur_result_t urEnqueueUSMFill2DFallback(ur_queue_handle_t hQueue, void *pMem, + size_t pitch, size_t patternSize, + const void *pPattern, size_t width, + size_t height, + uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent) { + ur_result_t Result = getContext()->urDdiTable.Enqueue.pfnUSMFill2D( + hQueue, pMem, pitch, patternSize, pPattern, width, height, + numEventsInWaitList, phEventWaitList, phEvent); + if (Result == UR_RESULT_SUCCESS || + Result != UR_RESULT_ERROR_UNSUPPORTED_FEATURE) { + return Result; + } + + // fallback code + auto pfnUSMFill = getContext()->urDdiTable.Enqueue.pfnUSMFill; + + std::vector WaitEvents(numEventsInWaitList); + + for (size_t HeightIndex = 0; HeightIndex < height; HeightIndex++) { + ur_event_handle_t Event = nullptr; + + UR_CALL(pfnUSMFill(hQueue, (void *)((char *)pMem + pitch * HeightIndex), + patternSize, pPattern, width, WaitEvents.size(), + WaitEvents.data(), &Event)); + + WaitEvents.push_back(Event); + } + + if (phEvent) { + UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait( + hQueue, WaitEvents.size(), WaitEvents.data(), phEvent)); + } + + for (const auto Event : WaitEvents) { + UR_CALL(getContext()->urDdiTable.Event.pfnRelease(Event)); + } + + return UR_RESULT_SUCCESS; +} + } // namespace /////////////////////////////////////////////////////////////////////////////// @@ -1726,11 +1768,6 @@ ur_result_t urEnqueueUSMMemcpy2D( { auto pfnUSMMemcpy = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy; - std::vector WaitEvents(numEventsInWaitList); - for (uint32_t i = 0; i < numEventsInWaitList; i++) { - WaitEvents[i] = phEventWaitList[i]; - } - for (size_t HeightIndex = 0; HeightIndex < height; HeightIndex++) { ur_event_handle_t Event = nullptr; const auto DstOrigin = @@ -1742,8 +1779,8 @@ ur_result_t urEnqueueUSMMemcpy2D( width - 1) + MSAN_ORIGIN_GRANULARITY; pfnUSMMemcpy(hQueue, false, (void *)DstOrigin, (void *)SrcOrigin, - SrcOriginEnd - SrcOrigin, WaitEvents.size(), - WaitEvents.data(), &Event); + SrcOriginEnd - SrcOrigin, numEventsInWaitList, + phEventWaitList, &Event); Events.push_back(Event); } } @@ -1756,9 +1793,9 @@ ur_result_t urEnqueueUSMMemcpy2D( const auto DstShadow = DstDI->Shadow->MemToShadow((uptr)pDst); const char Pattern = 0; ur_event_handle_t Event = nullptr; - UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill2D( - hQueue, (void *)DstShadow, dstPitch, 1, &Pattern, width, height, 0, - nullptr, &Event)); + UR_CALL(urEnqueueUSMFill2DFallback(hQueue, (void *)DstShadow, dstPitch, 1, + &Pattern, width, height, 0, nullptr, + &Event)); Events.push_back(Event); } @@ -1767,7 +1804,7 @@ ur_result_t urEnqueueUSMMemcpy2D( hQueue, Events.size(), Events.data(), phEvent)); } - for (const auto &E : Events) + for (const auto E : Events) UR_CALL(getContext()->urDdiTable.Event.pfnRelease(E)); return UR_RESULT_SUCCESS;