Skip to content

Commit

Permalink
Diagonal solver now works with differently input+output structures
Browse files Browse the repository at this point in the history
Honestly, it's a little suspicious whether we should even allow this: should being diagonal perhaps imply that the input and output structures are identical?

Right now I'm choosing to allow this because it's pretty subtle, and users can apply their own diagonal tags, so if nothign else it's an easy mistake to be tolerant of.

In particular, *we* were making this mistake by treating scalar operators as diagonal even when they had different structures. This was causing a downstream issue in Diffrax.
  • Loading branch information
patrick-kidger committed Oct 21, 2024
1 parent 4d99283 commit efe3f6c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 16 deletions.
6 changes: 5 additions & 1 deletion lineax/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ def tree_where(

def resolve_rcond(rcond, n, m, dtype):
if rcond is None:
return jnp.finfo(dtype).eps * max(n, m)
# This `2 *` is a heuristic: I have seen very rare failures without it, in ways
# that seem to depend on JAX compilation state. (E.g. running unrelated JAX
# computations beforehand, in a completely different JIT-compiled region, can
# result in differences in the success/failure of the solve.)
return 2 * jnp.finfo(dtype).eps * max(n, m)
else:
return jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)

Expand Down
44 changes: 30 additions & 14 deletions lineax/_solver/diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,23 @@
from typing import Any, Optional
from typing_extensions import TypeAlias

import jax.flatten_util as jfu
import jax.numpy as jnp
from jaxtyping import Array, PyTree

from .._misc import resolve_rcond
from .._operator import AbstractLinearOperator, diagonal, has_unit_diagonal, is_diagonal
from .._solution import RESULTS
from .._solve import AbstractLinearSolver
from .misc import (
pack_structures,
PackedStructures,
ravel_vector,
transpose_packed_structures,
unravel_solution,
)


_DiagonalState: TypeAlias = Optional[Array]
_DiagonalState: TypeAlias = tuple[Optional[Array], PackedStructures]


class Diagonal(AbstractLinearSolver[_DiagonalState], strict=True):
Expand All @@ -52,39 +58,49 @@ def init(
raise ValueError(
"`Diagonal` may only be used for linear solves with diagonal matrices"
)
packed_structures = pack_structures(operator)
if has_unit_diagonal(operator):
return None
return None, packed_structures
else:
return diagonal(operator)
return diagonal(operator), packed_structures

def compute(
self, state: _DiagonalState, vector: PyTree[Array], options: dict[str, Any]
) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]:
diag = state
diag, packed_structures = state
del state, options
unit_diagonal = diag is None
# diagonal => symmetric => (in_structure == out_structure) =>
# we don't need to use packed structures.
vector = ravel_vector(vector, packed_structures)
if unit_diagonal:
solution = vector
else:
vector, unflatten = jfu.ravel_pytree(vector)
if not self.well_posed:
(size,) = diag.shape
rcond = resolve_rcond(self.rcond, size, size, diag.dtype)
abs_diag = jnp.abs(diag)
diag = jnp.where(abs_diag > rcond * jnp.max(abs_diag), diag, jnp.inf)
solution = unflatten(vector / diag)
solution = vector / diag
solution = unravel_solution(solution, packed_structures)
return solution, RESULTS.successful, {}

def transpose(self, state: _DiagonalState, options: dict[str, Any]):
# Matrix is symmetric
return state, options
del options
diag, packed_structures = state
transposed_packed_structures = transpose_packed_structures(packed_structures)
transpose_state = diag, transposed_packed_structures
transpose_options = {}
return transpose_state, transpose_options

def conj(self, state: _DiagonalState, options: dict[str, Any]):
if state is None:
return None, options
return state.conj(), options
del options
diag, packed_structures = state
if diag is None:
conj_diag = None
else:
conj_diag = diag.conj()
conj_options = {}
conj_state = conj_diag, packed_structures
return conj_state, conj_options

def allow_dependent_columns(self, operator):
return not self.well_posed
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "lineax"
version = "0.0.6"
version = "0.0.7"
description = "Linear solvers in JAX and Equinox."
readme = "README.md"
requires-python ="~=3.9"
Expand Down

0 comments on commit efe3f6c

Please sign in to comment.