Skip to content

Commit

Permalink
Merge pull request #20 from DifferentiableUniverseInitiative/joss-paper
Browse files Browse the repository at this point in the history
Joss paper of jaxDecomp
  • Loading branch information
ASKabalan authored Jul 18, 2024
2 parents 1be36aa + adb0d8b commit 067ee89
Show file tree
Hide file tree
Showing 12 changed files with 987 additions and 7 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/joss-paper-pdf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Draft PDF
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

jobs:
paper:
runs-on: ubuntu-latest
name: Paper Draft
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Build draft PDF
uses: openjournals/openjournals-draft-action@master
with:
journal: joss
# This should be the path to the paper within your repo.
paper-path: joss-paper/paper.md
- name: Upload
uses: actions/upload-artifact@v3
with:
name: paper
# This is the output path where Pandoc will write the compiled
# PDF. Note, this should be the same directory as the input
# paper.md
path: joss-paper/paper.pdf
7 changes: 0 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,3 @@ repos:
hooks:
- id: isort
name: isort (python)
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.4
hooks:
- id: clang-format
files: '\.(c|cc|cpp|h|hpp|cxx|hh|cu|cuh)$'
exclude: '^third_party/|/pybind11/'
name: clang-format
23 changes: 23 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Use-Case Examples

This directory contains examples of how to use the jaxDecomp library on a few use cases.

## Distributed LPT Cosmological Simulation

This example demonstrates the use of the 3D distributed FFT and halo exchange functions in the `jaxDecomp` library to implement a distributed LPT cosmological simulation. We provide a notebook to visualize the results of the simulation in [visualizer.ipynb](visualizer.ipynb).

To run the demo, some additional dependencies are required. You can install them by running:

```bash
pip install jax-cosmo
```

Then, you can run the example by executing the following command:
```bash
mpirun -n 4 python lpt_nbody_demo.py --nc 256 --box_size 256 --pdims 4x4 --halo_size 32 --output out
```

We also include an example of a slurm script in [submit_rusty.sbatch](submit_rusty.sbatch) that can be used to run the example on a slurm cluster with:
```bash
sbatch submit_rusty.sbatch
```
280 changes: 280 additions & 0 deletions examples/lpt_nbody_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
import argparse
import os
from functools import partial
from typing import Any, Callable, Hashable

Specs = Any
AxisName = Hashable

import jax

jax.config.update('jax_enable_x64', False)

import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from jax._src import mesh as mesh_lib
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from scatter import scatter

import jaxdecomp


def shmap(f: Callable,
in_specs: Specs,
out_specs: Specs,
check_rep: bool = True,
auto: frozenset[AxisName] = frozenset()):
"""Helper function to create a shard_map function that extracts the mesh from the
context."""
mesh = mesh_lib.thread_resources.env.physical_mesh
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)


def _global_to_local_size(nc: int):
""" Helper function to get the local size of a mesh given the global size.
"""
pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape
return [nc // pdims[0], nc // pdims[1], nc]


def fttk(nc: int) -> list:
"""
Generate Fourier transform wave numbers for a given mesh.
Args:
nc (int): Shape of the mesh grid.
Returns:
list: List of wave number arrays for each dimension in
the order [kx, ky, kz].
"""
kd = np.fft.fftfreq(nc) * 2 * np.pi

@partial(
shmap,
in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None)))
def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]),
kx.reshape([1, 1, -1])) # yapf: disable
ky, kz, kx = get_kvec(kd, kd, kd) # The order of the output
# corresponds to the order of dimensions in the transposed FFT
# output
return kx, ky, kz


def gravity_kernel(kx, ky, kz):
""" Computes a Fourier kernel combining laplace and derivative
operators to compute gravitational forces.
Args:
kvec (tuple of float): Wave numbers in Fourier space.
Returns:
tuple of jnp.ndarray: kernels for each dimension.
"""
kk = kx**2 + ky**2 + kz**2
laplace_kernel = jnp.where(kk == 0, 1., 1. / kk)

grav_kernel = (laplace_kernel * 1j * kx,
laplace_kernel * 1j * ky,
laplace_kernel * 1j * kz) # yapf: disable
return grav_kernel


def gaussian_field_and_forces(key, nc, box_size, power_spectrum):
"""
Generate a Gaussian field with a given power spectrum, along with gravitational forces.
Args:
key (int): Key for the random number generator.
nc (int): Number of cells in the mesh.
box_size (float): Size of the box.
power_spectrum (callable): Power spectrum function.
Returns:
tuple of jnp.ndarray: The generated Gaussian field and the gravitational forces.
"""
local_mesh_shape = _global_to_local_size(nc)

# Create a distributed field drawn from a Gaussian distribution in real space
delta = shmap(
partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'),
in_specs=P(None),
out_specs=P('x', 'y'))(key) # yapf: disable

# Compute the Fourier transform of the field
delta_k = jaxdecomp.fft.pfft3d(delta.astype(jnp.complex64))

# Compute the Fourier wavenumbers of the field
kx, ky, kz = fttk(nc)
kk = jnp.sqrt(kx**2 + ky**2 + kz**2) * (nc / box_size)

# Apply power spectrum to Fourier modes
delta_k *= (power_spectrum(kk) * (nc / box_size)**3)**0.5

# Compute inverse Fourier transform to recover the initial conditions in real space
delta = jaxdecomp.fft.pifft3d(delta_k).real

# Compute gravitational forces associated with this field
grav_kernel = gravity_kernel(kx, ky, kz)
forces_k = [g * delta_k for g in grav_kernel]

# Retrieve the forces in real space by inverse Fourier transforming
forces = jnp.stack([jaxdecomp.fft.pifft3d(f).real for f in forces_k], axis=-1)

return delta, forces


def cic_paint(displacement, halo_size):
""" Paints particles on a mesh using Cloud-In-Cell interpolation.
Args:
displacement (jnp.ndarray): Displacement of each particle.
halo_size (int): Halo size for painting.
Returns:
jnp.ndarray: Density field.
"""
local_mesh_shape = _global_to_local_size(displacement.shape[0])
hs = halo_size

@partial(shmap, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
def cic_op(disp):
""" CiC operation on each local slice of the mesh."""
# Create a mesh to paint the particles on for the local slice
mesh = jnp.zeros(disp.shape[:-1], dtype='float32')

# Padding the mesh along the two first dimensions
mesh = jnp.pad(mesh, [[hs, hs], [hs, hs], [0, 0]])

# Compute the position of the particles on a regular grid
pos_x, pos_y, pos_z = jnp.meshgrid(
jnp.arange(local_mesh_shape[0]),
jnp.arange(local_mesh_shape[1]),
jnp.arange(local_mesh_shape[2]),
indexing='ij')

# adding an offset of size halo size
pos = jnp.stack([pos_x + hs, pos_y + hs, pos_z], axis=-1)

# Apply scatter operation to paint the particles on the local mesh
field = scatter(pos.reshape([-1, 3]), disp.reshape([-1, 3]), mesh)

return field

# Performs painting on a padded mesh, with halos on the two first dimensions
field = cic_op(displacement)

# Run halo exchange to get the correct values at the boundaries
field = jaxdecomp.halo_exchange(
field,
halo_extents=(hs // 2, hs // 2, 0),
halo_periods=(True, True, True))

@partial(shmap, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
def unpad(x):
""" Removes the padding and reduce the halo regions"""
x = x.at[hs:hs + hs // 2].add(x[:hs // 2])
x = x.at[-(hs + hs // 2):-hs].add(x[-hs // 2:])
x = x.at[:, hs:hs + hs // 2].add(x[:, :hs // 2])
x = x.at[:, -(hs + hs // 2):-hs].add(x[:, -hs // 2:])
return x[hs:-hs, hs:-hs, :]

# Unpad the output array
field = unpad(field)
return field


@partial(jax.jit, static_argnames=('nc', 'box_size', 'halo_size'))
def simulation_fn(key, nc, box_size, halo_size, a=1.0):
"""
Run a simulation to generate initial conditions and density field using LPT.
Args:
key (list of int): Jax random key for the random number generator.
nc (int): Size of the mesh grid.
box_size (float): Size of the box.
halo_size (int): Halo size for painting.
a (float): Scale factor of final field.
Returns:
tuple of jnp.ndarray: Initial conditions and final density field.
"""
# Build a default cosmology
cosmo = jc.Planck15()

# Create a small function to generate the linear matter power spectrum at arbitrary k
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(cosmo, k)
pk_fn = jax.jit(lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).
reshape(x.shape))

# Generate a Gaussian field and gravitational forces from a power spectrum
intial_conditions, initial_forces = gaussian_field_and_forces(
key=key, nc=nc, box_size=box_size, power_spectrum=pk_fn)

# Compute the LPT displacement that particles initialy placed on a regular grid
# would experience at scale factor a, by simple Zeldovich approximation
initial_displacement = jc.background.growth_factor(
cosmo, jnp.atleast_1d(a)) * initial_forces

# Paints the displaced particles on a mesh to obtain the density field
final_field = cic_paint(initial_displacement, halo_size)

return intial_conditions, final_field


def main(args):
print(f"Running with arguments {args}")

# Setting up distributed jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()

# Setting up distributed random numbers
master_key = jax.random.PRNGKey(42)
key = jax.random.split(master_key, size)[rank]

# Create computing mesh and sharding information
pdims = tuple(map(int, args.pdims.split('x')))
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))

# Run the simulation on the compute mesh
with mesh:
initial_conds, final_field = simulation_fn(
key=key, nc=args.nc, box_size=args.box_size, halo_size=args.halo_size)

# Create output directory to save the results
output_dir = args.output
os.makedirs(output_dir, exist_ok=True)
np.save(f'{output_dir}/initial_conditions_{rank}.npy',
initial_conds.addressable_data(0))
np.save(f'{output_dir}/field_{rank}.npy', final_field.addressable_data(0))
print(f"Finished saved to {output_dir}")

# Closing distributed jax
jax.distributed.shutdown()


if __name__ == '__main__':
parser = argparse.ArgumentParser("Distributed LPT N-body simulation.")
parser.add_argument(
'--pdims', type=str, default='1x1', help="Processor grid dimensions")
parser.add_argument(
'--nc', type=int, default=256, help="Number of cells in the mesh")
parser.add_argument(
'--box_size', type=float, default=512., help="Box size in Mpc/h")
parser.add_argument(
'--halo_size', type=int, default=32, help="Halo size for painting")
parser.add_argument('--output', type=str, default='out')
args = parser.parse_args()

main(args)
Loading

0 comments on commit 067ee89

Please sign in to comment.