From 131b03a9c4a3219f59c49c59d61ae3b1b953ab4c Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Wed, 19 Jun 2024 16:44:47 +0200 Subject: [PATCH 01/17] Clean-up transpose --- jaxdecomp/_src/transpose.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/jaxdecomp/_src/transpose.py b/jaxdecomp/_src/transpose.py index 3b90313..8f5eba0 100644 --- a/jaxdecomp/_src/transpose.py +++ b/jaxdecomp/_src/transpose.py @@ -5,13 +5,13 @@ import jax import jaxlib.mlir.ir as ir import numpy as np -from jax._src.api import ShapeDtypeStruct -from jax._src.core import ShapedArray +from jax import ShapeDtypeStruct from jax._src.interpreters import mlir from jax._src.lib.mlir.dialects import hlo from jax._src.typing import Array, ArrayLike -from jax.core import Primitive, ShapedArray -from jax.interpreters import ad, xla +from jax.core import ShapedArray +from jax.interpreters import mlir +from jax.interpreters.mlir import hlo from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P from jaxlib.hlo_helpers import custom_call @@ -21,10 +21,6 @@ from jaxdecomp._src.spmd_ops import (BasePrimitive, get_axis_size, register_primitive) -_out_axes = {'x_y': 1, 'y_z': 2, 'z_y': 1, 'y_x': 0} - -import traceback - class TransposePrimitive(BasePrimitive): From 003557cb61f5ae082e9fe683a02c30b428cfa189 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Wed, 19 Jun 2024 16:45:14 +0200 Subject: [PATCH 02/17] Clean-up FFT --- jaxdecomp/_src/fft.py | 422 ++++++++++++++++++------------------------ 1 file changed, 184 insertions(+), 238 deletions(-) diff --git a/jaxdecomp/_src/fft.py b/jaxdecomp/_src/fft.py index 2e4e0f4..3a83b1f 100644 --- a/jaxdecomp/_src/fft.py +++ b/jaxdecomp/_src/fft.py @@ -1,16 +1,16 @@ from functools import partial -from typing import Union +from typing import Tuple, Union import jax import jaxlib.mlir.ir as ir import numpy as np -from jax import jit -from jax._src.api import jit +from jax import ShapeDtypeStruct from jax._src.interpreters import mlir from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.util import promote_dtypes_complex -from jax.core import Primitive -from jax.interpreters import ad, xla +from jax.core import Primitive, ShapedArray +from jax.interpreters import ad, mlir, xla +from jax.interpreters.mlir import hlo from jax.lib import xla_client from jaxlib.hlo_helpers import custom_call @@ -19,9 +19,12 @@ FftType = xla_client.FftType from jax.experimental.custom_partitioning import custom_partitioning -from jax.sharding import NamedSharding +from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P +from jaxdecomp._src.spmd_ops import (BasePrimitive, get_axis_size, + register_primitive) + def _str_to_fft_type(s: str) -> xla_client.FftType: if s in ("fft", "FFT"): @@ -36,279 +39,222 @@ def _str_to_fft_type(s: str) -> xla_client.FftType: raise ValueError(f"Unknown FFT type '{s}'") -# Note : This must no longer be jitted because it will have single device abstract shapes -# The actual jit is done in the pfft function in lower_fn -# This means that sfft should never be lowered as is and only be lowered in the context of pfft -#@partial(jit, static_argnums=(1, 2, 3, 4)) -def sfft(x, - fft_type: Union[xla_client.FftType, str], - adjoint=False, - pdims=[1, 1], - global_shape=[1024, 1024, 1024]): - - #TODO(wassim) : find a way to prevent user from using the primitive directly - - if isinstance(fft_type, str): - typ = _str_to_fft_type(fft_type) - elif isinstance(fft_type, xla_client.FftType): - typ = fft_type - else: - raise TypeError(f"Unknown FFT type value '{fft_type}'") - - if typ in [xla_client.FftType.RFFT, xla_client.FftType.IRFFT]: - raise TypeError("only complex FFTs are currently supported through pfft.") - - (x,) = promote_dtypes_complex(x) - - return sfft_p.bind( - x, fft_type=typ, pdims=pdims, global_shape=global_shape, adjoint=adjoint) +class FFTPrimitive(BasePrimitive): + name = "fft" + multiple_results = False + impl_static_args = (1,) + inner_primitive = None + outer_primitive = None -def sfft_abstract_eval(x, fft_type, pdims, global_shape, adjoint): + @staticmethod + def abstract(x, fft_type, pdims, global_shape, adjoint): - # TODO(Wassim) : this only handles cube shapes - # This function is called twice once with the global array and once with the local slice shape + if global_shape == x.shape: + return FFTPrimitive.outer_abstract(x, fft_type=fft_type, adjoint=adjoint) - # Figure out what is the pencil decomposition at the output - axis = 0 - if fft_type in [xla_client.FftType.RFFT, xla_client.FftType.FFT]: - axis = 2 + match fft_type: + case xla_client.FftType.FFT: + # FFT is X to Y to Z so Z-Pencil is returned + # Except if we are doing a YZ slab in which case we return a Y-Pencil + transpose_shape = (1, 2, 0) + transposed_pdims = pdims + case xla_client.FftType.IFFT: + # IFFT is Z to X to Y so X-Pencil is returned + # In YZ slab case we only need one transposition back to get the X-Pencil + transpose_shape = (2, 0, 1) + transposed_pdims = pdims + case _: + raise TypeError( + "only complex FFTs are currently supported through pfft.") - output_shape = None - match fft_type: - case xla_client.FftType.FFT: - # FFT is X to Y to Z so Z-Pencil is returned - # Except if we are doing a YZ slab in which case we return a Y-Pencil - transpose_shape = (1, 2, 0) - transposed_pdims = pdims - case xla_client.FftType.IFFT: - # IFFT is Z to X to Y so X-Pencil is returned - # In YZ slab case we only need one transposition back to get the X-Pencil - transpose_shape = (2, 0, 1) - transposed_pdims = pdims - case _: - raise TypeError("only complex FFTs are currently supported through pfft.") - - # Are we operating on the global array? - # This is called when the abstract_eval of the custom partitioning is called _custom_partitioning_abstract_eval in https://github.com/google/jax/blob/main/jax/experimental/custom_partitioning.py#L223 - if x.shape == global_shape: - shape = tuple([global_shape[i] for i in transpose_shape]) - output_shape = shape - # Or are we operating on a local slice? - # this is called JAX calls make_jaxpr(lower_fn) in https://github.com/google/jax/blob/main/jax/experimental/custom_partitioning.py#L142C5-L142C35 - else: output_shape = (global_shape[transpose_shape[0]] // transposed_pdims[1], global_shape[transpose_shape[1]] // transposed_pdims[0], global_shape[transpose_shape[2]]) - # Sanity check - assert (output_shape is not None) - return x.update(shape=output_shape, dtype=x.dtype) - - -def sfft_lowering(ctx, a, *, fft_type, pdims, global_shape, adjoint): - (x_aval,) = ctx.avals_in - (aval_out,) = ctx.avals_out - dtype = x_aval.dtype - a_type = ir.RankedTensorType(a.type) - # We currently only support complex FFTs through this interface, so let's check the fft type - assert fft_type in (FftType.FFT, - FftType.IFFT), "Only complex FFTs are currently supported" - - # Figure out which fft we want - forward = fft_type in (FftType.FFT,) - is_double = np.finfo(dtype).dtype == np.float64 - - # Get original global shape - match fft_type: - case xla_client.FftType.FFT: - transpose_back_shape = (0, 1, 2) - case xla_client.FftType.IFFT: - transpose_back_shape = (2, 0, 1) - case _: + return ShapedArray(output_shape, x.dtype) + + @staticmethod + def outer_abstract(x, fft_type, adjoint): + + match fft_type: + case xla_client.FftType.FFT: + # FFT is X to Y to Z so Z-Pencil is returned + # Except if we are doing a YZ slab in which case we return a Y-Pencil + transpose_shape = (1, 2, 0) + case xla_client.FftType.IFFT: + # IFFT is Z to X to Y so X-Pencil is returned + # In YZ slab case we only need one transposition back to get the X-Pencil + transpose_shape = (2, 0, 1) + case _: + raise TypeError( + "only complex FFTs are currently supported through pfft.") + + output_shape = tuple([x.shape[i] for i in transpose_shape]) + return ShapedArray(output_shape, x.dtype) + + @staticmethod + def lowering(ctx, a, *, fft_type, pdims, global_shape, adjoint): + (x_aval,) = ctx.avals_in + (aval_out,) = ctx.avals_out + dtype = x_aval.dtype + a_type = ir.RankedTensorType(a.type) + # We currently only support complex FFTs through this interface, so let's check the fft type + assert fft_type in ( + FftType.FFT, FftType.IFFT), "Only complex FFTs are currently supported" + + # Figure out which fft we want + forward = fft_type in (FftType.FFT,) + is_double = np.finfo(dtype).dtype == np.float64 + + # Get original global shape + match fft_type: + case xla_client.FftType.FFT: + transpose_back_shape = (0, 1, 2) + case xla_client.FftType.IFFT: + transpose_back_shape = (2, 0, 1) + case _: + raise TypeError( + "only complex FFTs are currently supported through pfft.") + # Make sure to get back the original shape of the X-Pencil + global_shape = tuple([global_shape[i] for i in transpose_back_shape]) + # Compute the descriptor for our FFT + config = _jaxdecomp.GridConfig() + + config.pdims = pdims + config.gdims = global_shape[::-1] + config.halo_comm_backend = jaxdecomp.config.halo_comm_backend + config.transpose_comm_backend = jaxdecomp.config.transpose_comm_backend + workspace_size, opaque = _jaxdecomp.build_fft_descriptor( + config, forward, is_double, adjoint) + + n = len(a_type.shape) + layout = tuple(range(n - 1, -1, -1)) + + # We ask XLA to allocate a workspace for this operation. + # TODO: check that the memory is not used all the time, just when needed + workspace = mlir.full_like_aval( + ctx, 0, jax.core.ShapedArray(shape=[workspace_size], dtype=np.byte)) + + # Run the custom op with same input and output shape, so that we can perform operations + # inplace. + result = custom_call( + "pfft3d", + result_types=[a_type], + operands=[a, workspace], + operand_layouts=[layout, (0,)], + result_layouts=[layout], + has_side_effect=True, + operand_output_aliases={0: 0}, + backend_config=opaque, + ) + + # Finally we reshape the arry to the expected shape. + return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), result).results + + # Note : This must no longer be jitted because it will have single device abstract shapes + # The actual jit is done in the pfft function in lower_fn + # This means that sfft should never be lowered as is and only be lowered in the context of pfft + @staticmethod + def impl(x, fft_type, adjoint): + + if isinstance(fft_type, str): + typ = _str_to_fft_type(fft_type) + elif isinstance(fft_type, xla_client.FftType): + typ = fft_type + else: + raise TypeError(f"Unknown FFT type value '{fft_type}'") + + if typ in [xla_client.FftType.RFFT, xla_client.FftType.IRFFT]: raise TypeError("only complex FFTs are currently supported through pfft.") - # Make sure to get back the original shape of the X-Pencil - global_shape = tuple([global_shape[i] for i in transpose_back_shape]) - # Compute the descriptor for our FFT - config = _jaxdecomp.GridConfig() - - config.pdims = pdims - config.gdims = global_shape[::-1] - config.halo_comm_backend = jaxdecomp.config.halo_comm_backend - config.transpose_comm_backend = jaxdecomp.config.transpose_comm_backend - workspace_size, opaque = _jaxdecomp.build_fft_descriptor( - config, forward, is_double, adjoint) - - n = len(a_type.shape) - layout = tuple(range(n - 1, -1, -1)) - - # We ask XLA to allocate a workspace for this operation. - # TODO: check that the memory is not used all the time, just when needed - workspace = mlir.full_like_aval( - ctx, 0, jax.core.ShapedArray(shape=[workspace_size], dtype=np.byte)) - - # Run the custom op with same input and output shape, so that we can perform operations - # inplace. - result = custom_call( - "pfft3d", - result_types=[a_type], - operands=[a, workspace], - operand_layouts=[layout, (0,)], - result_layouts=[layout], - has_side_effect=True, - operand_output_aliases={0: 0}, - backend_config=opaque, - ) - - # Finally we reshape the arry to the expected shape. - out_type = ir.RankedTensorType.get(aval_out.shape, a_type.element_type) - return hlo.ReshapeOp(out_type, result).results - - -def _fft_transpose_rule(x, operand, fft_type, pdims, global_shape, adjoint): - assert fft_type in [FftType.FFT, FftType.IFFT] - if fft_type == FftType.FFT: - result = sfft(x, FftType.IFFT, ~adjoint, pdims, global_shape) - elif fft_type == FftType.IFFT: - result = sfft(x, FftType.FFT, ~adjoint, pdims, global_shape) - else: - raise NotImplementedError - - return (result,) + (x,) = promote_dtypes_complex(x) -def get_axis_size(sharding, index): - axis_name = sharding.spec[index] - if axis_name == None: - return 1 - else: - return sharding.mesh.shape[sharding.spec[index]] - + pdims = (1, jax.device_count()) + global_shape = x.shape -# Only named sharding have a spec -# this function is actually useless because positional sharding do not have a spec -# in case the user does not use a context mesh this will fail -# this is a placeholder function for the future -# the spec needs to be carried by a custom object that we create ourselfs -# to get inspired : https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuFFTMp/JAX_FFT/src/xfft/xfft.py#L20 -def to_named_sharding(sharding): - return NamedSharding(sharding.mesh, P(*sharding.spec)) + return FFTPrimitive.inner_primitive.bind( + x, + fft_type=typ, + pdims=pdims, + global_shape=global_shape, + adjoint=adjoint) + @staticmethod + def per_shard_impl(x, kind, pdims, global_shape, adjoint): + return FFTPrimitive.inner_primitive.bind( + x, kind=kind, pdims=pdims, global_shape=global_shape, adjoint=adjoint) -def partition(fft_type, adjoint, mesh, arg_shapes, result_shape): - """ - Tells XLA how to partition the primitive + @staticmethod + def infer_sharding_from_operands(fft_type, adjoint, mesh: Mesh, + arg_infos: Tuple[ShapeDtypeStruct], + result_infos: Tuple[ShapedArray]): + """ + Tell XLA how to infer the sharding of the output from the input sharding. Args: mesh (Mesh): The contextual mesh arg_shapes (tuple): A tuple of ShapeDtypeStruct that contains the shape and the sharding of each input operand - result_shape (ShapeDtypeStruct) : a ShapeDtypeStruct reprsenting a single output + result_shape (ShapedArray) : a single ShapedArray reprsenting a single output without the sharding information Returns: - Mesh (Mesh) : The mesh. - - function: The lowered function, to allow the user to redefine how the primitive is called in a context of a specific sharding result_sharding (XLACompatibleSharding): The sharding result for example a NamedSharding. - arg_shardings (tuple): a tuple of all XLACompatibleSharding of the input operands """ + input_sharding = arg_infos[0].sharding + return NamedSharding(mesh, P(*input_sharding.spec)) - # pfft only has one operand - input_sharding = arg_shapes[0].sharding + @staticmethod + def partition(fft_type, adjoint, mesh, arg_shapes, result_shape): + """ + Tells XLA how to partition the primitive - def lower_fn(operand): - # Operand is a local slice and arg_shapes contains the global shape - # No need to retranpose in the relowered function because abstract eval understands sliced input - # and in the original lowering we use aval.out - # it cannot work any other way because custom partition compares the output of the lower_fn with the abs eval (after comparing the global one) - # this means that the abs eval should handle both global shapes and slice shape + Args: + mesh (Mesh): The contextual mesh - global_shape = arg_shapes[0].shape - pdims = (get_axis_size(input_sharding, 1), get_axis_size(input_sharding, 0)) + arg_shapes (tuple): A tuple of ShapeDtypeStruct that contains the shape and the sharding of each input operand - output = sfft(operand, fft_type, adjoint, pdims, global_shape) - - # This is supposed to let us avoid making an extra transpose in the YZ case - # it does not work - # # In case of YZ slab the cuda code tranposes only once - # # We transpose again to give back the Z-Pencil to the user in case of FFT and the X-Pencil in case of IFFT - # # this transposition is supposed to compiled out by XLA when doing a gradient (forward followed by backward) - # if get_axis_size(input_sharding, 0) == 1: - # if fft_type == FftType.FFT: - # output = output.transpose((1, 2, 0)) - # elif fft_type == FftType.IFFT: - # output = output.transpose((2, 0, 1)) - return output - - return mesh, lower_fn, \ - to_named_sharding(result_shape.sharding), \ - (to_named_sharding(arg_shapes[0].sharding),) - - -def infer_sharding_from_operands(fft_type, adjoint, mesh, arg_shapes, - result_shape): - # Static arguments fft_type adjoint are carried along - """ - Tell XLA how to infer the sharding of the output from the input sharding. + result_shape (ShapeDtypeStruct) : a ShapeDtypeStruct reprsenting a single output - Args: - mesh (Mesh): The contextual mesh + Returns: + Mesh (Mesh) : The mesh. - arg_shapes (tuple): A tuple of ShapeDtypeStruct that contains the shape and the sharding of each input operand + function: The lowered function, to allow the user to redefine how the primitive is called in a context of a specific sharding - result_shape (ShapedArray) : a single ShapedArray reprsenting a single output without the sharding information + result_sharding (XLACompatibleSharding): The sharding result for example a NamedSharding. - Returns: + arg_shardings (tuple): a tuple of all XLACompatibleSharding of the input operands + """ - result_sharding (XLACompatibleSharding): The sharding result for example a NamedSharding. + # pfft only has one operand + input_sharding = NamedSharding(mesh, P(*arg_shapes[0].sharding.spec)) + output_sharding = NamedSharding(mesh, P(*result_shape.sharding.spec)) - """ - # only one operand is used in pfft - input_sharding = arg_shapes[0].sharding - return NamedSharding(mesh, P(*input_sharding.spec)) + pdims = (get_axis_size(input_sharding, 1), get_axis_size(input_sharding, 0)) + global_shape = arg_shapes[0].shape + + impl = partial( + FFTPrimitive.per_shard_impl, + fft_type=fft_type, + pdims=pdims, + global_shape=global_shape) + + return mesh, impl, output_sharding, (input_sharding,) + + +register_primitive(FFTPrimitive) -@partial(custom_partitioning, static_argnums=(1, 2)) def pfft_p_lower(x, fft_type, adjoint=False): - # the product of the fake dim has to be equal to the product of the global shape - # Fake dims and shape values are irrelevant because they are never used as concrete values only as Traced values - # their shapes however are used in the abstract eval of the custom partitioning - - size = jax.device_count() - # The pdims product must be equal to the number of devices because this is checked both in the abstract eval and in cudecomp - dummy_pdims = (1, size) - dummy_global = x.shape - return sfft(x, fft_type, adjoint, dummy_pdims, dummy_global) - - -sfft_p = Primitive("pfft") -sfft_p.def_impl(partial(xla.apply_primitive, sfft_p)) -sfft_p.def_abstract_eval(sfft_abstract_eval) -ad.deflinear2(sfft_p, _fft_transpose_rule) -mlir.register_lowering(sfft_p, sfft_lowering, platform="gpu") - -# Define the partitioning for the primitive -pfft_p_lower.def_partition( - partition=partition, - infer_sharding_from_operands=infer_sharding_from_operands) - -# declaring a differentiable SPMD primitive -# Inspired from -# https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/cpp_extensions.py#L188 -# https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/cpp_extensions.py#L694 -# https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/layernorm.py#L49 -# Note TE does a cleaner way of defining the primitive by using register_primitive(cls): that declares two primitives -# An inner which is represented here by sfft -# And an outer which by analogy shoud be represented here by pfft - - -# Do not jit this -# the jit is happening in jaxdecomp/fft.py: _do_pfft + return FFTPrimitive.outer_primitive.bind( + x, fft_type=fft_type, adjoint=adjoint) + + @partial(jax.custom_vjp, nondiff_argnums=(1, 2)) -def pfft(x, fft_type, adjoint=False): +def pfft(x, fft_type, adjoint): output, _ = _pfft_fwd_rule(x, fft_type=fft_type, adjoint=adjoint) return output From 6ccdfc97e67aeb6a66a643df539cf1624160dcfe Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Wed, 19 Jun 2024 17:06:24 +0200 Subject: [PATCH 03/17] Clean-up Halo --- jaxdecomp/_src/halo.py | 258 ++++++++++++++++++++--------------------- 1 file changed, 129 insertions(+), 129 deletions(-) diff --git a/jaxdecomp/_src/halo.py b/jaxdecomp/_src/halo.py index c526fb1..b973b2f 100644 --- a/jaxdecomp/_src/halo.py +++ b/jaxdecomp/_src/halo.py @@ -14,100 +14,127 @@ import jaxdecomp from jaxdecomp._src import _jaxdecomp +from jaxdecomp._src.spmd_ops import (BasePrimitive, get_axis_size, + register_primitive) + + +class HaloPrimitive(BasePrimitive): + + name = "halo_exchange" + multiple_results = False + impl_static_args = (1, 2, 3) + inner_primitive = None + + @staticmethod + def abstract(x, halo_extents, halo_periods, reduce_halo, pdims, global_shape): + return x.update(shape=x.shape, dtype=x.dtype) + + @staticmethod + def outer_abstract(x, halo_extents, halo_periods, reduce_halo, pdims, + global_shape): + return x.update(shape=x.shape, dtype=x.dtype) + + @staticmethod + def lowering(ctx, x, halo_extents, halo_periods, reduce_halo, pdims, + global_shape): + (x_aval,) = ctx.avals_in + x_type = ir.RankedTensorType(x.type) + n = len(x_type.shape) + + is_double = np.finfo(x_aval.dtype).dtype == np.float64 + + # Compute the descriptor for our FFT + config = _jaxdecomp.GridConfig() + config.pdims = pdims + config.gdims = global_shape[::-1] + config.halo_comm_backend = jaxdecomp.config.halo_comm_backend + config.transpose_comm_backend = jaxdecomp.config.transpose_comm_backend + + workspace_size, opaque = _jaxdecomp.build_halo_descriptor( + config, is_double, halo_extents[::-1], halo_periods[::-1], 0) + layout = tuple(range(n - 1, -1, -1)) + + workspace = mlir.full_like_aval( + ctx, 0, jax.core.ShapedArray(shape=[workspace_size], dtype=np.byte)) + + # reduce_halo is not used in the inner primitive + out = custom_call( + "halo", + result_types=[x_type], + operands=[x, workspace], + operand_layouts=[layout, (0,)], + result_layouts=[layout], + has_side_effect=True, + operand_output_aliases={0: 0}, + backend_config=opaque, + ) + return out.results + + @staticmethod + def impl(x, halo_extents, halo_periods, reduce_halo): + + pdims = (1, jax.device_count()) + global_shape = x.shape + + return HaloPrimitive.inner_primitive.bind( + x, + halo_extents=halo_extents, + halo_periods=halo_periods, + reduce_halo=reduce_halo, + pdims=pdims, + global_shape=global_shape, + ) + + @staticmethod + def per_shard_impl(x, halo_extents, halo_periods, reduce_halo, pdims, + global_shape): + output = HaloPrimitive.inner_primitive.bind( + x, + halo_extents=halo_extents, + halo_periods=halo_periods, + reduce_halo=reduce_halo, + pdims=pdims, + global_shape=global_shape, + ) + if reduce_halo: -# This is the inner primitive and it should not be used directly -def inner_halo_exchange(x, - *, - halo_extents: Tuple[int, int, int], - halo_periods: Tuple[bool, bool, bool], - reduce_halo: bool = True, - pdims: Tuple[int, int] = (1, 1), - global_shape: Tuple[int, int, int] = [1024, 1024, - 1024]): - # TODO: check float or real - return inner_halo_p.bind( - x, - halo_extents=halo_extents, - halo_periods=halo_periods, - reduce_halo=reduce_halo, - pdims=pdims, - global_shape=global_shape) - - -def inner_halo_abstract_eval(x, halo_extents, halo_periods, reduce_halo, pdims, - global_shape): - # The return shape is equal to the global shape for the inner primitive (the one that is not exposed) - # The return shape is equal to the slice shape for the outer primitive (the one that is exposed) - # in all cases it is x.shape - return x.update(shape=x.shape, dtype=x.dtype) - - -def inner_halo_lowering(ctx, x, *, halo_extents, halo_periods, reduce_halo, - pdims, global_shape): - (x_aval,) = ctx.avals_in - x_type = ir.RankedTensorType(x.type) - n = len(x_type.shape) - - is_double = np.finfo(x_aval.dtype).dtype == np.float64 - - # Compute the descriptor for our FFT - config = _jaxdecomp.GridConfig() - config.pdims = pdims - config.gdims = global_shape[::-1] - config.halo_comm_backend = jaxdecomp.config.halo_comm_backend - config.transpose_comm_backend = jaxdecomp.config.transpose_comm_backend - - workspace_size, opaque = _jaxdecomp.build_halo_descriptor( - config, is_double, halo_extents[::-1], halo_periods[::-1], 0) - layout = tuple(range(n - 1, -1, -1)) - - workspace = mlir.full_like_aval( - ctx, 0, jax.core.ShapedArray(shape=[workspace_size], dtype=np.byte)) - - # reduce_halo is not used in the inner primitive - out = custom_call( - "halo", - result_types=[x_type], - operands=[x, workspace], - operand_layouts=[layout, (0,)], - result_layouts=[layout], - has_side_effect=True, - operand_output_aliases={0: 0}, - backend_config=opaque, - ) - return out.results - - -def inner_halo_transpose_rule(x, operand, halo_extents, halo_periods, pdims, - global_shape): - result = halo_exchange(x, halo_extents, halo_periods, pdims, global_shape) - return (result,) - - -# Custom Primitive - - -def get_axis_size(sharding, index): - axis_name = sharding.spec[index] - if axis_name == None: - return 1 - else: - return sharding.mesh.shape[sharding.spec[index]] + halo_x, halo_y, halo_z = halo_extents + ## Apply correction along x + if halo_x > 0: + output = output.at[halo_x:halo_x + halo_x // 2].add(output[:halo_x // + 2]) + output = output.at[-(halo_x + halo_x // 2):-halo_x].add( + output[-halo_x // 2:]) + ## Apply correction along y + if halo_y > 0: + output = output.at[:, halo_y:halo_y + halo_y // 2].add( + output[:, :halo_y // 2]) + output = output.at[:, -(halo_y + halo_y // 2):-halo_y].add( + output[:, -halo_y // 2:]) + ## Apply correction along z + if halo_z > 0: + output = output.at[:, :, halo_z:halo_z + halo_z // 2].add( + output[:, :, :halo_z // 2]) + output = output.at[:, :, -(halo_z + halo_z // 2):-halo_z].add( + output[:, :, -halo_z // 2:]) -def partition(halo_extents, halo_periods, reduce_halo, mesh, arg_shapes, - result_shape): + return output - # halo_exchange has three operands - # x, halo_extents, halo_periods - # (sanity check) - #assert len(arg_shapes) == 3 , "halo_exchange must have only 3 operands in the partitioning lower function" - # only x is sharded the other are fully replicated - halo_exchange_sharding = arg_shapes[0].sharding + @staticmethod + def infer_sharding_from_operands(halo_extents, halo_periods, reduce_halo, + mesh, arg_shapes, result_shape): + # Sharding is the same here aswell because halo_exchange is a pointwise operation + halo_exchange_sharding = arg_shapes[0].sharding + return NamedSharding(mesh, P(*halo_exchange_sharding.spec)) - def lower_fn(operand): + @staticmethod + def partition(halo_extents, halo_periods, reduce_halo, mesh, arg_shapes, + result_shape): + halo_exchange_sharding = NamedSharding(mesh, + P(*arg_shapes[0].sharding)) global_shape = arg_shapes[0].shape pdims = (get_axis_size(halo_exchange_sharding, 1), get_axis_size(halo_exchange_sharding, 0)) @@ -116,59 +143,32 @@ def lower_fn(operand): global_shape[1] - 2 * pdims[0] * halo_extents[1],\ global_shape[2] - 2 * halo_extents[2]) - output = inner_halo_exchange(operand, halo_extents=halo_extents, \ - halo_periods=halo_periods, pdims=pdims, global_shape=shape_without_halo) - - if reduce_halo: - ## Apply correction along x - output = output.at[halo_extents[0]:2 * halo_extents[0]].add( - output[:halo_extents[0]]) - output = output.at[-2 * halo_extents[0]:-halo_extents[0]].add( - output[-halo_extents[0]:]) - ## Apply correction along y - output = output.at[:, halo_extents[1]:2 * halo_extents[1]].add( - output[:, :halo_extents[1]]) - output = output.at[:, -2 * halo_extents[1]:-halo_extents[1]].add( - output[:, -halo_extents[1]:]) - - return output + impl = partial( + HaloPrimitive.per_shard_impl, + halo_extents=halo_extents, + halo_periods=halo_periods, + reduce_halo=reduce_halo, + pdims=pdims, + global_shape=shape_without_halo) - return mesh, lower_fn, \ - result_shape.sharding, \ - (halo_exchange_sharding,) + return mesh, impl, halo_exchange_sharding, (halo_exchange_sharding,) -def infer_sharding_from_operands(halo_extents, halo_periods, reduce_halo, mesh, - arg_shapes, result_shape): - # Sharding is the same here aswell because halo_exchange is a pointwise operation - halo_exchange_sharding = arg_shapes[0].sharding - return halo_exchange_sharding +register_primitive(HaloPrimitive) -@partial(custom_partitioning, static_argnums=(1, 2, 3)) def halo_p_lower(x, halo_extents, halo_periods, reduce_halo): - size = jax.device_count() - # The pdims product must be equal to the number of devices because this is checked both in the abstract eval and in cudecomp - dummy_pdims = (1, size) - dummy_global = x.shape - return inner_halo_exchange(x, halo_extents=halo_extents, halo_periods=halo_periods,reduce_halo = reduce_halo,\ - pdims=dummy_pdims, global_shape=dummy_global) + return HaloPrimitive.outer_primitive.bind( + x, + halo_extents=halo_extents, + halo_periods=halo_periods, + reduce_halo=reduce_halo, + ) # declare primitive -inner_halo_p = Primitive("halo_exchange") -inner_halo_p.def_impl(partial(xla.apply_primitive, inner_halo_p)) -inner_halo_p.def_abstract_eval(inner_halo_abstract_eval) -ad.deflinear2(inner_halo_p, inner_halo_transpose_rule) -mlir.register_lowering(inner_halo_p, inner_halo_lowering, platform="gpu") - -# Define the partitioning for the primitive -halo_p_lower.def_partition( - partition=partition, - infer_sharding_from_operands=infer_sharding_from_operands) - # Custom Partitioning @partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3)) From 32e79298254cf0ec142cdec953adbcece90be6a6 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Mon, 24 Jun 2024 17:22:03 +0200 Subject: [PATCH 04/17] Make all tests run together --- tests/conftest.py | 32 ++++++++++++++++++++++++++++++++ tests/run_all_tests.sh | 2 ++ tests/test_allgather.py | 4 ++-- tests/test_fft.py | 13 ++----------- tests/test_halo.py | 25 +++++++++---------------- tests/test_padding.py | 9 ++------- tests/test_transpose.py | 10 ++-------- 7 files changed, 51 insertions(+), 44 deletions(-) create mode 100644 tests/conftest.py create mode 100755 tests/run_all_tests.sh diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1bbd47e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,32 @@ +import jax +import pytest + +import jaxdecomp + +setup_done = False + + +def initialize_distributed(): + global setup_done + if not setup_done: + jax.distributed.initialize() + setup_done = True + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_teardown_session(): + # Code to run at the start of the session + print("Starting session...") + initialize_distributed() + # Setup code here + # e.g., connecting to a database, initializing some resources, etc. + + yield + + # Code to run at the end of the session + print("Ending session...") + jaxdecomp.finalize() + jax.distributed.shutdown() + + # Teardown code here + # e.g., closing connections, cleaning up resources, etc. diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh new file mode 100755 index 0000000..c6db0da --- /dev/null +++ b/tests/run_all_tests.sh @@ -0,0 +1,2 @@ +#!/bin/bash +python -m pytest -s -v test_transpose.py test_padding.py test_halo.py test_fft.py test_allgather.py &> validation_${SLURM_PROCID}.log diff --git a/tests/test_allgather.py b/tests/test_allgather.py index 9dfb265..3d1d70f 100644 --- a/tests/test_allgather.py +++ b/tests/test_allgather.py @@ -6,6 +6,7 @@ from math import prod import pytest +from conftest import initialize_distributed from jax.experimental import mesh_utils, multihost_utils from jax.experimental.shard_map import shard_map from jax.sharding import Mesh @@ -13,8 +14,7 @@ from numpy.testing import assert_array_equal # Initialize jax distributed to instruct jax local process which GPU to use -jaxdecomp.init() -jax.distributed.initialize() +initialize_distributed() rank = jax.process_index() size = jax.process_count() diff --git a/tests/test_fft.py b/tests/test_fft.py index 3a8de08..293f77a 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -6,6 +6,7 @@ import jax.numpy as jnp import pytest +from conftest import initialize_distributed from jax.experimental import mesh_utils, multihost_utils from jax.experimental.shard_map import shard_map from jax.sharding import Mesh @@ -15,7 +16,7 @@ import jaxdecomp # Initialize cuDecomp -jax.distributed.initialize() +initialize_distributed() rank = jax.process_index() size = jax.process_count() @@ -205,13 +206,3 @@ def test_vmap(): atol=1e-10) # Check the reverse FFT assert_allclose(array, rec_array, rtol=1e-10, atol=1e-10) - - -# find a way to finalize pytest -def test_end(): - # Make sure that it is cleaned up - # This has to be this way because pytest runs the "global code" before running the tests - # There are other solutions https://stackoverflow.com/questions/41871278/pytest-run-a-function-at-the-end-of-the-tests - # but this require the least amount of work - jaxdecomp.finalize() - jax.distributed.shutdown() diff --git a/tests/test_halo.py b/tests/test_halo.py index a34c53b..b9dd6db 100644 --- a/tests/test_halo.py +++ b/tests/test_halo.py @@ -7,6 +7,7 @@ from math import prod import jax.numpy as jnp +from conftest import initialize_distributed from jax.experimental import mesh_utils, multihost_utils from jax.experimental.shard_map import shard_map from jax.sharding import Mesh @@ -17,8 +18,7 @@ from jaxdecomp import slice_pad, slice_unpad # Initialize jax distributed to instruct jax local process which GPU to use -jaxdecomp.init() -jax.distributed.initialize() +initialize_distributed() rank = jax.process_index() size = jax.process_count() @@ -209,18 +209,17 @@ def sharded_add_multiply(arr): assert_array_equal(exchanged_slice[-halo_size:], periodic_exchanged_slice[-halo_size:]) - # Test reduced halo - # Lower center of the previous slice - previous_halo_extension = prev_slice[-2 * halo_size:-halo_size] + previous_halo_extension = prev_slice[-2 * halo_size:-3 * (halo_size // 2)] # Upper center of the next slice - next_halo_extension = next_slice[halo_size:2 * halo_size] + next_halo_extension = next_slice[3 * (halo_size // 2):2 * halo_size] # Upper and lower center of the reduced slice - upper_halo_reduced = reduced_slice[:halo_size] - lower_halo_reduced = reduced_slice[-halo_size:] + + upper_halo_reduced = reduced_slice[:halo_size // 2] + lower_halo_reduced = reduced_slice[-(halo_size // 2):] # Upper and lower center of the original slice (after update no exchange and halo reduction) - upper_halo_original = unpadded_slice[:halo_size] - lower_halo_original = unpadded_slice[-halo_size:] + upper_halo_original = unpadded_slice[:halo_size // 2] + lower_halo_original = unpadded_slice[-(halo_size // 2):] # Upper slice should be equal to original upper slice + lower center of the previous slice assert_array_equal(upper_halo_reduced, @@ -228,9 +227,3 @@ def sharded_add_multiply(arr): # Lower slice should be equal to original lower slice + upper center of the next slice assert_array_equal(lower_halo_reduced, (next_halo_extension + lower_halo_original)) - - -def test_end(): - # fake test to finalize the MPI processes - jaxdecomp.finalize() - jax.distributed.shutdown() diff --git a/tests/test_padding.py b/tests/test_padding.py index 0f4917f..133b49f 100644 --- a/tests/test_padding.py +++ b/tests/test_padding.py @@ -6,6 +6,7 @@ jax.config.update("jax_enable_x64", True) import jax.numpy as jnp +from conftest import initialize_distributed from jax import lax from jax.experimental import mesh_utils, multihost_utils from jax.experimental.shard_map import shard_map @@ -17,8 +18,7 @@ from jaxdecomp._src.padding import slice_pad, slice_unpad # Initialize jax distributed to instruct jax local process which GPU to use -jaxdecomp.init() -jax.distributed.initialize() +initialize_distributed() rank = jax.process_index() size = jax.process_count() @@ -204,8 +204,3 @@ def test_complex_unpad(pdims, global_shape): # Make sure the unpadded arrays is equal to the original array assert_array_equal(gathered_original, gathered_unpadded) - -def test_end(): - # fake test to finalize the MPI processes - jaxdecomp.finalize() - jax.distributed.shutdown() diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 561e376..5c88862 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -9,6 +9,7 @@ import jax.numpy as jnp import numpy as np +from conftest import initialize_distributed from jax.experimental import mesh_utils, multihost_utils from jax.experimental.shard_map import shard_map from jax.sharding import Mesh @@ -19,8 +20,7 @@ from jaxdecomp import (transposeXtoY, transposeYtoX, transposeYtoZ, transposeZtoY) -jaxdecomp.init() -jax.distributed.initialize() +initialize_distributed() rank = jax.process_index() size = jax.process_count() @@ -193,9 +193,3 @@ def jax_transpose(global_array): print(f"Shape of JAX array {jax_grad.shape}") # Check the gradients assert_allclose(jax_grad, gathered_grads, rtol=1e-5, atol=1e-5) - - -def test_end(): - # fake test to finalize the MPI processes - jaxdecomp.finalize() - jax.distributed.shutdown() From 8de863c58bf463b4f07cc913661c1a0e32f0d4b4 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Mon, 24 Jun 2024 17:22:21 +0200 Subject: [PATCH 05/17] Minor fft fixes --- jaxdecomp/_src/fft.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/jaxdecomp/_src/fft.py b/jaxdecomp/_src/fft.py index 3a83b1f..3345829 100644 --- a/jaxdecomp/_src/fft.py +++ b/jaxdecomp/_src/fft.py @@ -9,8 +9,7 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.util import promote_dtypes_complex from jax.core import Primitive, ShapedArray -from jax.interpreters import ad, mlir, xla -from jax.interpreters.mlir import hlo +from jax.interpreters import ad, xla from jax.lib import xla_client from jaxlib.hlo_helpers import custom_call @@ -43,7 +42,7 @@ class FFTPrimitive(BasePrimitive): name = "fft" multiple_results = False - impl_static_args = (1,) + impl_static_args = (1, 2) inner_primitive = None outer_primitive = None @@ -168,8 +167,6 @@ def impl(x, fft_type, adjoint): if typ in [xla_client.FftType.RFFT, xla_client.FftType.IRFFT]: raise TypeError("only complex FFTs are currently supported through pfft.") - (x,) = promote_dtypes_complex(x) - pdims = (1, jax.device_count()) global_shape = x.shape @@ -181,9 +178,13 @@ def impl(x, fft_type, adjoint): adjoint=adjoint) @staticmethod - def per_shard_impl(x, kind, pdims, global_shape, adjoint): + def per_shard_impl(x, fft_type, pdims, global_shape, adjoint): return FFTPrimitive.inner_primitive.bind( - x, kind=kind, pdims=pdims, global_shape=global_shape, adjoint=adjoint) + x, + fft_type=fft_type, + pdims=pdims, + global_shape=global_shape, + adjoint=adjoint) @staticmethod def infer_sharding_from_operands(fft_type, adjoint, mesh: Mesh, @@ -240,7 +241,8 @@ def partition(fft_type, adjoint, mesh, arg_shapes, result_shape): FFTPrimitive.per_shard_impl, fft_type=fft_type, pdims=pdims, - global_shape=global_shape) + global_shape=global_shape, + adjoint=adjoint) return mesh, impl, output_sharding, (input_sharding,) @@ -248,13 +250,16 @@ def partition(fft_type, adjoint, mesh, arg_shapes, result_shape): register_primitive(FFTPrimitive) -def pfft_p_lower(x, fft_type, adjoint=False): +def pfft_p_lower(x, fft_type, adjoint): + + (x,) = promote_dtypes_complex(x) + return FFTPrimitive.outer_primitive.bind( x, fft_type=fft_type, adjoint=adjoint) @partial(jax.custom_vjp, nondiff_argnums=(1, 2)) -def pfft(x, fft_type, adjoint): +def pfft(x, fft_type, adjoint=False): output, _ = _pfft_fwd_rule(x, fft_type=fft_type, adjoint=adjoint) return output From 2f3a8d26504de0392e04528821bba5475f6205ee Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Mon, 24 Jun 2024 17:23:31 +0200 Subject: [PATCH 06/17] Minor halo fixes --- jaxdecomp/_src/halo.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/jaxdecomp/_src/halo.py b/jaxdecomp/_src/halo.py index b973b2f..66e872f 100644 --- a/jaxdecomp/_src/halo.py +++ b/jaxdecomp/_src/halo.py @@ -30,8 +30,7 @@ def abstract(x, halo_extents, halo_periods, reduce_halo, pdims, global_shape): return x.update(shape=x.shape, dtype=x.dtype) @staticmethod - def outer_abstract(x, halo_extents, halo_periods, reduce_halo, pdims, - global_shape): + def outer_abstract(x, halo_extents, halo_periods, reduce_halo): return x.update(shape=x.shape, dtype=x.dtype) @staticmethod @@ -134,7 +133,7 @@ def partition(halo_extents, halo_periods, reduce_halo, mesh, arg_shapes, result_shape): halo_exchange_sharding = NamedSharding(mesh, - P(*arg_shapes[0].sharding)) + P(*arg_shapes[0].sharding.spec)) global_shape = arg_shapes[0].shape pdims = (get_axis_size(halo_exchange_sharding, 1), get_axis_size(halo_exchange_sharding, 0)) From 8d26b00ec074beb4162b2560ba5e69259381b669 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Mon, 24 Jun 2024 17:23:41 +0200 Subject: [PATCH 07/17] Minor tranpose fixes --- jaxdecomp/_src/transpose.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/jaxdecomp/_src/transpose.py b/jaxdecomp/_src/transpose.py index 8f5eba0..e14ac0c 100644 --- a/jaxdecomp/_src/transpose.py +++ b/jaxdecomp/_src/transpose.py @@ -10,8 +10,6 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.typing import Array, ArrayLike from jax.core import ShapedArray -from jax.interpreters import mlir -from jax.interpreters.mlir import hlo from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P from jaxlib.hlo_helpers import custom_call From e1fd5581f2b0f9a6c19def6b01beb9e6d6831be5 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Thu, 27 Jun 2024 15:57:30 +0200 Subject: [PATCH 08/17] Fix halo descriptor hashing --- include/halo.h | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/include/halo.h b/include/halo.h index 70bd68f..aa40710 100644 --- a/include/halo.h +++ b/include/halo.h @@ -22,10 +22,13 @@ class haloDescriptor_t { ~haloDescriptor_t() = default; bool operator==(const haloDescriptor_t& other) const { - return (double_precision == other.double_precision && halo_extents == other.halo_extents && - halo_periods == other.halo_periods && axis == other.axis && config.gdims[0] == other.config.gdims[0] && - config.gdims[1] == other.config.gdims[1] && config.gdims[2] == other.config.gdims[2] && - config.pdims[0] == other.config.pdims[0] && config.pdims[1] == other.config.pdims[1]); + return (double_precision == other.double_precision && halo_extents[0] == other.halo_extents[0] && + halo_extents[1] == other.halo_extents[1] && halo_extents[2] == other.halo_extents[2] && + halo_periods[0] == other.halo_periods[0] && halo_periods[1] == other.halo_periods[1] && + halo_periods[2] == other.halo_periods[2] && axis == other.axis && + config.gdims[0] == other.config.gdims[0] && config.gdims[1] == other.config.gdims[1] && + config.gdims[2] == other.config.gdims[2] && config.pdims[0] == other.config.pdims[0] && + config.pdims[1] == other.config.pdims[1]); } }; From c69b64d1823bfb49fb01b22e278a0fe6aacea8c1 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Thu, 13 Jun 2024 13:42:05 +0200 Subject: [PATCH 09/17] Fix complex slice_unpad problem --- tests/test_padding.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_padding.py b/tests/test_padding.py index 133b49f..819ddc6 100644 --- a/tests/test_padding.py +++ b/tests/test_padding.py @@ -203,4 +203,3 @@ def test_complex_unpad(pdims, global_shape): # Make sure the unpadded arrays is equal to the original array assert_array_equal(gathered_original, gathered_unpadded) - From f1e6089e334a8f3c3dc963c8c9ae847af8cd5a78 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Thu, 27 Jun 2024 20:51:30 +0200 Subject: [PATCH 10/17] numpy doc and type hinting --- jaxdecomp/_src/fft.py | 328 +++++++++++++++++++++++++++++------- jaxdecomp/_src/halo.py | 321 +++++++++++++++++++++++++++++++---- jaxdecomp/_src/padding.py | 258 +++++++++++++++++++++++++--- jaxdecomp/_src/transpose.py | 266 ++++++++++++++++++++++++++--- jaxdecomp/fft.py | 73 +++++++- 5 files changed, 1091 insertions(+), 155 deletions(-) diff --git a/jaxdecomp/_src/fft.py b/jaxdecomp/_src/fft.py index 3345829..013921f 100644 --- a/jaxdecomp/_src/fft.py +++ b/jaxdecomp/_src/fft.py @@ -8,24 +8,40 @@ from jax._src.interpreters import mlir from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.util import promote_dtypes_complex +from jax._src.typing import Array from jax.core import Primitive, ShapedArray -from jax.interpreters import ad, xla from jax.lib import xla_client +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P from jaxlib.hlo_helpers import custom_call import jaxdecomp from jaxdecomp._src import _jaxdecomp - -FftType = xla_client.FftType -from jax.experimental.custom_partitioning import custom_partitioning -from jax.sharding import Mesh, NamedSharding -from jax.sharding import PartitionSpec as P - from jaxdecomp._src.spmd_ops import (BasePrimitive, get_axis_size, register_primitive) +FftType = xla_client.FftType + def _str_to_fft_type(s: str) -> xla_client.FftType: + """ + Convert a string to an FFT type enum. + + Parameters + ---------- + s : str + String representation of FFT type. + + Returns + ------- + xla_client.FftType + Corresponding FFT type enum. + + Raises + ------ + ValueError + If the string `s` does not match known FFT types. + """ if s in ("fft", "FFT"): return xla_client.FftType.FFT elif s in ("ifft", "IFFT"): @@ -39,6 +55,9 @@ def _str_to_fft_type(s: str) -> xla_client.FftType: class FFTPrimitive(BasePrimitive): + """ + Custom primitive for FFT operations. + """ name = "fft" multiple_results = False @@ -47,8 +66,30 @@ class FFTPrimitive(BasePrimitive): outer_primitive = None @staticmethod - def abstract(x, fft_type, pdims, global_shape, adjoint): - + def abstract(x: Array, fft_type: xla_client.FftType, pdims: Tuple[int, int], + global_shape: Tuple[int, int, + int], adjoint: bool) -> ShapedArray: + """ + Abstract function to compute the shape of FFT output. + + Parameters + ---------- + x : Array + Input array. + fft_type : xla_client.FftType + Type of FFT operation. + pdims : Tuple[int, int] + Parallel dimensions. + global_shape : Tuple[int, int, int] + Global shape of the array. + adjoint : bool + Whether to compute the adjoint FFT. + + Returns + ------- + ShapedArray + Shape of the output array. + """ if global_shape == x.shape: return FFTPrimitive.outer_abstract(x, fft_type=fft_type, adjoint=adjoint) @@ -74,8 +115,25 @@ def abstract(x, fft_type, pdims, global_shape, adjoint): return ShapedArray(output_shape, x.dtype) @staticmethod - def outer_abstract(x, fft_type, adjoint): - + def outer_abstract(x: Array, fft_type: xla_client.FftType, + adjoint: bool) -> ShapedArray: + """ + Abstract function for outer FFT operation. + + Parameters + ---------- + x : Array + Input array. + fft_type : xla_client.FftType + Type of FFT operation. + adjoint : bool + Whether to compute the adjoint FFT. + + Returns + ------- + ShapedArray + Shape of the output array. + """ match fft_type: case xla_client.FftType.FFT: # FFT is X to Y to Z so Z-Pencil is returned @@ -93,7 +151,32 @@ def outer_abstract(x, fft_type, adjoint): return ShapedArray(output_shape, x.dtype) @staticmethod - def lowering(ctx, a, *, fft_type, pdims, global_shape, adjoint): + def lowering(ctx, a: Array, *, fft_type: xla_client.FftType, + pdims: Tuple[int, int], global_shape: Tuple[int, int, + int], adjoint: bool): + """ + Lowering function for FFT primitive. + + Parameters + ---------- + ctx + Context. + a : Primitive + Input primitive. + fft_type : xla_client.FftType + Type of FFT operation. + pdims : Tuple[int, int] + Parallel dimensions. + global_shape : Tuple[int, int, int] + Global shape of the array. + adjoint : bool + Whether to compute the adjoint FFT. + + Returns + ------- + list + List of results from the operation. + """ (x_aval,) = ctx.avals_in (aval_out,) = ctx.avals_out dtype = x_aval.dtype @@ -119,7 +202,6 @@ def lowering(ctx, a, *, fft_type, pdims, global_shape, adjoint): global_shape = tuple([global_shape[i] for i in transpose_back_shape]) # Compute the descriptor for our FFT config = _jaxdecomp.GridConfig() - config.pdims = pdims config.gdims = global_shape[::-1] config.halo_comm_backend = jaxdecomp.config.halo_comm_backend @@ -131,7 +213,6 @@ def lowering(ctx, a, *, fft_type, pdims, global_shape, adjoint): layout = tuple(range(n - 1, -1, -1)) # We ask XLA to allocate a workspace for this operation. - # TODO: check that the memory is not used all the time, just when needed workspace = mlir.full_like_aval( ctx, 0, jax.core.ShapedArray(shape=[workspace_size], dtype=np.byte)) @@ -151,12 +232,25 @@ def lowering(ctx, a, *, fft_type, pdims, global_shape, adjoint): # Finally we reshape the arry to the expected shape. return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), result).results - # Note : This must no longer be jitted because it will have single device abstract shapes - # The actual jit is done in the pfft function in lower_fn - # This means that sfft should never be lowered as is and only be lowered in the context of pfft @staticmethod - def impl(x, fft_type, adjoint): - + def impl(x: Array, fft_type: Union[str, xla_client.FftType], adjoint: bool): + """ + Implementation function for FFT primitive. + + Parameters + ---------- + x : Array + Input array. + fft_type : Union[str, xla_client.FftType] + Type of FFT operation. + adjoint : bool + Whether to compute the adjoint FFT. + + Returns + ------- + Primitive + Result of the operation. + """ if isinstance(fft_type, str): typ = _str_to_fft_type(fft_type) elif isinstance(fft_type, xla_client.FftType): @@ -178,7 +272,30 @@ def impl(x, fft_type, adjoint): adjoint=adjoint) @staticmethod - def per_shard_impl(x, fft_type, pdims, global_shape, adjoint): + def per_shard_impl(x: Array, fft_type: xla_client.FftType, pdims: Tuple[int, + int], + global_shape: Tuple[int, int, int], adjoint: bool): + """ + Implementation function for per-shard FFT primitive. + + Parameters + ---------- + x : Array + Input array. + fft_type : xla_client.FftType + Type of FFT operation. + pdims : Tuple[int, int] + Parallel dimensions. + global_shape : Tuple[int, int, int] + Global shape of the array. + adjoint : bool + Whether to compute the adjoint FFT. + + Returns + ------- + Primitive + Result of the operation. + """ return FFTPrimitive.inner_primitive.bind( x, fft_type=fft_type, @@ -187,50 +304,60 @@ def per_shard_impl(x, fft_type, pdims, global_shape, adjoint): adjoint=adjoint) @staticmethod - def infer_sharding_from_operands(fft_type, adjoint, mesh: Mesh, - arg_infos: Tuple[ShapeDtypeStruct], - result_infos: Tuple[ShapedArray]): + def infer_sharding_from_operands( + fft_type: xla_client.FftType, adjoint: bool, mesh: Mesh, + arg_infos: Tuple[ShapeDtypeStruct], + result_infos: Tuple[ShapedArray]) -> NamedSharding: """ - Tell XLA how to infer the sharding of the output from the input sharding. - - Args: - mesh (Mesh): The contextual mesh - - arg_shapes (tuple): A tuple of ShapeDtypeStruct that contains the shape and the sharding of each input operand - - result_shape (ShapedArray) : a single ShapedArray reprsenting a single output without the sharding information - - Returns: - - result_sharding (XLACompatibleSharding): The sharding result for example a NamedSharding. - + Infer sharding for FFT primitive based on operands. + + Parameters + ---------- + fft_type : xla_client.FftType + Type of FFT operation. + adjoint : bool + Whether to compute the adjoint FFT. + mesh : Mesh + Contextual mesh for sharding. + arg_infos : Tuple[ShapeDtypeStruct] + Shape and sharding information of input operands. + result_infos : Tuple[ShapedArray] + Shape information of output. + + Returns + ------- + NamedSharding + Sharding information for the result. """ input_sharding = arg_infos[0].sharding return NamedSharding(mesh, P(*input_sharding.spec)) @staticmethod - def partition(fft_type, adjoint, mesh, arg_shapes, result_shape): + def partition( + fft_type: xla_client.FftType, adjoint: bool, mesh: Mesh, + arg_shapes: Tuple[ShapeDtypeStruct], result_shape: ShapeDtypeStruct + ) -> Tuple[Mesh, partial, NamedSharding, Tuple[NamedSharding]]: + """ + Partition the FFT primitive for XLA. + + Parameters + ---------- + fft_type : xla_client.FftType + Type of FFT operation. + adjoint : bool + Whether to compute the adjoint FFT. + mesh : Mesh + Contextual mesh for sharding. + arg_shapes : Tuple[ShapeDtypeStruct] + Shape and sharding information of input operands. + result_shape : ShapeDtypeStruct + Shape and sharding information of output. + + Returns + ------- + Tuple[Mesh, partial, NamedSharding, Tuple[NamedSharding]] + Mesh, lowered function, output sharding, and input operand sharding. """ - Tells XLA how to partition the primitive - - Args: - mesh (Mesh): The contextual mesh - - arg_shapes (tuple): A tuple of ShapeDtypeStruct that contains the shape and the sharding of each input operand - - result_shape (ShapeDtypeStruct) : a ShapeDtypeStruct reprsenting a single output - - Returns: - Mesh (Mesh) : The mesh. - - function: The lowered function, to allow the user to redefine how the primitive is called in a context of a specific sharding - - result_sharding (XLACompatibleSharding): The sharding result for example a NamedSharding. - - arg_shardings (tuple): a tuple of all XLACompatibleSharding of the input operands - """ - - # pfft only has one operand input_sharding = NamedSharding(mesh, P(*arg_shapes[0].sharding.spec)) output_sharding = NamedSharding(mesh, P(*result_shape.sharding.spec)) @@ -250,8 +377,25 @@ def partition(fft_type, adjoint, mesh, arg_shapes, result_shape): register_primitive(FFTPrimitive) -def pfft_p_lower(x, fft_type, adjoint): - +def pfft_p_lower(x: Array, fft_type: Union[str, xla_client.FftType], + adjoint: bool) -> Primitive: + """ + Lowering function for pfft primitive. + + Parameters + ---------- + x : Array + Input array. + fft_type : Union[str, xla_client.FftType] + Type of FFT operation. + adjoint : bool + Whether to compute the adjoint FFT. + + Returns + ------- + Primitive + Result of the operation. + """ (x,) = promote_dtypes_complex(x) return FFTPrimitive.outer_primitive.bind( @@ -259,18 +403,74 @@ def pfft_p_lower(x, fft_type, adjoint): @partial(jax.custom_vjp, nondiff_argnums=(1, 2)) -def pfft(x, fft_type, adjoint=False): +def pfft(x: Array, + fft_type: Union[str, xla_client.FftType], + adjoint: bool = False) -> Primitive: + """ + Custom VJP definition for pfft. + + Parameters + ---------- + x : Array + Input array. + fft_type : Union[str, xla_client.FftType] + Type of FFT operation. + adjoint : bool, optional + Whether to compute the adjoint FFT. Defaults to False. + + Returns + ------- + Primitive + Result of the operation. + """ output, _ = _pfft_fwd_rule(x, fft_type=fft_type, adjoint=adjoint) return output -def _pfft_fwd_rule(x, fft_type: str, adjoint: bool = False): - # Linear function has no residuals +def _pfft_fwd_rule(x: Array, + fft_type: Union[str, xla_client.FftType], + adjoint: bool = False) -> Tuple[Primitive, None]: + """ + Forward rule for pfft. + + Parameters + ---------- + x : Array + Input array. + fft_type : Union[str, xla_client.FftType] + Type of FFT operation. + adjoint : bool, optional + Whether to compute the adjoint FFT. Defaults to False. + + Returns + ------- + Tuple[Primitive, None] + Result of the operation and None (no residuals). + """ return pfft_p_lower(x, fft_type=fft_type, adjoint=adjoint), None -def _pfft_bwd_rule(fft_type, adjoint, ctx, g): - +def _pfft_bwd_rule(fft_type: Union[str, xla_client.FftType], adjoint: bool, ctx, + g: Primitive) -> Tuple[Primitive]: + """ + Backward rule for pfft. + + Parameters + ---------- + fft_type : Union[str, xla_client.FftType] + Type of FFT operation. + adjoint : bool + Whether to compute the adjoint FFT. + ctx + Context. + g : Primitive + Gradient value. + + Returns + ------- + Tuple[Primitive] + Result of the operation. + """ assert fft_type in [FftType.FFT, FftType.IFFT] if fft_type == FftType.FFT: fft_type = FftType.IFFT diff --git a/jaxdecomp/_src/halo.py b/jaxdecomp/_src/halo.py index 66e872f..671b254 100644 --- a/jaxdecomp/_src/halo.py +++ b/jaxdecomp/_src/halo.py @@ -5,9 +5,8 @@ import jaxlib.mlir.ir as ir import numpy as np from jax._src.interpreters import mlir +from jax._src.typing import Array from jax.core import Primitive -from jax.experimental.custom_partitioning import custom_partitioning -from jax.interpreters import ad, xla from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P from jaxlib.hlo_helpers import custom_call @@ -19,30 +18,108 @@ class HaloPrimitive(BasePrimitive): + """ + Custom primitive for performing halo exchange operation. + """ name = "halo_exchange" multiple_results = False impl_static_args = (1, 2, 3) inner_primitive = None + outer_primitive = None @staticmethod - def abstract(x, halo_extents, halo_periods, reduce_halo, pdims, global_shape): + def abstract(x: Array, halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, bool], reduce_halo: bool, + pdims: Tuple[int, int], global_shape: Tuple[int, int, + int]) -> Array: + """ + Abstract function for determining the shape and dtype after the halo exchange operation. + + Parameters + ---------- + x : Array + Input array. + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + reduce_halo : bool + Flag indicating whether to reduce the halo. + pdims : Tuple[int, int] + Processor dimensions. + global_shape : Tuple[int, int, int] + Global shape of the array. + + Returns + ------- + Array + Abstract array after the halo exchange operation. + """ return x.update(shape=x.shape, dtype=x.dtype) @staticmethod - def outer_abstract(x, halo_extents, halo_periods, reduce_halo): + def outer_abstract(x: Array, halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, + bool], reduce_halo: bool) -> Array: + """ + Abstract function for determining the shape and dtype without considering inner details. + + Parameters + ---------- + x : Array + Input array. + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + reduce_halo : bool + Flag indicating whether to reduce the halo. + + Returns + ------- + Array + Abstract array after the halo exchange operation. + """ return x.update(shape=x.shape, dtype=x.dtype) @staticmethod - def lowering(ctx, x, halo_extents, halo_periods, reduce_halo, pdims, - global_shape): + def lowering(ctx, x: Array, halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, bool], reduce_halo: bool, + pdims: Tuple[int, int], global_shape: Tuple[int, int, + int]) -> Array: + """ + Lowering function to generate the MLIR representation for halo exchange. + + Parameters + ---------- + ctx + Context for the operation. + x : Array + Input array. + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + reduce_halo : bool + Flag indicating whether to reduce the halo. + pdims : Tuple[int, int] + Processor dimensions. + global_shape : Tuple[int, int, int] + Global shape of the array. + + Returns + ------- + Array + Resulting array after the halo exchange operation. + """ (x_aval,) = ctx.avals_in x_type = ir.RankedTensorType(x.type) n = len(x_type.shape) is_double = np.finfo(x_aval.dtype).dtype == np.float64 - # Compute the descriptor for our FFT + # Compute the descriptor for the halo exchange operation config = _jaxdecomp.GridConfig() config.pdims = pdims config.gdims = global_shape[::-1] @@ -56,7 +133,7 @@ def lowering(ctx, x, halo_extents, halo_periods, reduce_halo, pdims, workspace = mlir.full_like_aval( ctx, 0, jax.core.ShapedArray(shape=[workspace_size], dtype=np.byte)) - # reduce_halo is not used in the inner primitive + # Perform custom call for halo exchange out = custom_call( "halo", result_types=[x_type], @@ -70,8 +147,28 @@ def lowering(ctx, x, halo_extents, halo_periods, reduce_halo, pdims, return out.results @staticmethod - def impl(x, halo_extents, halo_periods, reduce_halo): - + def impl(x: Array, halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, + bool], reduce_halo: bool) -> Primitive: + """ + Implementation function for performing halo exchange. + + Parameters + ---------- + x : Array + Input array. + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + reduce_halo : bool + Flag indicating whether to reduce the halo. + + Returns + ------- + Primitive + Inner primitive bound with input parameters. + """ pdims = (1, jax.device_count()) global_shape = x.shape @@ -85,8 +182,33 @@ def impl(x, halo_extents, halo_periods, reduce_halo): ) @staticmethod - def per_shard_impl(x, halo_extents, halo_periods, reduce_halo, pdims, - global_shape): + def per_shard_impl(x: Array, halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, bool], reduce_halo: bool, + pdims: Tuple[int, int], global_shape: Tuple[int, int, + int]) -> Array: + """ + Implementation function for performing halo exchange per shard. + + Parameters + ---------- + x : Array + Input array. + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + reduce_halo : bool + Flag indicating whether to reduce the halo. + pdims : Tuple[int, int] + Processor dimensions. + global_shape : Tuple[int, int, int] + Global shape of the array. + + Returns + ------- + Array + Resulting array after the halo exchange operation. + """ output = HaloPrimitive.inner_primitive.bind( x, halo_extents=halo_extents, @@ -97,22 +219,21 @@ def per_shard_impl(x, halo_extents, halo_periods, reduce_halo, pdims, ) if reduce_halo: - halo_x, halo_y, halo_z = halo_extents - ## Apply correction along x + # Apply corrections along x if halo_x > 0: output = output.at[halo_x:halo_x + halo_x // 2].add(output[:halo_x // 2]) output = output.at[-(halo_x + halo_x // 2):-halo_x].add( output[-halo_x // 2:]) - ## Apply correction along y + # Apply corrections along y if halo_y > 0: output = output.at[:, halo_y:halo_y + halo_y // 2].add( output[:, :halo_y // 2]) output = output.at[:, -(halo_y + halo_y // 2):-halo_y].add( output[:, -halo_y // 2:]) - ## Apply correction along z + # Apply corrections along z if halo_z > 0: output = output.at[:, :, halo_z:halo_z + halo_z // 2].add( output[:, :, :halo_z // 2]) @@ -122,24 +243,74 @@ def per_shard_impl(x, halo_extents, halo_periods, reduce_halo, pdims, return output @staticmethod - def infer_sharding_from_operands(halo_extents, halo_periods, reduce_halo, - mesh, arg_shapes, result_shape): - # Sharding is the same here aswell because halo_exchange is a pointwise operation + def infer_sharding_from_operands( + halo_extents: Tuple[int, int, + int], halo_periods: Tuple[bool, bool, + bool], reduce_halo: bool, + mesh: NamedSharding, arg_shapes: Tuple[ir.ShapeDtypeStruct], + result_shape: ir.ShapedArray) -> NamedSharding: + """ + Infer sharding information for halo exchange operation. + + Parameters + ---------- + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + reduce_halo : bool + Flag indicating whether to reduce the halo. + mesh : NamedSharding + Mesh object for sharding. + arg_shapes : Tuple[ir.ShapeDtypeStruct] + Shapes and dtypes of input operands. + result_shape : ir.ShapedArray + Shape and dtype of the output result. + + Returns + ------- + NamedSharding + Sharding information for halo exchange operation. + """ halo_exchange_sharding = arg_shapes[0].sharding return NamedSharding(mesh, P(*halo_exchange_sharding.spec)) @staticmethod - def partition(halo_extents, halo_periods, reduce_halo, mesh, arg_shapes, - result_shape): - + def partition(halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, bool], reduce_halo: bool, + mesh: NamedSharding, arg_shapes: Tuple[ir.ShapeDtypeStruct], + result_shape: ir.ShapedArray) -> Tuple[NamedSharding, partial]: + """ + Partition function for halo exchange operation. + + Parameters + ---------- + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + reduce_halo : bool + Flag indicating whether to reduce the halo. + mesh : NamedSharding + Mesh object for sharding. + arg_shapes : Tuple[ir.ShapeDtypeStruct] + Shapes and dtypes of input operands. + result_shape : ir.ShapedArray + Shape and dtype of the output result. + + Returns + ------- + Tuple[NamedSharding, partial] + Mesh object, implementation function, sharding information, and its tuple. + """ halo_exchange_sharding = NamedSharding(mesh, P(*arg_shapes[0].sharding.spec)) global_shape = arg_shapes[0].shape pdims = (get_axis_size(halo_exchange_sharding, 1), get_axis_size(halo_exchange_sharding, 0)) - shape_without_halo = (global_shape[0] - 2 * pdims[1] * halo_extents[0],\ - global_shape[1] - 2 * pdims[0] * halo_extents[1],\ + shape_without_halo = (global_shape[0] - 2 * pdims[1] * halo_extents[0], + global_shape[1] - 2 * pdims[0] * halo_extents[1], global_shape[2] - 2 * halo_extents[2]) impl = partial( @@ -156,8 +327,28 @@ def partition(halo_extents, halo_periods, reduce_halo, mesh, arg_shapes, register_primitive(HaloPrimitive) -def halo_p_lower(x, halo_extents, halo_periods, reduce_halo): - +def halo_p_lower(x: Array, halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, + bool], reduce_halo: bool) -> Primitive: + """ + Lowering function for the halo exchange operation. + + Parameters + ---------- + x : Array + Input array. + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + reduce_halo : bool + Flag indicating whether to reduce the halo. + + Returns + ------- + Primitive + Inner primitive bound with input parameters. + """ return HaloPrimitive.outer_primitive.bind( x, halo_extents=halo_extents, @@ -166,25 +357,89 @@ def halo_p_lower(x, halo_extents, halo_periods, reduce_halo): ) -# declare primitive - - # Custom Partitioning @partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3)) -def halo_exchange(x, halo_extents, halo_periods, reduce_halo=False): +def halo_exchange(x: Array, + halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, bool], + reduce_halo: bool = False) -> Array: + """ + Halo exchange operation with custom VJP. + + Parameters + ---------- + x : Array + Input array. + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + reduce_halo : bool, optional + Flag indicating whether to reduce the halo. Default is False. + + Returns + ------- + Array + Output array after the halo exchange operation. + """ output, _ = _halo_fwd_rule(x, halo_extents, halo_periods, reduce_halo) return output -def _halo_fwd_rule(x, halo_extents, halo_periods, reduce_halo): - # Linear function has no residuals +def _halo_fwd_rule(x: Array, halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, bool], + reduce_halo: bool) -> Tuple[Array, None]: + """ + Forward rule for the halo exchange operation. + + Parameters + ---------- + x : Array + Input array. + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + reduce_halo : bool + Flag indicating whether to reduce the halo. + + Returns + ------- + Tuple[Array, None] + Output array after the halo exchange operation and None for no residuals. + """ return halo_p_lower(x, halo_extents, halo_periods, reduce_halo), None -def _halo_bwd_rule(halo_extents, halo_periods, reduce_halo, ctx, g): +def _halo_bwd_rule(halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, bool], reduce_halo: bool, + ctx, g: Array) -> Tuple[Array]: + """ + Backward rule for the halo exchange operation. + + Parameters + ---------- + halo_extents : Tuple[int, int, int] + Extents of the halo in x, y, and z dimensions. + halo_periods : Tuple[bool, bool, bool] + Periodicity of the halo in x, y, and z dimensions. + reduce_halo : bool + Flag indicating whether to reduce the halo. + ctx + Context for the operation. + g : Array + Gradient array. + + Returns + ------- + Tuple[Array] + Gradient array after the halo exchange operation. + """ return halo_p_lower(g, halo_extents, halo_periods, reduce_halo), +# Define VJP for custom halo_exchange operation halo_exchange.defvjp(_halo_fwd_rule, _halo_bwd_rule) +# JIT compile the halo_exchange operation halo_exchange = jax.jit(halo_exchange, static_argnums=(1, 2, 3)) diff --git a/jaxdecomp/_src/padding.py b/jaxdecomp/_src/padding.py index 473e799..c541cf5 100644 --- a/jaxdecomp/_src/padding.py +++ b/jaxdecomp/_src/padding.py @@ -1,5 +1,3 @@ -from abc import ABCMeta, abstractmethod -from dataclasses import dataclass from functools import partial from typing import Tuple @@ -8,20 +6,32 @@ from jax._src.api import ShapeDtypeStruct from jax._src.core import ShapedArray from jax._src.typing import Array, ArrayLike -from jax.experimental.custom_partitioning import custom_partitioning from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P from jaxdecomp._src.spmd_ops import CustomParPrimitive, register_primitive -## Padding Custom SPMD lowering class SlicePaddingPrimitive(CustomParPrimitive): - - name = "slice_pad" - multiple_results = False - impl_static_args = (1, 2, 3) - outer_pritimive = None + """ + Custom primitive for slice padding operation. + + Attributes + ---------- + name : str + The name of the primitive operation. + multiple_results : bool + Whether the operation produces multiple results. + impl_static_args : tuple + Static arguments for the implementation. + outer_pritimive : object + Outer primitive used for the operation. + """ + + name: str = "slice_pad" + multiple_results: bool = False + impl_static_args: Tuple[int, int, int] = (1, 2, 3) + outer_pritimive: object = None # Global array implementation is used purely for its abstract eval # at jit time, the shape of the global array output is infered from this function @@ -33,7 +43,25 @@ def impl(arr: ArrayLike, padding_width: int | tuple[int], pdims: tuple[int], mode: str = 'constant') -> Array: - + """ + Implementation of the slice padding operation. + + Parameters + ---------- + arr : ArrayLike + Input array to be padded. + padding_width : int | tuple[int] + Width of padding to apply. + pdims : tuple[int] + Dimensions for padding. + mode : str, optional + Padding mode ('constant' by default). + + Returns + ------- + Array + Padded array. + """ assert arr.ndim == 3, "Only 3D arrays are supported" assert len(pdims) == 2, "Only 2D pdims are supported" @@ -83,11 +111,27 @@ def impl(arr: ArrayLike, return arr - # Actual per slice implementation of the primitive @staticmethod def per_shard_impl(arr: ArrayLike, padding_width: int | tuple[int], mode: str = 'constant') -> Array: + """ + Per-shard implementation of the slice padding operation. + + Parameters + ---------- + arr : ArrayLike + Input array to be padded. + padding_width : int | tuple[int] + Width of padding to apply. + mode : str, optional + Padding mode ('constant' by default). + + Returns + ------- + Array + Padded array. + """ return jnp.pad(arr, padding_width, mode=mode) @staticmethod @@ -95,6 +139,29 @@ def infer_sharding_from_operands(padding_width: int | tuple[int], pdims: tuple[int], mode: str, mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], result_infos: Tuple[ShapedArray]): + """ + Infers sharding information from operands for the slice padding operation. + + Parameters + ---------- + padding_width : int | tuple[int] + Width of padding to apply. + pdims : tuple[int] + Dimensions for padding. + mode : str + Padding mode. + mesh : Mesh + Computational mesh. + arg_infos : Tuple[ShapeDtypeStruct] + Information about operands. + result_infos : Tuple[ShapedArray] + Information about results. + + Returns + ------- + NamedSharding + Sharding information. + """ input_sharding = arg_infos[0].sharding return NamedSharding(input_sharding.mesh, P(*input_sharding.spec)) @@ -102,8 +169,29 @@ def infer_sharding_from_operands(padding_width: int | tuple[int], def partition(padding_width: int | tuple[int], pdims: tuple[int], mode: str, mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], result_infos: Tuple[ShapedArray]): - - # Only one non static input and one output + """ + Partitions the slice padding operation across a computational mesh. + + Parameters + ---------- + padding_width : int | tuple[int] + Width of padding to apply. + pdims : tuple[int] + Dimensions for padding. + mode : str + Padding mode. + mesh : Mesh + Computational mesh. + arg_infos : Tuple[ShapeDtypeStruct] + Information about operands. + result_infos : Tuple[ShapedArray] + Information about results. + + Returns + ------- + Tuple + Mesh, implementation, output sharding, and input sharding. + """ input_sharding = NamedSharding(mesh, P(*arg_infos[0].sharding.spec)) output_sharding = NamedSharding(mesh, P(*result_infos.sharding.spec)) impl = partial( @@ -121,25 +209,69 @@ def slice_pad(x: ArrayLike, padding_width: int | tuple[int], pdims: tuple[int], mode: str = 'constant') -> Array: + """ + JIT-compiled function for slice padding operation. + + Parameters + ---------- + x : ArrayLike + Input array to be padded. + padding_width : int | tuple[int] + Width of padding to apply. + pdims : tuple[int] + Dimensions for padding. + mode : str, optional + Padding mode ('constant' by default). + + Returns + ------- + Array + Padded array. + """ return SlicePaddingPrimitive.outer_lowering(x, padding_width, pdims, mode) -## Unpadding Custom SPMD lowering - - class SliceUnPaddingPrimitive(CustomParPrimitive): + """ + Custom primitive for slice unpading operation. + + Attributes + ---------- + name : str + The name of the primitive operation. + multiple_results : bool + Whether the operation produces multiple results. + impl_static_args : tuple + Static arguments for the implementation. + outer_pritimive : object + Outer primitive used for the operation. + """ + + name: str = "slice_unpad" + multiple_results: bool = False + impl_static_args: Tuple[int, int] = (1, 2) + outer_pritimive: object = None - name = "slice_unpad" - multiple_results = False - impl_static_args = (1, 2) - outer_pritimive = None - - # Same as padding, the global array implementation is used purely for its abstract eval @staticmethod def impl(arr: ArrayLike, padding_width: int | tuple[int], pdims: tuple[int]) -> Array: - - # If padding width is an integer then unpad the entire array + """ + Implementation of the slice unpading operation. + + Parameters + ---------- + arr : ArrayLike + Input array to be unpadded. + padding_width : int | tuple[int] + Width of padding to remove. + pdims : tuple[int] + Dimensions for unpadding. + + Returns + ------- + Array + Unpadded array. + """ if isinstance(padding_width, int): unpadding_width = ((padding_width, padding_width),) * arr.ndim elif isinstance(padding_width, tuple): @@ -175,10 +307,23 @@ def impl(arr: ArrayLike, padding_width: int | tuple[int], return arr - # Actual per slice implementation of the primitive @staticmethod def per_shard_impl(arr: ArrayLike, padding_width: int | tuple[int]) -> Array: - # If padding width is an integer then unpad the entire array + """ + Per-shard implementation of the slice unpading operation. + + Parameters + ---------- + arr : ArrayLike + Input array to be unpadded. + padding_width : int | tuple[int] + Width of padding to remove. + + Returns + ------- + Array + Unpadded array. + """ if isinstance(padding_width, int): unpadding_width = ((padding_width, padding_width),) * arr.ndim elif isinstance(padding_width, tuple): @@ -200,6 +345,27 @@ def infer_sharding_from_operands(padding_width: int | tuple[int], pdims: tuple[int], mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], result_infos: Tuple[ShapedArray]): + """ + Infers sharding information from operands for the slice unpading operation. + + Parameters + ---------- + padding_width : int | tuple[int] + Width of padding to remove. + pdims : tuple[int] + Dimensions for unpadding. + mesh : Mesh + Computational mesh. + arg_infos : Tuple[ShapeDtypeStruct] + Information about operands. + result_infos : Tuple[ShapedArray] + Information about results. + + Returns + ------- + NamedSharding + Sharding information. + """ input_sharding = arg_infos[0].sharding return NamedSharding(input_sharding.mesh, P(*input_sharding.spec)) @@ -207,8 +373,27 @@ def infer_sharding_from_operands(padding_width: int | tuple[int], def partition(padding_width: int | tuple[int], pdims: tuple[int], mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], result_infos: Tuple[ShapedArray]): - - # Only one non static input and one output + """ + Partitions the slice unpading operation across a computational mesh. + + Parameters + ---------- + padding_width : int | tuple[int] + Width of padding to remove. + pdims : tuple[int] + Dimensions for unpadding. + mesh : Mesh + Computational mesh. + arg_infos : Tuple[ShapeDtypeStruct] + Information about operands. + result_infos : Tuple[ShapedArray] + Information about results. + + Returns + ------- + Tuple + Mesh, implementation, output sharding, and input sharding. + """ input_sharding = NamedSharding(mesh, P(*arg_infos[0].sharding.spec)) output_sharding = NamedSharding(mesh, P(*result_infos.sharding.spec)) impl = partial( @@ -222,4 +407,21 @@ def partition(padding_width: int | tuple[int], pdims: tuple[int], mesh: Mesh, @partial(jit, static_argnums=(1, 2)) def slice_unpad(arr: ArrayLike, padding_width: int | tuple[int], pdims: tuple[int]) -> Array: + """ + JIT-compiled function for slice unpading operation. + + Parameters + ---------- + arr : ArrayLike + Input array to be unpadded. + padding_width : int | tuple[int] + Width of padding to remove. + pdims : tuple[int] + Dimensions for unpadding. + + Returns + ------- + Array + Unpadded array. + """ return SliceUnPaddingPrimitive.outer_lowering(arr, padding_width, pdims) diff --git a/jaxdecomp/_src/transpose.py b/jaxdecomp/_src/transpose.py index e14ac0c..d6b8eba 100644 --- a/jaxdecomp/_src/transpose.py +++ b/jaxdecomp/_src/transpose.py @@ -1,5 +1,4 @@ from functools import partial -from os import name from typing import Tuple import jax @@ -21,16 +20,51 @@ class TransposePrimitive(BasePrimitive): - - name = "transpose" - multiple_results = False - impl_static_args = (1,) - inner_primitive = None - outer_primitive = None + """ + JAX primitive for transposing arrays with different partitioning strategies. + + Attributes + ---------- + name : str + Name of the primitive ("transpose"). + multiple_results : bool + Boolean indicating if the primitive returns multiple results (False). + impl_static_args : tuple + Static arguments for the implementation (tuple containing (1,)). + inner_primitive : object + Inner core.Primitive object for the primitive. + outer_primitive : object + Outer core.Primitive object for the primitive. + """ + + name: str = "transpose" + multiple_results: bool = False + impl_static_args: Tuple[int] = (1,) + inner_primitive: object = None + outer_primitive: object = None @staticmethod - def abstract(x, kind, pdims, global_shape): - + def abstract(x: ArrayLike, kind: str, pdims: Tuple[int], + global_shape: Tuple[int]) -> ShapedArray: + """ + Abstract method to describe the shape of the output array after transposition. + + Parameters + ---------- + x : ArrayLike + Input array. + kind : str + Kind of transposition ('x_y', 'y_z', 'z_y', 'y_x'). + pdims : tuple[int] + Partition dimensions. + global_shape : tuple[int] + Global shape of the input array. + + Returns + ------- + ShapedArray + Abstract shape of the output array after transposition. + """ if global_shape == x.shape: return TransposePrimitive.outer_abstract(x, kind) # Make sure that global_shape is divisible by pdims and equals to slice @@ -58,8 +92,22 @@ def abstract(x, kind, pdims, global_shape): return ShapedArray(shape, x.dtype) @staticmethod - def outer_abstract(x, kind): - + def outer_abstract(x: ArrayLike, kind: str) -> ShapedArray: + """ + Abstract method for transposition that does not require knowledge of global shape. + + Parameters + ---------- + x : ArrayLike + Input array. + kind : str + Kind of transposition ('x_y', 'y_z', 'z_y', 'y_x'). + + Returns + ------- + ShapedArray + Abstract shape of the output array after transposition. + """ assert kind in ['x_y', 'y_z', 'z_y', 'y_x'] match kind: # From X to Y the axis are rolled by 1 and pdims are swapped wrt to the input pdims @@ -74,7 +122,29 @@ def outer_abstract(x, kind): return ShapedArray(shape, x.dtype) @staticmethod - def lowering(ctx, x, *, kind, pdims, global_shape): + def lowering(ctx, x: ArrayLike, *, kind: str, pdims: Tuple[int], + global_shape: Tuple[int]): + """ + Method to lower the transposition operation to MLIR. + + Parameters + ---------- + ctx : object + Context for the operation. + x : ArrayLike + Input array. + kind : str + Kind of transposition ('x_y', 'y_z', 'z_y', 'y_x'). + pdims : tuple[int] + Partition dimensions. + global_shape : tuple[int] + Global shape of the input array. + + Returns + ------- + List + List of lowered results. + """ assert kind in ['x_y', 'y_z', 'z_y', 'y_x'] (aval_in,) = ctx.avals_in (aval_out,) = ctx.avals_out @@ -130,23 +200,76 @@ def lowering(ctx, x, *, kind, pdims, global_shape): return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), result).results @staticmethod - def impl(x, kind): + def impl(x: ArrayLike, kind: str): + """ + Implementation method for the transposition primitive. + + Parameters + ---------- + x : ArrayLike + Input array. + kind : str + Kind of transposition ('x_y', 'y_z', 'z_y', 'y_x'). + + Returns + ------- + Object + Result of binding the inner primitive with input arguments. + """ size = jax.device_count() - # The pdims product must be equal to the number of devices because this is checked both in the abstract eval and in cudecomp - pdims = (size, 1) + pdims = (size, 1) # pdims product must be equal to the number of devices global_shape = x.shape return TransposePrimitive.inner_primitive.bind( x, kind=kind, pdims=pdims, global_shape=global_shape) @staticmethod - def per_shard_impl(x, kind, pdims, global_shape): + def per_shard_impl(x: ArrayLike, kind: str, pdims: Tuple[int], + global_shape: Tuple[int]): + """ + Per-shard implementation method for the transposition primitive. + + Parameters + ---------- + x : ArrayLike + Input array. + kind : str + Kind of transposition ('x_y', 'y_z', 'z_y', 'y_x'). + pdims : tuple[int] + Partition dimensions. + global_shape : tuple[int] + Global shape of the input array. + + Returns + ------- + Object + Result of binding the inner primitive with input arguments. + """ return TransposePrimitive.inner_primitive.bind( x, kind=kind, pdims=pdims, global_shape=global_shape) @staticmethod - def infer_sharding_from_operands(kind: str, mesh: Mesh, - arg_infos: Tuple[ShapeDtypeStruct], - result_infos: Tuple[ShapedArray]): + def infer_sharding_from_operands( + kind: str, mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], + result_infos: Tuple[ShapedArray]) -> NamedSharding: + """ + Method to infer sharding information from operands for custom partitioning. + + Parameters + ---------- + kind : str + Kind of transposition ('x_y', 'y_z', 'z_y', 'y_x'). + mesh : Mesh + Sharding mesh information. + arg_infos : Tuple[ShapeDtypeStruct] + Tuple of ShapeDtypeStruct for input operands. + result_infos : Tuple[ShapedArray] + Tuple of ShapedArray for result information. + + Returns + ------- + NamedSharding + Named sharding information. + """ input_sharding = arg_infos[0].sharding tranposed_pdims = (input_sharding.spec[1], input_sharding.spec[0], None) @@ -154,9 +277,29 @@ def infer_sharding_from_operands(kind: str, mesh: Mesh, return NamedSharding(input_sharding.mesh, P(*tranposed_pdims)) @staticmethod - def partition(kind: str, mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], - result_infos: Tuple[ShapedArray]): - + def partition( + kind: str, mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], + result_infos: Tuple[ShapedArray] + ) -> Tuple[Mesh, callable, NamedSharding, Tuple[NamedSharding]]: + """ + Method to partition the transposition operation for custom partitioning. + + Parameters + ---------- + kind : str + Kind of transposition ('x_y', 'y_z', 'z_y', 'y_x'). + mesh : Mesh + Sharding mesh information. + arg_infos : Tuple[ShapeDtypeStruct] + Tuple of ShapeDtypeStruct for input operands. + result_infos : Tuple[ShapedArray] + Tuple of ShapedArray for result information. + + Returns + ------- + Tuple + Tuple containing mesh, implementation function, output sharding, and input sharding. + """ input_sharding = NamedSharding(mesh, P(*arg_infos[0].sharding.spec)) output_sharding = NamedSharding(mesh, P(*result_infos.sharding.spec)) global_shape = arg_infos[0].shape @@ -180,14 +323,44 @@ def partition(kind: str, mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], register_primitive(TransposePrimitive) -@partial(jax.jit, static_argnums=(1)) -def transpose(x, kind: str) -> Array: +@partial(jax.jit, static_argnums=(1,)) +def transpose(x: ArrayLike, kind: str) -> Array: + """ + JIT-compiled function for performing transposition using the outer primitive. + + Parameters + ---------- + x : ArrayLike + Input array. + kind : str + Kind of transposition ('x_y', 'y_z', 'z_y', 'y_x'). + + Returns + ------- + Array + Transposed array. + """ return TransposePrimitive.outer_primitive.bind(x, kind=kind) -# X to Y +# Custom transposition functions + + @jax.custom_vjp def transposeXtoY(x: ArrayLike) -> Array: + """ + Custom JAX transposition function for X to Y. + + Parameters + ---------- + x : ArrayLike + Input array. + + Returns + ------- + Array + Transposed array. + """ return transpose(x, 'x_y') @@ -199,9 +372,21 @@ def transposeXtoY_bwd(_, g): return transpose(g, 'y_x'), -# Y to X @jax.custom_vjp def transposeYtoZ(x: ArrayLike) -> Array: + """ + Custom JAX transposition function for Y to Z. + + Parameters + ---------- + x : ArrayLike + Input array. + + Returns + ------- + Array + Transposed array. + """ return transpose(x, 'y_z') @@ -213,9 +398,21 @@ def transposeYtoZ_bwd(_, g): return transpose(g, 'z_y'), -# Z to Y @jax.custom_vjp def transposeZtoY(x: ArrayLike) -> Array: + """ + Custom JAX transposition function for Z to Y. + + Parameters + ---------- + x : ArrayLike + Input array. + + Returns + ------- + Array + Transposed array. + """ return transpose(x, 'z_y') @@ -227,9 +424,21 @@ def transposeZtoY_bwd(_, g): return transpose(g, 'y_z'), -# Y to X @jax.custom_vjp def transposeYtoX(x: ArrayLike) -> Array: + """ + Custom JAX transposition function for Y to X. + + Parameters + ---------- + x : ArrayLike + Input array. + + Returns + ------- + Array + Transposed array. + """ return transpose(x, 'y_x') @@ -241,6 +450,7 @@ def transposeYtoX_bwd(_, g): return transpose(g, 'x_y'), +# Define VJPs for custom transposition functions transposeXtoY.defvjp(transposeXtoY_fwd, transposeXtoY_bwd) transposeYtoZ.defvjp(transposeYtoZ_fwd, transposeYtoZ_bwd) transposeZtoY.defvjp(transposeZtoY_fwd, transposeZtoY_bwd) diff --git a/jaxdecomp/fft.py b/jaxdecomp/fft.py index 137c842..e7938b6 100644 --- a/jaxdecomp/fft.py +++ b/jaxdecomp/fft.py @@ -17,6 +17,28 @@ def _fft_norm(s: Array, func_name: str, norm: str) -> Array: + """ + Compute the normalization factor for FFT operations. + + Parameters + ---------- + s : Array + Shape of the input array. + func_name : str + Name of the FFT function ("fft" or "ifft"). + norm : str + Type of normalization ("backward", "ortho", or "forward"). + + Returns + ------- + Array + Normalization factor. + + Raises + ------ + ValueError + If an invalid norm value is provided. + """ if norm == "backward": return 1 / jnp.prod(s) if func_name.startswith("i") else jnp.array(1) elif norm == "ortho": @@ -37,8 +59,25 @@ def _do_pfft( arr: ArrayLike, norm: Optional[str], ) -> Array: - # this is not allowed in a multi host setup - # arr = jnp.asarray(a) + """ + Perform 3D FFT or inverse 3D FFT on the input array. + + Parameters + ---------- + func_name : str + Name of the FFT function ("fft" or "ifft"). + fft_type : xla_client.FftType + Type of FFT operation. + arr : ArrayLike + Input array to transform. + norm : Optional[str] + Type of normalization ("backward", "ortho", or "forward"). + + Returns + ------- + Array + Transformed array after FFT or inverse FFT. + """ transformed = _pfft(arr, fft_type) transformed *= _fft_norm( jnp.array(arr.shape, dtype=transformed.dtype), func_name, norm) @@ -46,8 +85,38 @@ def _do_pfft( def pfft3d(a: ArrayLike, norm: Optional[str] = "backward") -> Array: + """ + Perform 3D FFT on the input array. + + Parameters + ---------- + a : ArrayLike + Input array to transform. + norm : Optional[str], optional + Type of normalization ("backward", "ortho", or "forward"), by default "backward". + + Returns + ------- + Array + Transformed array after 3D FFT. + """ return _do_pfft("fft", xla_client.FftType.FFT, a, norm=norm) def pifft3d(a: ArrayLike, norm: Optional[str] = "backward") -> Array: + """ + Perform inverse 3D FFT on the input array. + + Parameters + ---------- + a : ArrayLike + Input array to transform. + norm : Optional[str], optional + Type of normalization ("backward", "ortho", or "forward"), by default "backward". + + Returns + ------- + Array + Transformed array after inverse 3D FFT. + """ return _do_pfft("ifft", xla_client.FftType.IFFT, a, norm=norm) From 09cc77d2250039ccefea2044e237c6d6bc342a04 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Thu, 27 Jun 2024 22:22:26 +0200 Subject: [PATCH 11/17] minor fix --- jaxdecomp/_src/halo.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/jaxdecomp/_src/halo.py b/jaxdecomp/_src/halo.py index 671b254..e61fe8c 100644 --- a/jaxdecomp/_src/halo.py +++ b/jaxdecomp/_src/halo.py @@ -4,9 +4,10 @@ import jax import jaxlib.mlir.ir as ir import numpy as np +from jax import ShapeDtypeStruct from jax._src.interpreters import mlir from jax._src.typing import Array -from jax.core import Primitive +from jax.core import Primitive, ShapedArray from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P from jaxlib.hlo_helpers import custom_call @@ -247,8 +248,8 @@ def infer_sharding_from_operands( halo_extents: Tuple[int, int, int], halo_periods: Tuple[bool, bool, bool], reduce_halo: bool, - mesh: NamedSharding, arg_shapes: Tuple[ir.ShapeDtypeStruct], - result_shape: ir.ShapedArray) -> NamedSharding: + mesh: NamedSharding, arg_infos: Tuple[ShapeDtypeStruct], + result_infos: Tuple[ShapedArray]) -> NamedSharding: """ Infer sharding information for halo exchange operation. @@ -272,14 +273,16 @@ def infer_sharding_from_operands( NamedSharding Sharding information for halo exchange operation. """ - halo_exchange_sharding = arg_shapes[0].sharding + halo_exchange_sharding = arg_infos[0].sharding return NamedSharding(mesh, P(*halo_exchange_sharding.spec)) @staticmethod - def partition(halo_extents: Tuple[int, int, int], - halo_periods: Tuple[bool, bool, bool], reduce_halo: bool, - mesh: NamedSharding, arg_shapes: Tuple[ir.ShapeDtypeStruct], - result_shape: ir.ShapedArray) -> Tuple[NamedSharding, partial]: + def partition( + halo_extents: Tuple[int, int, + int], halo_periods: Tuple[bool, bool, + bool], reduce_halo: bool, + mesh: NamedSharding, arg_shapes: Tuple[ShapeDtypeStruct], + result_shape: ShapeDtypeStruct) -> Tuple[NamedSharding, partial]: """ Partition function for halo exchange operation. From 947fd6a8c6452438a3abbd7962ad2a0e32657f6f Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Fri, 28 Jun 2024 12:35:54 +0200 Subject: [PATCH 12/17] Allow users to build with cuda 11 or 12 (12 by default) --- CMakeLists.txt | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e0d58b..fced2db 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,9 @@ cmake_minimum_required(VERSION 3.19...3.25) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) # Latest JAX v0.4.26 no longer supports cuda 11.8 -set(NVHPC_CUDA_VERSION 12.2) +# By default, build for CUDA 12.2, users can override this with -DNVHPC_CUDA_VERSION=11.8 +set(NVHPC_CUDA_VERSION 12.2 CACHE STRING "CUDA version to build for" ) + # Build debug # set(CMAKE_BUILD_TYPE Debug) add_subdirectory(third_party/cuDecomp) @@ -15,8 +17,8 @@ option(CUDECOMP_BUILD_FORTRAN "Build Fortran bindings" OFF) option(CUDECOMP_ENABLE_NVSHMEM "Enable NVSHMEM" OFF) option(CUDECOMP_BUILD_EXTRAS "Build benchmark, examples, and tests" OFF) -set(CUDECOMP_CUDA_CC_LIST "70;80" CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.") - +# 70: Volta, 80: Ampere, 89: RTX 4060 +set(CUDECOMP_CUDA_CC_LIST "70;80;89" CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.") # Add pybind11 and cuDecomp subdirectories add_subdirectory(pybind11) From 3b414bb30a52932b8f0f664d86e3c5f8928f06be Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Thu, 4 Jul 2024 19:07:49 +0200 Subject: [PATCH 13/17] Update change log --- CHANGELOG.md | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f993dd9..753b753 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,19 +1,17 @@ # Change log - +## jaxdecomp 0.0.1 +* New version compatible with JAX 0.4.30 +* jaxDecomp now works in a multi-host environment +* Added custom partitioning for FFTs +* Added custom partitioning for halo exchange +* Added custom partitioning for slice_pad and slice_unpad +* Add example for multi-host FFTs in `examples/jaxdecomp_lpt.py` -## jaxdecomp 0.0.1 + +## jaxdecomp 0.0.1rc2 * Changes * Added utility to run autotuning From d7d1308ff86a0db0f6fc6e58c38ff65f941efe62 Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Sun, 7 Jul 2024 17:18:13 -0400 Subject: [PATCH 14/17] Update CHANGELOG.md --- CHANGELOG.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 753b753..10cbb41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,12 +3,13 @@ ## jaxdecomp 0.0.1 -* New version compatible with JAX 0.4.30 -* jaxDecomp now works in a multi-host environment -* Added custom partitioning for FFTs -* Added custom partitioning for halo exchange -* Added custom partitioning for slice_pad and slice_unpad -* Add example for multi-host FFTs in `examples/jaxdecomp_lpt.py` +* Changes + * New version compatible with JAX 0.4.30 + * jaxDecomp now works in a multi-host environment + * Added custom partitioning for FFTs + * Added custom partitioning for halo exchange + * Added custom partitioning for slice_pad and slice_unpad + * Add example for multi-host FFTs in `examples/jaxdecomp_lpt.py` ## jaxdecomp 0.0.1rc2 From b22659083b1cde97f83c8e89567a218056819205 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Mon, 8 Jul 2024 00:29:37 +0200 Subject: [PATCH 15/17] Remove reduce halo --- README.md | 3 +- jaxdecomp/_src/halo.py | 96 ++++++++---------------------------------- tests/test_halo.py | 35 +-------------- 3 files changed, 21 insertions(+), 113 deletions(-) diff --git a/README.md b/README.md index d266c5d..009e04e 100644 --- a/README.md +++ b/README.md @@ -57,8 +57,7 @@ with mesh: # Perform a halo exchange + reduce exchanged_reduced = jaxdecomp.halo_exchange(padded_array, halo_extents=(32,32,32), - halo_periods=(True,True,True), - reduce_halo=True) + halo_periods=(True,True,True)) # Remove the halo regions recarray = jaxdecomp.slice_unpad(exchanged_reduced, padding_width, pdims) diff --git a/jaxdecomp/_src/halo.py b/jaxdecomp/_src/halo.py index e61fe8c..ddcf5cc 100644 --- a/jaxdecomp/_src/halo.py +++ b/jaxdecomp/_src/halo.py @@ -31,9 +31,8 @@ class HaloPrimitive(BasePrimitive): @staticmethod def abstract(x: Array, halo_extents: Tuple[int, int, int], - halo_periods: Tuple[bool, bool, bool], reduce_halo: bool, - pdims: Tuple[int, int], global_shape: Tuple[int, int, - int]) -> Array: + halo_periods: Tuple[bool, bool, bool], pdims: Tuple[int, int], + global_shape: Tuple[int, int, int]) -> Array: """ Abstract function for determining the shape and dtype after the halo exchange operation. @@ -45,8 +44,6 @@ def abstract(x: Array, halo_extents: Tuple[int, int, int], Extents of the halo in x, y, and z dimensions. halo_periods : Tuple[bool, bool, bool] Periodicity of the halo in x, y, and z dimensions. - reduce_halo : bool - Flag indicating whether to reduce the halo. pdims : Tuple[int, int] Processor dimensions. global_shape : Tuple[int, int, int] @@ -61,8 +58,7 @@ def abstract(x: Array, halo_extents: Tuple[int, int, int], @staticmethod def outer_abstract(x: Array, halo_extents: Tuple[int, int, int], - halo_periods: Tuple[bool, bool, - bool], reduce_halo: bool) -> Array: + halo_periods: Tuple[bool, bool, bool]) -> Array: """ Abstract function for determining the shape and dtype without considering inner details. @@ -74,8 +70,6 @@ def outer_abstract(x: Array, halo_extents: Tuple[int, int, int], Extents of the halo in x, y, and z dimensions. halo_periods : Tuple[bool, bool, bool] Periodicity of the halo in x, y, and z dimensions. - reduce_halo : bool - Flag indicating whether to reduce the halo. Returns ------- @@ -86,9 +80,8 @@ def outer_abstract(x: Array, halo_extents: Tuple[int, int, int], @staticmethod def lowering(ctx, x: Array, halo_extents: Tuple[int, int, int], - halo_periods: Tuple[bool, bool, bool], reduce_halo: bool, - pdims: Tuple[int, int], global_shape: Tuple[int, int, - int]) -> Array: + halo_periods: Tuple[bool, bool, bool], pdims: Tuple[int, int], + global_shape: Tuple[int, int, int]) -> Array: """ Lowering function to generate the MLIR representation for halo exchange. @@ -102,8 +95,6 @@ def lowering(ctx, x: Array, halo_extents: Tuple[int, int, int], Extents of the halo in x, y, and z dimensions. halo_periods : Tuple[bool, bool, bool] Periodicity of the halo in x, y, and z dimensions. - reduce_halo : bool - Flag indicating whether to reduce the halo. pdims : Tuple[int, int] Processor dimensions. global_shape : Tuple[int, int, int] @@ -149,8 +140,7 @@ def lowering(ctx, x: Array, halo_extents: Tuple[int, int, int], @staticmethod def impl(x: Array, halo_extents: Tuple[int, int, int], - halo_periods: Tuple[bool, bool, - bool], reduce_halo: bool) -> Primitive: + halo_periods: Tuple[bool, bool, bool]) -> Primitive: """ Implementation function for performing halo exchange. @@ -162,8 +152,6 @@ def impl(x: Array, halo_extents: Tuple[int, int, int], Extents of the halo in x, y, and z dimensions. halo_periods : Tuple[bool, bool, bool] Periodicity of the halo in x, y, and z dimensions. - reduce_halo : bool - Flag indicating whether to reduce the halo. Returns ------- @@ -177,14 +165,13 @@ def impl(x: Array, halo_extents: Tuple[int, int, int], x, halo_extents=halo_extents, halo_periods=halo_periods, - reduce_halo=reduce_halo, pdims=pdims, global_shape=global_shape, ) @staticmethod def per_shard_impl(x: Array, halo_extents: Tuple[int, int, int], - halo_periods: Tuple[bool, bool, bool], reduce_halo: bool, + halo_periods: Tuple[bool, bool, bool], pdims: Tuple[int, int], global_shape: Tuple[int, int, int]) -> Array: """ @@ -198,8 +185,6 @@ def per_shard_impl(x: Array, halo_extents: Tuple[int, int, int], Extents of the halo in x, y, and z dimensions. halo_periods : Tuple[bool, bool, bool] Periodicity of the halo in x, y, and z dimensions. - reduce_halo : bool - Flag indicating whether to reduce the halo. pdims : Tuple[int, int] Processor dimensions. global_shape : Tuple[int, int, int] @@ -214,40 +199,15 @@ def per_shard_impl(x: Array, halo_extents: Tuple[int, int, int], x, halo_extents=halo_extents, halo_periods=halo_periods, - reduce_halo=reduce_halo, pdims=pdims, global_shape=global_shape, ) - if reduce_halo: - halo_x, halo_y, halo_z = halo_extents - - # Apply corrections along x - if halo_x > 0: - output = output.at[halo_x:halo_x + halo_x // 2].add(output[:halo_x // - 2]) - output = output.at[-(halo_x + halo_x // 2):-halo_x].add( - output[-halo_x // 2:]) - # Apply corrections along y - if halo_y > 0: - output = output.at[:, halo_y:halo_y + halo_y // 2].add( - output[:, :halo_y // 2]) - output = output.at[:, -(halo_y + halo_y // 2):-halo_y].add( - output[:, -halo_y // 2:]) - # Apply corrections along z - if halo_z > 0: - output = output.at[:, :, halo_z:halo_z + halo_z // 2].add( - output[:, :, :halo_z // 2]) - output = output.at[:, :, -(halo_z + halo_z // 2):-halo_z].add( - output[:, :, -halo_z // 2:]) - return output @staticmethod def infer_sharding_from_operands( - halo_extents: Tuple[int, int, - int], halo_periods: Tuple[bool, bool, - bool], reduce_halo: bool, + halo_extents: Tuple[int, int, int], halo_periods: Tuple[bool, bool, bool], mesh: NamedSharding, arg_infos: Tuple[ShapeDtypeStruct], result_infos: Tuple[ShapedArray]) -> NamedSharding: """ @@ -259,8 +219,6 @@ def infer_sharding_from_operands( Extents of the halo in x, y, and z dimensions. halo_periods : Tuple[bool, bool, bool] Periodicity of the halo in x, y, and z dimensions. - reduce_halo : bool - Flag indicating whether to reduce the halo. mesh : NamedSharding Mesh object for sharding. arg_shapes : Tuple[ir.ShapeDtypeStruct] @@ -278,9 +236,7 @@ def infer_sharding_from_operands( @staticmethod def partition( - halo_extents: Tuple[int, int, - int], halo_periods: Tuple[bool, bool, - bool], reduce_halo: bool, + halo_extents: Tuple[int, int, int], halo_periods: Tuple[bool, bool, bool], mesh: NamedSharding, arg_shapes: Tuple[ShapeDtypeStruct], result_shape: ShapeDtypeStruct) -> Tuple[NamedSharding, partial]: """ @@ -292,8 +248,6 @@ def partition( Extents of the halo in x, y, and z dimensions. halo_periods : Tuple[bool, bool, bool] Periodicity of the halo in x, y, and z dimensions. - reduce_halo : bool - Flag indicating whether to reduce the halo. mesh : NamedSharding Mesh object for sharding. arg_shapes : Tuple[ir.ShapeDtypeStruct] @@ -320,7 +274,6 @@ def partition( HaloPrimitive.per_shard_impl, halo_extents=halo_extents, halo_periods=halo_periods, - reduce_halo=reduce_halo, pdims=pdims, global_shape=shape_without_halo) @@ -331,8 +284,7 @@ def partition( def halo_p_lower(x: Array, halo_extents: Tuple[int, int, int], - halo_periods: Tuple[bool, bool, - bool], reduce_halo: bool) -> Primitive: + halo_periods: Tuple[bool, bool, bool]) -> Primitive: """ Lowering function for the halo exchange operation. @@ -344,8 +296,6 @@ def halo_p_lower(x: Array, halo_extents: Tuple[int, int, int], Extents of the halo in x, y, and z dimensions. halo_periods : Tuple[bool, bool, bool] Periodicity of the halo in x, y, and z dimensions. - reduce_halo : bool - Flag indicating whether to reduce the halo. Returns ------- @@ -356,16 +306,13 @@ def halo_p_lower(x: Array, halo_extents: Tuple[int, int, int], x, halo_extents=halo_extents, halo_periods=halo_periods, - reduce_halo=reduce_halo, ) # Custom Partitioning @partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3)) -def halo_exchange(x: Array, - halo_extents: Tuple[int, int, int], - halo_periods: Tuple[bool, bool, bool], - reduce_halo: bool = False) -> Array: +def halo_exchange(x: Array, halo_extents: Tuple[int, int, int], + halo_periods: Tuple[bool, bool, bool]) -> Array: """ Halo exchange operation with custom VJP. @@ -377,21 +324,18 @@ def halo_exchange(x: Array, Extents of the halo in x, y, and z dimensions. halo_periods : Tuple[bool, bool, bool] Periodicity of the halo in x, y, and z dimensions. - reduce_halo : bool, optional - Flag indicating whether to reduce the halo. Default is False. Returns ------- Array Output array after the halo exchange operation. """ - output, _ = _halo_fwd_rule(x, halo_extents, halo_periods, reduce_halo) + output, _ = _halo_fwd_rule(x, halo_extents, halo_periods) return output def _halo_fwd_rule(x: Array, halo_extents: Tuple[int, int, int], - halo_periods: Tuple[bool, bool, bool], - reduce_halo: bool) -> Tuple[Array, None]: + halo_periods: Tuple[bool, bool, bool]) -> Tuple[Array, None]: """ Forward rule for the halo exchange operation. @@ -403,20 +347,18 @@ def _halo_fwd_rule(x: Array, halo_extents: Tuple[int, int, int], Extents of the halo in x, y, and z dimensions. halo_periods : Tuple[bool, bool, bool] Periodicity of the halo in x, y, and z dimensions. - reduce_halo : bool - Flag indicating whether to reduce the halo. Returns ------- Tuple[Array, None] Output array after the halo exchange operation and None for no residuals. """ - return halo_p_lower(x, halo_extents, halo_periods, reduce_halo), None + return halo_p_lower(x, halo_extents, halo_periods), None def _halo_bwd_rule(halo_extents: Tuple[int, int, int], - halo_periods: Tuple[bool, bool, bool], reduce_halo: bool, - ctx, g: Array) -> Tuple[Array]: + halo_periods: Tuple[bool, bool, + bool], ctx, g: Array) -> Tuple[Array]: """ Backward rule for the halo exchange operation. @@ -426,8 +368,6 @@ def _halo_bwd_rule(halo_extents: Tuple[int, int, int], Extents of the halo in x, y, and z dimensions. halo_periods : Tuple[bool, bool, bool] Periodicity of the halo in x, y, and z dimensions. - reduce_halo : bool - Flag indicating whether to reduce the halo. ctx Context for the operation. g : Array @@ -438,7 +378,7 @@ def _halo_bwd_rule(halo_extents: Tuple[int, int, int], Tuple[Array] Gradient array after the halo exchange operation. """ - return halo_p_lower(g, halo_extents, halo_periods, reduce_halo), + return halo_p_lower(g, halo_extents, halo_periods), # Define VJP for custom halo_exchange operation diff --git a/tests/test_halo.py b/tests/test_halo.py index b9dd6db..c2ace74 100644 --- a/tests/test_halo.py +++ b/tests/test_halo.py @@ -134,13 +134,6 @@ def sharded_add_multiply(arr): updated_array, halo_extents=(halo_size, 0, 0), halo_periods=(False, False, False)) - exchanged_reduced_array = jaxdecomp.halo_exchange( - updated_array, - halo_extents=(halo_size, 0, 0), - halo_periods=(True, True, True), - reduce_halo=True) - unpadded_reduced_array = slice_unpad(exchanged_reduced_array, padding, - pdims) unpadded_updated_array = slice_unpad(updated_array, padding, pdims) # Gather array from all processes @@ -151,8 +144,6 @@ def sharded_add_multiply(arr): periodic_exchanged_array, tiled=True) updated_gathered_array = multihost_utils.process_allgather( updated_array, tiled=True) - gathered_exchanged_reduced_array = multihost_utils.process_allgather( - unpadded_reduced_array, tiled=True) gathered_unpadded_updated_array = multihost_utils.process_allgather( unpadded_updated_array, tiled=True) # Get the slices using array_split @@ -162,17 +153,14 @@ def sharded_add_multiply(arr): updated_gathered_array, size, axis=0) gathered_periodic_exchanged_slices = jnp.array_split( periodic_exchanged_gathered_array, size, axis=0) - gathered_reduced_exchange_slices = jnp.array_split( - gathered_exchanged_reduced_array, size, axis=0) gathered_unpadded_updated_slices = jnp.array_split( gathered_unpadded_updated_array, size, axis=0) gathered_arrays = zip(gathered_periodic_exchanged_slices, gathered_array_slices\ - , gathered_reduced_exchange_slices, gathered_updated_slices \ + , gathered_updated_slices \ , gathered_unpadded_updated_slices) - for slice_indx, (periodic_exchanged_slice, exchanged_slice, reduced_slice, - original_slice, + for slice_indx, (periodic_exchanged_slice, exchanged_slice, original_slice, unpadded_slice) in enumerate(gathered_arrays): next_indx = slice_indx + 1 @@ -208,22 +196,3 @@ def sharded_add_multiply(arr): periodic_exchanged_slice[:halo_size]) assert_array_equal(exchanged_slice[-halo_size:], periodic_exchanged_slice[-halo_size:]) - - # Lower center of the previous slice - previous_halo_extension = prev_slice[-2 * halo_size:-3 * (halo_size // 2)] - # Upper center of the next slice - next_halo_extension = next_slice[3 * (halo_size // 2):2 * halo_size] - # Upper and lower center of the reduced slice - - upper_halo_reduced = reduced_slice[:halo_size // 2] - lower_halo_reduced = reduced_slice[-(halo_size // 2):] - # Upper and lower center of the original slice (after update no exchange and halo reduction) - upper_halo_original = unpadded_slice[:halo_size // 2] - lower_halo_original = unpadded_slice[-(halo_size // 2):] - - # Upper slice should be equal to original upper slice + lower center of the previous slice - assert_array_equal(upper_halo_reduced, - (previous_halo_extension + upper_halo_original)) - # Lower slice should be equal to original lower slice + upper center of the next slice - assert_array_equal(lower_halo_reduced, - (next_halo_extension + lower_halo_original)) From 5e5484e74998332d17383c7774663c23a8122006 Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Sun, 7 Jul 2024 20:00:22 -0400 Subject: [PATCH 16/17] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 009e04e..733127c 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ with mesh: halo_extents=(32,32,32), halo_periods=(True,True,True)) # Remove the halo regions - recarray = jaxdecomp.slice_unpad(exchanged_reduced, padding_width, pdims) + recarray = jaxdecomp.slice_unpad(exchanged_array, padding_width, pdims) # Gather the results (only if it fits on CPU memory) gathered_array = multihost_utils.process_allgather(recarray, tiled=True) From ac8f2efb4b12e246baf2b3dfc90589c1eb319587 Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Sun, 7 Jul 2024 20:00:28 -0400 Subject: [PATCH 17/17] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 733127c..bbeb096 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,8 @@ with mesh: # Add halo regions to our array padding_width = ((32,32),(32,32),(32,32)) # Has to a tuple of tuples padded_array = jaxdecomp.slice_pad(recarray, padding_width , pdims) - # Perform a halo exchange + reduce - exchanged_reduced = jaxdecomp.halo_exchange(padded_array, + # Perform a halo exchange + exchanged_array = jaxdecomp.halo_exchange(padded_array, halo_extents=(32,32,32), halo_periods=(True,True,True)) # Remove the halo regions