-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Joss paper of jaxDecomp #20
Conversation
aa5403d
to
d053753
Compare
Lol, @ASKabalan ^^ so many force pushes, can we forbid the force pushes from now on? |
Co-authored-by: Wassim KABALAN <[email protected]>
Thanks @ASKabalan for the draft, it's a very good start, I have some high level comments that I will add here, and maybe make some particular comments on the markdown file. My main overaching comment is that this is not a jaxpm paper, it's a jaxDecomp paper. PM simulations are just one potential example of real world application, but not the only raison d'etre of the library.
I think our story here in the abstract could be the following:
# Create an array
x = jax.random.normal(jax.random.key(0), (32, 32, 32))
# Distributes the array
sharding = PositionalSharding(mesh_utils.create_device_mesh((2,2,1)))
x = jax.device_put(x, sharding)
# Perform 1D FFT along the last dimension and transpose the array
x = jnp.fft.fft(x).transpose(2,0,1) # [z', x, y]
x = jnp.fft.fft(x).transpose(2,0,1) # [y', z', x]
x = jnp.fft.fft(x) # [y', z', x'] If we have such a comparison in the benchmark, we can refer to it here as a statement of need. Then at the end of the statement, we can mention a real world application, and that's where we can talk about PM simulations for cosmology. We can in particular mention FlowPM (distributed but in TF, so useless) and pmwd (not distributed and so limited to 512 volumes).
You can add a couple of lines of code to illustrate the API for points 2 and 3 above.
def potential(delta):
delta_k = pfft3d(delta)
kvec = ...
laplace_kernel = 1/kk
potential_k = delta_k * laplace_kernel
return ipfft3d(potential_k)
|
joss-paper/paper.md
Outdated
|
||
## Distributed Halo Exchange | ||
|
||
In a particle mesh simulation, we use the 3DFFT to estimate the force field acting on the particles. The force field is then interpolated to the particles, and the particles are moved accordingly. The particles that are close to the boundary of the local domain need to be updated using the data from the neighboring domains. This is done using a halo exchange operation. Where we pad each slice of the simulation then we perform a halo exchange operation to update the particles that are close to the boundary of the local domain. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't motivate the halo exchange from the pm simulation, halo exchanges are very common operations for distributed computing: https://wgropp.cs.illinois.edu/courses/cs598-s15/lectures/lecture25.pdf
(first result on google)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So here I think we just want to explain that cuDecomp allows for the exchange of border regions, which is a pattern necessary to handle border crossing in many types of simulations.
What's wrong with the colors? |
Oh .. |
…ss-paper Proposal for simplified demo
note, I found this previous implementation I had made using xmap of 3d distributed fft @partial(xmap,
in_axes={ 0: 'x', 1: 'y' },
out_axes=['x', 'y', ...],
axis_resources={ 'x': 'nx', 'y': 'ny' })
@jax.jit
def pfft3d(mesh):
# [x, y, z]
mesh = jnp.fft.fft(mesh) # Transform on z
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.fft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
mesh = jnp.fft.fft(mesh) # Transform on y
# [z, x, y]
return mesh
@partial(xmap,
in_axes={ 0: 'x', 1: 'y' },
out_axes=['x', 'y', ...],
axis_resources={ 'x': 'nx', 'y': 'ny' })
@jax.jit
def pifft3d(mesh):
# [z, x, y]
mesh = jnp.fft.ifft(mesh) # Transform on y
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.ifft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
mesh = jnp.fft.ifft(mesh) # Transform on z
# [x, y, z]
return mesh something like this, but using shard_map is probably what we want to benchmark jaxDecomp against |
Could you push the benchmark scripts @ASKabalan when you are back? Curious to see if we can gain a bit performance |
@EiffL I am trying to make MPI4JAX work |
Adding a draft of JOSS paper