Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Hessian of gamma-distributed samples #21432

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

NeilGirdhar
Copy link
Contributor

@NeilGirdhar NeilGirdhar commented May 25, 2024

Fixes #16076

@NeilGirdhar NeilGirdhar force-pushed the gamma_hessian branch 2 times, most recently from f5cbd12 to 38c61c0 Compare May 25, 2024 22:38
from jax._src import dtypes
from jax._src.interpreters import ad
from jax._src.interpreters import mlir
from jax._src.lib.mlir.dialects import chlo
from jax._src.typing import Array, ArrayLike

def _while_loop_scan(cond_fun, body_fun, init_val, max_iter):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I move this to api_util?

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Jul 22, 2024

@jakevdp Is there any interest in merging something like this? It's really hard to do the same thing in client code without copying thousands of lines of Jax code and then keeping them updated (which is extremely time-consuming). Or perhaps you have an alternative idea for how I can accomplish this?

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 1, 2024

Hi @NeilGirdhar, sorry for the delay on this. I'm hoping @froystig or @mattjj can weigh-in here.

I'm a bit concerned about the approach here, because the bounded while-loop might have memory impacts when computing the second derivative for large numbers of samples: my understanding is it would require statically allocating a buffer 256 times larger than the buffer of samples you're generating. It may be that remat can help with that, but I'm not entirely sure. I'd like to make sure we have the right solution here to avoid potentially adding a memory-related footgun to JAX library code.

Perhaps there are ways to compute this second derivative more directly, without differentiating through the implementation of the first derivative?

@jakevdp jakevdp requested a review from mattjj August 1, 2024 18:13
@NeilGirdhar
Copy link
Contributor Author

Absolutely no problem about the delay. Congratulations on completing the Jax implementation of the Array API so quickly!

I'm a bit concerned about the approach here, because the bounded while-loop might have memory impacts when computing the second derivative for large numbers of samples: my understanding is it would require statically allocating a buffer 256 times larger than the buffer of samples you're generating. It may be that remat can help with that, but I'm not entirely sure. I'd like to make sure we have the right solution here to avoid potentially adding a memory-related footgun to JAX library code.

Perhaps there are ways to compute this second derivative more directly, without differentiating through the implementation of the first derivative?

Your concerns make perfect sense to me. I'm going to let the others weight in, but in the interest of eliminating back-and-forth, I'll make some comments and suggestions if that's okay 😄

First of all, the reason for my frequent force-pushes (sorry if that was noisy?) is because this feature is so important to me that I am now pointing my repo to my PR branch rather than to Jax directly. This way I have access to this feature. I tried lifting the gamma-random-generation out of Jax, but it's a large mass of code that has changed over the last year, so that was too much work to keep updated.

As for the time and space concerns, I want to first remind readers that only the second derivative code is slow. I agree with your point that this could be a footgun.

The ideal approach is probably to replace the for loop with solving a fixed point. (So, it would go back to being just a while loop in all cases.) Is there any precedent to fixed point optimization in Jax's source code? I know that JaxOpt and tjax have fixed point solvers. It would be some work, but should be possible to recast the algorithm slightly so that it fits the fixed point interface.

An alternative approach would be to tune the 256 constant. I think it's about ten times too big. I didn't think about speed or memory, and I just wanted to get it working. Tuning this constant might solve the speed problem, but the fixed point solution makes the memory cost constant, I think.

What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Please consider implementing differentiation for the Hessian of gamma variates
3 participants