Skip to content

Commit

Permalink
NumPyInterface: Use a GMRES wrapper that is compatible with newer Sci…
Browse files Browse the repository at this point in the history
…Py versions.

This wrapper is copied from TransiFlow. It also makes sure we actually use
maxit number of iterations, instead of maxit number of restarts.
  • Loading branch information
Sbte committed Jul 5, 2024
1 parent 9ca117a commit 5cb0e8c
Showing 1 changed file with 28 additions and 3 deletions.
31 changes: 28 additions & 3 deletions jadapy/NumPyInterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,34 @@ def solve(self, op, x, tol, maxit):

out = x.copy()
for i in range(x.shape[1]):
out[:, i], info = linalg.gmres(op, x[:, i], restart=100, maxiter=maxit, tol=tol, atol=0)
out[:, i], info, iterations = gmres(op, x[:, i], maxit, tol)
if info < 0:
raise Exception('GMRES returned ' + str(info))
elif info > 0 and maxit > 1:
warnings.warn('GMRES did not converge in ' + str(info) + ' iterations')
elif info > 0 and maxit > 10:
warnings.warn('GMRES did not converge in ' + str(iterations) + ' iterations')
return out


def gmres(A, b, maxit, tol, restart=None, prec=None):
iterations = 0

def callback(_r):
nonlocal iterations
iterations += 1

if restart is None:
restart = min(maxit, 100)

maxiter = (maxit - 1) // restart + 1

try:
y, info = linalg.gmres(A, b, restart=restart, maxiter=maxiter,
rtol=tol, atol=0, M=prec,
callback=callback, callback_type='pr_norm')
except TypeError:
# Compatibility with SciPy <= 1.11
y, info = linalg.gmres(A, b, restart=restart, maxiter=maxiter,
tol=tol, atol=0, M=prec,
callback=callback, callback_type='pr_norm')

return y, info, iterations

0 comments on commit 5cb0e8c

Please sign in to comment.