Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUB block reduce primitive that returns the reduced result in all threads #3917

Open
brycelelbach opened this issue Feb 23, 2025 · 1 comment
Labels
cub For all items related to CUB

Comments

@brycelelbach
Copy link
Collaborator

brycelelbach commented Feb 23, 2025

We need an AllReduce/AllSum that performs a reduction distributes the reduction across the entire block. I need this exposed in both CUB and cuda.cooperative. This pattern is fairly common in machine learning kernels. For example, you need it for softmax. Other programming frameworks have a primitive for this, such as MPI_Allreduce in MPI.

Today, CUB block reductions only return a meaningful result (the reduced value) in the first thread of the block.

AllReduce/AllSum may need to be a new algorithm instead of a new method on BlockReduce, because you need a small amount of shared memory (one element's worth) to distribute the result. Existing shared memory used for the reduction itself can be reused for this purpose, although you might need to add an additional two barriers (instead of just 1) if you're reusing temporary storage, so an additional element of shared memory may be worth it. If there are BlockReduce specializations that use 0 shared memory, it would be particularly undesirable to increase their shared memory usage to be non-zero. Since the extra shared memory would only be needed if you're using AllReduce/AllSum, it may be better to have it in a separate algorithm.

Here's what a basic implementation would look like. If we know how the underlying BlockReduce works, we can do better; the final reduced value may already be in the temporary storage and the first barrier may be unnecessary.

template <auto ThreadsPerBlock, typename T>
__global__ T AllSum(T input) 
{
  using BlockReduce = cub::BlockReduce<int, ThreadsPerBlock>;
  __shared__ union {
    typename BlockReduce::TempStorage reduce;
    T distribute;
  } tmp;

  auto sum = BlockReduce(tmp.reduce).Sum(input);

  CTA_SYNC(); // May be removed if we know how the reduce algorithm works.

  if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0)
    tmp.distribute = sum;

  CTA_SYNC(); // This is always necessary.

  return tmp.distribute;
}

My current workaround in my cuda.cooperative examples:

import numpy as np
from numba import cuda
import cuda.cooperative.experimental as cudax
import math
from numba import uint8, float32
from pynvjitlink import patch
patch.patch_numba_linker(lto=True)

def cudax_block_broadcast(dtype):
  context = cuda.descriptor.cuda_target.target_context
  size = context.get_value_type(dtype).get_abi_size(context.target_data)

  @cuda.jit
  def broadcast(temp_storage, red):
    cuda.syncthreads()

    acc = temp_storage.view(dtype)

    if cuda.threadIdx.x == 0 and cuda.threadIdx.y == 0 and cuda.threadIdx.z == 0:
      acc[0] = red

    cuda.syncthreads()

    return acc[0]

  broadcast.temp_storage_bytes = size

  return broadcast

num_rows = 128
num_cols = 512

locals_per_thread = 8

threads_in_block = int(num_cols / locals_per_thread)

def maximum(a, b):
  return a if a > b else b

block_load      = cudax.block.load(float32, threads_in_block, locals_per_thread, algorithm='warp_transpose')
block_store     = cudax.block.store(float32, threads_in_block, locals_per_thread, algorithm='warp_transpose')
block_broadcast = cudax_block_broadcast(float32)
block_max       = cudax.block.reduce(float32, threads_in_block, maximum, locals_per_thread)
block_sum       = cudax.block.sum(float32, threads_in_block, locals_per_thread)

temp_storage_bytes = max(block_load.temp_storage_bytes, block_store.temp_storage_bytes, block_broadcast.temp_storage_bytes, block_max.temp_storage_bytes, block_sum.temp_storage_bytes)

@cuda.jit(fastmath=True, link=block_store.files+block_load.files+block_max.files+block_sum.files)
def softmax(input, output):
  temp_storage = cuda.shared.array(shape=temp_storage_bytes, dtype=uint8)
  locals = cuda.local.array(shape=locals_per_thread, dtype=float32)

  block_load(temp_storage, input[cuda.blockIdx.x, :], locals)

  cuda.syncthreads()

  max = block_broadcast(temp_storage, block_max(temp_storage, locals))

  for i in range(locals_per_thread):
    locals[i] = math.exp(locals[i] - max)

  cuda.syncthreads()

  sum = block_broadcast(temp_storage, block_sum(temp_storage, locals))

  for i in range(locals_per_thread):
    locals[i] = locals[i] / sum

  cuda.syncthreads()

  block_store(temp_storage, output[cuda.blockIdx.x, :], locals)

input = cuda.to_device(np.random.default_rng(0).random(size=(num_rows, num_cols), dtype=np.float32))
output = cuda.to_device(np.full((num_rows, num_cols), -1, dtype=np.float32))

print(input.copy_to_host())

softmax[num_rows, threads_in_block](input, output)

result = output.copy_to_host()

print(input.copy_to_host())
print(result)
print(result.sum(axis=1))

assert np.all(np.isclose(result.sum(axis=1), 1.))
@github-project-automation github-project-automation bot moved this to Todo in CCCL Feb 23, 2025
@brycelelbach brycelelbach added the cub For all items related to CUB label Feb 23, 2025
@brycelelbach
Copy link
Collaborator Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cub For all items related to CUB
Projects
Status: Todo
Development

No branches or pull requests

1 participant