-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New solver for block tridiagonal matrices #80
base: main
Are you sure you want to change the base?
Conversation
Ah, this is excellent! Thank you for the contribution. I really like this implementation, which looks very clean throughout. This does touch on an interesting point: @packquickly and I were discussing adding a general "block operator" abstraction, which wraps around other operators. (For example, it's common to have a block I'm wondering if it might be possible to support block-tridiagonal as a special case of that. I suspect the general case of doing a linear solve with respect to arbitrary block matrices might be a bit tricky, though. @packquickly WDYT? |
First off, excellent stuff!! Thank you very much for this PR. Regarding block matrices, now may be a good time to settle on an abstraction for them. General linear solves against the compositional block matrices Patrick mentioned does seem a bit painful to do efficiently, but only a bit. At first glance, following the non-recursive implementation of Block LU in Golub and Van Loan should work as a generic solver for these block operators. Looking at the code here, and the implementation of block tridiagonal outlined in Golub and Van Loan I think going from this implementation of block tridiagonal to one using the compositional "block operator" as Patrick suggested would not require too many changes either. |
Any progress on this? I'd be super interested in having a general abstraction for block linear operators, especially if the blocks themselves can be block linear operators. For example, a problem I'm working on now has a block triangular matrix where each block is itself block tri-diagonal, with the inner most blocks being dense. A general block LU would be a good starting point, though having specialized options for common sparsity patterns (eg block tridiagonal, triangular, hessenberg, diagonal etc) would be useful too. The main limitation I see with having nested block operators is that jax doesn't really like array-of-struct type things, so maybe there's a more clever way? |
lineax/_solver/blocktridiagonal.py
Outdated
return step, y | ||
|
||
_, new_d_p = matrix_linear_solve_vec(0, d - jnp.matmul(a, d_p)) | ||
_, new_c_p = lax.scan(matrix_linear_solve_vec, 0, c.T) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this scan
can be replaced by vmap
which ends up being a significant speedup on GPU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reviewing this, I have swapped to vmap in latest commit
I'd be very happy to see this revived as well! On the topic of block operators: @packquickly and I discussed this a bit offline, and we don't think there's actually a nice way to express this in general. IIRC this basically boiled down to the mathematics: there isn't (?) a nice way to express a linear-solve-against-a-block-operator in terms of linear solves against the individual operators. That leaves me wondering if the best way to handle this is something like (a) nonetheless introducing a general block operator, but (b) not trying to introduce a corresponding general "block solver": rather, have individual solvers such as the one implemented here, and have them verify that the components of the block-operator look like what they expect. I'm fairly open to suggestions/implementations on this one. :) |
Can you clarify a bit what you mean by
Do you just mean that at some point you end up with things like
|
Hmmm. So to your point, I think we can sort-of do this. The following identity certainly exists (do others like it exists for NxM blocks, not just 2x2?): https://en.wikipedia.org/wiki/Block_matrix#Inversion and the I'm guessing the above is what you have in mind? I think the question is on the solver. How would the above help us avoid writing the custom solver implemented in this PR? I don't see a way to do that. It's a neat mathematical identity, but in the dense case it's not better than just materialising and using LU, whilst in the structured case we still need to write custom solvers. So the place I land on is that whilst we can certainly create a block operator abstraction without difficulty, we would still need individual solver implementations corresponding to different block structures. |
Ok yeah, there would still be a need for specific solvers like this one, but I think the block abstraction would still be useful, and possibly combining it with the existing tag system for structured block matrices, each with specialized solvers (like the existing ones for diagonal, tridiagonal, triangular, PSD etc) I'll make an issue for tracking further ideas for this. |
class BlockTridiagonal(AbstractLinearSolver[_BlockTridiagonalState]): | ||
"""Block tridiagonal solver for linear systems, using the Thomas algorithm.""" | ||
|
||
def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For most of the other dense matrix solvers, Solver.init
also performs the necessary factorizations (see https://github.com/patrick-kidger/lineax/blob/main/lineax/_solver/lu.py for an example)
This could also be done here by splitting up the compute method a bit, I have a version like this that I've been using here:
https://gist.github.com/f0uriest/c4764bd8f7882a4b63457b1585b41ad9
Generally the factorization costs n^3, but the backsolve only n^2 so it's useful if you're re-using the operator against multiple RHS in a way that can't be easily vmapped (ie, using a block tridiagonal preconditioner for one of the iterative solvers)
(my version also has a reverse
option for the factorization which I've found is needed in some cases when spectrally discretizing PDEs. this can easily be done by the user as a pre-processing step so maybe not needed but might be a useful option)
Hi,
I needed a block tridiagonal solver for a project of mine, so I took a stab at adding one to lineax (very nice package, thank you!). The solve is a simple extension of the Thomas algorithm and this can scale better than LU which doesn't exploit the banded structure. I tested my implementation vs MatrixLinearOperator on the matrix and it can be considerably faster (~ 4x faster for 100 diagonal 2x2 blocks).
I have run the existing tests and have written another of my own to test the block tridiagonal representation and solve. All tests pass except “test_identity_with_different_structures_complex”, but this also fails for me on the main branch(?).
I will admit that the tag and token methodology used by lineax isn't super familiar to me so apologies if I have not used this properly for the new block tridiagonal class.
Hopefully this addition is of use.