Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jaxdecomp proto #21

Open
wants to merge 114 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
a742065
adding example of distributed solution
EiffL Jul 9, 2024
6408aff
put back old functgion
EiffL Jul 9, 2024
319942a
update formatting
EiffL Jul 9, 2024
ac86468
add halo exchange and slice pad
ASKabalan Jul 9, 2024
e62cd84
apply formatting
ASKabalan Jul 18, 2024
5775a37
implement distributed optimized cic_paint
ASKabalan Jul 18, 2024
7501b5b
Use new cic_paint with halo
ASKabalan Jul 18, 2024
7f48cfa
Fix seed for distributed normal
ASKabalan Jul 18, 2024
c81d4d2
Wrap interpolation function to avoid all gather
ASKabalan Jul 18, 2024
abde543
Return normal order frequencies for single GPU
ASKabalan Jul 18, 2024
82be568
add example
ASKabalan Jul 18, 2024
4f508b7
format
ASKabalan Jul 18, 2024
5f6d42e
add optimised bench script
ASKabalan Jul 18, 2024
1f6b9c3
times in ms
ASKabalan Jul 18, 2024
ed8cf8e
add lpt2
ASKabalan Jul 18, 2024
0216837
update benchmark and add slurm
ASKabalan Jul 18, 2024
5b7f595
Visualize only final field
ASKabalan Jul 18, 2024
1f20351
Update scripts/distributed_pm.py
ASKabalan Jul 19, 2024
f25eb7d
Adjust pencil type for frequencies
ASKabalan Jul 28, 2024
8c5bd76
fix painting issue with slabs
ASKabalan Aug 2, 2024
75604d2
Shared operation in fourrier space now take inverted sharding axis for
ASKabalan Aug 2, 2024
ccbfee3
add assert to make pyright happy
ASKabalan Aug 2, 2024
aebc3e7
adjust test for hpc-plotter
ASKabalan Aug 2, 2024
9af4659
add PMWD test
ASKabalan Aug 2, 2024
831291c
bench
Aug 2, 2024
ece8c93
format
ASKabalan Aug 2, 2024
783a974
Merge remote-tracking branch 'upstream/main' into ASKabalan/jaxdecomp…
ASKabalan Aug 2, 2024
02754cf
added github workflow
ASKabalan Aug 2, 2024
8da3149
fix formatting from main
ASKabalan Aug 2, 2024
2ea05a1
Update for jaxDecomp pure JAX
Aug 7, 2024
ab86699
merge with JZ
ASKabalan Oct 18, 2024
afecb13
revert single halo extent change
ASKabalan Oct 20, 2024
01b9527
update for latest jaxDecomp
ASKabalan Oct 21, 2024
ff1c5e8
remove fourrier_space in autoshmap
ASKabalan Oct 21, 2024
0ce7219
make normal_field work with single controller
ASKabalan Oct 21, 2024
9c94f99
format
ASKabalan Oct 21, 2024
375f204
make distributed pm work in single controller
ASKabalan Oct 21, 2024
5a587fd
merge bench_pm
ASKabalan Oct 21, 2024
a160a3f
update to leapfrog
ASKabalan Oct 22, 2024
38714cf
add a strict dependency on jaxdecomp
ASKabalan Oct 22, 2024
591ee32
global mesh no longer needed
ASKabalan Oct 22, 2024
a5b267b
kernels.py no longer uses global mesh
ASKabalan Oct 22, 2024
56ffd26
quick fix in distributed
ASKabalan Oct 22, 2024
80c56dc
pm.py no longer uses global mesh
ASKabalan Oct 22, 2024
105568e
painting.py no longer uses global mesh
ASKabalan Oct 22, 2024
4d944f0
update demo script
ASKabalan Oct 22, 2024
a8b194f
quick fix in kernels
ASKabalan Oct 22, 2024
0433c61
quick fix in distributed
ASKabalan Oct 22, 2024
2f50993
update demo
ASKabalan Oct 22, 2024
8623308
merge hugos LPT2 code
ASKabalan Oct 22, 2024
82b8f56
format
ASKabalan Oct 22, 2024
85cca44
Merge remote-tracking branch 'upstream/main' into ASKabalan/jaxdecomp…
ASKabalan Oct 22, 2024
d28982e
Small fix
ASKabalan Oct 22, 2024
82f2987
format
ASKabalan Oct 22, 2024
31ca41b
remove duplicate get_ode_fn
ASKabalan Oct 22, 2024
cf799b6
update visualizer
ASKabalan Oct 22, 2024
0bb992f
update compensate CIC
ASKabalan Oct 22, 2024
45b2c7f
By default check_rep is false for shard_map
ASKabalan Oct 22, 2024
505f2ec
remove experimental distributed code
ASKabalan Oct 25, 2024
5d4f438
update PGDCorrection and neural ode to use new fft3d
ASKabalan Oct 25, 2024
8e8e896
jaxDecomp pfft3d promotes to complex automatically
ASKabalan Oct 25, 2024
69c35d1
Merge remote-tracking branch 'upstream/main' into ASKabalan/jaxdecomp…
ASKabalan Oct 25, 2024
0f833f0
remove deprecated stuff
EiffL Oct 24, 2024
d2f1eb2
fix painting issue with read_cic
ASKabalan Oct 26, 2024
ff8856d
use jnp interp instead of jc interp
ASKabalan Oct 26, 2024
0c96a4d
delete old slurms
ASKabalan Oct 26, 2024
49dd18a
add notebook examples
ASKabalan Oct 26, 2024
11f7e90
Merge remote-tracking branch 'upstream/ASKabalan/jaxdecomp_proto' int…
ASKabalan Oct 26, 2024
4342279
apply formatting
ASKabalan Oct 26, 2024
cc4f310
add distributed zeros
ASKabalan Oct 27, 2024
d62c38f
fix code in LPT2
ASKabalan Oct 27, 2024
b4fdb74
jit cic_paint
ASKabalan Oct 27, 2024
c93894f
update notebooks
ASKabalan Oct 27, 2024
19011d0
apply formating
ASKabalan Oct 27, 2024
a757b62
get local shape and zeros can be used by users
ASKabalan Oct 30, 2024
f3b431a
add a user facing function to create uniform particle grid
ASKabalan Oct 30, 2024
2ad035a
use jax interp instead of jax_cosmo
ASKabalan Oct 30, 2024
b3a264a
use float64 for enmeshing
ASKabalan Oct 30, 2024
b09580d
Allow applying weights with relative cic paint
ASKabalan Oct 30, 2024
e9529d3
Weights can be traced
ASKabalan Oct 30, 2024
4da4c66
remove script folder
ASKabalan Oct 30, 2024
72457d6
update example notebooks
ASKabalan Oct 30, 2024
a030ec4
delete outdated design file
ASKabalan Oct 30, 2024
f0c43f8
add readme for tutorials
ASKabalan Oct 30, 2024
a067954
update readme
ASKabalan Oct 30, 2024
42d8e89
fix small error
ASKabalan Oct 30, 2024
6256fba
forgot particles in multi host
ASKabalan Oct 30, 2024
2472a5d
clarifying why cic_paint_dx is slower
ASKabalan Nov 10, 2024
ad45666
clarifying the halo size dependence on the box size
ASKabalan Nov 10, 2024
12c74e2
ability to choose snapshots number with MultiHost script
ASKabalan Nov 10, 2024
0946842
Adding animation notebook
ASKabalan Nov 10, 2024
435c7c8
Put plotting in package
ASKabalan Nov 14, 2024
b32014b
Add finite difference laplace kernel + powerspec functions from Hugo
ASKabalan Dec 5, 2024
c1b276d
Put plotting utils in package
ASKabalan Dec 5, 2024
e0c118a
By default use absoulute painting with
ASKabalan Dec 5, 2024
21373b8
update code
ASKabalan Dec 6, 2024
36ef18e
update notebooks
ASKabalan Dec 6, 2024
f70583b
add tests
ASKabalan Dec 6, 2024
7823fda
Upgrade setup.py to pyproject
ASKabalan Dec 6, 2024
af29c40
Format
ASKabalan Dec 8, 2024
97f39bd
format tests
ASKabalan Dec 8, 2024
ac4ef9e
update test dependencies
ASKabalan Dec 8, 2024
5d34d3c
add test workflow
ASKabalan Dec 8, 2024
adaf7d2
fix deprecated FftType in jaxpm.kernels
ASKabalan Dec 8, 2024
d8c68ac
Add aboucaud comments
ASKabalan Dec 8, 2024
b264da5
JAX version is 0.4.35 until Diffrax new release
ASKabalan Dec 8, 2024
47c69c6
add numpy explicitly as dependency for tests
ASKabalan Dec 8, 2024
8951f5c
fix install order for tests
ASKabalan Dec 8, 2024
3ce0be6
add numpy to be installed
ASKabalan Dec 8, 2024
ae684c9
enforce no build isolation for fastpm
ASKabalan Dec 8, 2024
7384343
pip install jaxpm test without build isolation
ASKabalan Dec 9, 2024
3be5dae
bump jaxdecomp version
ASKabalan Dec 9, 2024
158478c
revert test workflow
ASKabalan Dec 9, 2024
f91aa93
remove outdated tests
ASKabalan Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ repos:
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
name: isort (python)
69 changes: 69 additions & 0 deletions dev/jaxdecomp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import argparse

import jax
import numpy as np

# Setting up distributed jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()

import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental import mesh_utils
from jax.sharding import Mesh

from jaxpm.painting import cic_paint
from jaxpm.pm import linear_field, lpt

mesh_shape = [256, 256, 256]
box_size = [256., 256., 256.]
snapshots = jnp.linspace(0.1, 1., 2)


@jax.jit
def run_simulation(omega_c, sigma8, seed):
# Create a cosmology
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)

# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk
).reshape(x.shape)

# Create initial conditions
initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=seed)

# Initialize particle displacements
dx, p, f = lpt(cosmo, initial_conditions, 1.0)

field = cic_paint(jnp.zeros_like(initial_conditions), dx)
return field


def main(args):
# Setting up distributed random numbers
master_key = jax.random.PRNGKey(42)
key = jax.random.split(master_key, size)[rank]

# Create computing mesh and sharding information
devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices.T, axis_names=('x', 'y'))

# Run the simulation on the compute mesh
with mesh:
field = run_simulation(0.32, 0.8, key)

print('done')
np.save(f'field_{rank}.npy', field.addressable_data(0))

# Closing distributed jax
jax.distributed.shutdown()


if __name__ == '__main__':
parser = argparse.ArgumentParser("Distributed LPT N-body simulation.")
args = parser.parse_args()
main(args)
151 changes: 151 additions & 0 deletions jaxpm/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from typing import Any, Callable, Hashable

Specs = Any
AxisName = Hashable

try:
import jaxdecomp
distributed = True
except ImportError:
print("jaxdecomp not installed. Distributed functions will not work.")
distributed = False

from functools import partial

import jax
import jax.numpy as jnp
from jax._src import mesh as mesh_lib
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P

# NOTE
# This should not be used as a decorator
EiffL marked this conversation as resolved.
Show resolved Hide resolved
# Must be used inside a function only
# Example
# BAD
# @autoshmap
# def foo():
# pass
# GOOD
# def foo():
# return autoshmap(foo_impl)()


def autoshmap(f: Callable,
in_specs: Specs,
out_specs: Specs,
check_rep: bool = True,
auto: frozenset[AxisName] = frozenset()):
"""Helper function to wrap the provided function in a shard map if
the code is being executed in a mesh context."""
mesh = mesh_lib.thread_resources.env.physical_mesh
if mesh.empty:
return f
else:
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)


def fft3d(x):
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
else:
return jnp.fft.fftn(x.astype(jnp.complex64))


def ifft3d(x):
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
return jaxdecomp.pifft3d(x).real
else:
return jnp.fft.ifftn(x).real


def get_halo_size(halo_size):
mesh = mesh_lib.thread_resources.env.physical_mesh
if mesh.empty:
zero_ext = (0, 0, 0)
zero_tuple = (0, 0)
return (zero_tuple, zero_tuple, zero_tuple), zero_ext
else:
pdims = mesh.devices.shape
halo_x = (0, 0) if pdims[0] == 1 else (halo_size, halo_size)
halo_y = (0, 0) if pdims[1] == 1 else (halo_size, halo_size)

halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2
halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext, 0))


def halo_exchange(x, halo_extents, halo_periods=(True, True, True)):
mesh = mesh_lib.thread_resources.env.physical_mesh
if distributed and not (mesh.empty) and (halo_extents[0] > 0
or halo_extents[1] > 0):
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
else:
return x


def slice_unpad_impl(x, pad_width):

halo_x, _ = pad_width[0]
halo_y, _ = pad_width[0]

# Apply corrections along x
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
# Apply corrections along y
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])

return x[halo_x:-halo_x, halo_y:-halo_y, :]


def slice_pad(x, pad_width):
mesh = mesh_lib.thread_resources.env.physical_mesh
if distributed and not (mesh.empty) and (pad_width[0][0] > 0
or pad_width[1][0] > 0):
return autoshmap((partial(jnp.pad, pad_width=pad_width)),
in_specs=(P('x', 'y')),
out_specs=P('x', 'y'))(x)
else:
return x


def slice_unpad(x, pad_width):
mesh = mesh_lib.thread_resources.env.physical_mesh
if distributed and not (mesh.empty) and (pad_width[0][0] > 0
or pad_width[1][0] > 0):
return autoshmap(partial(slice_unpad_impl, pad_width=pad_width),
in_specs=(P('x', 'y')),
out_specs=P('x', 'y'))(x)
else:
return x


def get_local_shape(mesh_shape):
""" Helper function to get the local size of a mesh given the global size.
"""
if mesh_lib.thread_resources.env.physical_mesh.empty:
return mesh_shape
else:
pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape
return [
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]
]


def normal_field(mesh_shape, seed=None):
"""Generate a Gaussian random field with the given power spectrum."""
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape)
if seed is None:
key = None
else:
size = jax.process_count()
rank = jax.process_index()
key = jax.random.split(seed, size)[rank]
return autoshmap(
partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'),
in_specs=P(None),
out_specs=P('x', 'y'))(key) # yapf: disable
else:
return jax.random.normal(shape=mesh_shape, key=seed)
60 changes: 41 additions & 19 deletions jaxpm/kernels.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,49 @@
from functools import partial

import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from jax._src import mesh as mesh_lib
from jax.sharding import PartitionSpec as P

from jaxpm.distributed import autoshmap


def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
""" Return k_vector given a shape (nc, nc, nc) and box_size
def fftk(shape, dtype=np.float32):
"""
Generate Fourier transform wave numbers for a given mesh.

Args:
nc (int): Shape of the mesh grid.

Returns:
list: List of wave number arrays for each dimension in
the order [kx, ky, kz].
"""
k = []
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
kdshape = np.ones(len(shape), dtype='int')
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
kdshape[d] = len(kd)
kd = kd.reshape(kdshape)
kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape]

@partial(autoshmap,
in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None)))
def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]),
kx.reshape([1, 1, -1])) # yapf: disable

if not mesh_lib.thread_resources.env.physical_mesh.empty:
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
else:
kx, ky, kz = get_kvec(kx, ky, kz) # The order corresponds

# to the order of dimensions in the transposed FFT
return kx, ky, kz


def interpolate_power_spectrum(input, k, pk):

k.append(kd.astype(dtype))
del kd, kdshape
return k
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
).reshape(x.shape)
return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'))(input)


def gradient_kernel(kvec, direction, order=1):
Expand Down Expand Up @@ -60,11 +86,7 @@ def laplace_kernel(kvec):
Complex kernel
"""
kk = sum(ki**2 for ki in kvec)
mask = (kk == 0).nonzero()
kk[mask] = 1
wts = 1. / kk
imask = (~(kk == 0)).astype(int)
wts *= imask
wts = jnp.where(kk == 0, 1., 1. / kk)
return wts


Expand Down
Loading