Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New solver for block tridiagonal matrices #80

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

aidancrilly
Copy link

Hi,

I needed a block tridiagonal solver for a project of mine, so I took a stab at adding one to lineax (very nice package, thank you!). The solve is a simple extension of the Thomas algorithm and this can scale better than LU which doesn't exploit the banded structure. I tested my implementation vs MatrixLinearOperator on the matrix and it can be considerably faster (~ 4x faster for 100 diagonal 2x2 blocks).

I have run the existing tests and have written another of my own to test the block tridiagonal representation and solve. All tests pass except “test_identity_with_different_structures_complex”, but this also fails for me on the main branch(?).
I will admit that the tag and token methodology used by lineax isn't super familiar to me so apologies if I have not used this properly for the new block tridiagonal class.

Hopefully this addition is of use.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Feb 2, 2024

Ah, this is excellent! Thank you for the contribution. I really like this implementation, which looks very clean throughout.

This does touch on an interesting point: @packquickly and I were discussing adding a general "block operator" abstraction, which wraps around other operators. (For example, it's common to have a block [dense identity] matrix when solving Levenberg-Mardquardt type problems, and that could benefit from some specialist solving as well.)

I'm wondering if it might be possible to support block-tridiagonal as a special case of that. I suspect the general case of doing a linear solve with respect to arbitrary block matrices might be a bit tricky, though.

@packquickly WDYT?

@packquickly
Copy link
Collaborator

First off, excellent stuff!! Thank you very much for this PR.

Regarding block matrices, now may be a good time to settle on an abstraction for them. General linear solves against the compositional block matrices Patrick mentioned does seem a bit painful to do efficiently, but only a bit. At first glance, following the non-recursive implementation of Block LU in Golub and Van Loan should work as a generic solver for these block operators. Looking at the code here, and the implementation of block tridiagonal outlined in Golub and Van Loan I think going from this implementation of block tridiagonal to one using the compositional "block operator" as Patrick suggested would not require too many changes either.

@f0uriest
Copy link

f0uriest commented May 7, 2024

Any progress on this? I'd be super interested in having a general abstraction for block linear operators, especially if the blocks themselves can be block linear operators. For example, a problem I'm working on now has a block triangular matrix where each block is itself block tri-diagonal, with the inner most blocks being dense.

A general block LU would be a good starting point, though having specialized options for common sparsity patterns (eg block tridiagonal, triangular, hessenberg, diagonal etc) would be useful too.

The main limitation I see with having nested block operators is that jax doesn't really like array-of-struct type things, so maybe there's a more clever way?

return step, y

_, new_d_p = matrix_linear_solve_vec(0, d - jnp.matmul(a, d_p))
_, new_c_p = lax.scan(matrix_linear_solve_vec, 0, c.T)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this scan can be replaced by vmap which ends up being a significant speedup on GPU

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing this, I have swapped to vmap in latest commit

@patrick-kidger
Copy link
Owner

I'd be very happy to see this revived as well!

On the topic of block operators: @packquickly and I discussed this a bit offline, and we don't think there's actually a nice way to express this in general. IIRC this basically boiled down to the mathematics: there isn't (?) a nice way to express a linear-solve-against-a-block-operator in terms of linear solves against the individual operators.

That leaves me wondering if the best way to handle this is something like (a) nonetheless introducing a general block operator, but (b) not trying to introduce a corresponding general "block solver": rather, have individual solvers such as the one implemented here, and have them verify that the components of the block-operator look like what they expect.

I'm fairly open to suggestions/implementations on this one. :)

@f0uriest
Copy link

f0uriest commented May 9, 2024

Can you clarify a bit what you mean by

there isn't (?) a nice way to express a linear-solve-against-a-block-operator in terms of linear solves against the individual operators.

Do you just mean that at some point you end up with things like solve(operator1, operator2) (which seems like you would then need to materialize operator 2)? I would think that could be handled with some sort of InverseLinearOperator abstraction + the regular ComposedLinearOperator such that solve(operator1, operator2) -> ComposedLinearOperator(InverseLinearOperator(operator1), operator2)

InverseLinearOperator.mv would basically be a wrapper around linear_solve, maybe with some precomputation in cases where it makes sense (like only doing an lu factorization once, similar to the Solver.init methods for some of them)

@patrick-kidger
Copy link
Owner

Hmmm. So to your point, I think we can sort-of do this. The following identity certainly exists (do others like it exists for NxM blocks, not just 2x2?):

https://en.wikipedia.org/wiki/Block_matrix#Inversion

and the (D - CA-1B)-1 components could be handled in the way you describe above.

I'm guessing the above is what you have in mind?

I think the question is on the solver. How would the above help us avoid writing the custom solver implemented in this PR? I don't see a way to do that. It's a neat mathematical identity, but in the dense case it's not better than just materialising and using LU, whilst in the structured case we still need to write custom solvers.

So the place I land on is that whilst we can certainly create a block operator abstraction without difficulty, we would still need individual solver implementations corresponding to different block structures.

@f0uriest
Copy link

Ok yeah, there would still be a need for specific solvers like this one, but I think the block abstraction would still be useful, and possibly combining it with the existing tag system for structured block matrices, each with specialized solvers (like the existing ones for diagonal, tridiagonal, triangular, PSD etc)

I'll make an issue for tracking further ideas for this.

class BlockTridiagonal(AbstractLinearSolver[_BlockTridiagonalState]):
"""Block tridiagonal solver for linear systems, using the Thomas algorithm."""

def init(self, operator: AbstractLinearOperator, options: dict[str, Any]):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For most of the other dense matrix solvers, Solver.init also performs the necessary factorizations (see https://github.com/patrick-kidger/lineax/blob/main/lineax/_solver/lu.py for an example)

This could also be done here by splitting up the compute method a bit, I have a version like this that I've been using here:
https://gist.github.com/f0uriest/c4764bd8f7882a4b63457b1585b41ad9

Generally the factorization costs n^3, but the backsolve only n^2 so it's useful if you're re-using the operator against multiple RHS in a way that can't be easily vmapped (ie, using a block tridiagonal preconditioner for one of the iterative solvers)

(my version also has a reverse option for the factorization which I've found is needed in some cases when spectrally discretizing PDEs. this can easily be done by the user as a pre-processing step so maybe not needed but might be a useful option)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants