Skip to content

Commit

Permalink
update paper
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Jul 18, 2024
1 parent dbac275 commit adb0d8b
Showing 1 changed file with 46 additions and 18 deletions.
64 changes: 46 additions & 18 deletions joss-paper/paper.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,44 +33,45 @@ bibliography: paper.bib

# Abstract

Cosmological simulations are key tools in understanding the distribution of galaxies and dark matter in the universe. Differentiable simulations provide access to gradients, significantly accelerating the inference process. Fast Particle Mesh (PM) simulations are excellent candidates due to their speed, simplicity, and inherent differentiability. However, as we enter the exascale era, simulation sizes are surpassing the maximum available memory, even for high-end HPC GPUs. Therefore, a multi-node distributed Particle Mesh simulation is necessary to simulate large cosmological volumes effectively.

The only step requiring communication in fast PM simulations is the fast Fourier transform (FFT). There are several implementations of distributed FFTs from the computer science community, such as [@2DECOMP&FFT], which allows distributed FFTs on CPUs, and the GPU implementation [@cuDecomp], which uses the NVIDIA Collective Communication Library (NCCL) for communication. However, these libraries do not provide differentiable 3D FFTs, which are essential for gradient-based monte-carlo sampling techniques like Hamiltonian Monte Carlo (HMC) and variational inference.
JAX [@JAX] has seen widespread adoption in both machine learning and scientific computing due to its flexibility and performance, as demonstrated in projects like JAX-Cosmo [@JAXCOSMO]. However, its application in distributed high-performance computing (HPC) has been limited by the complex nature of inter-GPU communications required in HPC scientific software, which is more challenging compared to deep learning networks. Previous solutions, such as mpi4jax [@mpi4jax], provided support for single program multiple data (SPMD) operations but faced significant scaling limitations.

To address this, we introduce `jaxDecomp`, a `JAX` library based on `cuDecomp` that efficiently decomposes simulation data into 2D slices (pencils) to facilitate multi-node parallel and differentiable Fast Fourier Transforms (FFTs) and halo exchanges. It leverages the power of compiled code directly within `JAX` code. This library will enable the large-scale distribution of simulations on High Performance Computing (HPC) clusters and aims to integrate seamlessly with existing open-source simulation codes like `JaxPM` and [@pmwd].

Recently, JAX has made a major push towards simplified SPMD programming, with the unification of the JAX array API and the introduction of several powerful APIs, such as `pjit`, `shard_map`, and `custom_partitioning`. However, not all native JAX operations have specialized distribution strategies, and `pjitting` a program can lead to excessive communication overhead for some operations, particularly the 3D Fast Fourier Transform (FFT), which is one of the most critical and widely used algorithms in scientific computing. Distributed FFTs are essential for many simulation and solvers, especially in fields like cosmology and fluid dynamics, where large-scale data processing is required.

# Statement of Need
To address these limitations, we introduce jaxDecomp, a JAX library that wraps NVIDIA's cuDecomp domain decomposition library [@cuDecomp]. jaxDecomp provides JAX primitives with highly efficient CUDA implementations for key operations such as 3D FFTs and halo exchanges. By integrating seamlessly with JAX, jaxDecomp supports running on multiple GPUs and nodes, enabling large-scale, distributed scientific computations. Implemented as JAX primitives, jaxDecomp builds directly on top of the distributed Array strategy in JAX and is compatible with JAX transformations such as `jax.grad` and `jax.jit`, ensuring fast execution and differentiability with a pythonic, easy-to-use interface. Using cuDecomp, jaxDecomp can switch between NCCL and CUDA-Aware MPI for distributed array transpose operations, allowing it to best fit the specific HPC cluster configuration.

Particle mesh simulations are essential for cosmological data analysis, particularly in full field inference. They simulate the large-scale structure of the universe and generate the likelihood of the data given the cosmological parameters. Given the high dimensionality of these simulations, advanced sampling techniques such as Hamiltonian Monte Carlo (HMC) and variational inference are required to efficiently explore the parameter space. Differentiable simulations are crucial in this context as they provide access to gradients, significantly accelerating the inference process.
# Statement of Need

To maximize the potential of particle mesh simulations, it is crucial to use a very fine grid, achieving high resolution of small-scale structures and a power spectrum close to that of hydrodynamical simulations. However, this fine grid significantly increases memory consumption. For instance, for a grid size of $4096^3$, the force vector field is about 1.5 TB. Ideally, computing and storing the vector field at each step can easily require 5 to 10 TB of memory, which is not feasible on a single GPU, even with high-end data center GPUs like the H100. Distributing the simulation across multiple GPUs and nodes is necessary to overcome this limitation. The only step requiring communication in fast PM simulations is the FFT. Using a 3D decomposition library like `cuDecomp` and integrating its functionality with `JAX` will allow us to perform distributed and differentiable 3D FFTs, enabling the simulation of large cosmological volumes on HPC clusters.
For numerical simulations on HPC systems, having a distributed, easy-to-use, and differentiable FFT is critical for achieving peak performance and scalability. While it is technically feasible to implement distributed FFTs using native JAX, there are significant benefits to using jaxDecomp. Although the performance difference may be marginal, jaxDecomp offers several advantages that make it a valuable tool for HPC applications.

To perform these simulations efficiently on modern HPC clusters, distributed FFTs are required. `jaxDecomp` addresses this need by distributing FFTs across multiple GPUs and nodes, fully compatible with `JAX`. This capability not only facilitates high-performance simulation but also ensures that the FFT operations remain differentiable, crucial for incorporating gradient-based optimization techniques like backpropagation in machine learning frameworks integrated with `JAX`.
In scientific applications such as particle mesh (PM) simulations for cosmology, existing frameworks like [@FlowPM] a TensorFlow-mesh based simulations, although distributed, TensorFlow-mesh is no longer actively maintained, while frameworks like JAX based frameworks like [@pmwd] are limited to 512 volumes due to lack of distribution capabilities. These examples underscore the critical need for scalable and efficient solutions. jaxDecomp addresses this gap by enabling distributed and differentiable 3D FFTs within JAX, thereby facilitating the simulation of large cosmological volumes on HPC clusters effectively.

# Implementation

## Distributed FFTs
jaxDecomp utilizes JAX's Custom JAX primitive to wrap cuDecomp operations, enabling integration of CUDA code within the HLO graph via XLA's custom_call. Leveraging the recent custom_partitioning JAX API, partitioning information is embedded in the HLO graph. This approach maintains the state of cuDecomp, including the processor grid and allocated memory, transparently for the user.
Domain Decomposition

The implementation of `jaxDecomp` does a serie of 1D FFTs using `cuFFT` on the undistributed axis, followed a multi GPU transposition using `cuDecomp` on the newly transposed undistributed axis.\
Starting with a 2D decomposition, The X axis is not distributed, and the Y and Z axes are distributed across multiple GPUs. The 1D FFTs are performed on the X axis, and the transposition is done from a X pencil to a Y pencil. The transposed data is then distributed across the GPUs. The 1D FFTs are performed on the Y axis, and the same is carried out for the Z axis.
jaxDecomp supports domain decomposition strategies such as 1D and 2D (pencil) decompositions. In 1D decomposition, arrays are decomposed along a single axis, while in 2D decomposition, arrays are decomposed into pencils (slabs). This flexibility allows for efficient distribution of data across multiple GPUs while preserving locality.

And inverse FFTs goes the other way around, by running a 1D inverse FFT on the Z axis, then transposing the data from a Z pencil to a Y pencil, and running a 1D inverse FFT on the Y axis, and finally transposing the data from a Y pencil to an X pencil and running a 1D inverse FFT on the X axis.
# Distributed FFTs

![Distributed FFTs using jaxDecomp](assets/fft.svg){width=60%}
## Forward FFTs

At any point of the simulation, the data is distributed accross 2 dimensions, with the third dimension being undistributed. This allows us to store simulation data of any size on the GPUs, as long as the data fits in the combined memory of the GPUs.
For distributed FFTs, jaxDecomp performs a series of 1D FFTs using cuFFT along the undistributed axis of the data. Following this, a multi-GPU transposition using cuDecomp reorganizes the data across GPUs. For instance, in a 2D decomposition where the X-axis remains undistributed and the Y and Z axes are distributed across GPUs, FFTs are executed sequentially on the X, Y, and Z axes respectively.
Inverse FFTs

## Distributed Halo Exchange
Inverse FFTs follow a similar process but in reverse. Starting with the distributed axis (Z in the example above), jaxDecomp performs a 1D inverse FFT, transposes the data across GPUs from a Z-pencil to a Y-pencil, executes a 1D inverse FFT on the Y axis, and continues this process until all axes are processed.

In a particle mesh simulation, we use the 3DFFT to estimate the force field acting on the particles. The force field is then interpolated to the particles, and the particles are moved accordingly. The particles that are close to the boundary of the local domain need to be updated using the data from the neighboring domains. This is done using a halo exchange operation. Where we pad each slice of the simulation then we perform a halo exchange operation to update the particles that are close to the boundary of the local domain.
# Distributed Halo Exchange

![Distributed Halo Exchange using jaxDecomp](assets/halo-exchange.svg){width=60%}
In jaxDecomp, the distributed halo exchange mechanism efficiently facilitates boundary updates essential for scientific computing algorithms and simulations. This operation involves padding each slice of simulation data and executing a halo exchange to synchronize information across the edges of local domains distributed across GPUs. By exchanging data at the boundaries, jaxDecomp ensures seamless communication and consistency between adjacent domains, crucial for achieving accurate and reliable results in distributed simulations on HPC clusters.


# Example

In this example, we show how to perform a distributed 3D FFT using `jaxDecomp` and `JAX`. The code snippet below demonstrates the initialization of the distributed mesh, the creation of the initial distributed tensor, and the execution of a distributed 3D FFT using `jaxDecomp`.
# API description

In this description, we show how to perform a distributed 3D FFT using `jaxDecomp` and `JAX`. The code snippet below demonstrates the initialization of the distributed mesh, the creation of the initial distributed tensor, and the execution of a distributed 3D FFT using `jaxDecomp`.

```python
import jax
Expand Down Expand Up @@ -118,17 +119,44 @@ def do_halo_exchange(z):
z = jaxdecomp.slice_unpad(z , padding , pdims)
return k_array

def do_transpose(x_pencil):
y_pencil = jaxdecomp.transposeXtoY(x_pencil)
z_pencil = jaxdecomp.transposeYtoZ(y_pencil)
y_pencil = jaxdecomp.transposeZtoY(z_pencil)
x_pencil = jaxdecomp.transposeYtoX(y_pencil)
return x_pencil



with mesh:
z = do_fft(z)
z = do_halo_exchange(z)
z = do_transpose(z)

```

# Example

In the provided example, the code computes the gravitational potential using Fast Fourier Transforms (FFT) within a JAX-based environment. The code can run on multiple GPUs and nodes using `jaxDecomp` and `JAX` while being fully differentiable.


```python
def potential(delta):
delta_k = pfft3d(delta)
ky, kz, kx= [jnp.fft.fftfreq(s) * 2 * np.pi for s in delta.shape]
laplace_kernel = jnp.where(kk == 0, 1., 1. / -(kx**2 + ky**2 + kz**2))
potential_k = delta_k * laplace_kernel
return ipfft3d(potential_k)
```


A more detailed example of a LPT simulation can be found in the [jaxdecomp_lpt](../examples/jaxdecomp_lpt.py).


# Benchmark

### TO REDO (ADD BENCHMARKS VS DISTRIBUTED JAX)

We benchmarked the distributed FFTs using `jaxDecomp` on a V100s with 32GB of memory. We compared the performance of `jaxDecomp` with the base `JAX` implementation.\
At $2048^3$ resolution, the base `JAX` implementation could not fit the data on a single GPU, while `jaxDecomp` could fit the data on 4 GPUs.

Expand Down

0 comments on commit adb0d8b

Please sign in to comment.