Skip to content

Commit

Permalink
Iterative Solvers return successfully when run with max_steps only (#…
Browse files Browse the repository at this point in the history
…129)

* Add helper for Poisson matrix

* Test iterative solvers not throwing when only run with max_steps

* Use consistent naming

* Add test doc

* Return successful if max_steps reached without tolerances
  • Loading branch information
Ceyron authored Jan 13, 2025
1 parent 936c31c commit 58f2a8b
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 4 deletions.
2 changes: 1 addition & 1 deletion lineax/_solver/bicgstab.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def body_fun(carry):
else:
result = RESULTS.where(
(num_steps == self.max_steps),
RESULTS.max_steps_reached,
RESULTS.max_steps_reached if has_scale else RESULTS.successful,
RESULTS.successful,
)
# breakdown is only an issue if we did not converge
Expand Down
2 changes: 1 addition & 1 deletion lineax/_solver/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def cheap_r():
else:
result = RESULTS.where(
num_steps == max_steps,
RESULTS.max_steps_reached,
RESULTS.max_steps_reached if has_scale else RESULTS.successful,
RESULTS.successful,
)

Expand Down
2 changes: 1 addition & 1 deletion lineax/_solver/gmres.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def body_fun(carry):
else:
result = RESULTS.where(
(num_steps == self.max_steps),
RESULTS.max_steps_reached,
RESULTS.max_steps_reached if has_scale else RESULTS.successful,
RESULTS.successful,
)
result = RESULTS.where(
Expand Down
9 changes: 9 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ def construct_singular_matrix(getkey, solver, tags, num=1, dtype=jnp.float64):
return tuple(matrix[:, 1:] for matrix in matrices)


def construct_poisson_matrix(size, dtype=jnp.float64):
matrix = (
-2 * jnp.diag(jnp.ones(size, dtype=dtype))
+ jnp.diag(jnp.ones(size - 1, dtype=dtype), 1)
+ jnp.diag(jnp.ones(size - 1, dtype=dtype), -1)
)
return matrix


if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-12
else:
Expand Down
24 changes: 23 additions & 1 deletion tests/test_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import lineax as lx
import pytest

from .helpers import tree_allclose
from .helpers import construct_poisson_matrix, tree_allclose


def test_gmres_large_dense(getkey):
Expand Down Expand Up @@ -156,3 +156,25 @@ def to_grad(x):

x = (jnp.arange(3.0), jnp.arange(3.0))
to_grad(x)


@pytest.mark.parametrize(
"solver",
(
lx.CG(0.0, 0.0, max_steps=2),
lx.NormalCG(0.0, 0.0, max_steps=2),
lx.BiCGStab(0.0, 0.0, max_steps=2),
lx.GMRES(0.0, 0.0, max_steps=2),
),
)
def test_iterative_solver_max_steps_only(solver):
"""Iterative solvers should work with max_steps only (no Equinox errors)."""
SIZE = 100

poisson_matrix = construct_poisson_matrix(SIZE)
poisson_operator = lx.MatrixLinearOperator(
poisson_matrix, tags=(lx.negative_semidefinite_tag, lx.symmetric_tag)
)
rhs = jax.random.normal(jax.random.key(0), (SIZE,))

lx.linear_solve(poisson_operator, rhs, solver)

0 comments on commit 58f2a8b

Please sign in to comment.