Skip to content

Commit

Permalink
add missing bibtex
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Jul 22, 2024
1 parent 1396d48 commit 78058ea
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
15 changes: 15 additions & 0 deletions joss-paper/paper.bib
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,18 @@ @book{HMC
author={Brooks, Steve and Gelman, Andrew and Jones, Galin and Meng, Xiao-Li},
year={2011},
month=may }

@article{mpi4jax, doi = {10.21105/joss.03419}, url = {https://doi.org/10.21105/joss.03419}, year = {2021}, publisher = {The Open Journal}, volume = {6}, number = {65}, pages = {3419}, author = {Dion Häfner and Filippo Vicentini}, title = {mpi4jax: Zero-copy MPI communication of JAX arrays}, journal = {Journal of Open Source Software} }


@article{JAXCOSMO,
title={JAX-COSMO: An End-to-End Differentiable and GPU Accelerated Cosmology Library},
volume={6},
ISSN={2565-6120},
url={http://dx.doi.org/10.21105/astro.2302.05163},
DOI={10.21105/astro.2302.05163},
journal={The Open Journal of Astrophysics},
publisher={Maynooth University},
author={Campagne, Jean-Eric and Lanusse, François and Zuntz, Joe and Boucaud, Alexandre and Casas, Santiago and Karamanis, Minas and Kirkby, David and Lanzieri, Denise and Peel, Austin and Li, Yin},
year={2023},
month=apr }
4 changes: 2 additions & 2 deletions joss-paper/paper.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ bibliography: paper.bib
# Abstract


JAX [@JAX] has seen widespread adoption in both machine learning and scientific computing due to its flexibility and performance, as demonstrated in projects like JAX-Cosmo [@JAXCOSMO]. However, its application in distributed high-performance computing (HPC) has been limited by the complex nature of inter-GPU communications required in HPC scientific software, which is more challenging compared to deep learning networks. Previous solutions, such as mpi4jax [@mpi4jax], provided support for single program multiple data (SPMD) operations but faced significant scaling limitations.
`JAX` [JAX] has seen widespread adoption in both machine learning and scientific computing due to its flexibility and performance, as demonstrated in projects like `JAX-Cosmo` [@JAXCOSMO]. However, its application in distributed high-performance computing (HPC) has been limited by the complex nature of inter-GPU communications required in HPC scientific software, which is more challenging compared to deep learning networks. Previous solutions, such as `MPI4JAX` [@mpi4jax], provided support for single program multiple data (SPMD) operations but faced significant scaling limitations.


Recently, JAX has made a major push towards simplified SPMD programming, with the unification of the JAX array API and the introduction of several powerful APIs, such as `pjit`, `shard_map`, and `custom_partitioning`. However, not all native JAX operations have specialized distribution strategies, and `pjitting` a program can lead to excessive communication overhead for some operations, particularly the 3D Fast Fourier Transform (FFT), which is one of the most critical and widely used algorithms in scientific computing. Distributed FFTs are essential for many simulation and solvers, especially in fields like cosmology and fluid dynamics, where large-scale data processing is required.
Expand All @@ -43,7 +43,7 @@ To address these limitations, we introduce jaxDecomp, a JAX library that wraps N

# Statement of Need

For numerical simulations on HPC systems, having a distributed, easy-to-use, and differentiable FFT is critical for achieving peak performance and scalability. While it is technically feasible to implement distributed FFTs using native JAX, for performance and memory-critical simulations, it is better to use specialized HPC codes. These codes, however, are not typically differentiable. The need for differentiable, performant, and memory-efficient code has risen due to the recent introduction of differentiable algorithms such as Hamiltonian Monte Carlo (HMC) and the No-U-Turn Sampler (NUTS).
For numerical simulations on HPC systems, having a distributed, easy-to-use, and differentiable FFT is critical for achieving peak performance and scalability. While it is technically feasible to implement distributed FFTs using native JAX, for performance and memory-critical simulations, it is better to use specialized HPC codes. These codes, however, are not typically differentiable. The need for differentiable, performant, and memory-efficient code has risen due to the recent introduction of differentiable algorithms such as Hamiltonian Monte Carlo (HMC) [@HMC] and the No-U-Turn Sampler (NUTS) [@NUTS].

In scientific applications such as particle mesh (PM) simulations for cosmology, existing frameworks like FlowPM, a TensorFlow-mesh based simulation, are distributed but no longer actively maintained. Similarly, JAX-based frameworks like pmwd are limited to 512 volumes due to the lack of distribution capabilities. These examples underscore the critical need for scalable and efficient solutions. jaxDecomp addresses this gap by enabling distributed and differentiable 3D FFTs within JAX, thereby facilitating the simulation of large cosmological volumes on HPC clusters effectively.

Expand Down

0 comments on commit 78058ea

Please sign in to comment.