Skip to content

Differentiating w.r.t. initial guess throws an error #104

@romanodev

Description

@romanodev

The gradient of the solution of a linear system solved iteratively w.r.t. to the initial guess should be zero. Instead, the following snippet

   
import lineax as lx
from jax import numpy as jnp
import jax

operator = lx.MatrixLinearOperator(jnp.array([[1,0],[0,1]]),tags=lx.positive_semidefinite_tag)
b = jnp.array([1.,2.])

def f(x0):

  return lx.linear_solve(operator, b,options={'y0':x0}, solver=lx.CG(atol=1e-12,rtol=1e-12)).value.sum()

x0 = jnp.zeros(2)

print(jax.grad(f)(x0))

gives [lineax version 0.0.5]

Traceback (most recent call last):
  File "/home/romanodev/Project/JAX-BTE/test_lineax.py", line 15, in <module>
    print(jax.grad(f)(x0))
          ^^^^^^^^^^^^^^^
  File "/home/romanodev/Project/JAX-BTE/test_lineax.py", line 11, in f
    return lx.linear_solve(operator, b,options={'y0':x0}, solver=lx.GMRES(atol=1e-12,rtol=1e-12)).value.sum()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Unexpected tangent. `lineax.linear_solve(..., options=...)` cannot be autodifferentiated.

The problem is quickly resolved by using jax.lax.stop_gradient

   
import lineax as lx
from jax import numpy as jnp
import jax

operator = lx.MatrixLinearOperator(jnp.array([[1,0],[0,1]]),tags=lx.positive_semidefinite_tag)
b = jnp.array([1.,2.])

def f(x0):

  return lx.linear_solve(operator, b,options={'y0':jax.lax.stop_gradient(x0)}, solver=lx.CG(atol=1e-12,rtol=1e-12)).value.sum()

x0 = jnp.zeros(2)

print(jax.grad(f)(x0))
[0. 0.]

For reference, JAX's solver works fine

from jax import numpy as jnp
import jax

A = jnp.array([[1,0],[0,1]])
b = jnp.array([1.,2.])

def f(x0):

  return jax.scipy.sparse.linalg.cg(lambda x:A.dot(x), b,tol=1e-10,x0=x0)[0].sum()

x0 = jnp.zeros(2)

print(jax.grad(f)(x0))
[0. 0.]

Even though this is a corner case, it may happen that the first guess is traced (it was my use case) in a more complex computational graph. Also, it would be great to be able to specify the first guess for the backward pass.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions