29
29
#if _CCCL_STD_VER >= 2017
30
30
namespace cuda ::experimental
31
31
{
32
- namespace detail
32
+ namespace __transforms
33
33
{
34
- // Types should define overloads of __cudax_launch_transform that are find-able
35
- // by ADL in order to customize how cudax::launch handles that type.
34
+ // Launch transform:
35
+ //
36
+ // The launch transform is a mechanism to transform arguments passed to the
37
+ // cudax::launch API prior to actually launching a kernel. This is useful for
38
+ // example, to automatically convert contiguous ranges into spans. It is also
39
+ // useful for executing per-argument actions before and after the kernel launch.
40
+ // A host_vector might want a pre-launch action to copy data from host to device
41
+ // and a post-launch action to copy data back from device to host.
42
+ //
43
+ // The launch transform happens in two steps. First, `cudax::launch` calls
44
+ // __launch_transform on each argument. If the argument has hooked the
45
+ // __launch_transform customization point, this returns a temporary object that
46
+ // has the pre-launch action in its constructor and the post-launch action in
47
+ // its destructor. The temporaries are all constructed before launching the
48
+ // kernel, and they are all destroyed immediately after, at the end of the full
49
+ // expression that performs the launch. If the `cudax::launch` argument has not
50
+ // hooked the __launch_transform customization point, then the argument is
51
+ // passed through.
52
+ //
53
+ // The result of __launch_transform is not necessarily what is passed to the
54
+ // kernel though. If __launch_transform returns an object with a
55
+ // `.kernel_transform()` member function, then `cudax::launch` will call that
56
+ // function. Its result is what gets passed as an argument to the kernel. If the
57
+ // __launch_transform result does not have a `.kernel_transform()` member
58
+ // function, then the __launch_transform result itself is passed to the kernel.
59
+
60
+ void __cudax_launch_transform ();
61
+
62
+ // Types that want to customize `__launch_transform` should define overloads of
63
+ // __cudax_launch_transform that are find-able by ADL.
36
64
template <typename _Arg>
37
65
using __launch_transform_direct_result_t =
38
66
decltype (__cudax_launch_transform(::cuda::stream_ref{}, _CUDA_VSTD::declval<_Arg>()));
39
67
40
- struct __fn
68
+ struct __launch_fn
41
69
{
42
70
template <typename _Arg>
43
71
_CCCL_NODISCARD decltype (auto ) operator()(::cuda::stream_ref __stream, _Arg&& __arg) const
44
72
{
45
- if constexpr (::cuda::std ::_IsValidExpansion<__launch_transform_direct_result_t , _Arg>::value)
73
+ if constexpr (_CUDA_VSTD ::_IsValidExpansion<__launch_transform_direct_result_t , _Arg>::value)
46
74
{
47
75
// This call is unqualified to allow ADL
48
76
return __cudax_launch_transform (__stream, _CUDA_VSTD::forward<_Arg>(__arg));
@@ -56,37 +84,40 @@ struct __fn
56
84
};
57
85
58
86
template <typename _Arg>
59
- using __launch_transform_result_t = decltype(__fn {}(::cuda::stream_ref{}, _CUDA_VSTD::declval<_Arg>()));
87
+ using __launch_transform_result_t = decltype(__launch_fn {}(::cuda::stream_ref{}, _CUDA_VSTD::declval<_Arg>()));
60
88
61
- template <typename _Arg, typename _Enable = void >
62
- struct __as_copy_arg
63
- {
64
- using type = __launch_transform_result_t <_Arg>;
65
- };
66
-
67
- // Copy needs to know if original value is a reference
68
89
template <typename _Arg>
69
- struct __as_copy_arg <_Arg,
70
- _CUDA_VSTD::void_t <typename _CUDA_VSTD::decay_t <__launch_transform_result_t <_Arg>>::__as_kernel_arg>>
90
+ using __kernel_transform_direct_result_t = decltype(_CUDA_VSTD::declval<_Arg>().kernel_transform());
91
+
92
+ struct __kernel_fn
71
93
{
72
- using type = typename _CUDA_VSTD::decay_t <__launch_transform_result_t <_Arg>>::__as_kernel_arg;
94
+ template <typename _Arg>
95
+ _CCCL_NODISCARD decltype (auto ) operator()(_Arg&& __arg) const
96
+ {
97
+ if constexpr (_CUDA_VSTD::_IsValidExpansion<__kernel_transform_direct_result_t , _Arg>::value)
98
+ {
99
+ return _CUDA_VSTD::forward<_Arg>(__arg).kernel_transform ();
100
+ }
101
+ else
102
+ {
103
+ return _CUDA_VSTD::forward<_Arg>(__arg);
104
+ }
105
+ }
73
106
};
74
107
75
108
template <typename _Arg>
76
- using __as_copy_arg_t = typename detail::__as_copy_arg <_Arg>::type ;
109
+ using __kernel_transform_result_t = decltype(__kernel_fn{}(_CUDA_VSTD::declval <_Arg>())) ;
77
110
78
- // While kernel argument can't be a reference
79
- template <typename _Arg>
80
- struct __as_kernel_arg
81
- {
82
- using type = _CUDA_VSTD::decay_t <typename __as_copy_arg<_Arg>::type>;
83
- };
111
+ } // namespace __transforms
112
+
113
+ using __transforms::__kernel_transform_result_t ;
114
+ using __transforms::__launch_transform_result_t ;
84
115
85
- _CCCL_GLOBAL_CONSTANT __fn __launch_transform{};
86
- } // namespace detail
116
+ _CCCL_GLOBAL_CONSTANT __transforms::__launch_fn __launch_transform{};
117
+ _CCCL_GLOBAL_CONSTANT __transforms::__kernel_fn __kernel_transform{};
87
118
88
119
template <typename _Arg>
89
- using as_kernel_arg_t = typename detail::__as_kernel_arg< _Arg>::type ;
120
+ using kernel_arg_t = _CUDA_VSTD:: decay_t < __kernel_transform_result_t < __launch_transform_result_t < _Arg>>> ;
90
121
91
122
} // namespace cuda::experimental
92
123
0 commit comments