Skip to content

Commit

Permalink
Add segmented_reduce python api (#3906)
Browse files Browse the repository at this point in the history
* Add algorithms.segmented_reduce Python API

Also avoid recomputing cccl_value of init in both segmented_reduce
and in reduce

* Change to input_array fixture

1. Include np.complex64
2. Device output size in a variable and reuse it to avoid repeated
   occurrances of literal values
3. Generate real/imag values for complex arrays in a single
   call to sampling function for efficiency
4. Change range of generated integral arrays based on the signness
   of the integral data type. For unsigned types we continue to
   sample in interval [0, 10), for signed we sample from [-5, 5].

* Corrected docstring of segmented_reduce function

* Add initial tests for segmented_reduce

* Improve readability of test_segmented_reduce_api example

* TransformIteratorKind need not override __eq__/__hash__ methods of the base

Additionally, changed the __hash__ of IteratorKind to mix the hash of
its value with hash of self.__class__.

* Add AdvancedIterator(it, offset=1) function

This is used to advance a given iterator `it` the `offset` steps without
running into multiple definitions of the advance/derefence methods.

* Add example for summing rows of a matrix using segmented_reduce

* Implement IteratorBase.__add__(self, offset : int) using make_advanced_iterator

* Use end_offsets = start_offsets + 1

This calls IteratorBase.__add__ to produce an iterator whose state
is advanced by 1, but which shares the same advance/dereference methods.

* Add a test for segmented_reduce on gpu_struct

* Change hash of transform iterator to mix its kind

* Rename variable n to sample_size

Also make generation of complex array in test_reduce.py more
efficient by genering real and imaginary components in a single
call to np.random.random instead of using two calls.

* Remove __hash__ and __eq__ special methods from some iterator classes

These were only defined for TransformIterator and AdvancedIterator classes,
but not for other classes.

Implemented review suggestion to type type(self) instead of self.__class__

* Tweak test_scan_array_input to avoid integer overflows during host accumulation

For short range data types we take a small slice of the input array to
avoid running into the overflow problem. This works because input_array
fixture samples from uniform discrete distribution with small upper range (8),
hence using 31 uint8 elements can run up to 31 * 7  = 217 ( < 255) and fits
in the type.

* Add cccl.set_cccl_iterator_state utility function and use in segmented_reduce.py

* Introduce _bindings.call_build utility

This finds compute capability and include paths and appends them
to the algorithm-specific arguments. Used the utility in segmented_reduce.

* Make call_build take *args, **kwargs
  • Loading branch information
oleksandr-pavlyk authored Feb 28, 2025
1 parent 83ba38c commit 0183959
Show file tree
Hide file tree
Showing 11 changed files with 490 additions and 27 deletions.
24 changes: 23 additions & 1 deletion python/cuda_parallel/cuda/parallel/experimental/_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import sys
from functools import lru_cache
from pathlib import Path
from typing import List
from typing import Callable, List

from numba import cuda

from cuda.cccl import get_include_paths # type: ignore[import-not-found]

Expand Down Expand Up @@ -41,3 +43,23 @@ def get_paths() -> List[bytes]:
if path is not None
]
return paths


def call_build(build_impl_fn: Callable, *args, **kwargs):
"""Calls given build_impl_fn callable while providing compute capability and paths
Returns result of the call.
"""
cc_major, cc_minor = cuda.get_current_device().compute_capability
cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths()
error = build_impl_fn(
*args,
cc_major,
cc_minor,
ctypes.c_char_p(cub_path),
ctypes.c_char_p(thrust_path),
ctypes.c_char_p(libcudacxx_path),
ctypes.c_char_p(cuda_include_path),
**kwargs,
)
return error
21 changes: 20 additions & 1 deletion python/cuda_parallel/cuda/parallel/experimental/_cccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
from numba import cuda, types

from ._utils.protocols import get_dtype, is_contiguous
from ._utils.protocols import get_data_pointer, get_dtype, is_contiguous
from .iterators._iterators import IteratorBase
from .typing import DeviceArrayLike, GpuStruct

Expand Down Expand Up @@ -121,6 +121,18 @@ class DeviceScanBuildResult(ctypes.Structure):
]


# MUST match `cccl_device_segmented_reduce_build_result_t` in c/include/cccl/c/segmented_reduce.h
class DeviceSegmentedReduceBuildResult(ctypes.Structure):
_fields_ = [
("cc", ctypes.c_int),
("cubin", ctypes.c_void_p),
("cubin_size", ctypes.c_size_t),
("library", ctypes.c_void_p),
("accumulator_size", ctypes.c_ulonglong),
("segmented_reduce_kernel", ctypes.c_void_p),
]


# MUST match `cccl_value_t` in c/include/cccl/c/types.h
class Value(ctypes.Structure):
_fields_ = [("type", TypeInfo), ("state", ctypes.c_void_p)]
Expand Down Expand Up @@ -283,3 +295,10 @@ def to_cccl_op(op: Callable, sig) -> Op:
None,
_data=(ltoir, name), # keep a reference to these in a _data attribute
)


def set_cccl_iterator_state(cccl_it: Iterator, input_it):
if cccl_it.type.value == IteratorKind.POINTER:
cccl_it.state = get_data_pointer(input_it)
else:
cccl_it.state = input_it.state
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@
from .merge_sort import merge_sort as merge_sort
from .reduce import reduce_into as reduce_into
from .scan import scan as scan
from .segmented_reduce import segmented_reduce

__all__ = ["merge_sort", "reduce_into", "scan", "segmented_reduce"]
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
self.d_in_cccl,
self.d_out_cccl,
self.op_wrapper,
cccl.to_cccl_value(h_init),
self.h_init_cccl,
cc_major,
cc_minor,
ctypes.c_char_p(cub_path),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import ctypes
from typing import Callable

import numba
import numpy as np
from numba.cuda.cudadrv import enums

from .. import _cccl as cccl
from .._bindings import call_build, get_bindings
from .._caching import CachableFunction, cache_with_key
from .._utils import protocols
from ..iterators._iterators import IteratorBase
from ..typing import DeviceArrayLike, GpuStruct


class _SegmentedReduce:
def __del__(self):
if self.build_result is None:
return
bindings = get_bindings()
bindings.cccl_device_segmented_reduce_cleanup(ctypes.byref(self.build_result))

def __init__(
self,
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike,
start_offsets_in: DeviceArrayLike | IteratorBase,
end_offsets_in: DeviceArrayLike | IteratorBase,
op: Callable,
h_init: np.ndarray | GpuStruct,
):
self.build_result = None
self.d_in_cccl = cccl.to_cccl_iter(d_in)
self.d_out_cccl = cccl.to_cccl_iter(d_out)
self.start_offsets_in_cccl = cccl.to_cccl_iter(start_offsets_in)
self.end_offsets_in_cccl = cccl.to_cccl_iter(end_offsets_in)
self.h_init_cccl = cccl.to_cccl_value(h_init)
if isinstance(h_init, np.ndarray):
value_type = numba.from_dtype(h_init.dtype)
else:
value_type = numba.typeof(h_init)
sig = (value_type, value_type)
self.op_wrapper = cccl.to_cccl_op(op, sig)
self.build_result = cccl.DeviceSegmentedReduceBuildResult()
self.bindings = get_bindings()
error = call_build(
self.bindings.cccl_device_segmented_reduce_build,
ctypes.byref(self.build_result),
self.d_in_cccl,
self.d_out_cccl,
self.start_offsets_in_cccl,
self.end_offsets_in_cccl,
self.op_wrapper,
self.h_init_cccl,
)
if error != enums.CUDA_SUCCESS:
raise ValueError("Error building reduce")

def __call__(
self,
temp_storage,
d_in,
d_out,
num_segments: int,
start_offsets_in,
end_offsets_in,
h_init,
stream=None,
):
set_state_fn = cccl.set_cccl_iterator_state
set_state_fn(self.d_in_cccl, d_in)
set_state_fn(self.d_out_cccl, d_out)
set_state_fn(self.start_offsets_in_cccl, start_offsets_in)
set_state_fn(self.end_offsets_in_cccl, end_offsets_in)
self.h_init_cccl.state = h_init.__array_interface__["data"][0]

stream_handle = protocols.validate_and_get_stream(stream)

if temp_storage is None:
temp_storage_bytes = ctypes.c_size_t()
d_temp_storage = None
else:
temp_storage_bytes = ctypes.c_size_t(temp_storage.nbytes)
d_temp_storage = protocols.get_data_pointer(temp_storage)

error = self.bindings.cccl_device_segmented_reduce(
self.build_result,
ctypes.c_void_p(d_temp_storage),
ctypes.byref(temp_storage_bytes),
self.d_in_cccl,
self.d_out_cccl,
ctypes.c_ulonglong(num_segments),
self.start_offsets_in_cccl,
self.end_offsets_in_cccl,
self.op_wrapper,
self.h_init_cccl,
ctypes.c_void_p(stream_handle),
)

if error != enums.CUDA_SUCCESS:
raise ValueError("Error reducing")

return temp_storage_bytes.value


def _to_key(d_in: DeviceArrayLike | IteratorBase):
"Return key for an input array-like argument or an iterator"
d_in_key = (
d_in.kind if isinstance(d_in, IteratorBase) else protocols.get_dtype(d_in)
)
return d_in_key


def make_cache_key(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike,
start_offsets_in: DeviceArrayLike | IteratorBase,
end_offsets_in: DeviceArrayLike | IteratorBase,
op: Callable,
h_init: np.ndarray,
):
d_in_key = _to_key(d_in)
d_out_key = protocols.get_dtype(d_out)
start_offsets_in_key = _to_key(start_offsets_in)
end_offsets_in_key = _to_key(end_offsets_in)
op_key = CachableFunction(op)
h_init_key = h_init.dtype
return (
d_in_key,
d_out_key,
start_offsets_in_key,
end_offsets_in_key,
op_key,
h_init_key,
)


@cache_with_key(make_cache_key)
def segmented_reduce(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike,
start_offsets_in: DeviceArrayLike | IteratorBase,
end_offsets_in: DeviceArrayLike | IteratorBase,
op: Callable,
h_init: np.ndarray,
):
"""Computes a device-wide segmented reduction using the specified binary ``op`` and initial value ``init``.
Example:
Below, ``segmented_reduce`` is used to compute the minimum value of a sequence of integers.
.. literalinclude:: ../../python/cuda_parallel/tests/test_segmented_reduce_api.py
:language: python
:dedent:
:start-after: example-begin segmented-reduce-min
:end-before: example-end segmented-reduce-min
Args:
d_in: Device array or iterator containing the input sequence of data items
d_out: Device array that will store the result of the reduction
start_offsets_in: Device array or iterator containing offsets to start of segments
end_offsets_in: Device array or iterator containing offsets to end of segments
op: Callable representing the binary operator to apply
init: Numpy array storing initial value of the reduction
Returns:
A callable object that can be used to perform the reduction
"""
return _SegmentedReduce(d_in, d_out, start_offsets_in, end_offsets_in, op, h_init)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __eq__(self, other):
return type(self) is type(other) and self.value_type == other.value_type

def __hash__(self):
return hash(self.value_type)
return hash((type(self), self.value_type))


@lru_cache(maxsize=None)
Expand Down Expand Up @@ -128,6 +128,9 @@ def advance(state, distance):
def dereference(state):
raise NotImplementedError("Subclasses must override dereference staticmethod")

def __add__(self, offset: int):
return make_advanced_iterator(self, offset=offset)


def sizeof_pointee(context, ptr):
size = context.get_abi_sizeof(ptr.type.pointee)
Expand Down Expand Up @@ -286,11 +289,7 @@ def dereference(state):


class TransformIteratorKind(IteratorKind):
def __eq__(self, other):
return type(self) is type(other) and self.value_type == other.value_type

def __hash__(self):
return hash(self.value_type)
pass


def make_transform_iterator(it, op: Callable):
Expand Down Expand Up @@ -336,12 +335,40 @@ def advance(state, distance):
def dereference(state):
return op(it_dereference(state))

def __hash__(self):
return hash((self._it, self._op))
return TransformIterator(it, op)

def __eq__(self, other):
if not isinstance(other.kind, TransformIteratorKind):
return NotImplemented
return self._it == other._it and self._op == other._op

return TransformIterator(it, op)
def make_advanced_iterator(it: IteratorBase, /, *, offset: int = 1):
it_advance = cuda.jit(type(it).advance, device=True)
it_dereference = cuda.jit(type(it).dereference, device=True)

class AdvancedIteratorKind(IteratorKind):
pass

class AdvancedIterator(IteratorBase):
iterator_kind_type = AdvancedIteratorKind

def __init__(self, it: IteratorBase, advance_steps: int):
self._it = it
cvalue_advanced = to_ctypes(it.value_type)(
it.cvalue + it.value_type(advance_steps)
)
super().__init__(
cvalue=cvalue_advanced,
numba_type=it.numba_type,
value_type=it.value_type,
)

@property
def kind(self):
return self.__class__.iterator_kind_type(self._it.kind)

@staticmethod
def advance(state, distance):
return it_advance(state, distance)

@staticmethod
def dereference(state):
return it_dereference(state)

return AdvancedIterator(it, offset)
20 changes: 15 additions & 5 deletions python/cuda_parallel/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,33 @@
np.uint64,
np.float32,
np.float64,
np.complex64,
np.complex128,
]
)
def input_array(request):
dtype = request.param
dtype = np.dtype(request.param)
sample_size = 1000

# Generate random values based on the dtype
if np.issubdtype(dtype, np.integer):
is_unsigned = dtype.kind == "u"
# For integer types, use np.random.randint for random integers
array = cp.random.randint(low=0, high=10, size=1000, dtype=dtype)
if is_unsigned:
low_inclusive, high_exclusive = 0, 8
else:
low_inclusive, high_exclusive = -5, 6
array = cp.random.randint(
low=low_inclusive, high=high_exclusive, size=sample_size, dtype=dtype
)
elif np.issubdtype(dtype, np.floating):
# For floating-point types, use np.random.random and cast to the required dtype
array = cp.random.random(1000).astype(dtype)
array = cp.random.random(sample_size).astype(dtype)
elif np.issubdtype(dtype, np.complexfloating):
# For complex types, generate random real and imaginary parts
real_part = cp.random.random(1000)
imag_part = cp.random.random(1000)
packed = cp.random.random(2 * sample_size)
real_part = packed[:sample_size]
imag_part = packed[sample_size:]
array = (real_part + 1j * imag_part).astype(dtype)

return array
Expand Down
8 changes: 5 additions & 3 deletions python/cuda_parallel/tests/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ def op(a, b):
reduce_into = algorithms.reduce_into(d_output, d_output, op, h_init)

for num_items in [42, 420000]:
h_input = np.random.random(num_items) + 1j * np.random.random(num_items)
real_imag = np.random.random((2, num_items))
h_input = real_imag[0] + 1j * real_imag[1]
d_input = numba.cuda.to_device(h_input)
temp_storage_bytes = reduce_into(None, d_input, d_output, d_input.size, h_init)
assert d_input.size == num_items
temp_storage_bytes = reduce_into(None, d_input, d_output, num_items, h_init)
d_temp_storage = numba.cuda.device_array(temp_storage_bytes, np.uint8)
reduce_into(d_temp_storage, d_input, d_output, d_input.size, h_init)
reduce_into(d_temp_storage, d_input, d_output, num_items, h_init)

result = d_output.copy_to_host()[0]
expected = np.sum(h_input, initial=h_init[0])
Expand Down
Loading

0 comments on commit 0183959

Please sign in to comment.