-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New solver based on Woodbury matrix identity #97
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thank you for putting this together! I really like the look of this.
I've just left an initial review -- let me know what you think :)
(K, L) = self.C.shape | ||
if K != L: | ||
raise ValueError(f"expecting square operator for C, got {K} by {L}") | ||
N = N.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
N
is an arbitrary PyTree, it doesn't necessarily have a .shape
attribute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, This is mostly my ignorance of how lineax is using PyTrees as linear operators.
I am not sure how allowing arbitrary PyTrees for A meshes with the requirement of a Woodbury structure. In the context of the Woodbury matrix identity, I would think A would have to have a matrix representation of an n by n matrix. For a PyTree representation then would C, U and V also need to be PyTrees such that each leaf of the tree can be made to have Woodbury structure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, so! Basically what's going on here is exactly what jax.jacfwd
does as well.
Consider for example:
import jax
import jax.numpy as jnp
def f(x: tuple[jax.Array, jax.Array]):
return {"a": x[0] * x[1], "b": x[0]}
x = (jnp.arange(2.)[:, None], jnp.arange(3.)[None, :])
jac = jax.jacfwd(f)(x)
What should we get?
Well, a PyTree-of-arrays is basically isomorphic to a vector (flatten every array and concatenate them all together), and the Jacobian of a function f: R^n -> R^m
is a matrix of shape (m, n)
.
Reasoning by analogy, we can see that:
- given an input PyTree whose leaves are enumerated by
i
, and for which each leaf has shapea_i
(a tuple); - given an output PyTree whose leaves are enumerated by
j
, and for which each leaf has shapeb_j
(also a tuple);
then the Jacobian should be a PyTree whose leaves are numerated by(j, i)
(a PyTree-of-PyTrees if you will), where each leaf has shape(*b_j, *a_i)
. (Here unpacking each tuple using Python notation.)
And indeed this is exactly what we see:
import equinox as eqx
eqx.tree_pprint(jac)
# {'a': (f32[2,3,2,1], f32[2,3,1,3]), 'b': (f32[2,1,2,1], f32[2,1,1,3])}
the "outer" PyTree has structure {'a': *, 'b': *}
(corresponding to the output of our function), the "inner" PyTree has structure (*, *)
(corresponding to the input of our function).
Meanwhile, each leaf has a shape obtained by concatenating the shapes of the corresponding pair of input and output leaves. For all possible pairs, notably! So in our use case here we wouldn't have a pytree-of-things-with-Woodbury-structure. Rather, we would have a single PyTree, which when thought of as a linear operator (much like the Jacobian), would itself have Woodbury structure!
Okay, hopefully that makes some kind of sense!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just thought I'd response on this since it has been a bit - thanks for the detailed response, it makes sense, I am (slowly) working on changes that would make the Woodbury implementation pytree compatible.
I think the size checking is easy enough if I have understood you correctly here (PyTree-of-arrays is basically isomorphic to a vector (flatten every array and concatenate them all together). When checking U and V we need to use the in_size of A and C for N and K.
There will need to be a few tree_unflatten's to move between the flattened space (where U and V live) and the pytree input space (where A and C potentially live). This makes the pushthrough operator a bit tricky but should be do-able with a little time.
I suppose my question would be, is there a nice way to wrap this interlink between flattened vector space and pytree space so that implementing this kind of thing will be easier in the future? Does it already exist somewhere outside (or potentially inside) lineax?
lineax/_solver/Woodbury.py
Outdated
vmapped_solve = jax.vmap( | ||
lambda x_vec: A_solver.compute(A_state, x_vec, {})[0], in_axes=1, out_axes=1 | ||
) | ||
pushthrough_mat = jnp.linalg.inv(C) + V @ vmapped_solve(U) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make C
an AbstractLinearOperator
as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally had C as an AbstractLinearOperator but it always had to be as_matrix() so its inverse can be computed by jnp.linalg. Therefore, it made the most sense that it is given as a matrix at input so that as_matrix operation is not hidden. If some linearoperators could have an inverse method then perhaps it would make sense for C to be an operator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, right!
Yes, you're completely correct. Indeed I think the appropriate thing to do would be to have a built-in notion of the inverse of an operator. (Tagging also @f0uriest as we were discussing this in #96.)
The inverse of an operator is another operator, and it is basically defined by the action of a linear solve. I think that should mean:
class InverseLinearOperator(AbstractLinearOperator):
operator: AbstractLinearOperator
solver: AbstractLinearSolver = AutoLinearSolver(well_posed=True)
def mv(self, vector):
return linear_solve(self.operator, vector)
... # other methods here
Which you can then use as InverseLinearOperator(C)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, this seems fair. I tried to implement this way and just wanted to raise a two points which you might have comments on.
- The InverseLinearOperator has to go in _solve.py (or at least not _operator.py) to avoid circular imports. This is potentially a little confusing but maybe it is really the right place for it.
- Having no conjugate method for LinearOperators makes for a bit of head-ache for this one. conj is attached to the solver so you can update the state of the solver for a conjugate inside InverseLinearOperator (similarly for transpose). However, when you do it this way your operator is not updated. The most obvious way to me to implement the as_matrix method was to use jnp.linalg.inv(self.operator.as_matrix()) - but this ignores the solver_state which might have been transposed or conjugated. I suppose my question is, should LinearOperators have a conjugate method or should LinearSolvers have an as_matrix method for the inverse?
This probably should be its own issue, PR, etc.
|
||
full_matrix = WB.as_matrix() | ||
|
||
true_x = jr.normal(getkey(), (N,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true_x = jr.normal(getkey(), (N,)) | |
true_x = jr.normal(getkey(), (N,), dtype=dtype) |
and in two other places too, and then you can add complex128
to type parametrization
In issue #3 , a solver based on Woodbury matrix identity was given as a new solver feature for lineax.
Here I have made a first attempt at implementing it. It works slightly differently than the other base LinearOperators as it takes a LinearOperator as an argument (for A), with U, C and V as JAX arrays. This allows for the correct specialised solvers to be used for inverse(A) operations. E.g. A can be a DiagonalLinearOperator/TridiagonalLinearOperator/etc. - see test_Woodbury in test_operator.py
Things I am not certain of:
But it made me notice that materialise_zeros is from equinox 0.11.4 but pyproject.toml requires >=0.11.3. Does this require correcting?