-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Hankel Transform / FFTLog #52
Comments
Yes totally agree. Indeed FFTLog also comes up in #30 although I think @sukhdeep1989 wants to implement a different approach for that. I haven't dived into the details of mcfit before but I'm pretty sure @eelregit would be interested in a JAX version. Other option that I looked into was to just transpose to JAX this implementation: https://github.com/JoeMcEwen/FAST-PT/blob/master/fastpt/HT.py Seemed pretty easy at first glance. |
What I added is for projected correlation functions. Interfacing with mcfit
will be great. FYI, hankel transforms and window effects do not need to be
differentiable (there is no dependence on parameters). So, I'm not sure if
implementation with jax is super important.
…On Thu, Jun 11, 2020 at 5:17 AM Francois Lanusse ***@***.***> wrote:
Yes totally agree. Indeed FFTLog also comes up in #30
<#30>
although I think @sukhdeep1989 <https://github.com/sukhdeep1989> wants to
implement a different approach for that.
I haven't dived into the details of mcfit before but I'm pretty sure
@eelregit <https://github.com/eelregit> would be interested in a JAX
version.
Other option that I looked into was to just transpose to JAX this
implementation:
https://github.com/JoeMcEwen/FAST-PT/blob/master/fastpt/HT.py
Seemed pretty easy at first glance.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#52 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AA4EPWWUI26MB6GARB4DJRLRWDDMJANCNFSM4N3LCHAA>
.
|
Ok, so I think what you mean @sukhdeep1989 is that we may not need to write everything in native JAX code, because when integrals are involved, we can compute explicit JVPs and so write custom autodiff rules instead of asking JAX to figure it out.
See this other issue #47 I have opened to do this more generally for all integrals, hopefully making the JAX compilation faster than it currently is |
Happy to help if needed!
I wonder if the second prime should be with respect to r and thus on J? |
On Thu, Jun 11, 2020 at 11:01 AM Yin Li ***@***.***> wrote:
Happy to help if needed!
xi = dk k P(k) J(kr)
xi’ = dk k P’(k) J(kr)
I wonder if the second prime should be with respect to r and thus on J?
The prime is with respect to cosmology, from what I understand. kr = ell
theta, is independent of cosmology.
… —
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#52 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AA4EPWTFP7S5LQ5CISTKZOLRWELXNANCNFSM4N3LCHAA>
.
|
I see. Will it suffice to replace |
It's very possible :-D |
I am thinking what's a good interface to switch between the numpy and jax backends. Also need to read some jax docs. Any recommendation? :) |
That's a good question. I don't think there is an easy way to do it, or at least not a generic one. They have a script that rewrite codes on the fly for various backends |
Indeed that looks quite complicated. I will look at them more closely. Otherwise it is straightforward to copy and replace numpy by jax, if you don't mind a PR like that ;) |
Hi everyone! I had to do an implementation of the Hankel transform in jax for a project I was working on, perhaps it is useful here as well. You can find it in this repo https://github.com/florpi/JaxHankel And actually, maybe you can help me understand why does it work since I'm converting jax arrays to numpy arrays and using a scipy function at one point (see here https://github.com/florpi/JaxHankel/blob/main/jax_fht/fht.py#L61) But the final derivatives seem fine |
Hey Carolina :-D That looks super useful indeed! To answer your question rightaway, it will work even if you use scipy functions (because implicitly it will convert back and forth between numpy and jax.numpy arrays) until you try to use jax.jit or jax.grad, and that will fail. But, it looks like at least Wouldn't that be enough for your usecase? I think it's doing the same as |
ah no sorry, I see now that the argument can be imaginary :-| but then it should "just" be a matter of adding an implemenation of loggamma |
Hey Carolina :) I agree with Francois. Maybe there's a way to cache scipy loggamma values in jnp.ndarray's, as long as one does not need derivative to the x values (scales). |
Hi both! thanks for your comments :) I also thought that if I was to call Regarding |
Ha yes, sorry I had missed that you were getting the jacobian. Yes, so, this works because the loggamma function is only used to compute coefficients for the Hankel transform. These coefficients are fixed, and you don't take derivatives with respect to them, so no problem during jit compililing or taking gradients. In practice when you jit compile that function, as long as the arguments are fixed, the scipy code will be called, and the results stored as a "constant" that is then used in the rest of the jax code. So most likely no need to reimplmeent loggamma :-D |
Thank you, that makes a lot of sense! Also, not having to reimplement loggamma makes me very happy :D |
I think the kernels (e.g. |
In order for jax_cosmo to be used in observational cosmology analyses (e.g. BAO, RSD, fNL) we need a JAX implementation of FFTLog algorithm in order to facilitate Survey Window function convolutions with the Power Spectrum.
This would also be helpful to get models of the correlation function as it was mentioned in another post.
There's already a package that's used very often in cosmology:
https://github.com/eelregit/mcfit
It should be possible to implement it using JAX.
The text was updated successfully, but these errors were encountered: