Skip to content

Commit 1faabf3

Browse files
ericnieblermiscco
andauthored
clean up the cudax __launch_transform code and document its purpose and design (NVIDIA#3526)
Co-authored-by: Michael Schellenberger Costa <[email protected]>
1 parent 1c792ab commit 1faabf3

File tree

9 files changed

+87
-64
lines changed

9 files changed

+87
-64
lines changed

cudax/examples/vector.cuh

+1-3
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,7 @@ private:
109109
__v_.sync_device_to_host(__str_, _Kind);
110110
}
111111

112-
using __as_kernel_arg = ::cuda::std::span<_Ty>;
113-
114-
operator ::cuda::std::span<_Ty>()
112+
::cuda::std::span<_Ty> kernel_transform() const
115113
{
116114
return {__v_.__d_.data().get(), __v_.__d_.size()};
117115
}

cudax/include/cuda/experimental/__algorithm/common.cuh

+3-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace cuda::experimental
3232
{
3333

3434
template <typename _Tp>
35-
_CCCL_CONCEPT __valid_1d_copy_fill_argument = _CUDA_VRANGES::contiguous_range<detail::__as_copy_arg_t<_Tp>>;
35+
_CCCL_CONCEPT __valid_1d_copy_fill_argument = _CUDA_VRANGES::contiguous_range<kernel_arg_t<_Tp>>;
3636

3737
template <typename _Tp, typename _Decayed = _CUDA_VSTD::decay_t<_Tp>>
3838
using __as_mdspan_t =
@@ -50,7 +50,8 @@ inline constexpr bool
5050
true;
5151

5252
template <typename _Tp>
53-
inline constexpr bool __valid_nd_copy_fill_argument = __convertible_to_mdspan<detail::__as_copy_arg_t<_Tp>>;
53+
inline constexpr bool __valid_nd_copy_fill_argument =
54+
__convertible_to_mdspan<__kernel_transform_result_t<__launch_transform_result_t<_Tp>>>;
5455

5556
} // namespace cuda::experimental
5657
#endif //__CUDAX_ALGORITHM_COMMON

cudax/include/cuda/experimental/__algorithm/copy.cuh

+6-8
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,8 @@ void copy_bytes(stream_ref __stream, _SrcTy&& __src, _DstTy&& __dst)
6969
{
7070
__copy_bytes_impl(
7171
__stream,
72-
_CUDA_VSTD::span(static_cast<detail::__as_copy_arg_t<_SrcTy>>(
73-
detail::__launch_transform(__stream, _CUDA_VSTD::forward<_SrcTy>(__src)))),
74-
_CUDA_VSTD::span(static_cast<detail::__as_copy_arg_t<_DstTy>>(
75-
detail::__launch_transform(__stream, _CUDA_VSTD::forward<_DstTy>(__dst)))));
72+
_CUDA_VSTD::span(__kernel_transform(__launch_transform(__stream, _CUDA_VSTD::forward<_SrcTy>(__src)))),
73+
_CUDA_VSTD::span(__kernel_transform(__launch_transform(__stream, _CUDA_VSTD::forward<_DstTy>(__dst)))));
7674
}
7775

7876
template <typename _SrcExtents, typename _DstExtents>
@@ -134,10 +132,10 @@ _CCCL_TEMPLATE(typename _SrcTy, typename _DstTy)
134132
_CCCL_REQUIRES(__valid_nd_copy_fill_argument<_SrcTy> _CCCL_AND __valid_nd_copy_fill_argument<_DstTy>)
135133
void copy_bytes(stream_ref __stream, _SrcTy&& __src, _DstTy&& __dst)
136134
{
137-
decltype(auto) __src_transformed = detail::__launch_transform(__stream, _CUDA_VSTD::forward<_SrcTy>(__src));
138-
decltype(auto) __dst_transformed = detail::__launch_transform(__stream, _CUDA_VSTD::forward<_DstTy>(__dst));
139-
decltype(auto) __src_as_arg = static_cast<detail::__as_copy_arg_t<_SrcTy>>(__src_transformed);
140-
decltype(auto) __dst_as_arg = static_cast<detail::__as_copy_arg_t<_DstTy>>(__dst_transformed);
135+
decltype(auto) __src_transformed = __launch_transform(__stream, _CUDA_VSTD::forward<_SrcTy>(__src));
136+
decltype(auto) __dst_transformed = __launch_transform(__stream, _CUDA_VSTD::forward<_DstTy>(__dst));
137+
decltype(auto) __src_as_arg = __kernel_transform(__src_transformed);
138+
decltype(auto) __dst_as_arg = __kernel_transform(__dst_transformed);
141139
__nd_copy_bytes_impl(
142140
__stream, __as_mdspan_t<decltype(__src_as_arg)>(__src_as_arg), __as_mdspan_t<decltype(__dst_as_arg)>(__dst_as_arg));
143141
}

cudax/include/cuda/experimental/__algorithm/fill.cuh

+6-6
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ _CCCL_TEMPLATE(typename _DstTy)
5555
_CCCL_REQUIRES(__valid_1d_copy_fill_argument<_DstTy>)
5656
void fill_bytes(stream_ref __stream, _DstTy&& __dst, uint8_t __value)
5757
{
58-
__fill_bytes_impl(__stream,
59-
_CUDA_VSTD::span(static_cast<detail::__as_copy_arg_t<_DstTy>>(
60-
detail::__launch_transform(__stream, _CUDA_VSTD::forward<_DstTy>(__dst)))),
61-
__value);
58+
__fill_bytes_impl(
59+
__stream,
60+
_CUDA_VSTD::span(__kernel_transform(__launch_transform(__stream, _CUDA_VSTD::forward<_DstTy>(__dst)))),
61+
__value);
6262
}
6363

6464
//! @brief Launches an operation to bytewise fill the memory into the provided stream.
@@ -77,8 +77,8 @@ _CCCL_TEMPLATE(typename _DstTy)
7777
_CCCL_REQUIRES(__valid_nd_copy_fill_argument<_DstTy>)
7878
void fill_bytes(stream_ref __stream, _DstTy&& __dst, uint8_t __value)
7979
{
80-
decltype(auto) __dst_transformed = detail::__launch_transform(__stream, _CUDA_VSTD::forward<_DstTy>(__dst));
81-
decltype(auto) __dst_as_arg = static_cast<detail::__as_copy_arg_t<_DstTy>>(__dst_transformed);
80+
decltype(auto) __dst_transformed = __launch_transform(__stream, _CUDA_VSTD::forward<_DstTy>(__dst));
81+
decltype(auto) __dst_as_arg = __kernel_transform(__dst_transformed);
8282
auto __dst_mdspan = __as_mdspan_t<decltype(__dst_as_arg)>(__dst_as_arg);
8383

8484
__fill_bytes_impl(

cudax/include/cuda/experimental/__launch/launch.cuh

+8-12
Original file line numberDiff line numberDiff line change
@@ -126,27 +126,23 @@ void launch(
126126
__ensure_current_device __dev_setter(stream);
127127
cudaError_t status;
128128
auto combined = conf.combine_with_default(kernel);
129-
if constexpr (::cuda::std::is_invocable_v<Kernel, kernel_config<Dimensions, Config...>, as_kernel_arg_t<Args>...>)
129+
if constexpr (::cuda::std::is_invocable_v<Kernel, kernel_config<Dimensions, Config...>, kernel_arg_t<Args>...>)
130130
{
131-
auto launcher = detail::kernel_launcher<decltype(combined), Kernel, as_kernel_arg_t<Args>...>;
131+
auto launcher = detail::kernel_launcher<decltype(combined), Kernel, kernel_arg_t<Args>...>;
132132
status = detail::launch_impl(
133133
stream,
134134
combined,
135135
launcher,
136136
combined,
137137
kernel,
138-
static_cast<as_kernel_arg_t<Args>>(detail::__launch_transform(stream, std::forward<Args>(args)))...);
138+
__kernel_transform(__launch_transform(stream, std::forward<Args>(args)))...);
139139
}
140140
else
141141
{
142-
static_assert(::cuda::std::is_invocable_v<Kernel, as_kernel_arg_t<Args>...>);
143-
auto launcher = detail::kernel_launcher_no_config<Kernel, as_kernel_arg_t<Args>...>;
142+
static_assert(::cuda::std::is_invocable_v<Kernel, kernel_arg_t<Args>...>);
143+
auto launcher = detail::kernel_launcher_no_config<Kernel, kernel_arg_t<Args>...>;
144144
status = detail::launch_impl(
145-
stream,
146-
combined,
147-
launcher,
148-
kernel,
149-
static_cast<as_kernel_arg_t<Args>>(detail::__launch_transform(stream, std::forward<Args>(args)))...);
145+
stream, combined, launcher, kernel, __kernel_transform(__launch_transform(stream, std::forward<Args>(args)))...);
150146
}
151147
if (status != cudaSuccess)
152148
{
@@ -206,7 +202,7 @@ void launch(::cuda::stream_ref stream,
206202
conf,
207203
kernel,
208204
conf,
209-
static_cast<as_kernel_arg_t<ActArgs>>(detail::__launch_transform(stream, std::forward<ActArgs>(args)))...);
205+
__kernel_transform(__launch_transform(stream, std::forward<ActArgs>(args)))...);
210206

211207
if (status != cudaSuccess)
212208
{
@@ -264,7 +260,7 @@ void launch(::cuda::stream_ref stream,
264260
stream, //
265261
conf,
266262
kernel,
267-
static_cast<as_kernel_arg_t<ActArgs>>(detail::__launch_transform(stream, std::forward<ActArgs>(args)))...);
263+
__kernel_transform(__launch_transform(stream, std::forward<ActArgs>(args)))...);
268264

269265
if (status != cudaSuccess)
270266
{

cudax/include/cuda/experimental/__launch/launch_transform.cuh

+57-26
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,48 @@
2929
#if _CCCL_STD_VER >= 2017
3030
namespace cuda::experimental
3131
{
32-
namespace detail
32+
namespace __transforms
3333
{
34-
// Types should define overloads of __cudax_launch_transform that are find-able
35-
// by ADL in order to customize how cudax::launch handles that type.
34+
// Launch transform:
35+
//
36+
// The launch transform is a mechanism to transform arguments passed to the
37+
// cudax::launch API prior to actually launching a kernel. This is useful for
38+
// example, to automatically convert contiguous ranges into spans. It is also
39+
// useful for executing per-argument actions before and after the kernel launch.
40+
// A host_vector might want a pre-launch action to copy data from host to device
41+
// and a post-launch action to copy data back from device to host.
42+
//
43+
// The launch transform happens in two steps. First, `cudax::launch` calls
44+
// __launch_transform on each argument. If the argument has hooked the
45+
// __launch_transform customization point, this returns a temporary object that
46+
// has the pre-launch action in its constructor and the post-launch action in
47+
// its destructor. The temporaries are all constructed before launching the
48+
// kernel, and they are all destroyed immediately after, at the end of the full
49+
// expression that performs the launch. If the `cudax::launch` argument has not
50+
// hooked the __launch_transform customization point, then the argument is
51+
// passed through.
52+
//
53+
// The result of __launch_transform is not necessarily what is passed to the
54+
// kernel though. If __launch_transform returns an object with a
55+
// `.kernel_transform()` member function, then `cudax::launch` will call that
56+
// function. Its result is what gets passed as an argument to the kernel. If the
57+
// __launch_transform result does not have a `.kernel_transform()` member
58+
// function, then the __launch_transform result itself is passed to the kernel.
59+
60+
void __cudax_launch_transform();
61+
62+
// Types that want to customize `__launch_transform` should define overloads of
63+
// __cudax_launch_transform that are find-able by ADL.
3664
template <typename _Arg>
3765
using __launch_transform_direct_result_t =
3866
decltype(__cudax_launch_transform(::cuda::stream_ref{}, _CUDA_VSTD::declval<_Arg>()));
3967

40-
struct __fn
68+
struct __launch_fn
4169
{
4270
template <typename _Arg>
4371
_CCCL_NODISCARD decltype(auto) operator()(::cuda::stream_ref __stream, _Arg&& __arg) const
4472
{
45-
if constexpr (::cuda::std::_IsValidExpansion<__launch_transform_direct_result_t, _Arg>::value)
73+
if constexpr (_CUDA_VSTD::_IsValidExpansion<__launch_transform_direct_result_t, _Arg>::value)
4674
{
4775
// This call is unqualified to allow ADL
4876
return __cudax_launch_transform(__stream, _CUDA_VSTD::forward<_Arg>(__arg));
@@ -56,37 +84,40 @@ struct __fn
5684
};
5785

5886
template <typename _Arg>
59-
using __launch_transform_result_t = decltype(__fn{}(::cuda::stream_ref{}, _CUDA_VSTD::declval<_Arg>()));
87+
using __launch_transform_result_t = decltype(__launch_fn{}(::cuda::stream_ref{}, _CUDA_VSTD::declval<_Arg>()));
6088

61-
template <typename _Arg, typename _Enable = void>
62-
struct __as_copy_arg
63-
{
64-
using type = __launch_transform_result_t<_Arg>;
65-
};
66-
67-
// Copy needs to know if original value is a reference
6889
template <typename _Arg>
69-
struct __as_copy_arg<_Arg,
70-
_CUDA_VSTD::void_t<typename _CUDA_VSTD::decay_t<__launch_transform_result_t<_Arg>>::__as_kernel_arg>>
90+
using __kernel_transform_direct_result_t = decltype(_CUDA_VSTD::declval<_Arg>().kernel_transform());
91+
92+
struct __kernel_fn
7193
{
72-
using type = typename _CUDA_VSTD::decay_t<__launch_transform_result_t<_Arg>>::__as_kernel_arg;
94+
template <typename _Arg>
95+
_CCCL_NODISCARD decltype(auto) operator()(_Arg&& __arg) const
96+
{
97+
if constexpr (_CUDA_VSTD::_IsValidExpansion<__kernel_transform_direct_result_t, _Arg>::value)
98+
{
99+
return _CUDA_VSTD::forward<_Arg>(__arg).kernel_transform();
100+
}
101+
else
102+
{
103+
return _CUDA_VSTD::forward<_Arg>(__arg);
104+
}
105+
}
73106
};
74107

75108
template <typename _Arg>
76-
using __as_copy_arg_t = typename detail::__as_copy_arg<_Arg>::type;
109+
using __kernel_transform_result_t = decltype(__kernel_fn{}(_CUDA_VSTD::declval<_Arg>()));
77110

78-
// While kernel argument can't be a reference
79-
template <typename _Arg>
80-
struct __as_kernel_arg
81-
{
82-
using type = _CUDA_VSTD::decay_t<typename __as_copy_arg<_Arg>::type>;
83-
};
111+
} // namespace __transforms
112+
113+
using __transforms::__kernel_transform_result_t;
114+
using __transforms::__launch_transform_result_t;
84115

85-
_CCCL_GLOBAL_CONSTANT __fn __launch_transform{};
86-
} // namespace detail
116+
_CCCL_GLOBAL_CONSTANT __transforms::__launch_fn __launch_transform{};
117+
_CCCL_GLOBAL_CONSTANT __transforms::__kernel_fn __kernel_transform{};
87118

88119
template <typename _Arg>
89-
using as_kernel_arg_t = typename detail::__as_kernel_arg<_Arg>::type;
120+
using kernel_arg_t = _CUDA_VSTD::decay_t<__kernel_transform_result_t<__launch_transform_result_t<_Arg>>>;
90121

91122
} // namespace cuda::experimental
92123

cudax/test/algorithm/common.cuh

+4-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,10 @@ struct weird_buffer
8888
int* data;
8989
std::size_t size;
9090

91-
using __as_kernel_arg = AsKernelArg;
91+
AsKernelArg kernel_transform()
92+
{
93+
return *this;
94+
};
9295

9396
operator cuda::std::span<int>()
9497
{

cudax/test/launch/launch_smoke.cu

+1-3
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,8 @@ struct launch_transform_to_int_convertible
116116
CUDAX_CHECK(kernel_run_proof);
117117
}
118118

119-
using __as_kernel_arg = int;
120-
121119
// This is the value that will be passed to the kernel
122-
explicit operator int() const
120+
int kernel_transform() const
123121
{
124122
return value_;
125123
}

examples/cudax/vector_add/vector.cuh

+1-3
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,7 @@ private:
109109
__v_.sync_device_to_host(__str_, _Kind);
110110
}
111111

112-
using __as_kernel_arg = ::cuda::std::span<_Ty>;
113-
114-
operator ::cuda::std::span<_Ty>()
112+
::cuda::std::span<_Ty> kernel_transform()
115113
{
116114
return {__v_.__d_.data().get(), __v_.__d_.size()};
117115
}

0 commit comments

Comments
 (0)