Skip to content

Commit

Permalink
Add dynamic CUB dispatch for unique_by_key (#3816)
Browse files Browse the repository at this point in the history
* Make vsmem_helper a template on the dispatch class so that we can pass it from the c.parallel layer. Also we don't need it to have any run-time state

* Pass VSMemHelper as template parameter to unique by key kernel

* Move the template parameters of VSMemHelper to its methods to workaround an nvcc 12.0 template instantiation issue with vsmem_helper_fallback_policy_t

* Make KeyT and ValueT templates for DispatchUniqueByKey
  • Loading branch information
NaderAlAwar authored Feb 28, 2025
1 parent b048cb7 commit 957eae9
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 68 deletions.
175 changes: 118 additions & 57 deletions cub/cub/device/dispatch/dispatch_unique_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#endif // no system header

#include <cub/device/dispatch/dispatch_scan.cuh>
#include <cub/device/dispatch/kernels/scan.cuh>
#include <cub/device/dispatch/kernels/unique_by_key.cuh>
#include <cub/device/dispatch/tuning/tuning_unique_by_key.cuh>
#include <cub/util_device.cuh>
Expand All @@ -51,6 +52,43 @@

CUB_NAMESPACE_BEGIN

namespace detail::unique_by_key
{
template <typename MaxPolicyT,
typename KeyInputIteratorT,
typename ValueInputIteratorT,
typename KeyOutputIteratorT,
typename ValueOutputIteratorT,
typename NumSelectedIteratorT,
typename ScanTileStateT,
typename EqualityOpT,
typename OffsetT>

struct DeviceUniqueByKeyKernelSource
{
CUB_DEFINE_KERNEL_GETTER(CompactInitKernel,
detail::scan::DeviceCompactInitKernel<ScanTileStateT, NumSelectedIteratorT>);

CUB_DEFINE_KERNEL_GETTER(
UniqueByKeySweepKernel,
DeviceUniqueByKeySweepKernel<
MaxPolicyT,
KeyInputIteratorT,
ValueInputIteratorT,
KeyOutputIteratorT,
ValueOutputIteratorT,
NumSelectedIteratorT,
ScanTileStateT,
EqualityOpT,
OffsetT>);

CUB_RUNTIME_FUNCTION ScanTileStateT TileState()
{
return ScanTileStateT();
}
};
} // namespace detail::unique_by_key

/******************************************************************************
* Dispatch
******************************************************************************/
Expand Down Expand Up @@ -88,7 +126,21 @@ template <
typename EqualityOpT,
typename OffsetT,
typename PolicyHub =
detail::unique_by_key::policy_hub<detail::it_value_t<KeyInputIteratorT>, detail::it_value_t<ValueInputIteratorT>>>
detail::unique_by_key::policy_hub<detail::it_value_t<KeyInputIteratorT>, detail::it_value_t<ValueInputIteratorT>>,
typename KernelSource = detail::unique_by_key::DeviceUniqueByKeyKernelSource<
typename PolicyHub::MaxPolicy,
KeyInputIteratorT,
ValueInputIteratorT,
KeyOutputIteratorT,
ValueOutputIteratorT,
NumSelectedIteratorT,
ScanTileState<OffsetT>,
EqualityOpT,
OffsetT>,
typename KernelLauncherFactory = detail::TripleChevronFactory,
typename VSMemHelperT = detail::unique_by_key::VSMemHelper,
typename KeyT = detail::it_value_t<KeyInputIteratorT>,
typename ValueT = detail::it_value_t<ValueInputIteratorT>>
struct DispatchUniqueByKey
{
/******************************************************************************
Expand All @@ -100,13 +152,6 @@ struct DispatchUniqueByKey
INIT_KERNEL_THREADS = 128,
};

// The input key and value type
using KeyT = detail::it_value_t<KeyInputIteratorT>;
using ValueT = detail::it_value_t<ValueInputIteratorT>;

// Tile status descriptor interface type
using ScanTileStateT = ScanTileState<OffsetT>;

/// Device-accessible allocation of temporary storage. When nullptr, the required allocation size
/// is written to `temp_storage_bytes` and no work is done.
void* d_temp_storage;
Expand Down Expand Up @@ -139,6 +184,10 @@ struct DispatchUniqueByKey
/// **[optional]** CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
cudaStream_t stream;

KernelSource kernel_source;

KernelLauncherFactory launcher_factory;

/**
* @param[in] d_temp_storage
* Device-accessible allocation of temporary storage.
Expand Down Expand Up @@ -184,7 +233,9 @@ struct DispatchUniqueByKey
NumSelectedIteratorT d_num_selected_out,
EqualityOpT equality_op,
OffsetT num_items,
cudaStream_t stream)
cudaStream_t stream,
KernelSource kernel_source = {},
KernelLauncherFactory launcher_factory = {})
: d_temp_storage(d_temp_storage)
, temp_storage_bytes(temp_storage_bytes)
, d_keys_in(d_keys_in)
Expand All @@ -195,27 +246,18 @@ struct DispatchUniqueByKey
, equality_op(equality_op)
, num_items(num_items)
, stream(stream)
, kernel_source(kernel_source)
, launcher_factory(launcher_factory)
{}

/******************************************************************************
* Dispatch entrypoints
******************************************************************************/

template <typename ActivePolicyT, typename InitKernel, typename ScanKernel>
CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t Invoke(InitKernel init_kernel, ScanKernel scan_kernel)
template <typename ActivePolicyT, typename InitKernelT, typename UniqueByKeySweepKernelT>
CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t
Invoke(InitKernelT init_kernel, UniqueByKeySweepKernelT sweep_kernel, ActivePolicyT policy = {})
{
using Policy = typename ActivePolicyT::UniqueByKeyPolicyT;

using VsmemHelperT = cub::detail::vsmem_helper_default_fallback_policy_t<
Policy,
detail::unique_by_key::AgentUniqueByKey,
KeyInputIteratorT,
ValueInputIteratorT,
KeyOutputIteratorT,
ValueOutputIteratorT,
EqualityOpT,
OffsetT>;

cudaError error = cudaSuccess;
do
{
Expand All @@ -228,17 +270,42 @@ struct DispatchUniqueByKey
}

// Number of input tiles
constexpr auto block_threads = VsmemHelperT::agent_policy_t::BLOCK_THREADS;
constexpr auto items_per_thread = VsmemHelperT::agent_policy_t::ITEMS_PER_THREAD;
int tile_size = block_threads * items_per_thread;
int num_tiles = static_cast<int>(::cuda::ceil_div(num_items, tile_size));
const auto vsmem_size = num_tiles * VsmemHelperT::vsmem_per_block;
const auto block_threads = VSMemHelperT::template BlockThreads<
typename ActivePolicyT::UniqueByKeyPolicyT,
KeyInputIteratorT,
ValueInputIteratorT,
KeyOutputIteratorT,
ValueOutputIteratorT,
EqualityOpT,
OffsetT>(policy.UniqueByKey());
const auto items_per_thread = VSMemHelperT::template ItemsPerThread<
typename ActivePolicyT::UniqueByKeyPolicyT,
KeyInputIteratorT,
ValueInputIteratorT,
KeyOutputIteratorT,
ValueOutputIteratorT,
EqualityOpT,
OffsetT>(policy.UniqueByKey());
int tile_size = block_threads * items_per_thread;
int num_tiles = static_cast<int>(::cuda::ceil_div(num_items, tile_size));
const auto vsmem_size =
num_tiles
* VSMemHelperT::template VSMemPerBlock<
typename ActivePolicyT::UniqueByKeyPolicyT,
KeyInputIteratorT,
ValueInputIteratorT,
KeyOutputIteratorT,
ValueOutputIteratorT,
EqualityOpT,
OffsetT>(policy.UniqueByKey());

// Specify temporary storage allocation requirements
size_t allocation_sizes[2] = {0, vsmem_size};

auto tile_state = kernel_source.TileState();

// Bytes needed for tile status descriptors
error = CubDebug(ScanTileStateT::AllocationSize(num_tiles, allocation_sizes[0]));
error = CubDebug(tile_state.AllocationSize(num_tiles, allocation_sizes[0]));
if (cudaSuccess != error)
{
break;
Expand All @@ -259,8 +326,6 @@ struct DispatchUniqueByKey
break;
}

// Construct the tile status interface
ScanTileStateT tile_state;
error = CubDebug(tile_state.Init(num_tiles, allocations[0], allocation_sizes[0]));
if (cudaSuccess != error)
{
Expand All @@ -276,7 +341,7 @@ struct DispatchUniqueByKey
#endif // CUB_DEBUG_LOG

// Invoke init_kernel to initialize tile descriptors
THRUST_NS_QUALIFIER::cuda_cub::detail::triple_chevron(init_grid_size, INIT_KERNEL_THREADS, 0, stream)
launcher_factory(init_grid_size, INIT_KERNEL_THREADS, 0, stream)
.doit(init_kernel, tile_state, num_tiles, d_num_selected_out);

// Check for failure to launch
Expand Down Expand Up @@ -313,13 +378,13 @@ struct DispatchUniqueByKey
scan_grid_size.y = ::cuda::ceil_div(num_tiles, max_dim_x);
scan_grid_size.x = CUB_MIN(num_tiles, max_dim_x);

// Log select_if_kernel configuration
// Log select_if_kernel configuration
#ifdef CUB_DEBUG_LOG
{
// Get SM occupancy for unique_by_key_kernel
int scan_sm_occupancy;
error = CubDebug(MaxSmOccupancy(scan_sm_occupancy, // out
scan_kernel,
int sweep_sm_occupancy;
error = CubDebug(MaxSmOccupancy(sweep_sm_occupancy, // out
sweep_kernel,
block_threads));
if (cudaSuccess != error)
{
Expand All @@ -334,14 +399,14 @@ struct DispatchUniqueByKey
block_threads,
(long long) stream,
items_per_thread,
scan_sm_occupancy);
sweep_sm_occupancy);
}
#endif // CUB_DEBUG_LOG

// Invoke select_if_kernel
error =
THRUST_NS_QUALIFIER::cuda_cub::detail::triple_chevron(scan_grid_size, block_threads, 0, stream)
.doit(scan_kernel,
launcher_factory(scan_grid_size, block_threads, 0, stream)
.doit(sweep_kernel,
d_keys_in,
d_values_in,
d_keys_out,
Expand Down Expand Up @@ -372,21 +437,11 @@ struct DispatchUniqueByKey
}

template <typename ActivePolicyT>
CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t Invoke()
CUB_RUNTIME_FUNCTION _CCCL_HOST _CCCL_FORCEINLINE cudaError_t Invoke(ActivePolicyT active_policy = {})
{
// Ensure kernels are instantiated.
return Invoke<ActivePolicyT>(
detail::scan::DeviceCompactInitKernel<ScanTileStateT, NumSelectedIteratorT>,
detail::unique_by_key::DeviceUniqueByKeySweepKernel<
typename PolicyHub::MaxPolicy,
KeyInputIteratorT,
ValueInputIteratorT,
KeyOutputIteratorT,
ValueOutputIteratorT,
NumSelectedIteratorT,
ScanTileStateT,
EqualityOpT,
OffsetT>);
auto wrapped_policy = detail::unique_by_key::MakeUniqueByKeyPolicyWrapper(active_policy);

return Invoke(kernel_source.CompactInitKernel(), kernel_source.UniqueByKeySweepKernel(), wrapped_policy);
}

/**
Expand Down Expand Up @@ -426,6 +481,7 @@ struct DispatchUniqueByKey
* **[optional]** CUDA stream to launch kernels within.
* Default is stream<sub>0</sub>.
*/
template <typename MaxPolicyT = typename PolicyHub::MaxPolicy>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Dispatch(
void* d_temp_storage,
size_t& temp_storage_bytes,
Expand All @@ -436,14 +492,17 @@ struct DispatchUniqueByKey
NumSelectedIteratorT d_num_selected_out,
EqualityOpT equality_op,
OffsetT num_items,
cudaStream_t stream)
cudaStream_t stream,
KernelSource kernel_source = {},
KernelLauncherFactory launcher_factory = {},
MaxPolicyT max_policy = {})
{
cudaError_t error;
do
{
// Get PTX version
int ptx_version = 0;
error = CubDebug(PtxVersion(ptx_version));
error = CubDebug(launcher_factory.PtxVersion(ptx_version));
if (cudaSuccess != error)
{
break;
Expand All @@ -460,10 +519,12 @@ struct DispatchUniqueByKey
d_num_selected_out,
equality_op,
num_items,
stream);
stream,
kernel_source,
launcher_factory);

// Dispatch to chained policy
error = CubDebug(PolicyHub::MaxPolicy::Invoke(ptx_version, dispatch));
error = CubDebug(max_policy.Invoke(ptx_version, dispatch));
if (cudaSuccess != error)
{
break;
Expand Down
38 changes: 33 additions & 5 deletions cub/cub/device/dispatch/kernels/unique_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,35 @@ CUB_NAMESPACE_BEGIN

namespace detail::unique_by_key
{

// TODO: this class should be templated on `typename... Ts` to avoid repetition,
// but due to an issue with NVCC 12.0 we currently template each member function
// individually instead.
struct VSMemHelper
{
template <typename ActivePolicyT, typename... Ts>
using VSMemHelperDefaultFallbackPolicyT =
vsmem_helper_default_fallback_policy_t<ActivePolicyT, detail::unique_by_key::AgentUniqueByKey, Ts...>;

template <typename ActivePolicyT, typename... Ts>
_CCCL_HOST_DEVICE static constexpr int BlockThreads(ActivePolicyT /*policy*/)
{
return VSMemHelperDefaultFallbackPolicyT<ActivePolicyT, Ts...>::agent_policy_t::BLOCK_THREADS;
}

template <typename ActivePolicyT, typename... Ts>
_CCCL_HOST_DEVICE static constexpr int ItemsPerThread(ActivePolicyT /*policy*/)
{
return VSMemHelperDefaultFallbackPolicyT<ActivePolicyT, Ts...>::agent_policy_t::ITEMS_PER_THREAD;
}

template <typename ActivePolicyT, typename... Ts>
_CCCL_HOST_DEVICE static constexpr ::cuda::std::size_t VSMemPerBlock(ActivePolicyT /*policy*/)
{
return VSMemHelperDefaultFallbackPolicyT<ActivePolicyT, Ts...>::vsmem_per_block;
}
};

/**
* @brief Unique by key kernel entry point (multi-block)
*
Expand Down Expand Up @@ -93,11 +122,11 @@ template <typename ChainedPolicyT,
typename NumSelectedIteratorT,
typename ScanTileStateT,
typename EqualityOpT,
typename OffsetT>
typename OffsetT,
typename VSMemHelperT = VSMemHelper>
__launch_bounds__(int(
vsmem_helper_default_fallback_policy_t<
VSMemHelperT::template VSMemHelperDefaultFallbackPolicyT<
typename ChainedPolicyT::ActivePolicy::UniqueByKeyPolicyT,
AgentUniqueByKey,
KeyInputIteratorT,
ValueInputIteratorT,
KeyOutputIteratorT,
Expand All @@ -116,9 +145,8 @@ __launch_bounds__(int(
int num_tiles,
vsmem_t vsmem)
{
using VsmemHelperT = vsmem_helper_default_fallback_policy_t<
using VsmemHelperT = typename VSMemHelperT::template VSMemHelperDefaultFallbackPolicyT<
typename ChainedPolicyT::ActivePolicy::UniqueByKeyPolicyT,
AgentUniqueByKey,
KeyInputIteratorT,
ValueInputIteratorT,
KeyOutputIteratorT,
Expand Down
Loading

0 comments on commit 957eae9

Please sign in to comment.