You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We need an AllReduce/AllSum that performs a reduction distributes the reduction across the entire block. I need this exposed in both CUB and cuda.cooperative. This pattern is fairly common in machine learning kernels. For example, you need it for softmax. Other programming frameworks have a primitive for this, such as MPI_Allreduce in MPI.
Today, CUB block reductions only return a meaningful result (the reduced value) in the first thread of the block.
AllReduce/AllSum may need to be a new algorithm instead of a new method on BlockReduce, because you need a small amount of shared memory (one element's worth) to distribute the result. Existing shared memory used for the reduction itself can be reused for this purpose, although you might need to add an additional two barriers (instead of just 1) if you're reusing temporary storage, so an additional element of shared memory may be worth it. If there are BlockReduce specializations that use 0 shared memory, it would be particularly undesirable to increase their shared memory usage to be non-zero. Since the extra shared memory would only be needed if you're using AllReduce/AllSum, it may be better to have it in a separate algorithm.
Here's what a basic implementation would look like. If we know how the underlying BlockReduce works, we can do better; the final reduced value may already be in the temporary storage and the first barrier may be unnecessary.
template <auto ThreadsPerBlock, typename T>
__global__ T AllSum(T input)
{
using BlockReduce = cub::BlockReduce<int, ThreadsPerBlock>;
__shared__ union {
typename BlockReduce::TempStorage reduce;
T distribute;
} tmp;
auto sum = BlockReduce(tmp.reduce).Sum(input);
CTA_SYNC(); // May be removed if we know how the reduce algorithm works.
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0)
tmp.distribute = sum;
CTA_SYNC(); // This is always necessary.
return tmp.distribute;
}
My current workaround in my cuda.cooperative examples:
import numpy as np
from numba import cuda
import cuda.cooperative.experimental as cudax
import math
from numba import uint8, float32
from pynvjitlink import patch
patch.patch_numba_linker(lto=True)
def cudax_block_broadcast(dtype):
context = cuda.descriptor.cuda_target.target_context
size = context.get_value_type(dtype).get_abi_size(context.target_data)
@cuda.jit
def broadcast(temp_storage, red):
cuda.syncthreads()
acc = temp_storage.view(dtype)
if cuda.threadIdx.x == 0 and cuda.threadIdx.y == 0 and cuda.threadIdx.z == 0:
acc[0] = red
cuda.syncthreads()
return acc[0]
broadcast.temp_storage_bytes = size
return broadcast
num_rows = 128
num_cols = 512
locals_per_thread = 8
threads_in_block = int(num_cols / locals_per_thread)
def maximum(a, b):
return a if a > b else b
block_load = cudax.block.load(float32, threads_in_block, locals_per_thread, algorithm='warp_transpose')
block_store = cudax.block.store(float32, threads_in_block, locals_per_thread, algorithm='warp_transpose')
block_broadcast = cudax_block_broadcast(float32)
block_max = cudax.block.reduce(float32, threads_in_block, maximum, locals_per_thread)
block_sum = cudax.block.sum(float32, threads_in_block, locals_per_thread)
temp_storage_bytes = max(block_load.temp_storage_bytes, block_store.temp_storage_bytes, block_broadcast.temp_storage_bytes, block_max.temp_storage_bytes, block_sum.temp_storage_bytes)
@cuda.jit(fastmath=True, link=block_store.files+block_load.files+block_max.files+block_sum.files)
def softmax(input, output):
temp_storage = cuda.shared.array(shape=temp_storage_bytes, dtype=uint8)
locals = cuda.local.array(shape=locals_per_thread, dtype=float32)
block_load(temp_storage, input[cuda.blockIdx.x, :], locals)
cuda.syncthreads()
max = block_broadcast(temp_storage, block_max(temp_storage, locals))
for i in range(locals_per_thread):
locals[i] = math.exp(locals[i] - max)
cuda.syncthreads()
sum = block_broadcast(temp_storage, block_sum(temp_storage, locals))
for i in range(locals_per_thread):
locals[i] = locals[i] / sum
cuda.syncthreads()
block_store(temp_storage, output[cuda.blockIdx.x, :], locals)
input = cuda.to_device(np.random.default_rng(0).random(size=(num_rows, num_cols), dtype=np.float32))
output = cuda.to_device(np.full((num_rows, num_cols), -1, dtype=np.float32))
print(input.copy_to_host())
softmax[num_rows, threads_in_block](input, output)
result = output.copy_to_host()
print(input.copy_to_host())
print(result)
print(result.sum(axis=1))
assert np.all(np.isclose(result.sum(axis=1), 1.))
The text was updated successfully, but these errors were encountered:
We need an
AllReduce
/AllSum
that performs a reduction distributes the reduction across the entire block. I need this exposed in both CUB and cuda.cooperative. This pattern is fairly common in machine learning kernels. For example, you need it for softmax. Other programming frameworks have a primitive for this, such asMPI_Allreduce
in MPI.Today, CUB block reductions only return a meaningful result (the reduced value) in the first thread of the block.
AllReduce
/AllSum
may need to be a new algorithm instead of a new method onBlockReduce
, because you need a small amount of shared memory (one element's worth) to distribute the result. Existing shared memory used for the reduction itself can be reused for this purpose, although you might need to add an additional two barriers (instead of just 1) if you're reusing temporary storage, so an additional element of shared memory may be worth it. If there areBlockReduce
specializations that use 0 shared memory, it would be particularly undesirable to increase their shared memory usage to be non-zero. Since the extra shared memory would only be needed if you're usingAllReduce
/AllSum
, it may be better to have it in a separate algorithm.Here's what a basic implementation would look like. If we know how the underlying BlockReduce works, we can do better; the final reduced value may already be in the temporary storage and the first barrier may be unnecessary.
My current workaround in my cuda.cooperative examples:
The text was updated successfully, but these errors were encountered: