-
Notifications
You must be signed in to change notification settings - Fork 32
Open
Description
The gradient of the solution of a linear system solved iteratively w.r.t. to the initial guess should be zero. Instead, the following snippet
import lineax as lx
from jax import numpy as jnp
import jax
operator = lx.MatrixLinearOperator(jnp.array([[1,0],[0,1]]),tags=lx.positive_semidefinite_tag)
b = jnp.array([1.,2.])
def f(x0):
return lx.linear_solve(operator, b,options={'y0':x0}, solver=lx.CG(atol=1e-12,rtol=1e-12)).value.sum()
x0 = jnp.zeros(2)
print(jax.grad(f)(x0))
gives [lineax
version 0.0.5
]
Traceback (most recent call last):
File "/home/romanodev/Project/JAX-BTE/test_lineax.py", line 15, in <module>
print(jax.grad(f)(x0))
^^^^^^^^^^^^^^^
File "/home/romanodev/Project/JAX-BTE/test_lineax.py", line 11, in f
return lx.linear_solve(operator, b,options={'y0':x0}, solver=lx.GMRES(atol=1e-12,rtol=1e-12)).value.sum()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Unexpected tangent. `lineax.linear_solve(..., options=...)` cannot be autodifferentiated.
The problem is quickly resolved by using jax.lax.stop_gradient
import lineax as lx
from jax import numpy as jnp
import jax
operator = lx.MatrixLinearOperator(jnp.array([[1,0],[0,1]]),tags=lx.positive_semidefinite_tag)
b = jnp.array([1.,2.])
def f(x0):
return lx.linear_solve(operator, b,options={'y0':jax.lax.stop_gradient(x0)}, solver=lx.CG(atol=1e-12,rtol=1e-12)).value.sum()
x0 = jnp.zeros(2)
print(jax.grad(f)(x0))
[0. 0.]
For reference, JAX's solver works fine
from jax import numpy as jnp
import jax
A = jnp.array([[1,0],[0,1]])
b = jnp.array([1.,2.])
def f(x0):
return jax.scipy.sparse.linalg.cg(lambda x:A.dot(x), b,tol=1e-10,x0=x0)[0].sum()
x0 = jnp.zeros(2)
print(jax.grad(f)(x0))
[0. 0.]
Even though this is a corner case, it may happen that the first guess is traced (it was my use case) in a more complex computational graph. Also, it would be great to be able to specify the first guess for the backward pass.
Metadata
Metadata
Assignees
Labels
No labels