-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Efficient NewtonCG Implementation #24
Comments
That's great to hear, thank you! On HVPs: if I understand you correctly, this is a general JAX question, rather than specifically a question of how to integrate a solve into Optimistix? You're looking to get both the gradient and a HVP without having to treat them both separately (which would be 3 sweeps in total). def to_jvp(x):
return jax.value_and_grad(fn)(x)
(f, dfdxi), (_, dfdxidxjvj) = jax.jvp(to_jvp, (x,), (v,)) |
Yes exactly! It isn't clear to me how to work this into a Newton-like CG solver, but I'll keep toying around. Thanks as always for your enthusiasm and detailed help. It is greatly appreciated! |
Newton CG is one of the algorithms which I think could be somewhat involved to implement. It is further from the existing solvers in Optimistix, so there's more custom work that needs to be done to get it running. There's two steps I would take if I were implementing it (which I may in the future):
This should be pretty much everything though. Just use |
Hi all, thanks for the phenomenal library. We're already using it in several statistical genetics methods in my group!
I've been porting over some older code of mine to use optimistix, rather than hand-rolled inference procedures and could use some advice. Currently, I am performing some variational inference using a mix of closed-form updates for variational parameters, as well as gradient-based updates for some hyperparameters. It -roughly- works like,
I'd -like- to retool the above to not only report the current value, aux values (i.e. updated variational parameters), and gradient wrt hyper param, but return a -hvp- function that could be used in a Newton CG like step in Optimistix. I know of the new
minimize
function, but what isn't clear is how to set up the scenario to not only report gradients, but also return ahvp
function internally without having to take two additional passes over the graph (i.e. once for value and grad, another two for hvp => forward + backward).Is this doable? Apologies if this is somewhat nebulous--I'm happy to clarify.
The text was updated successfully, but these errors were encountered: