Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiating w.r.t. initial guess throws an error #104

Open
romanodev opened this issue Aug 16, 2024 · 2 comments
Open

Differentiating w.r.t. initial guess throws an error #104

romanodev opened this issue Aug 16, 2024 · 2 comments

Comments

@romanodev
Copy link

The gradient of the solution of a linear system solved iteratively w.r.t. to the initial guess should be zero. Instead, the following snippet

   
import lineax as lx
from jax import numpy as jnp
import jax

operator = lx.MatrixLinearOperator(jnp.array([[1,0],[0,1]]),tags=lx.positive_semidefinite_tag)
b = jnp.array([1.,2.])

def f(x0):

  return lx.linear_solve(operator, b,options={'y0':x0}, solver=lx.CG(atol=1e-12,rtol=1e-12)).value.sum()

x0 = jnp.zeros(2)

print(jax.grad(f)(x0))

gives [lineax version 0.0.5]

Traceback (most recent call last):
  File "/home/romanodev/Project/JAX-BTE/test_lineax.py", line 15, in <module>
    print(jax.grad(f)(x0))
          ^^^^^^^^^^^^^^^
  File "/home/romanodev/Project/JAX-BTE/test_lineax.py", line 11, in f
    return lx.linear_solve(operator, b,options={'y0':x0}, solver=lx.GMRES(atol=1e-12,rtol=1e-12)).value.sum()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Unexpected tangent. `lineax.linear_solve(..., options=...)` cannot be autodifferentiated.

The problem is quickly resolved by using jax.lax.stop_gradient

   
import lineax as lx
from jax import numpy as jnp
import jax

operator = lx.MatrixLinearOperator(jnp.array([[1,0],[0,1]]),tags=lx.positive_semidefinite_tag)
b = jnp.array([1.,2.])

def f(x0):

  return lx.linear_solve(operator, b,options={'y0':jax.lax.stop_gradient(x0)}, solver=lx.CG(atol=1e-12,rtol=1e-12)).value.sum()

x0 = jnp.zeros(2)

print(jax.grad(f)(x0))
[0. 0.]

For reference, JAX's solver works fine

from jax import numpy as jnp
import jax

A = jnp.array([[1,0],[0,1]])
b = jnp.array([1.,2.])

def f(x0):

  return jax.scipy.sparse.linalg.cg(lambda x:A.dot(x), b,tol=1e-10,x0=x0)[0].sum()

x0 = jnp.zeros(2)

print(jax.grad(f)(x0))
[0. 0.]

Even though this is a corner case, it may happen that the first guess is traced (it was my use case) in a more complex computational graph. Also, it would be great to be able to specify the first guess for the backward pass.

@patrick-kidger
Copy link
Owner

I think this is working as intended. We don't support any notion of differentiating with respect to options, so a user should explicitly opt out of this -- rather than potentially getting silently unexpected gradients.

On using an initial guess for the backward pass -- indeed, right now we don't seem to support this. Probably the correct thing to do would be to just use the transpose of the initial guess for the forward pass, by filling in these two methods:

lineax/lineax/_solver/cg.py

Lines 236 to 248 in 4a7b108

def transpose(self, state: _CGState, options: dict[str, Any]):
del options
psd_op, is_nsd = state
transpose_state = psd_op.transpose(), is_nsd
transpose_options = {}
return transpose_state, transpose_options
def conj(self, state: _CGState, options: dict[str, Any]):
del options
psd_op, is_nsd = state
conj_state = conj(psd_op), is_nsd
conj_options = {}
return conj_state, conj_options

I'd be happy to take a PR on this!

@romanodev
Copy link
Author

I see. I will look into it. Great library, BTW!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants