Skip to content

Commit

Permalink
clean up the cudax __launch_transform code and document its purpose…
Browse files Browse the repository at this point in the history
… and design (#3526)

Co-authored-by: Michael Schellenberger Costa <[email protected]>
  • Loading branch information
ericniebler and miscco authored Feb 6, 2025
1 parent 1c792ab commit 1faabf3
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 64 deletions.
4 changes: 1 addition & 3 deletions cudax/examples/vector.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@ private:
__v_.sync_device_to_host(__str_, _Kind);
}

using __as_kernel_arg = ::cuda::std::span<_Ty>;

operator ::cuda::std::span<_Ty>()
::cuda::std::span<_Ty> kernel_transform() const
{
return {__v_.__d_.data().get(), __v_.__d_.size()};
}
Expand Down
5 changes: 3 additions & 2 deletions cudax/include/cuda/experimental/__algorithm/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace cuda::experimental
{

template <typename _Tp>
_CCCL_CONCEPT __valid_1d_copy_fill_argument = _CUDA_VRANGES::contiguous_range<detail::__as_copy_arg_t<_Tp>>;
_CCCL_CONCEPT __valid_1d_copy_fill_argument = _CUDA_VRANGES::contiguous_range<kernel_arg_t<_Tp>>;

template <typename _Tp, typename _Decayed = _CUDA_VSTD::decay_t<_Tp>>
using __as_mdspan_t =
Expand All @@ -50,7 +50,8 @@ inline constexpr bool
true;

template <typename _Tp>
inline constexpr bool __valid_nd_copy_fill_argument = __convertible_to_mdspan<detail::__as_copy_arg_t<_Tp>>;
inline constexpr bool __valid_nd_copy_fill_argument =
__convertible_to_mdspan<__kernel_transform_result_t<__launch_transform_result_t<_Tp>>>;

} // namespace cuda::experimental
#endif //__CUDAX_ALGORITHM_COMMON
14 changes: 6 additions & 8 deletions cudax/include/cuda/experimental/__algorithm/copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,8 @@ void copy_bytes(stream_ref __stream, _SrcTy&& __src, _DstTy&& __dst)
{
__copy_bytes_impl(
__stream,
_CUDA_VSTD::span(static_cast<detail::__as_copy_arg_t<_SrcTy>>(
detail::__launch_transform(__stream, _CUDA_VSTD::forward<_SrcTy>(__src)))),
_CUDA_VSTD::span(static_cast<detail::__as_copy_arg_t<_DstTy>>(
detail::__launch_transform(__stream, _CUDA_VSTD::forward<_DstTy>(__dst)))));
_CUDA_VSTD::span(__kernel_transform(__launch_transform(__stream, _CUDA_VSTD::forward<_SrcTy>(__src)))),
_CUDA_VSTD::span(__kernel_transform(__launch_transform(__stream, _CUDA_VSTD::forward<_DstTy>(__dst)))));
}

template <typename _SrcExtents, typename _DstExtents>
Expand Down Expand Up @@ -134,10 +132,10 @@ _CCCL_TEMPLATE(typename _SrcTy, typename _DstTy)
_CCCL_REQUIRES(__valid_nd_copy_fill_argument<_SrcTy> _CCCL_AND __valid_nd_copy_fill_argument<_DstTy>)
void copy_bytes(stream_ref __stream, _SrcTy&& __src, _DstTy&& __dst)
{
decltype(auto) __src_transformed = detail::__launch_transform(__stream, _CUDA_VSTD::forward<_SrcTy>(__src));
decltype(auto) __dst_transformed = detail::__launch_transform(__stream, _CUDA_VSTD::forward<_DstTy>(__dst));
decltype(auto) __src_as_arg = static_cast<detail::__as_copy_arg_t<_SrcTy>>(__src_transformed);
decltype(auto) __dst_as_arg = static_cast<detail::__as_copy_arg_t<_DstTy>>(__dst_transformed);
decltype(auto) __src_transformed = __launch_transform(__stream, _CUDA_VSTD::forward<_SrcTy>(__src));
decltype(auto) __dst_transformed = __launch_transform(__stream, _CUDA_VSTD::forward<_DstTy>(__dst));
decltype(auto) __src_as_arg = __kernel_transform(__src_transformed);
decltype(auto) __dst_as_arg = __kernel_transform(__dst_transformed);
__nd_copy_bytes_impl(
__stream, __as_mdspan_t<decltype(__src_as_arg)>(__src_as_arg), __as_mdspan_t<decltype(__dst_as_arg)>(__dst_as_arg));
}
Expand Down
12 changes: 6 additions & 6 deletions cudax/include/cuda/experimental/__algorithm/fill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ _CCCL_TEMPLATE(typename _DstTy)
_CCCL_REQUIRES(__valid_1d_copy_fill_argument<_DstTy>)
void fill_bytes(stream_ref __stream, _DstTy&& __dst, uint8_t __value)
{
__fill_bytes_impl(__stream,
_CUDA_VSTD::span(static_cast<detail::__as_copy_arg_t<_DstTy>>(
detail::__launch_transform(__stream, _CUDA_VSTD::forward<_DstTy>(__dst)))),
__value);
__fill_bytes_impl(
__stream,
_CUDA_VSTD::span(__kernel_transform(__launch_transform(__stream, _CUDA_VSTD::forward<_DstTy>(__dst)))),
__value);
}

//! @brief Launches an operation to bytewise fill the memory into the provided stream.
Expand All @@ -77,8 +77,8 @@ _CCCL_TEMPLATE(typename _DstTy)
_CCCL_REQUIRES(__valid_nd_copy_fill_argument<_DstTy>)
void fill_bytes(stream_ref __stream, _DstTy&& __dst, uint8_t __value)
{
decltype(auto) __dst_transformed = detail::__launch_transform(__stream, _CUDA_VSTD::forward<_DstTy>(__dst));
decltype(auto) __dst_as_arg = static_cast<detail::__as_copy_arg_t<_DstTy>>(__dst_transformed);
decltype(auto) __dst_transformed = __launch_transform(__stream, _CUDA_VSTD::forward<_DstTy>(__dst));
decltype(auto) __dst_as_arg = __kernel_transform(__dst_transformed);
auto __dst_mdspan = __as_mdspan_t<decltype(__dst_as_arg)>(__dst_as_arg);

__fill_bytes_impl(
Expand Down
20 changes: 8 additions & 12 deletions cudax/include/cuda/experimental/__launch/launch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,27 +126,23 @@ void launch(
__ensure_current_device __dev_setter(stream);
cudaError_t status;
auto combined = conf.combine_with_default(kernel);
if constexpr (::cuda::std::is_invocable_v<Kernel, kernel_config<Dimensions, Config...>, as_kernel_arg_t<Args>...>)
if constexpr (::cuda::std::is_invocable_v<Kernel, kernel_config<Dimensions, Config...>, kernel_arg_t<Args>...>)
{
auto launcher = detail::kernel_launcher<decltype(combined), Kernel, as_kernel_arg_t<Args>...>;
auto launcher = detail::kernel_launcher<decltype(combined), Kernel, kernel_arg_t<Args>...>;
status = detail::launch_impl(
stream,
combined,
launcher,
combined,
kernel,
static_cast<as_kernel_arg_t<Args>>(detail::__launch_transform(stream, std::forward<Args>(args)))...);
__kernel_transform(__launch_transform(stream, std::forward<Args>(args)))...);
}
else
{
static_assert(::cuda::std::is_invocable_v<Kernel, as_kernel_arg_t<Args>...>);
auto launcher = detail::kernel_launcher_no_config<Kernel, as_kernel_arg_t<Args>...>;
static_assert(::cuda::std::is_invocable_v<Kernel, kernel_arg_t<Args>...>);
auto launcher = detail::kernel_launcher_no_config<Kernel, kernel_arg_t<Args>...>;
status = detail::launch_impl(
stream,
combined,
launcher,
kernel,
static_cast<as_kernel_arg_t<Args>>(detail::__launch_transform(stream, std::forward<Args>(args)))...);
stream, combined, launcher, kernel, __kernel_transform(__launch_transform(stream, std::forward<Args>(args)))...);
}
if (status != cudaSuccess)
{
Expand Down Expand Up @@ -206,7 +202,7 @@ void launch(::cuda::stream_ref stream,
conf,
kernel,
conf,
static_cast<as_kernel_arg_t<ActArgs>>(detail::__launch_transform(stream, std::forward<ActArgs>(args)))...);
__kernel_transform(__launch_transform(stream, std::forward<ActArgs>(args)))...);

if (status != cudaSuccess)
{
Expand Down Expand Up @@ -264,7 +260,7 @@ void launch(::cuda::stream_ref stream,
stream, //
conf,
kernel,
static_cast<as_kernel_arg_t<ActArgs>>(detail::__launch_transform(stream, std::forward<ActArgs>(args)))...);
__kernel_transform(__launch_transform(stream, std::forward<ActArgs>(args)))...);

if (status != cudaSuccess)
{
Expand Down
83 changes: 57 additions & 26 deletions cudax/include/cuda/experimental/__launch/launch_transform.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,48 @@
#if _CCCL_STD_VER >= 2017
namespace cuda::experimental
{
namespace detail
namespace __transforms
{
// Types should define overloads of __cudax_launch_transform that are find-able
// by ADL in order to customize how cudax::launch handles that type.
// Launch transform:
//
// The launch transform is a mechanism to transform arguments passed to the
// cudax::launch API prior to actually launching a kernel. This is useful for
// example, to automatically convert contiguous ranges into spans. It is also
// useful for executing per-argument actions before and after the kernel launch.
// A host_vector might want a pre-launch action to copy data from host to device
// and a post-launch action to copy data back from device to host.
//
// The launch transform happens in two steps. First, `cudax::launch` calls
// __launch_transform on each argument. If the argument has hooked the
// __launch_transform customization point, this returns a temporary object that
// has the pre-launch action in its constructor and the post-launch action in
// its destructor. The temporaries are all constructed before launching the
// kernel, and they are all destroyed immediately after, at the end of the full
// expression that performs the launch. If the `cudax::launch` argument has not
// hooked the __launch_transform customization point, then the argument is
// passed through.
//
// The result of __launch_transform is not necessarily what is passed to the
// kernel though. If __launch_transform returns an object with a
// `.kernel_transform()` member function, then `cudax::launch` will call that
// function. Its result is what gets passed as an argument to the kernel. If the
// __launch_transform result does not have a `.kernel_transform()` member
// function, then the __launch_transform result itself is passed to the kernel.

void __cudax_launch_transform();

// Types that want to customize `__launch_transform` should define overloads of
// __cudax_launch_transform that are find-able by ADL.
template <typename _Arg>
using __launch_transform_direct_result_t =
decltype(__cudax_launch_transform(::cuda::stream_ref{}, _CUDA_VSTD::declval<_Arg>()));

struct __fn
struct __launch_fn
{
template <typename _Arg>
_CCCL_NODISCARD decltype(auto) operator()(::cuda::stream_ref __stream, _Arg&& __arg) const
{
if constexpr (::cuda::std::_IsValidExpansion<__launch_transform_direct_result_t, _Arg>::value)
if constexpr (_CUDA_VSTD::_IsValidExpansion<__launch_transform_direct_result_t, _Arg>::value)
{
// This call is unqualified to allow ADL
return __cudax_launch_transform(__stream, _CUDA_VSTD::forward<_Arg>(__arg));
Expand All @@ -56,37 +84,40 @@ struct __fn
};

template <typename _Arg>
using __launch_transform_result_t = decltype(__fn{}(::cuda::stream_ref{}, _CUDA_VSTD::declval<_Arg>()));
using __launch_transform_result_t = decltype(__launch_fn{}(::cuda::stream_ref{}, _CUDA_VSTD::declval<_Arg>()));

template <typename _Arg, typename _Enable = void>
struct __as_copy_arg
{
using type = __launch_transform_result_t<_Arg>;
};

// Copy needs to know if original value is a reference
template <typename _Arg>
struct __as_copy_arg<_Arg,
_CUDA_VSTD::void_t<typename _CUDA_VSTD::decay_t<__launch_transform_result_t<_Arg>>::__as_kernel_arg>>
using __kernel_transform_direct_result_t = decltype(_CUDA_VSTD::declval<_Arg>().kernel_transform());

struct __kernel_fn
{
using type = typename _CUDA_VSTD::decay_t<__launch_transform_result_t<_Arg>>::__as_kernel_arg;
template <typename _Arg>
_CCCL_NODISCARD decltype(auto) operator()(_Arg&& __arg) const
{
if constexpr (_CUDA_VSTD::_IsValidExpansion<__kernel_transform_direct_result_t, _Arg>::value)
{
return _CUDA_VSTD::forward<_Arg>(__arg).kernel_transform();
}
else
{
return _CUDA_VSTD::forward<_Arg>(__arg);
}
}
};

template <typename _Arg>
using __as_copy_arg_t = typename detail::__as_copy_arg<_Arg>::type;
using __kernel_transform_result_t = decltype(__kernel_fn{}(_CUDA_VSTD::declval<_Arg>()));

// While kernel argument can't be a reference
template <typename _Arg>
struct __as_kernel_arg
{
using type = _CUDA_VSTD::decay_t<typename __as_copy_arg<_Arg>::type>;
};
} // namespace __transforms

using __transforms::__kernel_transform_result_t;
using __transforms::__launch_transform_result_t;

_CCCL_GLOBAL_CONSTANT __fn __launch_transform{};
} // namespace detail
_CCCL_GLOBAL_CONSTANT __transforms::__launch_fn __launch_transform{};
_CCCL_GLOBAL_CONSTANT __transforms::__kernel_fn __kernel_transform{};

template <typename _Arg>
using as_kernel_arg_t = typename detail::__as_kernel_arg<_Arg>::type;
using kernel_arg_t = _CUDA_VSTD::decay_t<__kernel_transform_result_t<__launch_transform_result_t<_Arg>>>;

} // namespace cuda::experimental

Expand Down
5 changes: 4 additions & 1 deletion cudax/test/algorithm/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ struct weird_buffer
int* data;
std::size_t size;

using __as_kernel_arg = AsKernelArg;
AsKernelArg kernel_transform()
{
return *this;
};

operator cuda::std::span<int>()
{
Expand Down
4 changes: 1 addition & 3 deletions cudax/test/launch/launch_smoke.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,8 @@ struct launch_transform_to_int_convertible
CUDAX_CHECK(kernel_run_proof);
}

using __as_kernel_arg = int;

// This is the value that will be passed to the kernel
explicit operator int() const
int kernel_transform() const
{
return value_;
}
Expand Down
4 changes: 1 addition & 3 deletions examples/cudax/vector_add/vector.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@ private:
__v_.sync_device_to_host(__str_, _Kind);
}

using __as_kernel_arg = ::cuda::std::span<_Ty>;

operator ::cuda::std::span<_Ty>()
::cuda::std::span<_Ty> kernel_transform()
{
return {__v_.__d_.data().get(), __v_.__d_.size()};
}
Expand Down

0 comments on commit 1faabf3

Please sign in to comment.