Skip to content

Commit

Permalink
Merge pull request #18 from DifferentiableUniverseInitiative/fix_slic…
Browse files Browse the repository at this point in the history
…e_unpad

Fix slice unpad
  • Loading branch information
ASKabalan authored Jul 4, 2024
2 parents 80172aa + f388649 commit b3978e4
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 28 deletions.
18 changes: 6 additions & 12 deletions jaxdecomp/_src/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
from typing import Tuple

import jax.numpy as jnp
from jax import jit, lax
from jax import jit
from jax._src.api import ShapeDtypeStruct
from jax._src.core import ShapedArray
from jax._src.typing import Array, ArrayLike
from jax.experimental.custom_partitioning import custom_partitioning
from jax.lax import dynamic_slice
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P

Expand Down Expand Up @@ -190,16 +189,11 @@ def per_shard_impl(arr: ArrayLike, padding_width: int | tuple[int]) -> Array:
first_x, last_x = unpadding_width[0]
first_y, last_y = unpadding_width[1]
first_z, last_z = unpadding_width[2]
# Interior padding is padding between each row .. not needed here
interiour_padding = 0

# unlike jnp.pad lax.pad can unpad if given negative values
return lax.pad(
arr,
padding_value=0.0,
padding_config=((-first_x, -last_x, interiour_padding),
(-first_y, -last_y, interiour_padding),
(-first_z, -last_z, interiour_padding)))
last_x = arr.shape[0] - last_x
last_y = arr.shape[1] - last_y
last_z = arr.shape[2] - last_z

return arr[first_x:last_x, first_y:last_y, first_z:last_z]

@staticmethod
def infer_sharding_from_operands(padding_width: int | tuple[int],
Expand Down
73 changes: 57 additions & 16 deletions tests/test_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,35 @@


# Helper function to create a 3D array and remap it to the global array
def create_spmd_array(global_shape, pdims):
def create_spmd_array(global_shape, pdims, complex=False):

assert (len(global_shape) == 3)
assert (len(pdims) == 2)
assert (prod(pdims) == size
), "The product of pdims must be equal to the number of MPI processes"

local_array = jax.random.normal(
shape=[
global_shape[0] // pdims[1], global_shape[1] // pdims[0],
global_shape[2]
],
key=jax.random.PRNGKey(rank))
if complex:

local_array = jax.random.normal(
shape=[
global_shape[0] // pdims[1], global_shape[1] // pdims[0],
global_shape[2]
],
key=jax.random.PRNGKey(rank)) + 1j * jax.random.normal(
shape=[
global_shape[0] // pdims[1], global_shape[1] // pdims[0],
global_shape[2]
],
key=jax.random.PRNGKey(rank + 1))

else:

local_array = jax.random.normal(
shape=[
global_shape[0] // pdims[1], global_shape[1] // pdims[0],
global_shape[2]
],
key=jax.random.PRNGKey(rank))
# Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims[::-1])
mesh = Mesh(devices, axis_names=('z', 'y'))
Expand All @@ -50,19 +66,18 @@ def create_spmd_array(global_shape, pdims):

pencil_1 = (size // 2, size // (size // 2))
pencil_2 = (size // (size // 2), size // 2)
decomp = [(size, 1), (1, size), pencil_1, pencil_2]
global_shapes = [(32, 32, 32), (29 * size, 19 * size, 17 * size)]


@pytest.mark.parametrize(
"pdims",
[(1, size),
(size, 1), pencil_1, pencil_2]) # Test with Slab and Pencil decompositions
def test_padding(pdims):
@pytest.mark.parametrize("pdims",
decomp) # Test with Slab and Pencil decompositions
@pytest.mark.parametrize("global_shape",
global_shapes) # Test cubes, non-cubes and primes
def test_padding(pdims, global_shape):

print("*" * 80)
print(f"Testing with pdims {pdims}")

global_shape = (29 * size, 19 * size, 17 * size
) # These sizes are prime numbers x size of the pmesh
print(f"Testing with pdims {pdims} and global_shape {global_shape}")

global_array, mesh = create_spmd_array(global_shape, pdims)

Expand Down Expand Up @@ -164,6 +179,32 @@ def sharded_unpad(arr):
assert_array_equal(gathered_original, gathered_unpadded)


@pytest.mark.parametrize("pdims",
decomp) # Test with Slab and Pencil decompositions
@pytest.mark.parametrize("global_shape",
global_shapes) # Test cubes, non-cubes and primes
def test_complex_unpad(pdims, global_shape):

print("*" * 80)
print(f"Testing with pdims {pdims} and global_shape {global_shape}")

global_array, mesh = create_spmd_array(global_shape, pdims, complex=True)

padding = ((32, 32), (32, 32), (0, 0))

with mesh:
jaxdecomp_padded = slice_pad(global_array, padding, pdims)
jaxdecomp_unpadded = slice_unpad(jaxdecomp_padded, padding, pdims)

gathered_original = multihost_utils.process_allgather(
global_array, tiled=True)
gathered_unpadded = multihost_utils.process_allgather(
jaxdecomp_unpadded, tiled=True)

# Make sure the unpadded arrays is equal to the original array
assert_array_equal(gathered_original, gathered_unpadded)


def test_end():
# fake test to finalize the MPI processes
jaxdecomp.finalize()
Expand Down

0 comments on commit b3978e4

Please sign in to comment.