diff --git a/jaxdecomp/_src/padding.py b/jaxdecomp/_src/padding.py index 7c658a6..473e799 100644 --- a/jaxdecomp/_src/padding.py +++ b/jaxdecomp/_src/padding.py @@ -4,12 +4,11 @@ from typing import Tuple import jax.numpy as jnp -from jax import jit, lax +from jax import jit from jax._src.api import ShapeDtypeStruct from jax._src.core import ShapedArray from jax._src.typing import Array, ArrayLike from jax.experimental.custom_partitioning import custom_partitioning -from jax.lax import dynamic_slice from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P @@ -190,16 +189,11 @@ def per_shard_impl(arr: ArrayLike, padding_width: int | tuple[int]) -> Array: first_x, last_x = unpadding_width[0] first_y, last_y = unpadding_width[1] first_z, last_z = unpadding_width[2] - # Interior padding is padding between each row .. not needed here - interiour_padding = 0 - - # unlike jnp.pad lax.pad can unpad if given negative values - return lax.pad( - arr, - padding_value=0.0, - padding_config=((-first_x, -last_x, interiour_padding), - (-first_y, -last_y, interiour_padding), - (-first_z, -last_z, interiour_padding))) + last_x = arr.shape[0] - last_x + last_y = arr.shape[1] - last_y + last_z = arr.shape[2] - last_z + + return arr[first_x:last_x, first_y:last_y, first_z:last_z] @staticmethod def infer_sharding_from_operands(padding_width: int | tuple[int], diff --git a/tests/test_padding.py b/tests/test_padding.py index cbf4477..0f4917f 100644 --- a/tests/test_padding.py +++ b/tests/test_padding.py @@ -26,19 +26,35 @@ # Helper function to create a 3D array and remap it to the global array -def create_spmd_array(global_shape, pdims): +def create_spmd_array(global_shape, pdims, complex=False): assert (len(global_shape) == 3) assert (len(pdims) == 2) assert (prod(pdims) == size ), "The product of pdims must be equal to the number of MPI processes" - local_array = jax.random.normal( - shape=[ - global_shape[0] // pdims[1], global_shape[1] // pdims[0], - global_shape[2] - ], - key=jax.random.PRNGKey(rank)) + if complex: + + local_array = jax.random.normal( + shape=[ + global_shape[0] // pdims[1], global_shape[1] // pdims[0], + global_shape[2] + ], + key=jax.random.PRNGKey(rank)) + 1j * jax.random.normal( + shape=[ + global_shape[0] // pdims[1], global_shape[1] // pdims[0], + global_shape[2] + ], + key=jax.random.PRNGKey(rank + 1)) + + else: + + local_array = jax.random.normal( + shape=[ + global_shape[0] // pdims[1], global_shape[1] // pdims[0], + global_shape[2] + ], + key=jax.random.PRNGKey(rank)) # Remap to the global array from the local slice devices = mesh_utils.create_device_mesh(pdims[::-1]) mesh = Mesh(devices, axis_names=('z', 'y')) @@ -50,19 +66,18 @@ def create_spmd_array(global_shape, pdims): pencil_1 = (size // 2, size // (size // 2)) pencil_2 = (size // (size // 2), size // 2) +decomp = [(size, 1), (1, size), pencil_1, pencil_2] +global_shapes = [(32, 32, 32), (29 * size, 19 * size, 17 * size)] -@pytest.mark.parametrize( - "pdims", - [(1, size), - (size, 1), pencil_1, pencil_2]) # Test with Slab and Pencil decompositions -def test_padding(pdims): +@pytest.mark.parametrize("pdims", + decomp) # Test with Slab and Pencil decompositions +@pytest.mark.parametrize("global_shape", + global_shapes) # Test cubes, non-cubes and primes +def test_padding(pdims, global_shape): print("*" * 80) - print(f"Testing with pdims {pdims}") - - global_shape = (29 * size, 19 * size, 17 * size - ) # These sizes are prime numbers x size of the pmesh + print(f"Testing with pdims {pdims} and global_shape {global_shape}") global_array, mesh = create_spmd_array(global_shape, pdims) @@ -164,6 +179,32 @@ def sharded_unpad(arr): assert_array_equal(gathered_original, gathered_unpadded) +@pytest.mark.parametrize("pdims", + decomp) # Test with Slab and Pencil decompositions +@pytest.mark.parametrize("global_shape", + global_shapes) # Test cubes, non-cubes and primes +def test_complex_unpad(pdims, global_shape): + + print("*" * 80) + print(f"Testing with pdims {pdims} and global_shape {global_shape}") + + global_array, mesh = create_spmd_array(global_shape, pdims, complex=True) + + padding = ((32, 32), (32, 32), (0, 0)) + + with mesh: + jaxdecomp_padded = slice_pad(global_array, padding, pdims) + jaxdecomp_unpadded = slice_unpad(jaxdecomp_padded, padding, pdims) + + gathered_original = multihost_utils.process_allgather( + global_array, tiled=True) + gathered_unpadded = multihost_utils.process_allgather( + jaxdecomp_unpadded, tiled=True) + + # Make sure the unpadded arrays is equal to the original array + assert_array_equal(gathered_original, gathered_unpadded) + + def test_end(): # fake test to finalize the MPI processes jaxdecomp.finalize()