diff --git a/CHANGELOG.md b/CHANGELOG.md index f993dd9..10cbb41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,19 +1,18 @@ # Change log - + * New version compatible with JAX 0.4.30 + * jaxDecomp now works in a multi-host environment + * Added custom partitioning for FFTs + * Added custom partitioning for halo exchange + * Added custom partitioning for slice_pad and slice_unpad + * Add example for multi-host FFTs in `examples/jaxdecomp_lpt.py` -## jaxdecomp 0.0.1 +## jaxdecomp 0.0.1rc2 * Changes * Added utility to run autotuning diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e0d58b..fced2db 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,9 @@ cmake_minimum_required(VERSION 3.19...3.25) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) # Latest JAX v0.4.26 no longer supports cuda 11.8 -set(NVHPC_CUDA_VERSION 12.2) +# By default, build for CUDA 12.2, users can override this with -DNVHPC_CUDA_VERSION=11.8 +set(NVHPC_CUDA_VERSION 12.2 CACHE STRING "CUDA version to build for" ) + # Build debug # set(CMAKE_BUILD_TYPE Debug) add_subdirectory(third_party/cuDecomp) @@ -15,8 +17,8 @@ option(CUDECOMP_BUILD_FORTRAN "Build Fortran bindings" OFF) option(CUDECOMP_ENABLE_NVSHMEM "Enable NVSHMEM" OFF) option(CUDECOMP_BUILD_EXTRAS "Build benchmark, examples, and tests" OFF) -set(CUDECOMP_CUDA_CC_LIST "70;80" CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.") - +# 70: Volta, 80: Ampere, 89: RTX 4060 +set(CUDECOMP_CUDA_CC_LIST "70;80;89" CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.") # Add pybind11 and cuDecomp subdirectories add_subdirectory(pybind11) diff --git a/README.md b/README.md index d266c5d..bbeb096 100644 --- a/README.md +++ b/README.md @@ -54,13 +54,12 @@ with mesh: # Add halo regions to our array padding_width = ((32,32),(32,32),(32,32)) # Has to a tuple of tuples padded_array = jaxdecomp.slice_pad(recarray, padding_width , pdims) - # Perform a halo exchange + reduce - exchanged_reduced = jaxdecomp.halo_exchange(padded_array, + # Perform a halo exchange + exchanged_array = jaxdecomp.halo_exchange(padded_array, halo_extents=(32,32,32), - halo_periods=(True,True,True), - reduce_halo=True) + halo_periods=(True,True,True)) # Remove the halo regions - recarray = jaxdecomp.slice_unpad(exchanged_reduced, padding_width, pdims) + recarray = jaxdecomp.slice_unpad(exchanged_array, padding_width, pdims) # Gather the results (only if it fits on CPU memory) gathered_array = multihost_utils.process_allgather(recarray, tiled=True) diff --git a/include/halo.h b/include/halo.h index 70bd68f..aa40710 100644 --- a/include/halo.h +++ b/include/halo.h @@ -22,10 +22,13 @@ class haloDescriptor_t { ~haloDescriptor_t() = default; bool operator==(const haloDescriptor_t& other) const { - return (double_precision == other.double_precision && halo_extents == other.halo_extents && - halo_periods == other.halo_periods && axis == other.axis && config.gdims[0] == other.config.gdims[0] && - config.gdims[1] == other.config.gdims[1] && config.gdims[2] == other.config.gdims[2] && - config.pdims[0] == other.config.pdims[0] && config.pdims[1] == other.config.pdims[1]); + return (double_precision == other.double_precision && halo_extents[0] == other.halo_extents[0] && + halo_extents[1] == other.halo_extents[1] && halo_extents[2] == other.halo_extents[2] && + halo_periods[0] == other.halo_periods[0] && halo_periods[1] == other.halo_periods[1] && + halo_periods[2] == other.halo_periods[2] && axis == other.axis && + config.gdims[0] == other.config.gdims[0] && config.gdims[1] == other.config.gdims[1] && + config.gdims[2] == other.config.gdims[2] && config.pdims[0] == other.config.pdims[0] && + config.pdims[1] == other.config.pdims[1]); } }; diff --git a/jaxdecomp/_src/fft.py b/jaxdecomp/_src/fft.py index 2e4e0f4..013921f 100644 --- a/jaxdecomp/_src/fft.py +++ b/jaxdecomp/_src/fft.py @@ -1,29 +1,47 @@ from functools import partial -from typing import Union +from typing import Tuple, Union import jax import jaxlib.mlir.ir as ir import numpy as np -from jax import jit -from jax._src.api import jit +from jax import ShapeDtypeStruct from jax._src.interpreters import mlir from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.util import promote_dtypes_complex -from jax.core import Primitive -from jax.interpreters import ad, xla +from jax._src.typing import Array +from jax.core import Primitive, ShapedArray from jax.lib import xla_client +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P from jaxlib.hlo_helpers import custom_call import jaxdecomp from jaxdecomp._src import _jaxdecomp +from jaxdecomp._src.spmd_ops import (BasePrimitive, get_axis_size, + register_primitive) FftType = xla_client.FftType -from jax.experimental.custom_partitioning import custom_partitioning -from jax.sharding import NamedSharding -from jax.sharding import PartitionSpec as P def _str_to_fft_type(s: str) -> xla_client.FftType: + """ + Convert a string to an FFT type enum. + + Parameters + ---------- + s : str + String representation of FFT type. + + Returns + ------- + xla_client.FftType + Corresponding FFT type enum. + + Raises + ------ + ValueError + If the string `s` does not match known FFT types. + """ if s in ("fft", "FFT"): return xla_client.FftType.FFT elif s in ("ifft", "IFFT"): @@ -36,290 +54,423 @@ def _str_to_fft_type(s: str) -> xla_client.FftType: raise ValueError(f"Unknown FFT type '{s}'") -# Note : This must no longer be jitted because it will have single device abstract shapes -# The actual jit is done in the pfft function in lower_fn -# This means that sfft should never be lowered as is and only be lowered in the context of pfft -#@partial(jit, static_argnums=(1, 2, 3, 4)) -def sfft(x, - fft_type: Union[xla_client.FftType, str], - adjoint=False, - pdims=[1, 1], - global_shape=[1024, 1024, 1024]): - - #TODO(wassim) : find a way to prevent user from using the primitive directly - - if isinstance(fft_type, str): - typ = _str_to_fft_type(fft_type) - elif isinstance(fft_type, xla_client.FftType): - typ = fft_type - else: - raise TypeError(f"Unknown FFT type value '{fft_type}'") - - if typ in [xla_client.FftType.RFFT, xla_client.FftType.IRFFT]: - raise TypeError("only complex FFTs are currently supported through pfft.") +class FFTPrimitive(BasePrimitive): + """ + Custom primitive for FFT operations. + """ - (x,) = promote_dtypes_complex(x) + name = "fft" + multiple_results = False + impl_static_args = (1, 2) + inner_primitive = None + outer_primitive = None - return sfft_p.bind( - x, fft_type=typ, pdims=pdims, global_shape=global_shape, adjoint=adjoint) - - -def sfft_abstract_eval(x, fft_type, pdims, global_shape, adjoint): - - # TODO(Wassim) : this only handles cube shapes - # This function is called twice once with the global array and once with the local slice shape - - # Figure out what is the pencil decomposition at the output - axis = 0 - if fft_type in [xla_client.FftType.RFFT, xla_client.FftType.FFT]: - axis = 2 - - output_shape = None - match fft_type: - case xla_client.FftType.FFT: - # FFT is X to Y to Z so Z-Pencil is returned - # Except if we are doing a YZ slab in which case we return a Y-Pencil - transpose_shape = (1, 2, 0) - transposed_pdims = pdims - case xla_client.FftType.IFFT: - # IFFT is Z to X to Y so X-Pencil is returned - # In YZ slab case we only need one transposition back to get the X-Pencil - transpose_shape = (2, 0, 1) - transposed_pdims = pdims - case _: - raise TypeError("only complex FFTs are currently supported through pfft.") + @staticmethod + def abstract(x: Array, fft_type: xla_client.FftType, pdims: Tuple[int, int], + global_shape: Tuple[int, int, + int], adjoint: bool) -> ShapedArray: + """ + Abstract function to compute the shape of FFT output. + + Parameters + ---------- + x : Array + Input array. + fft_type : xla_client.FftType + Type of FFT operation. + pdims : Tuple[int, int] + Parallel dimensions. + global_shape : Tuple[int, int, int] + Global shape of the array. + adjoint : bool + Whether to compute the adjoint FFT. + + Returns + ------- + ShapedArray + Shape of the output array. + """ + if global_shape == x.shape: + return FFTPrimitive.outer_abstract(x, fft_type=fft_type, adjoint=adjoint) + + match fft_type: + case xla_client.FftType.FFT: + # FFT is X to Y to Z so Z-Pencil is returned + # Except if we are doing a YZ slab in which case we return a Y-Pencil + transpose_shape = (1, 2, 0) + transposed_pdims = pdims + case xla_client.FftType.IFFT: + # IFFT is Z to X to Y so X-Pencil is returned + # In YZ slab case we only need one transposition back to get the X-Pencil + transpose_shape = (2, 0, 1) + transposed_pdims = pdims + case _: + raise TypeError( + "only complex FFTs are currently supported through pfft.") - # Are we operating on the global array? - # This is called when the abstract_eval of the custom partitioning is called _custom_partitioning_abstract_eval in https://github.com/google/jax/blob/main/jax/experimental/custom_partitioning.py#L223 - if x.shape == global_shape: - shape = tuple([global_shape[i] for i in transpose_shape]) - output_shape = shape - # Or are we operating on a local slice? - # this is called JAX calls make_jaxpr(lower_fn) in https://github.com/google/jax/blob/main/jax/experimental/custom_partitioning.py#L142C5-L142C35 - else: output_shape = (global_shape[transpose_shape[0]] // transposed_pdims[1], global_shape[transpose_shape[1]] // transposed_pdims[0], global_shape[transpose_shape[2]]) - # Sanity check - assert (output_shape is not None) - return x.update(shape=output_shape, dtype=x.dtype) - - -def sfft_lowering(ctx, a, *, fft_type, pdims, global_shape, adjoint): - (x_aval,) = ctx.avals_in - (aval_out,) = ctx.avals_out - dtype = x_aval.dtype - a_type = ir.RankedTensorType(a.type) - # We currently only support complex FFTs through this interface, so let's check the fft type - assert fft_type in (FftType.FFT, - FftType.IFFT), "Only complex FFTs are currently supported" - - # Figure out which fft we want - forward = fft_type in (FftType.FFT,) - is_double = np.finfo(dtype).dtype == np.float64 - - # Get original global shape - match fft_type: - case xla_client.FftType.FFT: - transpose_back_shape = (0, 1, 2) - case xla_client.FftType.IFFT: - transpose_back_shape = (2, 0, 1) - case _: - raise TypeError("only complex FFTs are currently supported through pfft.") - # Make sure to get back the original shape of the X-Pencil - global_shape = tuple([global_shape[i] for i in transpose_back_shape]) - # Compute the descriptor for our FFT - config = _jaxdecomp.GridConfig() - - config.pdims = pdims - config.gdims = global_shape[::-1] - config.halo_comm_backend = jaxdecomp.config.halo_comm_backend - config.transpose_comm_backend = jaxdecomp.config.transpose_comm_backend - workspace_size, opaque = _jaxdecomp.build_fft_descriptor( - config, forward, is_double, adjoint) - - n = len(a_type.shape) - layout = tuple(range(n - 1, -1, -1)) - - # We ask XLA to allocate a workspace for this operation. - # TODO: check that the memory is not used all the time, just when needed - workspace = mlir.full_like_aval( - ctx, 0, jax.core.ShapedArray(shape=[workspace_size], dtype=np.byte)) - - # Run the custom op with same input and output shape, so that we can perform operations - # inplace. - result = custom_call( - "pfft3d", - result_types=[a_type], - operands=[a, workspace], - operand_layouts=[layout, (0,)], - result_layouts=[layout], - has_side_effect=True, - operand_output_aliases={0: 0}, - backend_config=opaque, - ) - - # Finally we reshape the arry to the expected shape. - out_type = ir.RankedTensorType.get(aval_out.shape, a_type.element_type) - return hlo.ReshapeOp(out_type, result).results - - -def _fft_transpose_rule(x, operand, fft_type, pdims, global_shape, adjoint): - assert fft_type in [FftType.FFT, FftType.IFFT] - if fft_type == FftType.FFT: - result = sfft(x, FftType.IFFT, ~adjoint, pdims, global_shape) - elif fft_type == FftType.IFFT: - result = sfft(x, FftType.FFT, ~adjoint, pdims, global_shape) - else: - raise NotImplementedError - - return (result,) - - -def get_axis_size(sharding, index): - axis_name = sharding.spec[index] - if axis_name == None: - return 1 - else: - return sharding.mesh.shape[sharding.spec[index]] - - -# Only named sharding have a spec -# this function is actually useless because positional sharding do not have a spec -# in case the user does not use a context mesh this will fail -# this is a placeholder function for the future -# the spec needs to be carried by a custom object that we create ourselfs -# to get inspired : https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuFFTMp/JAX_FFT/src/xfft/xfft.py#L20 -def to_named_sharding(sharding): - return NamedSharding(sharding.mesh, P(*sharding.spec)) - - -def partition(fft_type, adjoint, mesh, arg_shapes, result_shape): - """ - Tells XLA how to partition the primitive - - Args: - mesh (Mesh): The contextual mesh - - arg_shapes (tuple): A tuple of ShapeDtypeStruct that contains the shape and the sharding of each input operand + return ShapedArray(output_shape, x.dtype) - result_shape (ShapeDtypeStruct) : a ShapeDtypeStruct reprsenting a single output - - Returns: - Mesh (Mesh) : The mesh. + @staticmethod + def outer_abstract(x: Array, fft_type: xla_client.FftType, + adjoint: bool) -> ShapedArray: + """ + Abstract function for outer FFT operation. + + Parameters + ---------- + x : Array + Input array. + fft_type : xla_client.FftType + Type of FFT operation. + adjoint : bool + Whether to compute the adjoint FFT. + + Returns + ------- + ShapedArray + Shape of the output array. + """ + match fft_type: + case xla_client.FftType.FFT: + # FFT is X to Y to Z so Z-Pencil is returned + # Except if we are doing a YZ slab in which case we return a Y-Pencil + transpose_shape = (1, 2, 0) + case xla_client.FftType.IFFT: + # IFFT is Z to X to Y so X-Pencil is returned + # In YZ slab case we only need one transposition back to get the X-Pencil + transpose_shape = (2, 0, 1) + case _: + raise TypeError( + "only complex FFTs are currently supported through pfft.") + + output_shape = tuple([x.shape[i] for i in transpose_shape]) + return ShapedArray(output_shape, x.dtype) + + @staticmethod + def lowering(ctx, a: Array, *, fft_type: xla_client.FftType, + pdims: Tuple[int, int], global_shape: Tuple[int, int, + int], adjoint: bool): + """ + Lowering function for FFT primitive. + + Parameters + ---------- + ctx + Context. + a : Primitive + Input primitive. + fft_type : xla_client.FftType + Type of FFT operation. + pdims : Tuple[int, int] + Parallel dimensions. + global_shape : Tuple[int, int, int] + Global shape of the array. + adjoint : bool + Whether to compute the adjoint FFT. + + Returns + ------- + list + List of results from the operation. + """ + (x_aval,) = ctx.avals_in + (aval_out,) = ctx.avals_out + dtype = x_aval.dtype + a_type = ir.RankedTensorType(a.type) + # We currently only support complex FFTs through this interface, so let's check the fft type + assert fft_type in ( + FftType.FFT, FftType.IFFT), "Only complex FFTs are currently supported" + + # Figure out which fft we want + forward = fft_type in (FftType.FFT,) + is_double = np.finfo(dtype).dtype == np.float64 + + # Get original global shape + match fft_type: + case xla_client.FftType.FFT: + transpose_back_shape = (0, 1, 2) + case xla_client.FftType.IFFT: + transpose_back_shape = (2, 0, 1) + case _: + raise TypeError( + "only complex FFTs are currently supported through pfft.") + # Make sure to get back the original shape of the X-Pencil + global_shape = tuple([global_shape[i] for i in transpose_back_shape]) + # Compute the descriptor for our FFT + config = _jaxdecomp.GridConfig() + config.pdims = pdims + config.gdims = global_shape[::-1] + config.halo_comm_backend = jaxdecomp.config.halo_comm_backend + config.transpose_comm_backend = jaxdecomp.config.transpose_comm_backend + workspace_size, opaque = _jaxdecomp.build_fft_descriptor( + config, forward, is_double, adjoint) + + n = len(a_type.shape) + layout = tuple(range(n - 1, -1, -1)) + + # We ask XLA to allocate a workspace for this operation. + workspace = mlir.full_like_aval( + ctx, 0, jax.core.ShapedArray(shape=[workspace_size], dtype=np.byte)) + + # Run the custom op with same input and output shape, so that we can perform operations + # inplace. + result = custom_call( + "pfft3d", + result_types=[a_type], + operands=[a, workspace], + operand_layouts=[layout, (0,)], + result_layouts=[layout], + has_side_effect=True, + operand_output_aliases={0: 0}, + backend_config=opaque, + ) + + # Finally we reshape the arry to the expected shape. + return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), result).results + + @staticmethod + def impl(x: Array, fft_type: Union[str, xla_client.FftType], adjoint: bool): + """ + Implementation function for FFT primitive. + + Parameters + ---------- + x : Array + Input array. + fft_type : Union[str, xla_client.FftType] + Type of FFT operation. + adjoint : bool + Whether to compute the adjoint FFT. + + Returns + ------- + Primitive + Result of the operation. + """ + if isinstance(fft_type, str): + typ = _str_to_fft_type(fft_type) + elif isinstance(fft_type, xla_client.FftType): + typ = fft_type + else: + raise TypeError(f"Unknown FFT type value '{fft_type}'") + + if typ in [xla_client.FftType.RFFT, xla_client.FftType.IRFFT]: + raise TypeError("only complex FFTs are currently supported through pfft.") - function: The lowered function, to allow the user to redefine how the primitive is called in a context of a specific sharding + pdims = (1, jax.device_count()) + global_shape = x.shape - result_sharding (XLACompatibleSharding): The sharding result for example a NamedSharding. + return FFTPrimitive.inner_primitive.bind( + x, + fft_type=typ, + pdims=pdims, + global_shape=global_shape, + adjoint=adjoint) - arg_shardings (tuple): a tuple of all XLACompatibleSharding of the input operands + @staticmethod + def per_shard_impl(x: Array, fft_type: xla_client.FftType, pdims: Tuple[int, + int], + global_shape: Tuple[int, int, int], adjoint: bool): """ + Implementation function for per-shard FFT primitive. + + Parameters + ---------- + x : Array + Input array. + fft_type : xla_client.FftType + Type of FFT operation. + pdims : Tuple[int, int] + Parallel dimensions. + global_shape : Tuple[int, int, int] + Global shape of the array. + adjoint : bool + Whether to compute the adjoint FFT. + + Returns + ------- + Primitive + Result of the operation. + """ + return FFTPrimitive.inner_primitive.bind( + x, + fft_type=fft_type, + pdims=pdims, + global_shape=global_shape, + adjoint=adjoint) + + @staticmethod + def infer_sharding_from_operands( + fft_type: xla_client.FftType, adjoint: bool, mesh: Mesh, + arg_infos: Tuple[ShapeDtypeStruct], + result_infos: Tuple[ShapedArray]) -> NamedSharding: + """ + Infer sharding for FFT primitive based on operands. + + Parameters + ---------- + fft_type : xla_client.FftType + Type of FFT operation. + adjoint : bool + Whether to compute the adjoint FFT. + mesh : Mesh + Contextual mesh for sharding. + arg_infos : Tuple[ShapeDtypeStruct] + Shape and sharding information of input operands. + result_infos : Tuple[ShapedArray] + Shape information of output. + + Returns + ------- + NamedSharding + Sharding information for the result. + """ + input_sharding = arg_infos[0].sharding + return NamedSharding(mesh, P(*input_sharding.spec)) + + @staticmethod + def partition( + fft_type: xla_client.FftType, adjoint: bool, mesh: Mesh, + arg_shapes: Tuple[ShapeDtypeStruct], result_shape: ShapeDtypeStruct + ) -> Tuple[Mesh, partial, NamedSharding, Tuple[NamedSharding]]: + """ + Partition the FFT primitive for XLA. + + Parameters + ---------- + fft_type : xla_client.FftType + Type of FFT operation. + adjoint : bool + Whether to compute the adjoint FFT. + mesh : Mesh + Contextual mesh for sharding. + arg_shapes : Tuple[ShapeDtypeStruct] + Shape and sharding information of input operands. + result_shape : ShapeDtypeStruct + Shape and sharding information of output. + + Returns + ------- + Tuple[Mesh, partial, NamedSharding, Tuple[NamedSharding]] + Mesh, lowered function, output sharding, and input operand sharding. + """ + input_sharding = NamedSharding(mesh, P(*arg_shapes[0].sharding.spec)) + output_sharding = NamedSharding(mesh, P(*result_shape.sharding.spec)) - # pfft only has one operand - input_sharding = arg_shapes[0].sharding - - def lower_fn(operand): - # Operand is a local slice and arg_shapes contains the global shape - # No need to retranpose in the relowered function because abstract eval understands sliced input - # and in the original lowering we use aval.out - # it cannot work any other way because custom partition compares the output of the lower_fn with the abs eval (after comparing the global one) - # this means that the abs eval should handle both global shapes and slice shape - - global_shape = arg_shapes[0].shape pdims = (get_axis_size(input_sharding, 1), get_axis_size(input_sharding, 0)) + global_shape = arg_shapes[0].shape - output = sfft(operand, fft_type, adjoint, pdims, global_shape) - - # This is supposed to let us avoid making an extra transpose in the YZ case - # it does not work - # # In case of YZ slab the cuda code tranposes only once - # # We transpose again to give back the Z-Pencil to the user in case of FFT and the X-Pencil in case of IFFT - # # this transposition is supposed to compiled out by XLA when doing a gradient (forward followed by backward) - # if get_axis_size(input_sharding, 0) == 1: - # if fft_type == FftType.FFT: - # output = output.transpose((1, 2, 0)) - # elif fft_type == FftType.IFFT: - # output = output.transpose((2, 0, 1)) - return output - - return mesh, lower_fn, \ - to_named_sharding(result_shape.sharding), \ - (to_named_sharding(arg_shapes[0].sharding),) + impl = partial( + FFTPrimitive.per_shard_impl, + fft_type=fft_type, + pdims=pdims, + global_shape=global_shape, + adjoint=adjoint) + return mesh, impl, output_sharding, (input_sharding,) -def infer_sharding_from_operands(fft_type, adjoint, mesh, arg_shapes, - result_shape): - # Static arguments fft_type adjoint are carried along - """ - Tell XLA how to infer the sharding of the output from the input sharding. - Args: - mesh (Mesh): The contextual mesh +register_primitive(FFTPrimitive) - arg_shapes (tuple): A tuple of ShapeDtypeStruct that contains the shape and the sharding of each input operand - result_shape (ShapedArray) : a single ShapedArray reprsenting a single output without the sharding information +def pfft_p_lower(x: Array, fft_type: Union[str, xla_client.FftType], + adjoint: bool) -> Primitive: + """ + Lowering function for pfft primitive. + + Parameters + ---------- + x : Array + Input array. + fft_type : Union[str, xla_client.FftType] + Type of FFT operation. + adjoint : bool + Whether to compute the adjoint FFT. + + Returns + ------- + Primitive + Result of the operation. + """ + (x,) = promote_dtypes_complex(x) - Returns: + return FFTPrimitive.outer_primitive.bind( + x, fft_type=fft_type, adjoint=adjoint) - result_sharding (XLACompatibleSharding): The sharding result for example a NamedSharding. - """ - # only one operand is used in pfft - input_sharding = arg_shapes[0].sharding - return NamedSharding(mesh, P(*input_sharding.spec)) - - -@partial(custom_partitioning, static_argnums=(1, 2)) -def pfft_p_lower(x, fft_type, adjoint=False): - # the product of the fake dim has to be equal to the product of the global shape - # Fake dims and shape values are irrelevant because they are never used as concrete values only as Traced values - # their shapes however are used in the abstract eval of the custom partitioning - - size = jax.device_count() - # The pdims product must be equal to the number of devices because this is checked both in the abstract eval and in cudecomp - dummy_pdims = (1, size) - dummy_global = x.shape - return sfft(x, fft_type, adjoint, dummy_pdims, dummy_global) - - -sfft_p = Primitive("pfft") -sfft_p.def_impl(partial(xla.apply_primitive, sfft_p)) -sfft_p.def_abstract_eval(sfft_abstract_eval) -ad.deflinear2(sfft_p, _fft_transpose_rule) -mlir.register_lowering(sfft_p, sfft_lowering, platform="gpu") - -# Define the partitioning for the primitive -pfft_p_lower.def_partition( - partition=partition, - infer_sharding_from_operands=infer_sharding_from_operands) - -# declaring a differentiable SPMD primitive -# Inspired from -# https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/cpp_extensions.py#L188 -# https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/cpp_extensions.py#L694 -# https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/layernorm.py#L49 -# Note TE does a cleaner way of defining the primitive by using register_primitive(cls): that declares two primitives -# An inner which is represented here by sfft -# And an outer which by analogy shoud be represented here by pfft - - -# Do not jit this -# the jit is happening in jaxdecomp/fft.py: _do_pfft @partial(jax.custom_vjp, nondiff_argnums=(1, 2)) -def pfft(x, fft_type, adjoint=False): +def pfft(x: Array, + fft_type: Union[str, xla_client.FftType], + adjoint: bool = False) -> Primitive: + """ + Custom VJP definition for pfft. + + Parameters + ---------- + x : Array + Input array. + fft_type : Union[str, xla_client.FftType] + Type of FFT operation. + adjoint : bool, optional + Whether to compute the adjoint FFT. Defaults to False. + + Returns + ------- + Primitive + Result of the operation. + """ output, _ = _pfft_fwd_rule(x, fft_type=fft_type, adjoint=adjoint) return output -def _pfft_fwd_rule(x, fft_type: str, adjoint: bool = False): - # Linear function has no residuals +def _pfft_fwd_rule(x: Array, + fft_type: Union[str, xla_client.FftType], + adjoint: bool = False) -> Tuple[Primitive, None]: + """ + Forward rule for pfft. + + Parameters + ---------- + x : Array + Input array. + fft_type : Union[str, xla_client.FftType] + Type of FFT operation. + adjoint : bool, optional + Whether to compute the adjoint FFT. Defaults to False. + + Returns + ------- + Tuple[Primitive, None] + Result of the operation and None (no residuals). + """ return pfft_p_lower(x, fft_type=fft_type, adjoint=adjoint), None -def _pfft_bwd_rule(fft_type, adjoint, ctx, g): - +def _pfft_bwd_rule(fft_type: Union[str, xla_client.FftType], adjoint: bool, ctx, + g: Primitive) -> Tuple[Primitive]: + """ + Backward rule for pfft. + + Parameters + ---------- + fft_type : Union[str, xla_client.FftType] + Type of FFT operation. + adjoint : bool + Whether to compute the adjoint FFT. + ctx + Context. + g : Primitive + Gradient value. + + Returns + ------- + Tuple[Primitive] + Result of the operation. + """ assert fft_type in [FftType.FFT, FftType.IFFT] if fft_type == FftType.FFT: fft_type = FftType.IFFT diff --git a/jaxdecomp/_src/halo.py b/jaxdecomp/_src/halo.py index c526fb1..ddcf5cc 100644 --- a/jaxdecomp/_src/halo.py +++ b/jaxdecomp/_src/halo.py @@ -4,188 +4,385 @@ import jax import jaxlib.mlir.ir as ir import numpy as np +from jax import ShapeDtypeStruct from jax._src.interpreters import mlir -from jax.core import Primitive -from jax.experimental.custom_partitioning import custom_partitioning -from jax.interpreters import ad, xla +from jax._src.typing import Array +from jax.core import Primitive, ShapedArray from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P from jaxlib.hlo_helpers import custom_call import jaxdecomp from jaxdecomp._src import _jaxdecomp +from jaxdecomp._src.spmd_ops import (BasePrimitive, get_axis_size, + register_primitive) + + +class HaloPrimitive(BasePrimitive): + """ + Custom primitive for performing halo exchange operation. + """ + + name = "halo_exchange" + multiple_results = False + impl_static_args = (1, 2, 3) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(x: Array, halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, bool], pdims: Tuple[int, int], + global_shape: Tuple[int, int, int]) -> Array: + """ + Abstract function for determining the shape and dtype after the halo exchange operation. + + Parameters + ---------- + x : Array + Input array. + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + pdims : Tuple[int, int] + Processor dimensions. + global_shape : Tuple[int, int, int] + Global shape of the array. + + Returns + ------- + Array + Abstract array after the halo exchange operation. + """ + return x.update(shape=x.shape, dtype=x.dtype) + + @staticmethod + def outer_abstract(x: Array, halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, bool]) -> Array: + """ + Abstract function for determining the shape and dtype without considering inner details. + + Parameters + ---------- + x : Array + Input array. + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + + Returns + ------- + Array + Abstract array after the halo exchange operation. + """ + return x.update(shape=x.shape, dtype=x.dtype) + + @staticmethod + def lowering(ctx, x: Array, halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, bool], pdims: Tuple[int, int], + global_shape: Tuple[int, int, int]) -> Array: + """ + Lowering function to generate the MLIR representation for halo exchange. + + Parameters + ---------- + ctx + Context for the operation. + x : Array + Input array. + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + pdims : Tuple[int, int] + Processor dimensions. + global_shape : Tuple[int, int, int] + Global shape of the array. + + Returns + ------- + Array + Resulting array after the halo exchange operation. + """ + (x_aval,) = ctx.avals_in + x_type = ir.RankedTensorType(x.type) + n = len(x_type.shape) + + is_double = np.finfo(x_aval.dtype).dtype == np.float64 + + # Compute the descriptor for the halo exchange operation + config = _jaxdecomp.GridConfig() + config.pdims = pdims + config.gdims = global_shape[::-1] + config.halo_comm_backend = jaxdecomp.config.halo_comm_backend + config.transpose_comm_backend = jaxdecomp.config.transpose_comm_backend + + workspace_size, opaque = _jaxdecomp.build_halo_descriptor( + config, is_double, halo_extents[::-1], halo_periods[::-1], 0) + layout = tuple(range(n - 1, -1, -1)) + + workspace = mlir.full_like_aval( + ctx, 0, jax.core.ShapedArray(shape=[workspace_size], dtype=np.byte)) + + # Perform custom call for halo exchange + out = custom_call( + "halo", + result_types=[x_type], + operands=[x, workspace], + operand_layouts=[layout, (0,)], + result_layouts=[layout], + has_side_effect=True, + operand_output_aliases={0: 0}, + backend_config=opaque, + ) + return out.results + + @staticmethod + def impl(x: Array, halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, bool]) -> Primitive: + """ + Implementation function for performing halo exchange. + + Parameters + ---------- + x : Array + Input array. + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + + Returns + ------- + Primitive + Inner primitive bound with input parameters. + """ + pdims = (1, jax.device_count()) + global_shape = x.shape + + return HaloPrimitive.inner_primitive.bind( + x, + halo_extents=halo_extents, + halo_periods=halo_periods, + pdims=pdims, + global_shape=global_shape, + ) + + @staticmethod + def per_shard_impl(x: Array, halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, bool], + pdims: Tuple[int, int], global_shape: Tuple[int, int, + int]) -> Array: + """ + Implementation function for performing halo exchange per shard. + + Parameters + ---------- + x : Array + Input array. + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + pdims : Tuple[int, int] + Processor dimensions. + global_shape : Tuple[int, int, int] + Global shape of the array. + + Returns + ------- + Array + Resulting array after the halo exchange operation. + """ + output = HaloPrimitive.inner_primitive.bind( + x, + halo_extents=halo_extents, + halo_periods=halo_periods, + pdims=pdims, + global_shape=global_shape, + ) + return output -# This is the inner primitive and it should not be used directly -def inner_halo_exchange(x, - *, - halo_extents: Tuple[int, int, int], - halo_periods: Tuple[bool, bool, bool], - reduce_halo: bool = True, - pdims: Tuple[int, int] = (1, 1), - global_shape: Tuple[int, int, int] = [1024, 1024, - 1024]): - # TODO: check float or real - return inner_halo_p.bind( - x, - halo_extents=halo_extents, - halo_periods=halo_periods, - reduce_halo=reduce_halo, - pdims=pdims, - global_shape=global_shape) - - -def inner_halo_abstract_eval(x, halo_extents, halo_periods, reduce_halo, pdims, - global_shape): - # The return shape is equal to the global shape for the inner primitive (the one that is not exposed) - # The return shape is equal to the slice shape for the outer primitive (the one that is exposed) - # in all cases it is x.shape - return x.update(shape=x.shape, dtype=x.dtype) - - -def inner_halo_lowering(ctx, x, *, halo_extents, halo_periods, reduce_halo, - pdims, global_shape): - (x_aval,) = ctx.avals_in - x_type = ir.RankedTensorType(x.type) - n = len(x_type.shape) - - is_double = np.finfo(x_aval.dtype).dtype == np.float64 - - # Compute the descriptor for our FFT - config = _jaxdecomp.GridConfig() - config.pdims = pdims - config.gdims = global_shape[::-1] - config.halo_comm_backend = jaxdecomp.config.halo_comm_backend - config.transpose_comm_backend = jaxdecomp.config.transpose_comm_backend - - workspace_size, opaque = _jaxdecomp.build_halo_descriptor( - config, is_double, halo_extents[::-1], halo_periods[::-1], 0) - layout = tuple(range(n - 1, -1, -1)) - - workspace = mlir.full_like_aval( - ctx, 0, jax.core.ShapedArray(shape=[workspace_size], dtype=np.byte)) - - # reduce_halo is not used in the inner primitive - out = custom_call( - "halo", - result_types=[x_type], - operands=[x, workspace], - operand_layouts=[layout, (0,)], - result_layouts=[layout], - has_side_effect=True, - operand_output_aliases={0: 0}, - backend_config=opaque, - ) - return out.results - - -def inner_halo_transpose_rule(x, operand, halo_extents, halo_periods, pdims, - global_shape): - result = halo_exchange(x, halo_extents, halo_periods, pdims, global_shape) - return (result,) - - -# Custom Primitive - - -def get_axis_size(sharding, index): - axis_name = sharding.spec[index] - if axis_name == None: - return 1 - else: - return sharding.mesh.shape[sharding.spec[index]] - - -def partition(halo_extents, halo_periods, reduce_halo, mesh, arg_shapes, - result_shape): - - # halo_exchange has three operands - # x, halo_extents, halo_periods - # (sanity check) - #assert len(arg_shapes) == 3 , "halo_exchange must have only 3 operands in the partitioning lower function" - # only x is sharded the other are fully replicated - halo_exchange_sharding = arg_shapes[0].sharding - - def lower_fn(operand): - + @staticmethod + def infer_sharding_from_operands( + halo_extents: Tuple[int, int, int], halo_periods: Tuple[bool, bool, bool], + mesh: NamedSharding, arg_infos: Tuple[ShapeDtypeStruct], + result_infos: Tuple[ShapedArray]) -> NamedSharding: + """ + Infer sharding information for halo exchange operation. + + Parameters + ---------- + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + mesh : NamedSharding + Mesh object for sharding. + arg_shapes : Tuple[ir.ShapeDtypeStruct] + Shapes and dtypes of input operands. + result_shape : ir.ShapedArray + Shape and dtype of the output result. + + Returns + ------- + NamedSharding + Sharding information for halo exchange operation. + """ + halo_exchange_sharding = arg_infos[0].sharding + return NamedSharding(mesh, P(*halo_exchange_sharding.spec)) + + @staticmethod + def partition( + halo_extents: Tuple[int, int, int], halo_periods: Tuple[bool, bool, bool], + mesh: NamedSharding, arg_shapes: Tuple[ShapeDtypeStruct], + result_shape: ShapeDtypeStruct) -> Tuple[NamedSharding, partial]: + """ + Partition function for halo exchange operation. + + Parameters + ---------- + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + mesh : NamedSharding + Mesh object for sharding. + arg_shapes : Tuple[ir.ShapeDtypeStruct] + Shapes and dtypes of input operands. + result_shape : ir.ShapedArray + Shape and dtype of the output result. + + Returns + ------- + Tuple[NamedSharding, partial] + Mesh object, implementation function, sharding information, and its tuple. + """ + halo_exchange_sharding = NamedSharding(mesh, + P(*arg_shapes[0].sharding.spec)) global_shape = arg_shapes[0].shape pdims = (get_axis_size(halo_exchange_sharding, 1), get_axis_size(halo_exchange_sharding, 0)) - shape_without_halo = (global_shape[0] - 2 * pdims[1] * halo_extents[0],\ - global_shape[1] - 2 * pdims[0] * halo_extents[1],\ + shape_without_halo = (global_shape[0] - 2 * pdims[1] * halo_extents[0], + global_shape[1] - 2 * pdims[0] * halo_extents[1], global_shape[2] - 2 * halo_extents[2]) - output = inner_halo_exchange(operand, halo_extents=halo_extents, \ - halo_periods=halo_periods, pdims=pdims, global_shape=shape_without_halo) - - if reduce_halo: - ## Apply correction along x - output = output.at[halo_extents[0]:2 * halo_extents[0]].add( - output[:halo_extents[0]]) - output = output.at[-2 * halo_extents[0]:-halo_extents[0]].add( - output[-halo_extents[0]:]) - ## Apply correction along y - output = output.at[:, halo_extents[1]:2 * halo_extents[1]].add( - output[:, :halo_extents[1]]) - output = output.at[:, -2 * halo_extents[1]:-halo_extents[1]].add( - output[:, -halo_extents[1]:]) - - return output - - return mesh, lower_fn, \ - result_shape.sharding, \ - (halo_exchange_sharding,) + impl = partial( + HaloPrimitive.per_shard_impl, + halo_extents=halo_extents, + halo_periods=halo_periods, + pdims=pdims, + global_shape=shape_without_halo) + return mesh, impl, halo_exchange_sharding, (halo_exchange_sharding,) -def infer_sharding_from_operands(halo_extents, halo_periods, reduce_halo, mesh, - arg_shapes, result_shape): - # Sharding is the same here aswell because halo_exchange is a pointwise operation - halo_exchange_sharding = arg_shapes[0].sharding - return halo_exchange_sharding +register_primitive(HaloPrimitive) -@partial(custom_partitioning, static_argnums=(1, 2, 3)) -def halo_p_lower(x, halo_extents, halo_periods, reduce_halo): - size = jax.device_count() - # The pdims product must be equal to the number of devices because this is checked both in the abstract eval and in cudecomp - dummy_pdims = (1, size) - dummy_global = x.shape - return inner_halo_exchange(x, halo_extents=halo_extents, halo_periods=halo_periods,reduce_halo = reduce_halo,\ - pdims=dummy_pdims, global_shape=dummy_global) +def halo_p_lower(x: Array, halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, bool]) -> Primitive: + """ + Lowering function for the halo exchange operation. + Parameters + ---------- + x : Array + Input array. + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. -# declare primitive - -inner_halo_p = Primitive("halo_exchange") -inner_halo_p.def_impl(partial(xla.apply_primitive, inner_halo_p)) -inner_halo_p.def_abstract_eval(inner_halo_abstract_eval) -ad.deflinear2(inner_halo_p, inner_halo_transpose_rule) -mlir.register_lowering(inner_halo_p, inner_halo_lowering, platform="gpu") - -# Define the partitioning for the primitive -halo_p_lower.def_partition( - partition=partition, - infer_sharding_from_operands=infer_sharding_from_operands) + Returns + ------- + Primitive + Inner primitive bound with input parameters. + """ + return HaloPrimitive.outer_primitive.bind( + x, + halo_extents=halo_extents, + halo_periods=halo_periods, + ) # Custom Partitioning @partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3)) -def halo_exchange(x, halo_extents, halo_periods, reduce_halo=False): - output, _ = _halo_fwd_rule(x, halo_extents, halo_periods, reduce_halo) +def halo_exchange(x: Array, halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, bool]) -> Array: + """ + Halo exchange operation with custom VJP. + + Parameters + ---------- + x : Array + Input array. + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + + Returns + ------- + Array + Output array after the halo exchange operation. + """ + output, _ = _halo_fwd_rule(x, halo_extents, halo_periods) return output -def _halo_fwd_rule(x, halo_extents, halo_periods, reduce_halo): - # Linear function has no residuals - return halo_p_lower(x, halo_extents, halo_periods, reduce_halo), None - - -def _halo_bwd_rule(halo_extents, halo_periods, reduce_halo, ctx, g): - return halo_p_lower(g, halo_extents, halo_periods, reduce_halo), - - +def _halo_fwd_rule(x: Array, halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, bool]) -> Tuple[Array, None]: + """ + Forward rule for the halo exchange operation. + + Parameters + ---------- + x : Array + Input array. + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + + Returns + ------- + Tuple[Array, None] + Output array after the halo exchange operation and None for no residuals. + """ + return halo_p_lower(x, halo_extents, halo_periods), None + + +def _halo_bwd_rule(halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, + bool], ctx, g: Array) -> Tuple[Array]: + """ + Backward rule for the halo exchange operation. + + Parameters + ---------- + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + ctx + Context for the operation. + g : Array + Gradient array. + + Returns + ------- + Tuple[Array] + Gradient array after the halo exchange operation. + """ + return halo_p_lower(g, halo_extents, halo_periods), + + +# Define VJP for custom halo_exchange operation halo_exchange.defvjp(_halo_fwd_rule, _halo_bwd_rule) +# JIT compile the halo_exchange operation halo_exchange = jax.jit(halo_exchange, static_argnums=(1, 2, 3)) diff --git a/jaxdecomp/_src/padding.py b/jaxdecomp/_src/padding.py index 473e799..c541cf5 100644 --- a/jaxdecomp/_src/padding.py +++ b/jaxdecomp/_src/padding.py @@ -1,5 +1,3 @@ -from abc import ABCMeta, abstractmethod -from dataclasses import dataclass from functools import partial from typing import Tuple @@ -8,20 +6,32 @@ from jax._src.api import ShapeDtypeStruct from jax._src.core import ShapedArray from jax._src.typing import Array, ArrayLike -from jax.experimental.custom_partitioning import custom_partitioning from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P from jaxdecomp._src.spmd_ops import CustomParPrimitive, register_primitive -## Padding Custom SPMD lowering class SlicePaddingPrimitive(CustomParPrimitive): - - name = "slice_pad" - multiple_results = False - impl_static_args = (1, 2, 3) - outer_pritimive = None + """ + Custom primitive for slice padding operation. + + Attributes + ---------- + name : str + The name of the primitive operation. + multiple_results : bool + Whether the operation produces multiple results. + impl_static_args : tuple + Static arguments for the implementation. + outer_pritimive : object + Outer primitive used for the operation. + """ + + name: str = "slice_pad" + multiple_results: bool = False + impl_static_args: Tuple[int, int, int] = (1, 2, 3) + outer_pritimive: object = None # Global array implementation is used purely for its abstract eval # at jit time, the shape of the global array output is infered from this function @@ -33,7 +43,25 @@ def impl(arr: ArrayLike, padding_width: int | tuple[int], pdims: tuple[int], mode: str = 'constant') -> Array: - + """ + Implementation of the slice padding operation. + + Parameters + ---------- + arr : ArrayLike + Input array to be padded. + padding_width : int | tuple[int] + Width of padding to apply. + pdims : tuple[int] + Dimensions for padding. + mode : str, optional + Padding mode ('constant' by default). + + Returns + ------- + Array + Padded array. + """ assert arr.ndim == 3, "Only 3D arrays are supported" assert len(pdims) == 2, "Only 2D pdims are supported" @@ -83,11 +111,27 @@ def impl(arr: ArrayLike, return arr - # Actual per slice implementation of the primitive @staticmethod def per_shard_impl(arr: ArrayLike, padding_width: int | tuple[int], mode: str = 'constant') -> Array: + """ + Per-shard implementation of the slice padding operation. + + Parameters + ---------- + arr : ArrayLike + Input array to be padded. + padding_width : int | tuple[int] + Width of padding to apply. + mode : str, optional + Padding mode ('constant' by default). + + Returns + ------- + Array + Padded array. + """ return jnp.pad(arr, padding_width, mode=mode) @staticmethod @@ -95,6 +139,29 @@ def infer_sharding_from_operands(padding_width: int | tuple[int], pdims: tuple[int], mode: str, mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], result_infos: Tuple[ShapedArray]): + """ + Infers sharding information from operands for the slice padding operation. + + Parameters + ---------- + padding_width : int | tuple[int] + Width of padding to apply. + pdims : tuple[int] + Dimensions for padding. + mode : str + Padding mode. + mesh : Mesh + Computational mesh. + arg_infos : Tuple[ShapeDtypeStruct] + Information about operands. + result_infos : Tuple[ShapedArray] + Information about results. + + Returns + ------- + NamedSharding + Sharding information. + """ input_sharding = arg_infos[0].sharding return NamedSharding(input_sharding.mesh, P(*input_sharding.spec)) @@ -102,8 +169,29 @@ def infer_sharding_from_operands(padding_width: int | tuple[int], def partition(padding_width: int | tuple[int], pdims: tuple[int], mode: str, mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], result_infos: Tuple[ShapedArray]): - - # Only one non static input and one output + """ + Partitions the slice padding operation across a computational mesh. + + Parameters + ---------- + padding_width : int | tuple[int] + Width of padding to apply. + pdims : tuple[int] + Dimensions for padding. + mode : str + Padding mode. + mesh : Mesh + Computational mesh. + arg_infos : Tuple[ShapeDtypeStruct] + Information about operands. + result_infos : Tuple[ShapedArray] + Information about results. + + Returns + ------- + Tuple + Mesh, implementation, output sharding, and input sharding. + """ input_sharding = NamedSharding(mesh, P(*arg_infos[0].sharding.spec)) output_sharding = NamedSharding(mesh, P(*result_infos.sharding.spec)) impl = partial( @@ -121,25 +209,69 @@ def slice_pad(x: ArrayLike, padding_width: int | tuple[int], pdims: tuple[int], mode: str = 'constant') -> Array: + """ + JIT-compiled function for slice padding operation. + + Parameters + ---------- + x : ArrayLike + Input array to be padded. + padding_width : int | tuple[int] + Width of padding to apply. + pdims : tuple[int] + Dimensions for padding. + mode : str, optional + Padding mode ('constant' by default). + + Returns + ------- + Array + Padded array. + """ return SlicePaddingPrimitive.outer_lowering(x, padding_width, pdims, mode) -## Unpadding Custom SPMD lowering - - class SliceUnPaddingPrimitive(CustomParPrimitive): + """ + Custom primitive for slice unpading operation. + + Attributes + ---------- + name : str + The name of the primitive operation. + multiple_results : bool + Whether the operation produces multiple results. + impl_static_args : tuple + Static arguments for the implementation. + outer_pritimive : object + Outer primitive used for the operation. + """ + + name: str = "slice_unpad" + multiple_results: bool = False + impl_static_args: Tuple[int, int] = (1, 2) + outer_pritimive: object = None - name = "slice_unpad" - multiple_results = False - impl_static_args = (1, 2) - outer_pritimive = None - - # Same as padding, the global array implementation is used purely for its abstract eval @staticmethod def impl(arr: ArrayLike, padding_width: int | tuple[int], pdims: tuple[int]) -> Array: - - # If padding width is an integer then unpad the entire array + """ + Implementation of the slice unpading operation. + + Parameters + ---------- + arr : ArrayLike + Input array to be unpadded. + padding_width : int | tuple[int] + Width of padding to remove. + pdims : tuple[int] + Dimensions for unpadding. + + Returns + ------- + Array + Unpadded array. + """ if isinstance(padding_width, int): unpadding_width = ((padding_width, padding_width),) * arr.ndim elif isinstance(padding_width, tuple): @@ -175,10 +307,23 @@ def impl(arr: ArrayLike, padding_width: int | tuple[int], return arr - # Actual per slice implementation of the primitive @staticmethod def per_shard_impl(arr: ArrayLike, padding_width: int | tuple[int]) -> Array: - # If padding width is an integer then unpad the entire array + """ + Per-shard implementation of the slice unpading operation. + + Parameters + ---------- + arr : ArrayLike + Input array to be unpadded. + padding_width : int | tuple[int] + Width of padding to remove. + + Returns + ------- + Array + Unpadded array. + """ if isinstance(padding_width, int): unpadding_width = ((padding_width, padding_width),) * arr.ndim elif isinstance(padding_width, tuple): @@ -200,6 +345,27 @@ def infer_sharding_from_operands(padding_width: int | tuple[int], pdims: tuple[int], mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], result_infos: Tuple[ShapedArray]): + """ + Infers sharding information from operands for the slice unpading operation. + + Parameters + ---------- + padding_width : int | tuple[int] + Width of padding to remove. + pdims : tuple[int] + Dimensions for unpadding. + mesh : Mesh + Computational mesh. + arg_infos : Tuple[ShapeDtypeStruct] + Information about operands. + result_infos : Tuple[ShapedArray] + Information about results. + + Returns + ------- + NamedSharding + Sharding information. + """ input_sharding = arg_infos[0].sharding return NamedSharding(input_sharding.mesh, P(*input_sharding.spec)) @@ -207,8 +373,27 @@ def infer_sharding_from_operands(padding_width: int | tuple[int], def partition(padding_width: int | tuple[int], pdims: tuple[int], mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], result_infos: Tuple[ShapedArray]): - - # Only one non static input and one output + """ + Partitions the slice unpading operation across a computational mesh. + + Parameters + ---------- + padding_width : int | tuple[int] + Width of padding to remove. + pdims : tuple[int] + Dimensions for unpadding. + mesh : Mesh + Computational mesh. + arg_infos : Tuple[ShapeDtypeStruct] + Information about operands. + result_infos : Tuple[ShapedArray] + Information about results. + + Returns + ------- + Tuple + Mesh, implementation, output sharding, and input sharding. + """ input_sharding = NamedSharding(mesh, P(*arg_infos[0].sharding.spec)) output_sharding = NamedSharding(mesh, P(*result_infos.sharding.spec)) impl = partial( @@ -222,4 +407,21 @@ def partition(padding_width: int | tuple[int], pdims: tuple[int], mesh: Mesh, @partial(jit, static_argnums=(1, 2)) def slice_unpad(arr: ArrayLike, padding_width: int | tuple[int], pdims: tuple[int]) -> Array: + """ + JIT-compiled function for slice unpading operation. + + Parameters + ---------- + arr : ArrayLike + Input array to be unpadded. + padding_width : int | tuple[int] + Width of padding to remove. + pdims : tuple[int] + Dimensions for unpadding. + + Returns + ------- + Array + Unpadded array. + """ return SliceUnPaddingPrimitive.outer_lowering(arr, padding_width, pdims) diff --git a/jaxdecomp/_src/transpose.py b/jaxdecomp/_src/transpose.py index 3b90313..d6b8eba 100644 --- a/jaxdecomp/_src/transpose.py +++ b/jaxdecomp/_src/transpose.py @@ -1,17 +1,14 @@ from functools import partial -from os import name from typing import Tuple import jax import jaxlib.mlir.ir as ir import numpy as np -from jax._src.api import ShapeDtypeStruct -from jax._src.core import ShapedArray +from jax import ShapeDtypeStruct from jax._src.interpreters import mlir from jax._src.lib.mlir.dialects import hlo from jax._src.typing import Array, ArrayLike -from jax.core import Primitive, ShapedArray -from jax.interpreters import ad, xla +from jax.core import ShapedArray from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P from jaxlib.hlo_helpers import custom_call @@ -21,22 +18,53 @@ from jaxdecomp._src.spmd_ops import (BasePrimitive, get_axis_size, register_primitive) -_out_axes = {'x_y': 1, 'y_z': 2, 'z_y': 1, 'y_x': 0} - -import traceback - class TransposePrimitive(BasePrimitive): - - name = "transpose" - multiple_results = False - impl_static_args = (1,) - inner_primitive = None - outer_primitive = None + """ + JAX primitive for transposing arrays with different partitioning strategies. + + Attributes + ---------- + name : str + Name of the primitive ("transpose"). + multiple_results : bool + Boolean indicating if the primitive returns multiple results (False). + impl_static_args : tuple + Static arguments for the implementation (tuple containing (1,)). + inner_primitive : object + Inner core.Primitive object for the primitive. + outer_primitive : object + Outer core.Primitive object for the primitive. + """ + + name: str = "transpose" + multiple_results: bool = False + impl_static_args: Tuple[int] = (1,) + inner_primitive: object = None + outer_primitive: object = None @staticmethod - def abstract(x, kind, pdims, global_shape): - + def abstract(x: ArrayLike, kind: str, pdims: Tuple[int], + global_shape: Tuple[int]) -> ShapedArray: + """ + Abstract method to describe the shape of the output array after transposition. + + Parameters + ---------- + x : ArrayLike + Input array. + kind : str + Kind of transposition ('x_y', 'y_z', 'z_y', 'y_x'). + pdims : tuple[int] + Partition dimensions. + global_shape : tuple[int] + Global shape of the input array. + + Returns + ------- + ShapedArray + Abstract shape of the output array after transposition. + """ if global_shape == x.shape: return TransposePrimitive.outer_abstract(x, kind) # Make sure that global_shape is divisible by pdims and equals to slice @@ -64,8 +92,22 @@ def abstract(x, kind, pdims, global_shape): return ShapedArray(shape, x.dtype) @staticmethod - def outer_abstract(x, kind): - + def outer_abstract(x: ArrayLike, kind: str) -> ShapedArray: + """ + Abstract method for transposition that does not require knowledge of global shape. + + Parameters + ---------- + x : ArrayLike + Input array. + kind : str + Kind of transposition ('x_y', 'y_z', 'z_y', 'y_x'). + + Returns + ------- + ShapedArray + Abstract shape of the output array after transposition. + """ assert kind in ['x_y', 'y_z', 'z_y', 'y_x'] match kind: # From X to Y the axis are rolled by 1 and pdims are swapped wrt to the input pdims @@ -80,7 +122,29 @@ def outer_abstract(x, kind): return ShapedArray(shape, x.dtype) @staticmethod - def lowering(ctx, x, *, kind, pdims, global_shape): + def lowering(ctx, x: ArrayLike, *, kind: str, pdims: Tuple[int], + global_shape: Tuple[int]): + """ + Method to lower the transposition operation to MLIR. + + Parameters + ---------- + ctx : object + Context for the operation. + x : ArrayLike + Input array. + kind : str + Kind of transposition ('x_y', 'y_z', 'z_y', 'y_x'). + pdims : tuple[int] + Partition dimensions. + global_shape : tuple[int] + Global shape of the input array. + + Returns + ------- + List + List of lowered results. + """ assert kind in ['x_y', 'y_z', 'z_y', 'y_x'] (aval_in,) = ctx.avals_in (aval_out,) = ctx.avals_out @@ -136,23 +200,76 @@ def lowering(ctx, x, *, kind, pdims, global_shape): return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), result).results @staticmethod - def impl(x, kind): + def impl(x: ArrayLike, kind: str): + """ + Implementation method for the transposition primitive. + + Parameters + ---------- + x : ArrayLike + Input array. + kind : str + Kind of transposition ('x_y', 'y_z', 'z_y', 'y_x'). + + Returns + ------- + Object + Result of binding the inner primitive with input arguments. + """ size = jax.device_count() - # The pdims product must be equal to the number of devices because this is checked both in the abstract eval and in cudecomp - pdims = (size, 1) + pdims = (size, 1) # pdims product must be equal to the number of devices global_shape = x.shape return TransposePrimitive.inner_primitive.bind( x, kind=kind, pdims=pdims, global_shape=global_shape) @staticmethod - def per_shard_impl(x, kind, pdims, global_shape): + def per_shard_impl(x: ArrayLike, kind: str, pdims: Tuple[int], + global_shape: Tuple[int]): + """ + Per-shard implementation method for the transposition primitive. + + Parameters + ---------- + x : ArrayLike + Input array. + kind : str + Kind of transposition ('x_y', 'y_z', 'z_y', 'y_x'). + pdims : tuple[int] + Partition dimensions. + global_shape : tuple[int] + Global shape of the input array. + + Returns + ------- + Object + Result of binding the inner primitive with input arguments. + """ return TransposePrimitive.inner_primitive.bind( x, kind=kind, pdims=pdims, global_shape=global_shape) @staticmethod - def infer_sharding_from_operands(kind: str, mesh: Mesh, - arg_infos: Tuple[ShapeDtypeStruct], - result_infos: Tuple[ShapedArray]): + def infer_sharding_from_operands( + kind: str, mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], + result_infos: Tuple[ShapedArray]) -> NamedSharding: + """ + Method to infer sharding information from operands for custom partitioning. + + Parameters + ---------- + kind : str + Kind of transposition ('x_y', 'y_z', 'z_y', 'y_x'). + mesh : Mesh + Sharding mesh information. + arg_infos : Tuple[ShapeDtypeStruct] + Tuple of ShapeDtypeStruct for input operands. + result_infos : Tuple[ShapedArray] + Tuple of ShapedArray for result information. + + Returns + ------- + NamedSharding + Named sharding information. + """ input_sharding = arg_infos[0].sharding tranposed_pdims = (input_sharding.spec[1], input_sharding.spec[0], None) @@ -160,9 +277,29 @@ def infer_sharding_from_operands(kind: str, mesh: Mesh, return NamedSharding(input_sharding.mesh, P(*tranposed_pdims)) @staticmethod - def partition(kind: str, mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], - result_infos: Tuple[ShapedArray]): - + def partition( + kind: str, mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], + result_infos: Tuple[ShapedArray] + ) -> Tuple[Mesh, callable, NamedSharding, Tuple[NamedSharding]]: + """ + Method to partition the transposition operation for custom partitioning. + + Parameters + ---------- + kind : str + Kind of transposition ('x_y', 'y_z', 'z_y', 'y_x'). + mesh : Mesh + Sharding mesh information. + arg_infos : Tuple[ShapeDtypeStruct] + Tuple of ShapeDtypeStruct for input operands. + result_infos : Tuple[ShapedArray] + Tuple of ShapedArray for result information. + + Returns + ------- + Tuple + Tuple containing mesh, implementation function, output sharding, and input sharding. + """ input_sharding = NamedSharding(mesh, P(*arg_infos[0].sharding.spec)) output_sharding = NamedSharding(mesh, P(*result_infos.sharding.spec)) global_shape = arg_infos[0].shape @@ -186,14 +323,44 @@ def partition(kind: str, mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], register_primitive(TransposePrimitive) -@partial(jax.jit, static_argnums=(1)) -def transpose(x, kind: str) -> Array: +@partial(jax.jit, static_argnums=(1,)) +def transpose(x: ArrayLike, kind: str) -> Array: + """ + JIT-compiled function for performing transposition using the outer primitive. + + Parameters + ---------- + x : ArrayLike + Input array. + kind : str + Kind of transposition ('x_y', 'y_z', 'z_y', 'y_x'). + + Returns + ------- + Array + Transposed array. + """ return TransposePrimitive.outer_primitive.bind(x, kind=kind) -# X to Y +# Custom transposition functions + + @jax.custom_vjp def transposeXtoY(x: ArrayLike) -> Array: + """ + Custom JAX transposition function for X to Y. + + Parameters + ---------- + x : ArrayLike + Input array. + + Returns + ------- + Array + Transposed array. + """ return transpose(x, 'x_y') @@ -205,9 +372,21 @@ def transposeXtoY_bwd(_, g): return transpose(g, 'y_x'), -# Y to X @jax.custom_vjp def transposeYtoZ(x: ArrayLike) -> Array: + """ + Custom JAX transposition function for Y to Z. + + Parameters + ---------- + x : ArrayLike + Input array. + + Returns + ------- + Array + Transposed array. + """ return transpose(x, 'y_z') @@ -219,9 +398,21 @@ def transposeYtoZ_bwd(_, g): return transpose(g, 'z_y'), -# Z to Y @jax.custom_vjp def transposeZtoY(x: ArrayLike) -> Array: + """ + Custom JAX transposition function for Z to Y. + + Parameters + ---------- + x : ArrayLike + Input array. + + Returns + ------- + Array + Transposed array. + """ return transpose(x, 'z_y') @@ -233,9 +424,21 @@ def transposeZtoY_bwd(_, g): return transpose(g, 'y_z'), -# Y to X @jax.custom_vjp def transposeYtoX(x: ArrayLike) -> Array: + """ + Custom JAX transposition function for Y to X. + + Parameters + ---------- + x : ArrayLike + Input array. + + Returns + ------- + Array + Transposed array. + """ return transpose(x, 'y_x') @@ -247,6 +450,7 @@ def transposeYtoX_bwd(_, g): return transpose(g, 'x_y'), +# Define VJPs for custom transposition functions transposeXtoY.defvjp(transposeXtoY_fwd, transposeXtoY_bwd) transposeYtoZ.defvjp(transposeYtoZ_fwd, transposeYtoZ_bwd) transposeZtoY.defvjp(transposeZtoY_fwd, transposeZtoY_bwd) diff --git a/jaxdecomp/fft.py b/jaxdecomp/fft.py index 137c842..e7938b6 100644 --- a/jaxdecomp/fft.py +++ b/jaxdecomp/fft.py @@ -17,6 +17,28 @@ def _fft_norm(s: Array, func_name: str, norm: str) -> Array: + """ + Compute the normalization factor for FFT operations. + + Parameters + ---------- + s : Array + Shape of the input array. + func_name : str + Name of the FFT function ("fft" or "ifft"). + norm : str + Type of normalization ("backward", "ortho", or "forward"). + + Returns + ------- + Array + Normalization factor. + + Raises + ------ + ValueError + If an invalid norm value is provided. + """ if norm == "backward": return 1 / jnp.prod(s) if func_name.startswith("i") else jnp.array(1) elif norm == "ortho": @@ -37,8 +59,25 @@ def _do_pfft( arr: ArrayLike, norm: Optional[str], ) -> Array: - # this is not allowed in a multi host setup - # arr = jnp.asarray(a) + """ + Perform 3D FFT or inverse 3D FFT on the input array. + + Parameters + ---------- + func_name : str + Name of the FFT function ("fft" or "ifft"). + fft_type : xla_client.FftType + Type of FFT operation. + arr : ArrayLike + Input array to transform. + norm : Optional[str] + Type of normalization ("backward", "ortho", or "forward"). + + Returns + ------- + Array + Transformed array after FFT or inverse FFT. + """ transformed = _pfft(arr, fft_type) transformed *= _fft_norm( jnp.array(arr.shape, dtype=transformed.dtype), func_name, norm) @@ -46,8 +85,38 @@ def _do_pfft( def pfft3d(a: ArrayLike, norm: Optional[str] = "backward") -> Array: + """ + Perform 3D FFT on the input array. + + Parameters + ---------- + a : ArrayLike + Input array to transform. + norm : Optional[str], optional + Type of normalization ("backward", "ortho", or "forward"), by default "backward". + + Returns + ------- + Array + Transformed array after 3D FFT. + """ return _do_pfft("fft", xla_client.FftType.FFT, a, norm=norm) def pifft3d(a: ArrayLike, norm: Optional[str] = "backward") -> Array: + """ + Perform inverse 3D FFT on the input array. + + Parameters + ---------- + a : ArrayLike + Input array to transform. + norm : Optional[str], optional + Type of normalization ("backward", "ortho", or "forward"), by default "backward". + + Returns + ------- + Array + Transformed array after inverse 3D FFT. + """ return _do_pfft("ifft", xla_client.FftType.IFFT, a, norm=norm) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1bbd47e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,32 @@ +import jax +import pytest + +import jaxdecomp + +setup_done = False + + +def initialize_distributed(): + global setup_done + if not setup_done: + jax.distributed.initialize() + setup_done = True + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_teardown_session(): + # Code to run at the start of the session + print("Starting session...") + initialize_distributed() + # Setup code here + # e.g., connecting to a database, initializing some resources, etc. + + yield + + # Code to run at the end of the session + print("Ending session...") + jaxdecomp.finalize() + jax.distributed.shutdown() + + # Teardown code here + # e.g., closing connections, cleaning up resources, etc. diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh new file mode 100755 index 0000000..c6db0da --- /dev/null +++ b/tests/run_all_tests.sh @@ -0,0 +1,2 @@ +#!/bin/bash +python -m pytest -s -v test_transpose.py test_padding.py test_halo.py test_fft.py test_allgather.py &> validation_${SLURM_PROCID}.log diff --git a/tests/test_allgather.py b/tests/test_allgather.py index 9dfb265..3d1d70f 100644 --- a/tests/test_allgather.py +++ b/tests/test_allgather.py @@ -6,6 +6,7 @@ from math import prod import pytest +from conftest import initialize_distributed from jax.experimental import mesh_utils, multihost_utils from jax.experimental.shard_map import shard_map from jax.sharding import Mesh @@ -13,8 +14,7 @@ from numpy.testing import assert_array_equal # Initialize jax distributed to instruct jax local process which GPU to use -jaxdecomp.init() -jax.distributed.initialize() +initialize_distributed() rank = jax.process_index() size = jax.process_count() diff --git a/tests/test_fft.py b/tests/test_fft.py index 3a8de08..293f77a 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -6,6 +6,7 @@ import jax.numpy as jnp import pytest +from conftest import initialize_distributed from jax.experimental import mesh_utils, multihost_utils from jax.experimental.shard_map import shard_map from jax.sharding import Mesh @@ -15,7 +16,7 @@ import jaxdecomp # Initialize cuDecomp -jax.distributed.initialize() +initialize_distributed() rank = jax.process_index() size = jax.process_count() @@ -205,13 +206,3 @@ def test_vmap(): atol=1e-10) # Check the reverse FFT assert_allclose(array, rec_array, rtol=1e-10, atol=1e-10) - - -# find a way to finalize pytest -def test_end(): - # Make sure that it is cleaned up - # This has to be this way because pytest runs the "global code" before running the tests - # There are other solutions https://stackoverflow.com/questions/41871278/pytest-run-a-function-at-the-end-of-the-tests - # but this require the least amount of work - jaxdecomp.finalize() - jax.distributed.shutdown() diff --git a/tests/test_halo.py b/tests/test_halo.py index a34c53b..c2ace74 100644 --- a/tests/test_halo.py +++ b/tests/test_halo.py @@ -7,6 +7,7 @@ from math import prod import jax.numpy as jnp +from conftest import initialize_distributed from jax.experimental import mesh_utils, multihost_utils from jax.experimental.shard_map import shard_map from jax.sharding import Mesh @@ -17,8 +18,7 @@ from jaxdecomp import slice_pad, slice_unpad # Initialize jax distributed to instruct jax local process which GPU to use -jaxdecomp.init() -jax.distributed.initialize() +initialize_distributed() rank = jax.process_index() size = jax.process_count() @@ -134,13 +134,6 @@ def sharded_add_multiply(arr): updated_array, halo_extents=(halo_size, 0, 0), halo_periods=(False, False, False)) - exchanged_reduced_array = jaxdecomp.halo_exchange( - updated_array, - halo_extents=(halo_size, 0, 0), - halo_periods=(True, True, True), - reduce_halo=True) - unpadded_reduced_array = slice_unpad(exchanged_reduced_array, padding, - pdims) unpadded_updated_array = slice_unpad(updated_array, padding, pdims) # Gather array from all processes @@ -151,8 +144,6 @@ def sharded_add_multiply(arr): periodic_exchanged_array, tiled=True) updated_gathered_array = multihost_utils.process_allgather( updated_array, tiled=True) - gathered_exchanged_reduced_array = multihost_utils.process_allgather( - unpadded_reduced_array, tiled=True) gathered_unpadded_updated_array = multihost_utils.process_allgather( unpadded_updated_array, tiled=True) # Get the slices using array_split @@ -162,17 +153,14 @@ def sharded_add_multiply(arr): updated_gathered_array, size, axis=0) gathered_periodic_exchanged_slices = jnp.array_split( periodic_exchanged_gathered_array, size, axis=0) - gathered_reduced_exchange_slices = jnp.array_split( - gathered_exchanged_reduced_array, size, axis=0) gathered_unpadded_updated_slices = jnp.array_split( gathered_unpadded_updated_array, size, axis=0) gathered_arrays = zip(gathered_periodic_exchanged_slices, gathered_array_slices\ - , gathered_reduced_exchange_slices, gathered_updated_slices \ + , gathered_updated_slices \ , gathered_unpadded_updated_slices) - for slice_indx, (periodic_exchanged_slice, exchanged_slice, reduced_slice, - original_slice, + for slice_indx, (periodic_exchanged_slice, exchanged_slice, original_slice, unpadded_slice) in enumerate(gathered_arrays): next_indx = slice_indx + 1 @@ -208,29 +196,3 @@ def sharded_add_multiply(arr): periodic_exchanged_slice[:halo_size]) assert_array_equal(exchanged_slice[-halo_size:], periodic_exchanged_slice[-halo_size:]) - - # Test reduced halo - - # Lower center of the previous slice - previous_halo_extension = prev_slice[-2 * halo_size:-halo_size] - # Upper center of the next slice - next_halo_extension = next_slice[halo_size:2 * halo_size] - # Upper and lower center of the reduced slice - upper_halo_reduced = reduced_slice[:halo_size] - lower_halo_reduced = reduced_slice[-halo_size:] - # Upper and lower center of the original slice (after update no exchange and halo reduction) - upper_halo_original = unpadded_slice[:halo_size] - lower_halo_original = unpadded_slice[-halo_size:] - - # Upper slice should be equal to original upper slice + lower center of the previous slice - assert_array_equal(upper_halo_reduced, - (previous_halo_extension + upper_halo_original)) - # Lower slice should be equal to original lower slice + upper center of the next slice - assert_array_equal(lower_halo_reduced, - (next_halo_extension + lower_halo_original)) - - -def test_end(): - # fake test to finalize the MPI processes - jaxdecomp.finalize() - jax.distributed.shutdown() diff --git a/tests/test_padding.py b/tests/test_padding.py index 0f4917f..819ddc6 100644 --- a/tests/test_padding.py +++ b/tests/test_padding.py @@ -6,6 +6,7 @@ jax.config.update("jax_enable_x64", True) import jax.numpy as jnp +from conftest import initialize_distributed from jax import lax from jax.experimental import mesh_utils, multihost_utils from jax.experimental.shard_map import shard_map @@ -17,8 +18,7 @@ from jaxdecomp._src.padding import slice_pad, slice_unpad # Initialize jax distributed to instruct jax local process which GPU to use -jaxdecomp.init() -jax.distributed.initialize() +initialize_distributed() rank = jax.process_index() size = jax.process_count() @@ -203,9 +203,3 @@ def test_complex_unpad(pdims, global_shape): # Make sure the unpadded arrays is equal to the original array assert_array_equal(gathered_original, gathered_unpadded) - - -def test_end(): - # fake test to finalize the MPI processes - jaxdecomp.finalize() - jax.distributed.shutdown() diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 561e376..5c88862 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -9,6 +9,7 @@ import jax.numpy as jnp import numpy as np +from conftest import initialize_distributed from jax.experimental import mesh_utils, multihost_utils from jax.experimental.shard_map import shard_map from jax.sharding import Mesh @@ -19,8 +20,7 @@ from jaxdecomp import (transposeXtoY, transposeYtoX, transposeYtoZ, transposeZtoY) -jaxdecomp.init() -jax.distributed.initialize() +initialize_distributed() rank = jax.process_index() size = jax.process_count() @@ -193,9 +193,3 @@ def jax_transpose(global_array): print(f"Shape of JAX array {jax_grad.shape}") # Check the gradients assert_allclose(jax_grad, gathered_grads, rtol=1e-5, atol=1e-5) - - -def test_end(): - # fake test to finalize the MPI processes - jaxdecomp.finalize() - jax.distributed.shutdown()