diff --git a/glass/core/algorithm.py b/glass/core/algorithm.py index 09fd7c38..dad772d0 100644 --- a/glass/core/algorithm.py +++ b/glass/core/algorithm.py @@ -2,17 +2,22 @@ from __future__ import annotations -import numpy as np -import numpy.typing as npt +import typing + +if typing.TYPE_CHECKING: + import cupy as cp + import jax.typing as jxt + import numpy as np + import numpy.typing as npt def nnls( - a: npt.NDArray[np.float64], - b: npt.NDArray[np.float64], + a: npt.NDArray[np.float64] | cp.ndarray | jxt.ArrayLike, + b: npt.NDArray[np.float64] | cp.ndarray | jxt.ArrayLike, *, tol: float = 0.0, maxiter: int | None = None, -) -> npt.NDArray[np.float64]: +) -> npt.NDArray[np.float64] | cp.ndarray | jxt.ArrayLike: """ Compute a non-negative least squares solution. @@ -51,8 +56,11 @@ def nnls( Chemometrics, 11, 393-401. """ - a = np.asanyarray(a) - b = np.asanyarray(b) + if a.__array_namespace__() != b.__array_namespace__(): + msg = "input arrays should belong to the same array library" + raise ValueError(msg) + + xp = a.__array_namespace__() if a.ndim != 2: msg = "input `a` is not a matrix" @@ -69,25 +77,25 @@ def nnls( if maxiter is None: maxiter = 3 * n - index = np.arange(n) - p = np.full(n, fill_value=False) - x = np.zeros(n) + index = xp.arange(n) + p = xp.full(n, fill_value=False) + x = xp.zeros(n) for _ in range(maxiter): - if np.all(p): + if xp.all(p): break - w = np.dot(b - a @ x, a) - m = index[~p][np.argmax(w[~p])] + w = xp.dot(b - a @ x, a) + m = index[~p][xp.argmax(w[~p])] if w[m] <= tol: break p[m] = True while True: ap = a[:, p] - xp = x[p] - sp = np.linalg.solve(ap.T @ ap, b @ ap) + x_new = x[p] + sp = xp.linalg.solve(ap.T @ ap, b @ ap) t = sp <= 0 - if not np.any(t): + if not xp.any(t): break - alpha = -np.min(xp[t] / (xp[t] - sp[t])) + alpha = -xp.min(xp[t] / (x_new[t] - sp[t])) x[p] += alpha * (sp - xp) p[x <= 0] = False x[p] = sp