-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Differentiating w.r.t. initial guess throws an error #104
Comments
I think this is working as intended. We don't support any notion of differentiating with respect to On using an initial guess for the backward pass -- indeed, right now we don't seem to support this. Probably the correct thing to do would be to just use the transpose of the initial guess for the forward pass, by filling in these two methods: Lines 236 to 248 in 4a7b108
I'd be happy to take a PR on this! |
I see. I will look into it. Great library, BTW! |
The gradient of the solution of a linear system solved iteratively w.r.t. to the initial guess should be zero. Instead, the following snippet
gives [
lineax
version0.0.5
]The problem is quickly resolved by using
jax.lax.stop_gradient
For reference, JAX's solver works fine
Even though this is a corner case, it may happen that the first guess is traced (it was my use case) in a more complex computational graph. Also, it would be great to be able to specify the first guess for the backward pass.
The text was updated successfully, but these errors were encountered: