Skip to content

Commit

Permalink
gh-405: array API support for glass.core.algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Nov 26, 2024
1 parent 4ef0167 commit 96b15e9
Showing 1 changed file with 25 additions and 17 deletions.
42 changes: 25 additions & 17 deletions glass/core/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@

from __future__ import annotations

import numpy as np
import numpy.typing as npt
import typing

if typing.TYPE_CHECKING:
import cupy as cp
import jax.typing as jxt
import numpy as np
import numpy.typing as npt


def nnls(
a: npt.NDArray[np.float64],
b: npt.NDArray[np.float64],
a: npt.NDArray[np.float64] | cp.ndarray | jxt.ArrayLike,
b: npt.NDArray[np.float64] | cp.ndarray | jxt.ArrayLike,
*,
tol: float = 0.0,
maxiter: int | None = None,
) -> npt.NDArray[np.float64]:
) -> npt.NDArray[np.float64] | cp.ndarray | jxt.ArrayLike:
"""
Compute a non-negative least squares solution.
Expand Down Expand Up @@ -51,8 +56,11 @@ def nnls(
Chemometrics, 11, 393-401.
"""
a = np.asanyarray(a)
b = np.asanyarray(b)
if a.__array_namespace__() != b.__array_namespace__():
msg = "input arrays should belong to the same array library"
raise ValueError(msg)

xp = a.__array_namespace__()

if a.ndim != 2:
msg = "input `a` is not a matrix"
Expand All @@ -69,25 +77,25 @@ def nnls(
if maxiter is None:
maxiter = 3 * n

index = np.arange(n)
p = np.full(n, fill_value=False)
x = np.zeros(n)
index = xp.arange(n)
p = xp.full(n, fill_value=False)
x = xp.zeros(n)
for _ in range(maxiter):
if np.all(p):
if xp.all(p):
break
w = np.dot(b - a @ x, a)
m = index[~p][np.argmax(w[~p])]
w = xp.dot(b - a @ x, a)
m = index[~p][xp.argmax(w[~p])]
if w[m] <= tol:
break
p[m] = True
while True:
ap = a[:, p]
xp = x[p]
sp = np.linalg.solve(ap.T @ ap, b @ ap)
x_new = x[p]
sp = xp.linalg.solve(ap.T @ ap, b @ ap)
t = sp <= 0
if not np.any(t):
if not xp.any(t):
break
alpha = -np.min(xp[t] / (xp[t] - sp[t]))
alpha = -xp.min(xp[t] / (x_new[t] - sp[t]))
x[p] += alpha * (sp - xp)
p[x <= 0] = False
x[p] = sp
Expand Down

0 comments on commit 96b15e9

Please sign in to comment.