Skip to content

[NFC][SYCL] Prefer to pass context_impl by raw ptr/ref #18936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: sycl
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sycl/include/sycl/interop_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ class interop_handle {
friend class detail::DispatchHostTask;
using ReqToMem = std::pair<detail::AccessorImplHost *, ur_mem_handle_t>;

#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
// Clean this up (no shared pointers). Not doing it right now because I expect
// there will be several iterations of simplifications possible and it would
// be hard to track which of them made their way into a minor public release
// and which didn't. Let's just clean it up once during ABI breaking window.
#endif
interop_handle(std::vector<ReqToMem> MemObjs,
const std::shared_ptr<detail::queue_impl> &Queue,
const std::shared_ptr<detail::device_impl> &Device,
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ kernel make_kernel(const context &TargetContext,

// Construct the SYCL queue from UR queue.
return detail::createSyclObjFromImpl<kernel>(
std::make_shared<kernel_impl>(UrKernel, ContextImpl, KernelBundleImpl));
std::make_shared<kernel_impl>(UrKernel, *ContextImpl, KernelBundleImpl));
}

kernel make_kernel(ur_native_handle_t NativeHandle,
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/async_alloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void *async_malloc(sycl::handler &h, sycl::usm::alloc kind, size_t size) {
sycl::make_error_code(sycl::errc::feature_not_supported),
"Only device backed asynchronous allocations are supported!");

auto &Adapter = h.getContextImplPtr()->getAdapter();
auto &Adapter = detail::getSyclObjImpl(h)->get_context().getAdapter();

// Get CG event dependencies for this allocation.
const auto &DepEvents = h.impl->CGData.MEvents;
Expand Down Expand Up @@ -117,7 +117,7 @@ __SYCL_EXPORT void *async_malloc(const sycl::queue &q, sycl::usm::alloc kind,
__SYCL_EXPORT void *async_malloc_from_pool(sycl::handler &h, size_t size,
const memory_pool &pool) {

auto &Adapter = h.getContextImplPtr()->getAdapter();
auto &Adapter = detail::getSyclObjImpl(h)->get_context().getAdapter();
auto &memPoolImpl = sycl::detail::getSyclObjImpl(pool);

// Get CG event dependencies for this allocation.
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/backend_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ inline namespace _V1 {
namespace detail {

template <class T> backend getImplBackend(const T &Impl) {
return Impl->getContextImplPtr()->getBackend();
return Impl->getContextImpl().getBackend();
}

} // namespace detail
Expand Down
26 changes: 13 additions & 13 deletions sycl/source/detail/bindless_images.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -813,9 +813,9 @@ get_image_memory_support(const image_descriptor &imageDescriptor,
const sycl::context &syclContext) {
std::shared_ptr<sycl::detail::device_impl> DevImpl =
sycl::detail::getSyclObjImpl(syclDevice);
std::shared_ptr<sycl::detail::context_impl> CtxImpl =
sycl::detail::getSyclObjImpl(syclContext);
const sycl::detail::AdapterPtr &Adapter = CtxImpl->getAdapter();
sycl::detail::context_impl &CtxImpl =
*sycl::detail::getSyclObjImpl(syclContext);
const sycl::detail::AdapterPtr &Adapter = CtxImpl.getAdapter();

ur_image_desc_t urDesc;
ur_image_format_t urFormat;
Expand All @@ -825,15 +825,15 @@ get_image_memory_support(const image_descriptor &imageDescriptor,
Adapter->call<sycl::errc::runtime,
sycl::detail::UrApiKind::
urBindlessImagesGetImageMemoryHandleTypeSupportExp>(
CtxImpl->getHandleRef(), DevImpl->getHandleRef(), &urDesc, &urFormat,
CtxImpl.getHandleRef(), DevImpl->getHandleRef(), &urDesc, &urFormat,
ur_exp_image_mem_type_t::UR_EXP_IMAGE_MEM_TYPE_USM_POINTER,
&supportsPointerAllocation);

ur_bool_t supportsOpaqueAllocation{0};
Adapter->call<sycl::errc::runtime,
sycl::detail::UrApiKind::
urBindlessImagesGetImageMemoryHandleTypeSupportExp>(
CtxImpl->getHandleRef(), DevImpl->getHandleRef(), &urDesc, &urFormat,
CtxImpl.getHandleRef(), DevImpl->getHandleRef(), &urDesc, &urFormat,
ur_exp_image_mem_type_t::UR_EXP_IMAGE_MEM_TYPE_OPAQUE_HANDLE,
&supportsOpaqueAllocation);

Expand Down Expand Up @@ -864,9 +864,9 @@ __SYCL_EXPORT bool is_image_handle_supported<unsampled_image_handle>(
const sycl::device &syclDevice, const sycl::context &syclContext) {
std::shared_ptr<sycl::detail::device_impl> DevImpl =
sycl::detail::getSyclObjImpl(syclDevice);
std::shared_ptr<sycl::detail::context_impl> CtxImpl =
sycl::detail::getSyclObjImpl(syclContext);
const sycl::detail::AdapterPtr &Adapter = CtxImpl->getAdapter();
sycl::detail::context_impl &CtxImpl =
*sycl::detail::getSyclObjImpl(syclContext);
const sycl::detail::AdapterPtr &Adapter = CtxImpl.getAdapter();

ur_image_desc_t urDesc;
ur_image_format_t urFormat;
Expand All @@ -881,7 +881,7 @@ __SYCL_EXPORT bool is_image_handle_supported<unsampled_image_handle>(
Adapter->call<sycl::errc::runtime,
sycl::detail::UrApiKind::
urBindlessImagesGetImageUnsampledHandleSupportExp>(
CtxImpl->getHandleRef(), DevImpl->getHandleRef(), &urDesc, &urFormat,
CtxImpl.getHandleRef(), DevImpl->getHandleRef(), &urDesc, &urFormat,
memHandleType, &supportsUnsampledHandle);

return supportsUnsampledHandle;
Expand All @@ -904,9 +904,9 @@ __SYCL_EXPORT bool is_image_handle_supported<sampled_image_handle>(
const sycl::device &syclDevice, const sycl::context &syclContext) {
std::shared_ptr<sycl::detail::device_impl> DevImpl =
sycl::detail::getSyclObjImpl(syclDevice);
std::shared_ptr<sycl::detail::context_impl> CtxImpl =
sycl::detail::getSyclObjImpl(syclContext);
const sycl::detail::AdapterPtr &Adapter = CtxImpl->getAdapter();
sycl::detail::context_impl &CtxImpl =
*sycl::detail::getSyclObjImpl(syclContext);
const sycl::detail::AdapterPtr &Adapter = CtxImpl.getAdapter();

ur_image_desc_t urDesc;
ur_image_format_t urFormat;
Expand All @@ -921,7 +921,7 @@ __SYCL_EXPORT bool is_image_handle_supported<sampled_image_handle>(
Adapter->call<
sycl::errc::runtime,
sycl::detail::UrApiKind::urBindlessImagesGetImageSampledHandleSupportExp>(
CtxImpl->getHandleRef(), DevImpl->getHandleRef(), &urDesc, &urFormat,
CtxImpl.getHandleRef(), DevImpl->getHandleRef(), &urDesc, &urFormat,
memHandleType, &supportsSampledHandle);

return supportsSampledHandle;
Expand Down
46 changes: 21 additions & 25 deletions sycl/source/detail/device_image_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ constexpr uint8_t ImageOriginKernelCompiler = 1 << 2;
class ManagedDeviceGlobalsRegistry {
public:
ManagedDeviceGlobalsRegistry(
const std::shared_ptr<context_impl> &ContextImpl,
const std::string &Prefix, std::vector<std::string> &&DeviceGlobalNames,
context_impl &ContextImpl, const std::string &Prefix,
std::vector<std::string> &&DeviceGlobalNames,
std::vector<std::unique_ptr<std::byte[]>> &&DeviceGlobalAllocations)
: MContextImpl{ContextImpl}, MPrefix{Prefix},
: MContextImpl{ContextImpl.shared_from_this()}, MPrefix{Prefix},
MDeviceGlobalNames{std::move(DeviceGlobalNames)},
MDeviceGlobalAllocations{std::move(DeviceGlobalAllocations)} {}

Expand Down Expand Up @@ -570,13 +570,13 @@ class device_image_impl {

ur_native_handle_t getNative() const {
assert(MProgram);
const auto &ContextImplPtr = detail::getSyclObjImpl(MContext);
const AdapterPtr &Adapter = ContextImplPtr->getAdapter();
context_impl &ContextImpl = *detail::getSyclObjImpl(MContext);
const AdapterPtr &Adapter = ContextImpl.getAdapter();

ur_native_handle_t NativeProgram = 0;
Adapter->call<UrApiKind::urProgramGetNativeHandle>(MProgram,
&NativeProgram);
if (ContextImplPtr->getBackend() == backend::opencl)
if (ContextImpl.getBackend() == backend::opencl)
__SYCL_OCL_CALL(clRetainProgram, ur::cast<cl_program>(NativeProgram));

return NativeProgram;
Expand Down Expand Up @@ -638,7 +638,7 @@ class device_image_impl {
auto [UrKernel, CacheMutex, ArgMask] =
PM.getOrCreateKernel(Context, AdjustedName,
/*PropList=*/{}, UrProgram);
return std::make_shared<kernel_impl>(UrKernel, getSyclObjImpl(Context),
return std::make_shared<kernel_impl>(UrKernel, *getSyclObjImpl(Context),
Self, OwnerBundle, ArgMask,
UrProgram, CacheMutex);
}
Expand All @@ -653,7 +653,7 @@ class device_image_impl {
// Kernel created by urKernelCreate is implicitly retained.

return std::make_shared<kernel_impl>(
UrKernel, detail::getSyclObjImpl(Context), Self, OwnerBundle,
UrKernel, *detail::getSyclObjImpl(Context), Self, OwnerBundle,
/*ArgMask=*/nullptr, UrProgram, /*CacheMutex=*/nullptr);
}

Expand Down Expand Up @@ -704,12 +704,11 @@ class device_image_impl {
assert(MRTCBinInfo);
assert(MOrigins & ImageOriginKernelCompiler);

const std::shared_ptr<sycl::detail::context_impl> &ContextImpl =
getSyclObjImpl(MContext);
sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext);

for (const auto &SyclDev : Devices) {
device_impl &DevImpl = *getSyclObjImpl(SyclDev);
if (!ContextImpl->hasDevice(DevImpl)) {
if (!ContextImpl.hasDevice(DevImpl)) {
throw sycl::exception(make_error_code(errc::invalid),
"device not part of kernel_bundle context");
}
Expand Down Expand Up @@ -742,7 +741,7 @@ class device_image_impl {
Devices, BuildOptions, *SourceStrPtr, UrProgram);
}

const AdapterPtr &Adapter = ContextImpl->getAdapter();
const AdapterPtr &Adapter = ContextImpl.getAdapter();

if (!FetchedFromCache)
UrProgram = createProgramFromSource(Devices, BuildOptions, LogPtr);
Expand All @@ -752,7 +751,7 @@ class device_image_impl {
UrProgram, DeviceVec.size(), DeviceVec.data(), XsFlags.c_str());
if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
Res = Adapter->call_nocheck<UrApiKind::urProgramBuild>(
ContextImpl->getHandleRef(), UrProgram, XsFlags.c_str());
ContextImpl.getHandleRef(), UrProgram, XsFlags.c_str());
}
Adapter->checkUrResult<errc::build>(Res);

Expand Down Expand Up @@ -796,12 +795,11 @@ class device_image_impl {
"compile is only available for kernel_bundle<bundle_state::source> "
"when the source language was sycl.");

std::shared_ptr<sycl::detail::context_impl> ContextImpl =
getSyclObjImpl(MContext);
sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext);

for (const auto &SyclDev : Devices) {
detail::device_impl &DevImpl = *getSyclObjImpl(SyclDev);
if (!ContextImpl->hasDevice(DevImpl)) {
if (!ContextImpl.hasDevice(DevImpl)) {
throw sycl::exception(make_error_code(errc::invalid),
"device not part of kernel_bundle context");
}
Expand Down Expand Up @@ -873,9 +871,8 @@ class device_image_impl {
const std::vector<device> Devices,
const std::vector<sycl::detail::string_view> &BuildOptions,
const std::string &SourceStr, ur_program_handle_t &UrProgram) const {
const std::shared_ptr<sycl::detail::context_impl> &ContextImpl =
getSyclObjImpl(MContext);
const AdapterPtr &Adapter = ContextImpl->getAdapter();
sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext);
const AdapterPtr &Adapter = ContextImpl.getAdapter();

std::string UserArgs = syclex::detail::userArgsAsString(BuildOptions);

Expand Down Expand Up @@ -904,7 +901,7 @@ class device_image_impl {
Properties.pMetadatas = nullptr;

Adapter->call<UrApiKind::urProgramCreateWithBinary>(
ContextImpl->getHandleRef(), DeviceHandles.size(), DeviceHandles.data(),
ContextImpl.getHandleRef(), DeviceHandles.size(), DeviceHandles.data(),
Lengths.data(), Binaries.data(), &Properties, &UrProgram);

return true;
Expand Down Expand Up @@ -1132,7 +1129,7 @@ class device_image_impl {
}

auto DGRegs = std::make_shared<ManagedDeviceGlobalsRegistry>(
getSyclObjImpl(MContext), std::string{Prefix},
*getSyclObjImpl(MContext), std::string{Prefix},
std::move(DeviceGlobalNames), std::move(DeviceGlobalAllocations));

// Mark the image as input so the program manager will bring it into
Expand Down Expand Up @@ -1195,9 +1192,8 @@ class device_image_impl {
createProgramFromSource(const std::vector<device> Devices,
const std::vector<sycl::detail::string_view> &Options,
std::string *LogPtr) const {
const std::shared_ptr<sycl::detail::context_impl> &ContextImpl =
getSyclObjImpl(MContext);
const AdapterPtr &Adapter = ContextImpl->getAdapter();
sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext);
const AdapterPtr &Adapter = ContextImpl.getAdapter();
const auto spirv = [&]() -> std::vector<uint8_t> {
switch (MRTCBinInfo->MLanguage) {
case syclex::source_language::opencl: {
Expand Down Expand Up @@ -1234,7 +1230,7 @@ class device_image_impl {
}();

ur_program_handle_t UrProgram = nullptr;
Adapter->call<UrApiKind::urProgramCreateWithIL>(ContextImpl->getHandleRef(),
Adapter->call<UrApiKind::urProgramCreateWithIL>(ContextImpl.getHandleRef(),
spirv.data(), spirv.size(),
nullptr, &UrProgram);
// program created by urProgramCreateWithIL is implicitly retained.
Expand Down
19 changes: 14 additions & 5 deletions sycl/source/detail/event_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,10 @@ void event_impl::setHandle(const ur_event_handle_t &UREvent) {
MEvent.store(UREvent);
}

const ContextImplPtr &event_impl::getContextImpl() {
context_impl &event_impl::getContextImpl() {
initContextIfNeeded();
return MContext;
assert(MContext && "Trying to get context from a host event!");
return *MContext;
}

const AdapterPtr &event_impl::getAdapter() {
Expand All @@ -152,9 +153,17 @@ const AdapterPtr &event_impl::getAdapter() {

void event_impl::setStateIncomplete() { MState = HES_NotComplete; }

void event_impl::setContextImpl(const ContextImplPtr &Context) {
void event_impl::setContextImpl(std::shared_ptr<context_impl> &&Context) {
MIsHostEvent = Context == nullptr;
MContext = Context;
MContext = std::move(Context);
}
void event_impl::setContextImpl(context_impl &Context) {
MIsHostEvent = false;
MContext = Context.shared_from_this();
}
void event_impl::setContextImpl(context_impl *Context) {
MIsHostEvent = Context == nullptr;
MContext = Context ? Context->shared_from_this() : nullptr;
}

event_impl::event_impl(ur_event_handle_t Event, const context &SyclContext,
Expand All @@ -178,7 +187,7 @@ event_impl::event_impl(ur_event_handle_t Event, const context &SyclContext,
event_impl::event_impl(queue_impl &Queue, private_tag)
: MQueue{Queue.weak_from_this()},
MIsProfilingEnabled{Queue.MIsProfilingEnabled} {
this->setContextImpl(Queue.getContextImplPtr());
this->setContextImpl(Queue.getContextImpl());
MState.store(HES_Complete);
}

Expand Down
12 changes: 5 additions & 7 deletions sycl/source/detail/event_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,21 +174,19 @@ class event_impl : public std::enable_shared_from_this<event_impl> {
void setHandle(const ur_event_handle_t &UREvent);

/// Returns context that is associated with this event.
///
/// \return a shared pointer to a valid context_impl.
const ContextImplPtr &getContextImpl();
context_impl &getContextImpl();

/// \return the Adapter associated with the context of this event.
/// Should be called when this is not a Host Event.
const AdapterPtr &getAdapter();

/// Associate event with the context.
///
/// Provided UrContext inside ContextImplPtr must be associated
/// Provided UrContext inside Context must be associated
/// with the UrEvent object stored in this class
///
/// @param Context is a shared pointer to an instance of valid context_impl.
void setContextImpl(const ContextImplPtr &Context);
void setContextImpl(std::shared_ptr<context_impl> &&Context);
void setContextImpl(context_impl &Context);
void setContextImpl(context_impl *Context);

/// Clear the event state
void setStateIncomplete();
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ exec_graph_impl::enqueue(sycl::detail::queue_impl &Queue,

auto CreateNewEvent([&]() {
auto NewEvent = sycl::detail::event_impl::create_device_event(Queue);
NewEvent->setContextImpl(Queue.getContextImplPtr());
NewEvent->setContextImpl(Queue.getContextImpl());
NewEvent->setStateIncomplete();
return NewEvent;
});
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/handler_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class handler_impl {
template <typename Self = handler_impl> context_impl &get_context() {
Self *self = this;
if (auto *Queue = self->get_queue_or_null())
return *Queue->getContextImplPtr();
return Queue->getContextImpl();
else
return *self->get_graph().getContextImplPtr();
}
Expand Down
11 changes: 5 additions & 6 deletions sycl/source/detail/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

namespace sycl {
inline namespace _V1 {
using ContextImplPtr = std::shared_ptr<sycl::detail::context_impl>;
namespace detail {
void waitEvents(std::vector<sycl::event> DepEvents) {
for (auto SyclEvent : DepEvents) {
Expand Down Expand Up @@ -59,10 +58,10 @@ retrieveKernelBinary(queue_impl &Queue, KernelNameStrRefT KernelName,
if (DeviceImage == DeviceImages.end()) {
return {nullptr, nullptr};
}
auto ContextImpl = Queue.getContextImplPtr();
context_impl &ContextImpl = Queue.getContextImpl();
ur_program_handle_t Program =
detail::ProgramManager::getInstance().createURProgram(
**DeviceImage, *ContextImpl, {createSyclObjFromImpl<device>(Dev)});
**DeviceImage, ContextImpl, {createSyclObjFromImpl<device>(Dev)});
return {*DeviceImage, Program};
}

Expand All @@ -80,11 +79,11 @@ retrieveKernelBinary(queue_impl &Queue, KernelNameStrRefT KernelName,
DeviceImage = SyclKernelImpl->getDeviceImage()->get_bin_image_ref();
Program = SyclKernelImpl->getDeviceImage()->get_ur_program_ref();
} else {
auto ContextImpl = Queue.getContextImplPtr();
context_impl &ContextImpl = Queue.getContextImpl();
DeviceImage = &detail::ProgramManager::getInstance().getDeviceImage(
KernelName, *ContextImpl, &Dev);
KernelName, ContextImpl, &Dev);
Program = detail::ProgramManager::getInstance().createURProgram(
*DeviceImage, *ContextImpl, {createSyclObjFromImpl<device>(Dev)});
*DeviceImage, ContextImpl, {createSyclObjFromImpl<device>(Dev)});
}
return {DeviceImage, Program};
}
Expand Down
Loading