Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of LSMR for iterative least squares. #86

Open
wants to merge 30 commits into
base: main
Choose a base branch
from

Conversation

f0uriest
Copy link

@f0uriest f0uriest commented Mar 16, 2024

Tests seem to be passing locally. I added LSMR in the places it seemed to make sense (mainly places where I saw SVD which solves similar problems. Let me know if there are other tests that should be added.

A few things that could use some input:

  • There are a large number of state variables to keep track of. Right now I'm just using nested tuples but there's probably a more elegant way.

Resolves #85

TODO:
- Figure out early exiting
- Add tests
- Add type annotations
- Add docstrings
- Figure out complex dtypes
- Clean up large number of state variables
- Return actual status message on failure
@f0uriest f0uriest marked this pull request as draft March 16, 2024 22:31
@f0uriest f0uriest marked this pull request as ready for review March 26, 2024 01:49
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really cool! Thank you for contributing it, I can see this must have been a lot of effort. :D I've gone through and left various comments, which should all be very small. In addition to those, can we add LSMR to the well-posed square tests?

On the exit conditions: how many does LSMR potentially have?

On complex dtypes: indeed, but I think we now have things almost entirely working. Should probably get that fixed on our end in the next few couple of weeks, and then we're just waiting on a new JAX release so that the upstream fixes there are generally available. (I think we should probably let all of that happen before merging this PR, so that we can add LSMR to the complex tests as well.)

lineax/_solver/lsmr.py Outdated Show resolved Hide resolved
lineax/_solver/lsmr.py Outdated Show resolved Hide resolved
lineax/_solver/lsmr.py Outdated Show resolved Hide resolved
lineax/_solver/lsmr.py Outdated Show resolved Hide resolved
lineax/_solver/lsmr.py Outdated Show resolved Hide resolved
lineax/_solver/lsmr.py Outdated Show resolved Hide resolved
lineax/_solver/lsmr.py Show resolved Hide resolved
lineax/_solver/lsmr.py Outdated Show resolved Hide resolved
lineax/_solver/lsmr.py Outdated Show resolved Hide resolved
lineax/_solver/lsmr.py Show resolved Hide resolved
@f0uriest
Copy link
Author

The original version has 7 possible exit conditions:

  1. exceeds maxiter
  2. x solve Ax=b to user specified tolerance (ie, no least squares needed, problem appears well posed)
  3. x is the least squares solution to user specified tolerance
  4. cond(A) larger than user specified conlim
  5. x solve Ax=b to tolerance=eps (same as 2, but guards against users setting tolerances to zero?)
  6. x is the least squares solution to tolerance=eps (same as 3?)
  7. cond(A) > 1/eps (same as 4?)

I think we could get rid of 5-7 provided that maxiter is always finite. 2 and 3 could probably be combined into a standard "successful" exit assuming the user doesn't really care whether its the "true" solution or least squares. The main new one would be the condition number one. We could maybe shoehorn that into lx.RESULTS.singular or lx.RESULTS.breakdown? or could add a new one just for that.

@f0uriest f0uriest requested a review from patrick-kidger April 30, 2024 01:55
@andycasey
Copy link

andycasey commented Oct 15, 2024

I know it's bad form to comment just when you want to 👍🏻 , but having LSMR in lineax basically made my problem go from "untractable and 10x slower with jax than scipy" to tractable.

The problem here is one where I have a large (complex) linear system and I never want to construct the design matrix, so iterative methods are all I can really use. QR or SVD had huge memory overheads and were very slow. The Jax implementation for GMRES (in lineax or jax.scipy.sparse.linalg.gmres) doesn't seem to accept non-square matrices, although the scipy.sparse.linalg.gmres implementation does.

So this is a big 👍🏻 to this feature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Iterative least squares solvers
4 participants