radix_bits parameter in BlockRadixRank::RankKeys #3964
-
QuestionI am trying to understand the purpose of the Can you please clarify if my testing is correct, or if there's something else I've misunderstood entirely? Kernel#define NUM_POINTS 4
struct Result {
uint32_t key;
int rank;
};
template <int RADIX_BITS>
__global__ void Rank_Kernel(uint32_t* input, Result* results) {
// --- CUB Templates
constexpr int block_threads = NUM_POINTS;
using block_radix_rank = cub::BlockRadixRank<block_threads, RADIX_BITS, false>;
cub::BFEDigitExtractor<uint32_t> extractor(0, RADIX_BITS);
// --- Allocate shared memory for BlockRadixSort
using storage_t = typename block_radix_rank::TempStorage;
__shared__ storage_t temp_storage;
// --- Thread Data
uint32_t keys[1];
int ranks[1];
const auto tid = threadIdx.x;
keys[0] = input[tid];
// --- Rank Keys
block_radix_rank(temp_storage).RankKeys(keys, ranks, extractor);
__syncthreads();
// --- Write Results
results[tid] = {.key = keys[0],
.rank = ranks[0]};
} Output:.......................[ CPU ].................GPU Radix Bits: System
Main#include <cub/cub.cuh>
#include <iostream>
std::vector<Result> launch_rank_kernel(uint32_t* d_input, int radix_bits) {
// --- Allocate Device Memory
Result* d_results;
cudaMalloc(&d_results, NUM_POINTS * sizeof(Result));
// --- Kernel Dims
dim3 block(NUM_POINTS);
dim3 grid(1);
// --- Launch Template Kernel
switch (radix_bits) {
case 1:
Rank_Kernel<1><<<grid, block>>>(d_input, d_results);
break;
case 2:
Rank_Kernel<2><<<grid, block>>>(d_input, d_results);
break;
case 3:
Rank_Kernel<3><<<grid, block>>>(d_input, d_results);
break;
case 4:
Rank_Kernel<4><<<grid, block>>>(d_input, d_results);
break;
case 5:
Rank_Kernel<5><<<grid, block>>>(d_input, d_results);
break;
case 6:
Rank_Kernel<6><<<grid, block>>>(d_input, d_results);
break;
case 7:
Rank_Kernel<7><<<grid, block>>>(d_input, d_results);
break;
case 8:
Rank_Kernel<8><<<grid, block>>>(d_input, d_results);
break;
case 9:
Rank_Kernel<9><<<grid, block>>>(d_input, d_results);
break;
}
cudaDeviceSynchronize();
// --- Copy to Host
std::vector<Result> h_gpu_results(NUM_POINTS);
cudaMemcpy(h_gpu_results.data(), d_results, NUM_POINTS * sizeof(Result), cudaMemcpyDefault);
return h_gpu_results;
}
void cpu_rank(const std::vector<uint32_t>& input, std::vector<Result>& results) {
// --- Initialize [Key|Rank] Pairs
std::vector<std::pair<uint32_t, int>> key_rank_pairs;
for (size_t i=0; i < input.size(); i++) {
const auto& val = input[i];
key_rank_pairs.emplace_back(val, i);
}
// --- Sort by Key
std::sort(key_rank_pairs.begin(), key_rank_pairs.end());
// --- Map sorted keys to their original position
results.resize(input.size());
for (size_t i=0; i < key_rank_pairs.size(); i++) {
auto index = key_rank_pairs[i].second;
auto rank = static_cast<int>(i);
results[index] = {key_rank_pairs[i].first, rank};
}
}
int get_max_significant_bits(const std::vector<uint32_t>& h_points) {
if (h_points.empty())
return 0;
uint32_t max_val = *std::max_element(h_points.begin(), h_points.end());
int bits = 0;
while (max_val) {
bits++;
max_val >>= 1;
}
return bits;
}
int main()
{
// --- Host Data
//std::vector<uint32_t> h_points = {34, 33, 18, 25};
std::vector<uint32_t> h_points = {16, 10, 9, 11};
std::vector<Result> cpu_results;
cpu_rank(h_points, cpu_results);
// --- Allocate Device Memory
uint32_t* d_points;
cudaMalloc(&d_points, NUM_POINTS * sizeof(uint32_t));
cudaMemcpy(d_points, h_points.data(), NUM_POINTS * sizeof(uint32_t), cudaMemcpyDefault);
std::map<int, std::vector<Result>> all_gpu_results;
for (int bits=1; bits < 10; bits++) {
all_gpu_results[bits] = launch_rank_kernel(d_points, bits);
}
printf("\t\t\t\t[ CPU ] \t\t GPU Radix Bits:\n");
printf("\t\t\t\t\t\t\t1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 \n");
printf("-------------------------------------------------------------\n");
for (size_t i=0; i < h_points.size(); i++) {
const auto cpu_res = cpu_results[i];
printf("Input: [%3d] = Rank [ %d ] ", cpu_res.key, cpu_res.rank);
for (int bits=1; bits < 10; bits++) {
const auto& gpu_res = all_gpu_results[bits][i];
printf("| %d ", gpu_res.rank);
}
printf("\n");
}
printf("Max significant bits = %d\n", get_max_significant_bits(h_points));
} |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Hello @JohnMansell! For a while, radix rank was an implementation detail of radix sort that just happened to be in the public namespace. This makes its interface not exactly user friendly. The missing context here is that radix rank doesn't rank all bits in the keys. Instead, it ranks up to // --- CUB Templates
constexpr int block_threads = NUM_POINTS;
using block_radix_rank = cub::BlockRadixRank<block_threads, RADIX_BITS, false>;
using block_exchange = cub::BlockExchange<uint32_t, block_threads, 1>;
// --- Allocate shared memory for BlockRadixSort
using radix_rank_storage_t = typename block_radix_rank::TempStorage;
using exchange_storage_t = typename block_exchange::TempStorage;
__shared__ radix_rank_storage_t radix_rank_temp_storage;
__shared__ exchange_storage_t exchange_temp_storage;
// --- Thread Data
uint32_t keys[1];
uint32_t vals[1];
int ranks[1];
const auto tid = threadIdx.x;
keys[0] = input[tid];
vals[0] = tid;
int begin_bit = 0;
int end_bit = sizeof(uint32_t) * 8;
while (true)
{
int pass_bits = cuda::std::min(RADIX_BITS, end_bit - begin_bit);
cub::BFEDigitExtractor<uint32_t> extractor(begin_bit, pass_bits);
// --- Rank Keys
block_radix_rank(radix_rank_temp_storage).RankKeys(keys, ranks, extractor);
block_exchange(exchange_temp_storage).ScatterToBlocked(keys, ranks);
__syncthreads();
block_exchange(exchange_temp_storage).ScatterToBlocked(vals, ranks);
begin_bit += RADIX_BITS;
if (begin_bit >= end_bit)
{
break;
}
}
__syncthreads();
ranks[0] = vals[0];
vals[0] = tid;
block_exchange(exchange_temp_storage).ScatterToBlocked(vals, ranks);
// --- Write Results
results[tid] = {.key = keys[0], .rank = static_cast<int>(vals[0])}; If you need to rank more than constexpr int block_threads = NUM_POINTS;
using block_radix_sort = cub::BlockRadixSort<uint32_t, block_threads, 1, uint32_t, RADIX_BITS>;
using radix_sort_storage_t = typename block_radix_sort::TempStorage;
__shared__ radix_sort_storage_t radix_sort_temp_storage;
uint32_t keys[1] = {input[threadIdx.x]};
uint32_t vals[1] = {threadIdx.x};
block_radix_sort(radix_sort_temp_storage).Sort(keys, vals);
results[vals[0]] = {.key = keys[0], .rank = static_cast<int>(threadIdx.x)}; |
Beta Was this translation helpful? Give feedback.
Hello @JohnMansell!
For a while, radix rank was an implementation detail of radix sort that just happened to be in the public namespace. This makes its interface not exactly user friendly. The missing context here is that radix rank doesn't rank all bits in the keys. Instead, it ranks up to
RADIX_BITS
in the keys. In other words, it does one pass only. To rank more thanRADIX_BITS
and get output ranks in the same thread as the one having the original input key, you'd have to add a loop that also exchanges keys along with values to preserve original order. Something along the following lines: