Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Hankel Transform / FFTLog #52

Open
minaskar opened this issue Jun 11, 2020 · 18 comments
Open

Implement Hankel Transform / FFTLog #52

minaskar opened this issue Jun 11, 2020 · 18 comments
Labels
enhancement New feature or request JAX Issues related to JAX

Comments

@minaskar
Copy link
Contributor

In order for jax_cosmo to be used in observational cosmology analyses (e.g. BAO, RSD, fNL) we need a JAX implementation of FFTLog algorithm in order to facilitate Survey Window function convolutions with the Power Spectrum.

This would also be helpful to get models of the correlation function as it was mentioned in another post.

There's already a package that's used very often in cosmology:
https://github.com/eelregit/mcfit

It should be possible to implement it using JAX.

@EiffL
Copy link
Member

EiffL commented Jun 11, 2020

Yes totally agree. Indeed FFTLog also comes up in #30 although I think @sukhdeep1989 wants to implement a different approach for that.

I haven't dived into the details of mcfit before but I'm pretty sure @eelregit would be interested in a JAX version.

Other option that I looked into was to just transpose to JAX this implementation: https://github.com/JoeMcEwen/FAST-PT/blob/master/fastpt/HT.py

Seemed pretty easy at first glance.

@EiffL EiffL added enhancement New feature or request JAX Issues related to JAX labels Jun 11, 2020
@sukhdeep2
Copy link
Collaborator

sukhdeep2 commented Jun 11, 2020 via email

@EiffL
Copy link
Member

EiffL commented Jun 11, 2020

Ok, so I think what you mean @sukhdeep1989 is that we may not need to write everything in native JAX code, because when integrals are involved, we can compute explicit JVPs and so write custom autodiff rules instead of asking JAX to figure it out.
so like:

xi = dk k P(k) J(kr)
xi’ = dk k P’(k) J(kr)

See this other issue #47 I have opened to do this more generally for all integrals, hopefully making the JAX compilation faster than it currently is

@eelregit
Copy link
Contributor

Happy to help if needed!

xi = dk k P(k) J(kr)
xi’ = dk k P’(k) J(kr)

I wonder if the second prime should be with respect to r and thus on J?

@sukhdeep2
Copy link
Collaborator

sukhdeep2 commented Jun 11, 2020 via email

@eelregit
Copy link
Contributor

I see.
I agree that the integral and d/d(cosmology) are orthogonal.

Will it suffice to replace numpy by jax.numpy in mcfit?

@EiffL
Copy link
Member

EiffL commented Jun 11, 2020

It's very possible :-D

@eelregit
Copy link
Contributor

I am thinking what's a good interface to switch between the numpy and jax backends. Also need to read some jax docs.

Any recommendation? :)

@EiffL
Copy link
Member

EiffL commented Jun 11, 2020

That's a good question. I don't think there is an easy way to do it, or at least not a generic one.
One example I know of backend switching between TF and Numpy is here: https://github.com/google/edward2#using-the-numpy-backend
Or between TF, NumPy, and JAX:
https://github.com/tensorflow/probability/tree/master/tensorflow_probability/python/experimental/substrates
But the mechanism seems pretty complicated:

https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/experimental/substrates/meta/rewrite.py

They have a script that rewrite codes on the fly for various backends

@eelregit
Copy link
Contributor

Indeed that looks quite complicated. I will look at them more closely. Otherwise it is straightforward to copy and replace numpy by jax, if you don't mind a PR like that ;)

@florpi
Copy link

florpi commented Jan 25, 2022

Hi everyone!

I had to do an implementation of the Hankel transform in jax for a project I was working on, perhaps it is useful here as well. You can find it in this repo https://github.com/florpi/JaxHankel

And actually, maybe you can help me understand why does it work since I'm converting jax arrays to numpy arrays and using a scipy function at one point (see here https://github.com/florpi/JaxHankel/blob/main/jax_fht/fht.py#L61) But the final derivatives seem fine

@EiffL
Copy link
Member

EiffL commented Jan 25, 2022

Hey Carolina :-D

That looks super useful indeed! To answer your question rightaway, it will work even if you use scipy functions (because implicitly it will convert back and forth between numpy and jax.numpy arrays) until you try to use jax.jit or jax.grad, and that will fail.

But, it looks like at least gammaln is implemented in jax already: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.special.gammaln.html#jax.scipy.special.gammaln

Wouldn't that be enough for your usecase? I think it's doing the same as loggamma(x) if x is positive.

@EiffL
Copy link
Member

EiffL commented Jan 25, 2022

ah no sorry, I see now that the argument can be imaginary :-| but then it should "just" be a matter of adding an implemenation of loggamma

@eelregit
Copy link
Contributor

eelregit commented Jan 25, 2022

Hey Carolina :)

I agree with Francois. Maybe there's a way to cache scipy loggamma values in jnp.ndarray's, as long as one does not need derivative to the x values (scales).

@florpi
Copy link

florpi commented Jan 26, 2022

Hi both! thanks for your comments :) I also thought that if I was to call grad or jacobian it would fail, but you can see that it works on this test https://github.com/florpi/JaxHankel/blob/main/test_jax_fht/test_analytical_cosmology.py#L39 Maybe I'm misunderstanding how jacobian works?

Regarding loggamma, the issue I had was exactly the complex number extension hehe I haven't looked too much into it, so it might not be hard. Any ideas are welcome

@EiffL
Copy link
Member

EiffL commented Jan 27, 2022

Ha yes, sorry I had missed that you were getting the jacobian.

Yes, so, this works because the loggamma function is only used to compute coefficients for the Hankel transform. These coefficients are fixed, and you don't take derivatives with respect to them, so no problem during jit compililing or taking gradients.

In practice when you jit compile that function, as long as the arguments are fixed, the scipy code will be called, and the results stored as a "constant" that is then used in the rest of the jax code.

So most likely no need to reimplmeent loggamma :-D

@florpi
Copy link

florpi commented Jan 27, 2022

Thank you, that makes a lot of sense! Also, not having to reimplement loggamma makes me very happy :D

@eelregit
Copy link
Contributor

I think the kernels (e.g. _u here) are to be evaluated at points that are determined by the input k or r scale (via e.g. Delta in that link).
So the kernel values are not jit constants unless we make the input scales static?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request JAX Issues related to JAX
Projects
None yet
Development

No branches or pull requests

5 participants