Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New solver based on Woodbury matrix identity #97

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lineax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
is_symmetric as is_symmetric,
is_tridiagonal as is_tridiagonal,
is_upper_triangular as is_upper_triangular,
is_Woodbury as is_Woodbury,
JacobianLinearOperator as JacobianLinearOperator,
linearise as linearise,
materialise as materialise,
Expand All @@ -45,6 +46,7 @@
TangentLinearOperator as TangentLinearOperator,
tridiagonal as tridiagonal,
TridiagonalLinearOperator as TridiagonalLinearOperator,
WoodburyLinearOperator as WoodburyLinearOperator,
)
from ._solution import RESULTS as RESULTS, Solution as Solution
from ._solve import (
Expand Down
143 changes: 143 additions & 0 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,78 @@ def out_structure(self):
return self.operator.out_structure()


class WoodburyLinearOperator(AbstractLinearOperator, strict=True):
"""As [`lineax.MatrixLinearOperator`][], but for specifically a matrix
with A + U C V structure, such that the Woodbury identity can be used"""

A: AbstractLinearOperator
C: Inexact[Array, " k k"]
U: Inexact[Array, " n k"]
V: Inexact[Array, " k n"]
UCV: Array
tags: frozenset[object] = eqx.field(static=True)

def __init__(
self,
A: AbstractLinearOperator,
C: Inexact[Array, " k k"],
U: Inexact[Array, " n k"],
V: Inexact[Array, " k n"],
tags: Union[object, frozenset[object]] = (),
aidancrilly marked this conversation as resolved.
Show resolved Hide resolved
):
"""**Arguments:**

Matrix of form A + U C V, such that the inverse can be computed
using Woodbury matrix identity

- `A`: Linear operator, in/out shape (n,n)
- `C`: A rank-two JAX array. Shape (k,k)
- `U`: A rank-two JAX array. Shape (n,k)
- `V`: A rank-two JAX array. Shape (k,n)

"""
self.A = A
self.C = inexact_asarray(C)
self.U = inexact_asarray(U)
self.V = inexact_asarray(V)
(N, M) = self.A.in_structure(), self.A.out_structure()
if N != M:
aidancrilly marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"expecting square operator for A, got {N} by {M}")
(K, L) = self.C.shape
if K != L:
raise ValueError(f"expecting square operator for C, got {K} by {L}")
N = N.shape[0]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N is an arbitrary PyTree, it doesn't necessarily have a .shape attribute.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, This is mostly my ignorance of how lineax is using PyTrees as linear operators.

I am not sure how allowing arbitrary PyTrees for A meshes with the requirement of a Woodbury structure. In the context of the Woodbury matrix identity, I would think A would have to have a matrix representation of an n by n matrix. For a PyTree representation then would C, U and V also need to be PyTrees such that each leaf of the tree can be made to have Woodbury structure?

Copy link
Owner

@patrick-kidger patrick-kidger Jun 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, so! Basically what's going on here is exactly what jax.jacfwd does as well.
Consider for example:

import jax
import jax.numpy as jnp

def f(x: tuple[jax.Array, jax.Array]):
    return {"a": x[0] * x[1], "b": x[0]}

x = (jnp.arange(2.)[:, None], jnp.arange(3.)[None, :])
jac = jax.jacfwd(f)(x)

What should we get?

Well, a PyTree-of-arrays is basically isomorphic to a vector (flatten every array and concatenate them all together), and the Jacobian of a function f: R^n -> R^m is a matrix of shape (m, n).

Reasoning by analogy, we can see that:

  • given an input PyTree whose leaves are enumerated by i, and for which each leaf has shape a_i (a tuple);
  • given an output PyTree whose leaves are enumerated by j, and for which each leaf has shape b_j (also a tuple);
    then the Jacobian should be a PyTree whose leaves are numerated by (j, i) (a PyTree-of-PyTrees if you will), where each leaf has shape (*b_j, *a_i). (Here unpacking each tuple using Python notation.)

And indeed this is exactly what we see:

import equinox as eqx
eqx.tree_pprint(jac)
# {'a': (f32[2,3,2,1], f32[2,3,1,3]), 'b': (f32[2,1,2,1], f32[2,1,1,3])}

the "outer" PyTree has structure {'a': *, 'b': *} (corresponding to the output of our function), the "inner" PyTree has structure (*, *) (corresponding to the input of our function).

Meanwhile, each leaf has a shape obtained by concatenating the shapes of the corresponding pair of input and output leaves. For all possible pairs, notably! So in our use case here we wouldn't have a pytree-of-things-with-Woodbury-structure. Rather, we would have a single PyTree, which when thought of as a linear operator (much like the Jacobian), would itself have Woodbury structure!


Okay, hopefully that makes some kind of sense!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just thought I'd response on this since it has been a bit - thanks for the detailed response, it makes sense, I am (slowly) working on changes that would make the Woodbury implementation pytree compatible.

I think the size checking is easy enough if I have understood you correctly here (PyTree-of-arrays is basically isomorphic to a vector (flatten every array and concatenate them all together). When checking U and V we need to use the in_size of A and C for N and K.

There will need to be a few tree_unflatten's to move between the flattened space (where U and V live) and the pytree input space (where A and C potentially live). This makes the pushthrough operator a bit tricky but should be do-able with a little time.

I suppose my question would be, is there a nice way to wrap this interlink between flattened vector space and pytree space so that implementing this kind of thing will be easier in the future? Does it already exist somewhere outside (or potentially inside) lineax?

if self.U.shape != (N, K):
raise ValueError("U does not have consistent shape with A and C")
if self.V.shape != (K, N):
raise ValueError("V does not have consistent shape with A and C")
self.UCV = self.U @ (self.C @ self.V)
aidancrilly marked this conversation as resolved.
Show resolved Hide resolved
self.tags = _frozenset(tags)

def mv(self, vector):
Ax = self.A.mv(vector)
UCVx = self.UCV @ vector
return Ax + UCVx

def as_matrix(self):
matrix = self.A.as_matrix() + self.UCV
return matrix

def transpose(self):
return WoodburyLinearOperator(
self.A.transpose(),
jnp.transpose(self.C),
jnp.transpose(self.V),
jnp.transpose(self.U),
)

def in_structure(self):
return self.A.in_structure()

def out_structure(self):
return self.A.out_structure()


#
# All operators below here are private to lineax.
#
Expand Down Expand Up @@ -1207,6 +1279,7 @@ def linearise(operator: AbstractLinearOperator) -> AbstractLinearOperator:
@linearise.register(IdentityLinearOperator)
@linearise.register(DiagonalLinearOperator)
@linearise.register(TridiagonalLinearOperator)
@linearise.register(WoodburyLinearOperator)
def _(operator):
return operator

Expand Down Expand Up @@ -1283,6 +1356,7 @@ def materialise(operator: AbstractLinearOperator) -> AbstractLinearOperator:
@materialise.register(IdentityLinearOperator)
@materialise.register(DiagonalLinearOperator)
@materialise.register(TridiagonalLinearOperator)
@materialise.register(WoodburyLinearOperator)
def _(operator):
return operator

Expand Down Expand Up @@ -1343,6 +1417,7 @@ def diagonal(operator: AbstractLinearOperator) -> Shaped[Array, " size"]:


@diagonal.register(MatrixLinearOperator)
@diagonal.register(WoodburyLinearOperator)
@diagonal.register(PyTreeLinearOperator)
@diagonal.register(JacobianLinearOperator)
@diagonal.register(FunctionLinearOperator)
Expand Down Expand Up @@ -1397,6 +1472,7 @@ def tridiagonal(


@tridiagonal.register(MatrixLinearOperator)
@tridiagonal.register(WoodburyLinearOperator)
@tridiagonal.register(PyTreeLinearOperator)
@tridiagonal.register(JacobianLinearOperator)
@tridiagonal.register(FunctionLinearOperator)
Expand Down Expand Up @@ -1429,6 +1505,33 @@ def _(operator):
return operator.diagonal, operator.lower_diagonal, operator.upper_diagonal


@ft.singledispatch
def woodbury(
operator: AbstractLinearOperator,
) -> tuple[
AbstractLinearOperator,
Shaped[Array, " k k"],
Shaped[Array, " n k"],
Shaped[Array, " k n"],
]:
"""Extracts the A, C, U, V Woodbury structure, from a linear
operator. Returns one linear operators and three matrices.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
A 4-tuple, consisting of
- A which is a linear operator
- C, U and V which are matrices
For all but the Woodbury operator this extraction is not possible
"""
_default_not_implemented("woodbury", operator)
aidancrilly marked this conversation as resolved.
Show resolved Hide resolved


@woodbury.register(WoodburyLinearOperator)
def _(operator):
return operator.A, operator.C, operator.U, operator.V


# is_symmetric


Expand All @@ -1451,6 +1554,7 @@ def is_symmetric(operator: AbstractLinearOperator) -> bool:


@is_symmetric.register(MatrixLinearOperator)
@is_symmetric.register(WoodburyLinearOperator)
@is_symmetric.register(PyTreeLinearOperator)
@is_symmetric.register(JacobianLinearOperator)
@is_symmetric.register(FunctionLinearOperator)
Expand Down Expand Up @@ -1503,6 +1607,7 @@ def is_diagonal(operator: AbstractLinearOperator) -> bool:


@is_diagonal.register(MatrixLinearOperator)
@is_diagonal.register(WoodburyLinearOperator)
@is_diagonal.register(PyTreeLinearOperator)
@is_diagonal.register(JacobianLinearOperator)
@is_diagonal.register(FunctionLinearOperator)
Expand Down Expand Up @@ -1543,6 +1648,7 @@ def is_tridiagonal(operator: AbstractLinearOperator) -> bool:


@is_tridiagonal.register(MatrixLinearOperator)
@is_tridiagonal.register(WoodburyLinearOperator)
@is_tridiagonal.register(PyTreeLinearOperator)
@is_tridiagonal.register(JacobianLinearOperator)
@is_tridiagonal.register(FunctionLinearOperator)
Expand All @@ -1557,6 +1663,36 @@ def _(operator):
return True


@ft.singledispatch
def is_Woodbury(operator: AbstractLinearOperator) -> bool:
"""Returns whether an operator is marked as Woodbury.
See [the documentation on linear operator tags](../api/tags.md) for more
information.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Either `True` or `False.`
"""
_default_not_implemented("is_Woodbury", operator)
aidancrilly marked this conversation as resolved.
Show resolved Hide resolved


@is_Woodbury.register(WoodburyLinearOperator)
def _(operator):
return True


@is_Woodbury.register(MatrixLinearOperator)
@is_Woodbury.register(PyTreeLinearOperator)
@is_Woodbury.register(JacobianLinearOperator)
@is_Woodbury.register(FunctionLinearOperator)
@is_Woodbury.register(IdentityLinearOperator)
@is_Woodbury.register(DiagonalLinearOperator)
@is_Woodbury.register(TridiagonalLinearOperator)
@is_Woodbury.register(TaggedLinearOperator) # TODO : check this
def _(operator):
return False


# has_unit_diagonal


Expand All @@ -1579,6 +1715,7 @@ def has_unit_diagonal(operator: AbstractLinearOperator) -> bool:


@has_unit_diagonal.register(MatrixLinearOperator)
@has_unit_diagonal.register(WoodburyLinearOperator)
@has_unit_diagonal.register(PyTreeLinearOperator)
@has_unit_diagonal.register(JacobianLinearOperator)
@has_unit_diagonal.register(FunctionLinearOperator)
Expand Down Expand Up @@ -1620,6 +1757,7 @@ def is_lower_triangular(operator: AbstractLinearOperator) -> bool:


@is_lower_triangular.register(MatrixLinearOperator)
@is_lower_triangular.register(WoodburyLinearOperator)
@is_lower_triangular.register(PyTreeLinearOperator)
@is_lower_triangular.register(JacobianLinearOperator)
@is_lower_triangular.register(FunctionLinearOperator)
Expand Down Expand Up @@ -1660,6 +1798,7 @@ def is_upper_triangular(operator: AbstractLinearOperator) -> bool:


@is_upper_triangular.register(MatrixLinearOperator)
@is_upper_triangular.register(WoodburyLinearOperator)
@is_upper_triangular.register(PyTreeLinearOperator)
@is_upper_triangular.register(JacobianLinearOperator)
@is_upper_triangular.register(FunctionLinearOperator)
Expand Down Expand Up @@ -1700,6 +1839,7 @@ def is_positive_semidefinite(operator: AbstractLinearOperator) -> bool:


@is_positive_semidefinite.register(MatrixLinearOperator)
@is_positive_semidefinite.register(WoodburyLinearOperator)
@is_positive_semidefinite.register(PyTreeLinearOperator)
@is_positive_semidefinite.register(JacobianLinearOperator)
@is_positive_semidefinite.register(FunctionLinearOperator)
Expand Down Expand Up @@ -1741,6 +1881,7 @@ def is_negative_semidefinite(operator: AbstractLinearOperator) -> bool:


@is_negative_semidefinite.register(MatrixLinearOperator)
@is_negative_semidefinite.register(WoodburyLinearOperator)
@is_negative_semidefinite.register(PyTreeLinearOperator)
@is_negative_semidefinite.register(JacobianLinearOperator)
@is_negative_semidefinite.register(FunctionLinearOperator)
Expand Down Expand Up @@ -1902,6 +2043,7 @@ def _(operator):
is_lower_triangular,
is_upper_triangular,
is_tridiagonal,
is_Woodbury,
):

@check.register(TangentLinearOperator)
Expand Down Expand Up @@ -1964,6 +2106,7 @@ def _(operator, check=check, tag=tag):
is_positive_semidefinite,
is_negative_semidefinite,
is_tridiagonal,
is_Woodbury,
):

@check.register(AddLinearOperator)
Expand Down
9 changes: 8 additions & 1 deletion lineax/_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
is_positive_semidefinite,
is_tridiagonal,
is_upper_triangular,
is_Woodbury,
linearise,
TangentLinearOperator,
)
Expand Down Expand Up @@ -268,7 +269,7 @@ def _linear_solve_transpose(inputs, cts_out):
_assert_defined, (operator, state, options, solver), is_leaf=_is_undefined
)
cts_solution = jtu.tree_map(
ft.partial(eqxi.materialise_zeros, allow_struct=True),
ft.partial(eqxi.materialise_zeros, allow_struct=True), # pyright: ignore
operator.in_structure(),
cts_solution,
)
Expand Down Expand Up @@ -498,6 +499,7 @@ def conj(
_cholesky_token = eqxi.str2jax("cholesky_token")
_lu_token = eqxi.str2jax("lu_token")
_svd_token = eqxi.str2jax("svd_token")
_woodbury_token = eqxi.str2jax("woodbury_token")


# Ugly delayed import because we have the dependency chain
Expand All @@ -518,6 +520,7 @@ def _lookup(token) -> AbstractLinearSolver:
_cholesky_token: _solver.Cholesky(), # pyright: ignore
_lu_token: _solver.LU(), # pyright: ignore
_svd_token: _solver.SVD(), # pyright: ignore
_woodbury_token: _solver.Woodbury(), # pyright: ignore
}
return _lookup_dict[token]

Expand All @@ -535,6 +538,7 @@ class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState], strict=True
- If the operator is triangular, then use [`lineax.Triangular`][].
- If the matrix is positive or negative definite, then use
[`lineax.Cholesky`][].
- If the matrix has structure A + U C V, then use [`lineax.Woodbury`][].
- Else use [`lineax.LU`][].

This is a good choice if you want to be certain that an error is raised for
Expand All @@ -554,6 +558,7 @@ class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState], strict=True
- If the operator is triangular, then use [`lineax.Triangular`][].
- If the matrix is positive or negative definite, then use
[`lineax.Cholesky`][].
- If the matrix has structure A + U C V, then use [`lineax.Woodbury`][].
- Else, use [`lineax.LU`][].

This is a good choice if your primary concern is computational efficiency. It will
Expand Down Expand Up @@ -582,6 +587,8 @@ def _select_solver(self, operator: AbstractLinearOperator):
operator
):
token = _cholesky_token
elif is_Woodbury(operator):
token = _woodbury_token
else:
token = _lu_token
elif self.well_posed is False:
Expand Down
Loading
Loading