Skip to content

Commit

Permalink
Updating readme and bumping version (#62)
Browse files Browse the repository at this point in the history
* update readme

* bump to v0.2.5

* format with isort and ruff
  • Loading branch information
ASKabalan authored Feb 25, 2025
1 parent 12fe88a commit 5fb1779
Show file tree
Hide file tree
Showing 23 changed files with 508 additions and 495 deletions.
20 changes: 13 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
> **Important**
> Version `0.2.0` includes a **pure JAX backend** that **no longer requires MPI**. For multi-node runs, MPI and NCCL backends are still available through **cuDecomp**.
`jaxDecomp` provides JAX bindings for NVIDIA's [cuDecomp](https://nvidia.github.io/cuDecomp/index.html) library [(Romero et al. 2022)](https://dl.acm.org/doi/abs/10.1145/3539781.3539797), enabling **multi-node parallel FFTs and halo exchanges** directly in low-level NCCL/CUDA-Aware MPI from your JAX code.
JAX reimplementation and bindings for NVIDIA's [cuDecomp](https://nvidia.github.io/cuDecomp/index.html) library [(Romero et al. 2022)](https://dl.acm.org/doi/abs/10.1145/3539781.3539797), enabling **multi-node parallel FFTs and halo exchanges** directly in low-level NCCL/CUDA-Aware MPI from your JAX code.

---

Expand Down Expand Up @@ -43,12 +43,16 @@ rec_array = jaxdecomp.fft.pifft3d(a)
exchanged = jaxdecomp.halo_exchange(a, halo_extents=(16, 16), halo_periods=(True, True))
```

All these functions are **JIT**-compatible and support **automatic differentiation** (with [some caveats](docs/02_caveats.md)).
All these functions are **JIT**-compatible and support **automatic differentiation** (with [some caveats](docs/02-caveats.md)).

See also:
- [Basic Usage](docs/01-basic_usage.md)
- [Distributed LPT Example](examples/lpt_nbody_demo.py)

> **Important**
> Multi-node FFTs work with both JAX and cuDecomp backends\
> For CPU with JAX, Multi-node is supported starting JAX v0.5.1 (with `gloo` backend)
---

## Running on an HPC Cluster
Expand All @@ -64,11 +68,12 @@ mpirun -n 8 python demo.py

See the Slurm [README](slurms/README.md) and [template script](slurms/template.slurm) for more details.


---

## Using cuDecomp (MPI and NCCL)

For **multi-node** or advanced features, compile and install with cuDecomp enabled:
For other features, compile and install with cuDecomp enabled as described in [install](#2-jax--cudecomp-backend-advanced):

```python
import jaxdecomp
Expand Down Expand Up @@ -109,7 +114,7 @@ This setup uses the pure-JAX backend—**no** MPI required.

### 2. JAX + cuDecomp Backend (Advanced)

If you need **multi-node** support, you can build from GitHub with cuDecomp enabled. This requires the [NVIDIA HPC SDK](https://developer.nvidia.com/hpc-sdk) or a similar environment providing a CUDA-aware MPI toolchain.
If you need to use `MPI` instead of `NCCL` for `GPU` or gloo for CPU, you can build from GitHub with cuDecomp enabled. This requires the [NVIDIA HPC SDK](https://developer.nvidia.com/hpc-sdk) or a similar environment providing a CUDA-aware MPI toolchain.

```bash
pip install -U pip
Expand All @@ -126,9 +131,10 @@ pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -C

## Machine-Specific Notes

### IDRIS Jean Zay (HPE SGI 8600)
### IDRIS [Jean Zay](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-hw-eng.html) HPE SGI 8600 supercomputer


As of October 2024, loading modules **in this exact order** works:
As of February 2025, loading modules **in this exact order** works:

```bash
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
Expand All @@ -143,7 +149,7 @@ pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -C

**Note**: If using only the pure-JAX backend, you do not need NVHPC.

### NERSC Perlmutter (HPE Cray EX)
#### NERSC [Perlmutter](https://docs.nersc.gov/systems/perlmutter/architecture/) HPE Cray EX supercomputer

As of November 2022:

Expand Down
44 changes: 22 additions & 22 deletions examples/lpt_nbody_demo.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import argparse
import os
from collections.abc import Hashable
from functools import partial
from typing import Any, Callable
from collections.abc import Hashable

Specs = Any
AxisName = Hashable

import jax

jax.config.update("jax_enable_x64", False)
jax.config.update('jax_enable_x64', False)

import jax.numpy as jnp
import jax_cosmo as jc
Expand Down Expand Up @@ -58,8 +58,8 @@ def fttk(nc: int):

@partial(
shmap,
in_specs=(P("x"), P("y"), P(None)),
out_specs=(P("x"), P(None, "y"), P(None)),
in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None)),
)
def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]),
Expand Down Expand Up @@ -148,11 +148,11 @@ def cic_paint(displacement, halo_size):
local_mesh_shape = _global_to_local_size(displacement.shape[0])
hs = halo_size

@partial(shmap, in_specs=(P("x", "y"),), out_specs=P("x", "y"))
@partial(shmap, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
def cic_op(disp):
"""CiC operation on each local slice of the mesh."""
# Create a mesh to paint the particles on for the local slice
mesh = jnp.zeros(disp.shape[:-1], dtype="float32")
mesh = jnp.zeros(disp.shape[:-1], dtype='float32')

# Padding the mesh along the two first dimensions
mesh = jnp.pad(mesh, [[hs, hs], [hs, hs], [0, 0]])
Expand All @@ -162,7 +162,7 @@ def cic_op(disp):
jnp.arange(local_mesh_shape[0]),
jnp.arange(local_mesh_shape[1]),
jnp.arange(local_mesh_shape[2]),
indexing="ij",
indexing='ij',
)

# adding an offset of size halo size
Expand All @@ -179,7 +179,7 @@ def cic_op(disp):
# Run halo exchange to get the correct values at the boundaries
field = jaxdecomp.halo_exchange(field, halo_extents=(hs // 2, hs // 2, 0), halo_periods=(True, True, True))

@partial(shmap, in_specs=(P("x", "y"),), out_specs=P("x", "y"))
@partial(shmap, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
def unpad(x):
"""Removes the padding and reduce the halo regions"""
x = x.at[hs : hs + hs // 2].add(x[: hs // 2])
Expand All @@ -193,7 +193,7 @@ def unpad(x):
return field


@partial(jax.jit, static_argnames=("nc", "box_size", "halo_size"))
@partial(jax.jit, static_argnames=('nc', 'box_size', 'halo_size'))
def simulation_fn(key, nc, box_size, halo_size, a=1.0):
"""
Run a simulation to generate initial conditions and density field using LPT.
Expand Down Expand Up @@ -230,7 +230,7 @@ def simulation_fn(key, nc, box_size, halo_size, a=1.0):


def main(args):
print(f"Running with arguments {args}")
print(f'Running with arguments {args}')

# Setting up distributed jax
jax.distributed.initialize()
Expand All @@ -242,9 +242,9 @@ def main(args):
key = jax.random.split(master_key, size)[rank]

# Create computing mesh and sharding information
pdims = tuple(map(int, args.pdims.split("x")))
pdims = tuple(map(int, args.pdims.split('x')))
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=("x", "y"))
mesh = Mesh(devices.T, axis_names=('x', 'y'))

# Run the simulation on the compute mesh
with mesh:
Expand All @@ -253,21 +253,21 @@ def main(args):
# Create output directory to save the results
output_dir = args.output
os.makedirs(output_dir, exist_ok=True)
np.save(f"{output_dir}/initial_conditions_{rank}.npy", initial_conds.addressable_data(0))
np.save(f"{output_dir}/field_{rank}.npy", final_field.addressable_data(0))
print(f"Finished saved to {output_dir}")
np.save(f'{output_dir}/initial_conditions_{rank}.npy', initial_conds.addressable_data(0))
np.save(f'{output_dir}/field_{rank}.npy', final_field.addressable_data(0))
print(f'Finished saved to {output_dir}')

# Closing distributed jax
jax.distributed.shutdown()


if __name__ == "__main__":
parser = argparse.ArgumentParser("Distributed LPT N-body simulation.")
parser.add_argument("--pdims", type=str, default="1x1", help="Processor grid dimensions")
parser.add_argument("--nc", type=int, default=256, help="Number of cells in the mesh")
parser.add_argument("--box_size", type=float, default=512.0, help="Box size in Mpc/h")
parser.add_argument("--halo_size", type=int, default=32, help="Halo size for painting")
parser.add_argument("--output", type=str, default="out")
if __name__ == '__main__':
parser = argparse.ArgumentParser('Distributed LPT N-body simulation.')
parser.add_argument('--pdims', type=str, default='1x1', help='Processor grid dimensions')
parser.add_argument('--nc', type=int, default=256, help='Number of cells in the mesh')
parser.add_argument('--box_size', type=float, default=512.0, help='Box size in Mpc/h')
parser.add_argument('--halo_size', type=int, default=32, help='Halo size for painting')
parser.add_argument('--output', type=str, default='out')
args = parser.parse_args()

main(args)
34 changes: 17 additions & 17 deletions examples/visualizer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt"
"import matplotlib.pyplot as plt\n",
"import numpy as np"
]
},
{
Expand Down Expand Up @@ -70,7 +70,7 @@
}
],
"source": [
"folder = \"../out\"\n",
"folder = '../out'\n",
"pdims = (4, 4)\n",
"\n",
"init_field_slices = []\n",
Expand All @@ -82,8 +82,8 @@
"\n",
" for j in range(pdims[1]):\n",
" slice_index = i * pdims[1] + j\n",
" row_field.append(np.load(f\"{folder}/field_{slice_index}.npy\"))\n",
" row_init_field.append(np.load(f\"{folder}/initial_conditions_{slice_index}.npy\"))\n",
" row_field.append(np.load(f'{folder}/field_{slice_index}.npy'))\n",
" row_init_field.append(np.load(f'{folder}/initial_conditions_{slice_index}.npy'))\n",
"\n",
" field_slices.append(np.vstack(row_field))\n",
" init_field_slices.append(np.vstack(row_init_field))\n",
Expand Down Expand Up @@ -156,22 +156,22 @@
" # Plot initial conditions\n",
" axes[0].imshow(\n",
" initial_conditions[slicing].sum(axis=proj_axis),\n",
" cmap=\"magma\",\n",
" cmap='magma',\n",
" extent=[0, box_size + 5, 0, box_size + 5],\n",
" )\n",
" axes[0].set_xlabel(\"Mpc/h\")\n",
" axes[0].set_ylabel(\"Mpc/h\")\n",
" axes[0].set_title(\"Initial conditions\")\n",
" axes[0].set_xlabel('Mpc/h')\n",
" axes[0].set_ylabel('Mpc/h')\n",
" axes[0].set_title('Initial conditions')\n",
"\n",
" # Plot LPT density field at z=0\n",
" axes[1].imshow(\n",
" field[slicing].sum(axis=proj_axis),\n",
" cmap=\"magma\",\n",
" cmap='magma',\n",
" extent=[0, box_size + 5, 0, box_size + 5],\n",
" )\n",
" axes[1].set_xlabel(\"Mpc/h\")\n",
" axes[1].set_ylabel(\"Mpc/h\")\n",
" axes[1].set_title(\"LPT density field at z=0\")\n",
" axes[1].set_xlabel('Mpc/h')\n",
" axes[1].set_ylabel('Mpc/h')\n",
" axes[1].set_title('LPT density field at z=0')\n",
"\n",
"\n",
"for i in range(3):\n",
Expand Down Expand Up @@ -211,12 +211,12 @@
"# Generate the plot\n",
"plt.imshow(\n",
" np.log10(field[:16].sum(axis=proj_axis) + 1),\n",
" cmap=\"magma\",\n",
" cmap='magma',\n",
" extent=[0, box_size, 0, box_size],\n",
")\n",
"plt.xlabel(\"Mpc/h\")\n",
"plt.ylabel(\"Mpc/h\")\n",
"plt.title(\"LPT density field at z=0\")\n",
"plt.xlabel('Mpc/h')\n",
"plt.ylabel('Mpc/h')\n",
"plt.title('LPT density field at z=0')\n",
"\n",
"# Display the plot\n",
"plt.show()"
Expand Down
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = ["scikit-build-core>=0.4.0", "pybind11>=2.9.0"]
build-backend = "scikit_build_core.build"
[project]
name = "jaxdecomp"
version = "0.2.4"
version = "0.2.5"
description = "JAX bindings for the cuDecomp library"
authors = [
{ name = "Wassim Kabalan" },
Expand Down Expand Up @@ -46,6 +46,7 @@ test-command = "pytest {project}/tests"

[tool.ruff]
line-length = 150
fix = true
src = ["src"]
exclude = ["third_party"]

Expand All @@ -59,7 +60,10 @@ select = [
'UP',
# flake8-debugger
'T10',
# isort
'I',
]

ignore = [
'E402', # module level import not at top of file
'E203',
Expand All @@ -69,3 +73,7 @@ ignore = [
'E722',
'UP037', # conflicts with jaxtyping Array annotations
]


[tool.ruff.format]
quote-style = 'single'
8 changes: 4 additions & 4 deletions scripts/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
tuned_config = jaxdecomp.get_autotuned_config(config, False, False, True, True, (32, 32, 32), (True, True, True))

if rank == 0:
print(rank, "*** Results of optimization ***")
print(rank, "pdims", tuned_config.pdims)
print(rank, "halo_comm_backend", tuned_config.halo_comm_backend)
print(rank, "transpose_comm_backend", tuned_config.transpose_comm_backend)
print(rank, '*** Results of optimization ***')
print(rank, 'pdims', tuned_config.pdims)
print(rank, 'halo_comm_backend', tuned_config.halo_comm_backend)
print(rank, 'transpose_comm_backend', tuned_config.transpose_comm_backend)
8 changes: 4 additions & 4 deletions scripts/test_fft3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

# Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims[::-1])
mesh = Mesh(devices, axis_names=("z", "y"))
global_array = multihost_utils.host_local_array_to_global_array(array, mesh, P("z", "y"))
mesh = Mesh(devices, axis_names=('z', 'y'))
global_array = multihost_utils.host_local_array_to_global_array(array, mesh, P('z', 'y'))


@jax.jit
Expand All @@ -38,7 +38,7 @@ def do_fft(x):
before = time.time()
karray = do_fft(global_array).block_until_ready()
after = time.time()
print(rank, "took", after - before, "s")
print(rank, 'took', after - before, 's')

# And now, let's do the inverse FFT
rec_array = jaxdecomp.fft.pifft3d(karray)
Expand All @@ -47,7 +47,7 @@ def do_fft(x):

# Let's test if things are like we expect
if rank == 0:
print("maximum reconstruction difference", diff)
print('maximum reconstruction difference', diff)

jaxdecomp.finalize()
jax.distributed.shutdown()
Loading

0 comments on commit 5fb1779

Please sign in to comment.