diff --git a/.github/workflows/joss-paper-pdf.yml b/.github/workflows/joss-paper-pdf.yml index d0af894..4d72122 100644 --- a/.github/workflows/joss-paper-pdf.yml +++ b/.github/workflows/joss-paper-pdf.yml @@ -1,11 +1,9 @@ name: Draft PDF on: push: - paths: - - paper.md - - paper.bib - - assets/* - - .github/workflows/draft-pdf.yml + branches: [ "main" ] + pull_request: + branches: [ "main" ] jobs: paper: diff --git a/joss-paper/paper.bib b/joss-paper/paper.bib index f1c84cd..69f4ece 100644 --- a/joss-paper/paper.bib +++ b/joss-paper/paper.bib @@ -49,3 +49,11 @@ @misc{pmwd primaryClass={astro-ph.IM}, url={https://arxiv.org/abs/2211.09958}, } + +@software{JAX, + author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, + title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, + url = {http://github.com/google/jax}, + version = {0.3.13}, + year = {2018}, +} diff --git a/joss-paper/paper.md b/joss-paper/paper.md index a888357..ad783da 100644 --- a/joss-paper/paper.md +++ b/joss-paper/paper.md @@ -1,10 +1,10 @@ --- title: 'jaxDecomp : JAX Library for 3D Domain Decomposition and Parallel FFTs' tags: + - Jax + - Cuda - Python - Hpc - - Cuda - - Jax - FFT - Simulations authors: @@ -22,6 +22,8 @@ affiliations: index: 1 - name: Université Paris-Saclay, Université Paris Cité, CEA, CNRS, AIM, 91191, Gif-sur-Yvette, France index: 2 + - name: Université Paris-Saclay, Université Paris Cité, CEA, CNRS, AIM, 91191, Gif-sur-Yvette, France + index: 3 date: 26 June 2024 bibliography: paper.bib @@ -30,55 +32,92 @@ bibliography: paper.bib # Summary -Cosmological simulations are a key tool to help us understand the distribution of galaxies and dark matter in the universe. Differentiable simulations give access to the gradients which significantly accelerate the inference process. Fast Particle Mesh (PM) simulations are a very good candidate due to their speed and simplicity and thus differentiability. However, entering the exascale era, simulation sizes are surpassing the maximum available memory even for the high-end HPC GPUs. For that, a multi-node distributed Particle Mesh simulation is needed to be truly able to simulate the large cosmological volumes. The only step that requires communications in the fast PM simulations is the fast Fourier transform (FFT). There are a few implementations of distributed FFTs coming from the computer science community, like [@2DECOMP&FFT] that allows distributed FFTs on CPUs and the GPU version [@cuDecomp] that uses NVIDIA Collective Communication Library (NCCL) for the communication. However, these libraries are not integrated with the differentiable simulation libraries like JAX. To address this, we introduce `jaxDecomp`, a `JAX` library based on `cuDecomp` that efficiently decomposes the simulation data into 2D slices (pencils) to facilitate multi-node parallel Fast Fourier Transforms (FFTs) and halo exchanges, leveraging the power of compiled code directly within `JAX` code. This library will enable the large-scale distribution of simulations on High Performance Computing (HPC) clusters and will seamlessly integrate with existing open-source simulation codes like `JaxPM` or [@pmwd]. +Cosmological simulations are key tools in understanding the distribution of galaxies and dark matter in the universe. Differentiable simulations provide access to gradients, significantly accelerating the inference process. Fast Particle Mesh (PM) simulations are excellent candidates due to their speed, simplicity, and inherent differentiability. However, as we enter the exascale era, simulation sizes are surpassing the maximum available memory, even for high-end HPC GPUs. Therefore, a multi-node distributed Particle Mesh simulation is necessary to simulate large cosmological volumes effectively. + +The only step requiring communication in fast PM simulations is the fast Fourier transform (FFT). There are several implementations of distributed FFTs from the computer science community, such as [@2DECOMP&FFT], which allows distributed FFTs on CPUs, and the GPU implementation [@cuDecomp], which uses the NVIDIA Collective Communication Library (NCCL) for communication. However, these libraries do not provide differentiable 3D FFTs, which are essential for gradient-based monte-carlo sampling techniques like Hamiltonian Monte Carlo (HMC) and variational inference. + +To address this, we introduce `jaxDecomp`, a `JAX` library based on `cuDecomp` that efficiently decomposes simulation data into 2D slices (pencils) to facilitate multi-node parallel and differentiable Fast Fourier Transforms (FFTs) and halo exchanges. It leverages the power of compiled code directly within `JAX` code. This library will enable the large-scale distribution of simulations on High Performance Computing (HPC) clusters and aims to integrate seamlessly with existing open-source simulation codes like `JaxPM` and [@pmwd]. + # Statement of Need -Particle mesh simulations are essential for cosmological data analysis, particularly in full field inference. They simulate the large-scale structure of the universe and generate the likelihood of the data given the cosmological parameters. Differentiable simulations unlock advanced sampling techniques like Hamiltonian Monte Carlo (HMC) and variational inference. To maximize the potential of particle mesh simulations, it is crucial to use a very fine grid, achieving a high resolution of small-scale structures and a power spectrum close to that of hydrodynamical simulations. +Particle mesh simulations are essential for cosmological data analysis, particularly in full field inference. They simulate the large-scale structure of the universe and generate the likelihood of the data given the cosmological parameters. Given the high dimensionality of these simulations, advanced sampling techniques such as Hamiltonian Monte Carlo (HMC) and variational inference are required to efficiently explore the parameter space. Differentiable simulations are crucial in this context as they provide access to gradients, significantly accelerating the inference process. -Full field inference, based on Bayesian hierarchical models, allows for the inference of cosmological parameters from galaxy survey data. This method utilizes all available data, rather than just estimated two-point correlation functions, this requires fast simulations of sizable fractions of the universe. Particle mesh simulations play a pivotal role in this process by generating the likelihood of the data given the cosmological parameters. +To maximize the potential of particle mesh simulations, it is crucial to use a very fine grid, achieving high resolution of small-scale structures and a power spectrum close to that of hydrodynamical simulations. However, this fine grid significantly increases memory consumption. For instance, for a grid size of $4096^3$, the force vector field is about 1.5 TB. Ideally, computing and storing the vector field at each step can easily require 5 to 10 TB of memory, which is not feasible on a single GPU, even with high-end data center GPUs like the H100. Distributing the simulation across multiple GPUs and nodes is necessary to overcome this limitation. The only step requiring communication in fast PM simulations is the FFT. Using a 3D decomposition library like `cuDecomp` and integrating its functionality with `JAX` will allow us to perform distributed and differentiable 3D FFTs, enabling the simulation of large cosmological volumes on HPC clusters. -To perform these simulations efficiently on modern HPC clusters, distributed FFTs are required. `jaxDecomp` addresses this need by distributing FFTs across multiple GPUs and nodes, fully compatible with JAX. This allows for simple Python API usage while benefiting from the performance of compiled code on single or multiple GPUs. +To perform these simulations efficiently on modern HPC clusters, distributed FFTs are required. `jaxDecomp` addresses this need by distributing FFTs across multiple GPUs and nodes, fully compatible with `JAX`. This capability not only facilitates high-performance simulation but also ensures that the FFT operations remain differentiable, crucial for incorporating gradient-based optimization techniques like backpropagation in machine learning frameworks integrated with `JAX`. # Implementation ## Distributed FFTs -The implementation of `jaxDecomp` focuses on 2D decomposition for parallel FFTs and efficient halo exchange. The process begins by creating a 2D grid of GPUs using the JAX API, followed by the creation of particles on this grid. The steps for performing FFTs and their inverse are as follows: +The implementation of `jaxDecomp` does a serie of 1D FFTs using `cuFFT` on the undistributed axis, followed a multi GPU transposition using `cuDecomp` on the newly transposed undistributed axis.\ +Starting with a 2D decomposition, The X axis is not distributed, and the Y and Z axes are distributed across multiple GPUs. The 1D FFTs are performed on the X axis, and the transposition is done from a X pencil to a Y pencil. The transposed data is then distributed across the GPUs. The 1D FFTs are performed on the Y axis, and the same is carried out for the Z axis. -```python -FFT1D_X(particles) -y_pencil = TransposeXtoY(particles) -FFT1D_Y(y_pencil) -z_pencil = TransposeYtoZ(y_pencil) -FFT1D_Z(z_pencil) -``` +And inverse FFTs goes the other way around, by running a 1D inverse FFT on the Z axis, then transposing the data from a Z pencil to a Y pencil, and running a 1D inverse FFT on the Y axis, and finally transposing the data from a Y pencil to an X pencil and running a 1D inverse FFT on the X axis. -And inverse FFTs goes the other way around +![Distributed FFTs using jaxDecomp](assets/fft.svg){width=40%} -```python -FFT1D_Z_inv(z_pencil) -y_pencil = TransposeZtoY(z_pencil) -FFT1D_Y_inv(y_pencil) -x_pencil = TransposeYtoX(y_pencil) -FFT1D_X_inv(x_pencil) -``` -
- - +At any point of the simulation, the data is distributed accross 2 dimensions, with the third dimension being undistributed. This allows us to store simulation data of any size on the GPUs, as long as the data fits in the combined memory of the GPUs. +## Distributed Halo Exchange +In a particle mesh simulation, we use the 3DFFT to estimate the force field acting on the particles. The force field is then interpolated to the particles, and the particles are moved accordingly. The particles that are close to the boundary of the local domain need to be updated using the data from the neighboring domains. This is done using a halo exchange operation. Where we pad each slice of the simulation then we perform a halo exchange operation to update the particles that are close to the boundary of the local domain. -## Distributed Halo Exchange +![Distributed Halo Exchange using jaxDecomp](assets/halo-exchange.svg){width=40%} -During the update step, particles cannot leave the local slice (GPU). jaxDecomp facilitates a halo exchange to handle particles on the edge of slices efficiently using the NCCL library. -
- -
+## Example +In this example, we show how to perform a distributed 3D FFT using `jaxDecomp` and `JAX`. The code snippet below demonstrates the initialization of the distributed mesh, the creation of the initial distributed tensor, and the execution of a distributed 3D FFT using `jaxDecomp`. + +```python +import jax +import jaxdecomp + +# Setup +master_key = jax.random.PRNGKey(42) +key = jax.random.split(master_key, size)[rank] +pdims = (2 , 2) +mesh_shape = [2048, 2048 , 2048] +halo_size = (256 , 256 , 0) + +# Create computing mesgh +devices = mesh_utils.create_device_mesh(pdims) +mesh = Mesh(devices, axis_names=('y', 'z')) +sharding = jax.sharding.NamedSharding(mesh, P('z', 'y')) + +### Create all initial distributed tensors ### +local_mesh_shape = [ + mesh_shape[0] // pdims[1], mesh_shape[1] // pdims[0], mesh_shape[2] +] + +z = jax.make_array_from_single_device_arrays( + shape=mesh_shape, + sharding=sharding, + arrays=[jax.random.normal(key, local_mesh_shape)]) + + +@jax.jit +def step(z): + with mesh: + padding = ((halo_size[0] , halo_size[0]) , (halo_size[1] , halo_size[1]) , (0 , 0)) + k_array = jaxdecomp.fft.pfft3d(z).real + k_array = k_array * 2 # element wise operation is distributed automatically by jax + k_array = jaxdecomp.slice_pad(k_array , padding , pdims) + k_array = jaxdecomp.halo_exchange( + k_array, + halo_extents=halo_size, + halo_periods=(True, True, True)) + k_array = jaxdecomp.slice_unpad(k_array , padding , pdims) + + return jaxdecomp.fft.ifft3d(k_array).real + +z = step(z) + +``` + +A more detailed example of a LPT simulation can be found in the [documentation](https://jaxdecomp.readthedocs.io/en/latest/). -Key FeaturesDifferentiable: Seamlessly integrates with JAX for differentiable simulations.Scalable: Efficiently distributes FFTs across multiple GPUs and nodes.User-Friendly: Provides a simple Python API, leveraging the power of compiled code.Halo Exchange: Performs efficient halo exchange using NCCL for particle updates.AcknowledgementsWe acknowledge contributions from François Lanusse and support from the Université Paris Cité and Université Paris-Saclay. Special thanks to the developers of cuDecomp and 2DECOMP&FFT for their foundational work in distributed FFTs.ReferencesyamlCopy code ---- -This draft incorporates your points and fills out the required sections. You can add specific figures, benchmarks, and additional details to further complete your paper. If you need further assistance with any particular section or additional information, feel free to ask! +## Stability and releases