From 30129cd9b250786278fb8797b187892b57025408 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Tue, 25 Apr 2023 11:07:12 +0200 Subject: [PATCH 01/15] init commit --- skglm/skglm_jax/README.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 skglm/skglm_jax/README.md diff --git a/skglm/skglm_jax/README.md b/skglm/skglm_jax/README.md new file mode 100644 index 000000000..976ea930d --- /dev/null +++ b/skglm/skglm_jax/README.md @@ -0,0 +1,22 @@ +## Installation + + +1. create then activate ``conda`` environnement +```shell +# create +conda create -n skglm-jax python=3.7 + +# activate env +conda activate skglm-jax +``` + +2. install ``skglm`` in editable mode +```shell +pip install skglm -e . +``` + +3. install dependencies +```shell +# jax +conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia +``` From 7c5d4f160ec50dfbc5b454296e47c67a161973a6 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Tue, 25 Apr 2023 11:36:13 +0200 Subject: [PATCH 02/15] datafit & penalty --- skglm/skglm_jax/__init__.py | 0 skglm/skglm_jax/datafits.py | 16 ++++++++++++++++ skglm/skglm_jax/penalties.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+) create mode 100644 skglm/skglm_jax/__init__.py create mode 100644 skglm/skglm_jax/datafits.py create mode 100644 skglm/skglm_jax/penalties.py diff --git a/skglm/skglm_jax/__init__.py b/skglm/skglm_jax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/skglm/skglm_jax/datafits.py b/skglm/skglm_jax/datafits.py new file mode 100644 index 000000000..8eb3ceb80 --- /dev/null +++ b/skglm/skglm_jax/datafits.py @@ -0,0 +1,16 @@ +from jax.numpy.linalg import norm as jnorm + + +class QuadraticJax: + + def value(self, X, y, w): + n_samples = X.shape[0] + return ((X @ w - y) ** 2).sum() / (2. * n_samples) + + def gradient_1d(self, X, y, w, j): + n_samples = X.shape[0] + return X[:, j] @ (X @ w - y) / n_samples + + def get_features_lipschitz_cst(self, X, y): + n_samples = X.shape[0] + return jnorm(X, ord=2, axis=0) ** 2 / n_samples diff --git a/skglm/skglm_jax/penalties.py b/skglm/skglm_jax/penalties.py new file mode 100644 index 000000000..51f88bc0a --- /dev/null +++ b/skglm/skglm_jax/penalties.py @@ -0,0 +1,30 @@ +import jax.numpy as jnp + + +class L1Jax: + + def __init__(self, alpha): + self.alpha = alpha + + def value(self, w): + return (self.alpha * jnp.abs(w)).sum() + + def prox_1d(self, value, stepsize): + shifted_value = jnp.abs(value) - stepsize * self.alpha + return jnp.sign(value) * jnp.maximum(shifted_value, 0.) + + def subdiff_dist(self, w, grad, ws): + dist = jnp.zeros(len(ws)) + + for idx, j in enumerate(ws): + w_j = w[j] + grad_j = grad[j] + + if w_j == 0.: + dist_j = max(abs(grad_j) - self.alpha, 0.) + else: + dist_j = abs(grad_j + jnp.sign(w_j) * self.alpha) + + dist[idx] = dist_j + + return dist From c1a3014d61839bf85da30e4d4325e86f7957b24d Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Tue, 25 Apr 2023 16:39:20 +0200 Subject: [PATCH 03/15] CD solver --- skglm/skglm_jax/__init__.py | 5 ++++ skglm/skglm_jax/anderson_cd.py | 49 ++++++++++++++++++++++++++++++++++ skglm/skglm_jax/datafits.py | 13 +++++++++ skglm/skglm_jax/penalties.py | 3 ++- 4 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 skglm/skglm_jax/anderson_cd.py diff --git a/skglm/skglm_jax/__init__.py b/skglm/skglm_jax/__init__.py index e69de29bb..eab52d2e3 100644 --- a/skglm/skglm_jax/__init__.py +++ b/skglm/skglm_jax/__init__.py @@ -0,0 +1,5 @@ +# if not set, raises an error related to CUDA linking API. +# as recommended, setting the 'XLA_FLAGS' to bypass it. +# side-effect: (perhaps) slow compilation time. +import os +os.environ['XLA_FLAGS'] = '--xla_gpu_force_compilation_parallelism=1' # noqa* diff --git a/skglm/skglm_jax/anderson_cd.py b/skglm/skglm_jax/anderson_cd.py new file mode 100644 index 000000000..5fe8ee516 --- /dev/null +++ b/skglm/skglm_jax/anderson_cd.py @@ -0,0 +1,49 @@ +import jax.numpy as jnp +from skglm.skglm_jax.datafits import QuadraticJax +from skglm.skglm_jax.penalties import L1Jax + + +class AndersonCD: + + def __init__(self, max_iter=100, verbose=0) -> None: + self.max_iter = max_iter + self.verbose = verbose + + def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax): + X, y = self._transfer_to_device(X, y) + + n_samples, n_features = X.shape + lipschitz = datafit.get_features_lipschitz_cst(X, y) + + w = jnp.zeros(n_features) + all_features = jnp.arange(n_features) + + for it in range(self.max_iter): + + for j in all_features: + + if lipschitz[j] == 0.: + continue + + step = 1 / lipschitz[j] + grad_j = datafit.gradient_1d(X, y, w, j) + next_w_j = penalty.prox_1d(w[j] - step * grad_j, step) + + w = w.at[j].set(next_w_j) + + if self.verbose: + p_obj = datafit.value(X, y, w) + penalty.value(w) + + grad_ws = datafit.gradient_ws(X, y, w, all_features) + subdiff_dist = penalty.subdiff_dist(w, grad_ws, all_features) + stop_crit = jnp.max(subdiff_dist) + + print( + f"Iter {it}: p_obj={p_obj:.8f} stop_crit={stop_crit:.4e}" + ) + + return w + + def _transfer_to_device(self, X, y): + # TODO: other checks + return jnp.asarray(X), jnp.asarray(y) diff --git a/skglm/skglm_jax/datafits.py b/skglm/skglm_jax/datafits.py index 8eb3ceb80..bb48f205c 100644 --- a/skglm/skglm_jax/datafits.py +++ b/skglm/skglm_jax/datafits.py @@ -1,7 +1,9 @@ +import jax.numpy as jnp from jax.numpy.linalg import norm as jnorm class QuadraticJax: + """1 / (2 n_samples) ||y - Xw||^2""" def value(self, X, y, w): n_samples = X.shape[0] @@ -11,6 +13,17 @@ def gradient_1d(self, X, y, w, j): n_samples = X.shape[0] return X[:, j] @ (X @ w - y) / n_samples + def gradient_ws(self, X, y, w, ws): + n_samples = X.shape[0] + Xw_minus_y = X @ w - y + grad_ws = jnp.zeros(len(ws)) + + for idx, j in enumerate(ws): + grad_j = X[:, j] @ Xw_minus_y / n_samples + grad_ws = grad_ws.at[idx].set(grad_j) + + return grad_ws + def get_features_lipschitz_cst(self, X, y): n_samples = X.shape[0] return jnorm(X, ord=2, axis=0) ** 2 / n_samples diff --git a/skglm/skglm_jax/penalties.py b/skglm/skglm_jax/penalties.py index 51f88bc0a..6799d9838 100644 --- a/skglm/skglm_jax/penalties.py +++ b/skglm/skglm_jax/penalties.py @@ -2,6 +2,7 @@ class L1Jax: + """alpha ||w||_1""" def __init__(self, alpha): self.alpha = alpha @@ -25,6 +26,6 @@ def subdiff_dist(self, w, grad, ws): else: dist_j = abs(grad_j + jnp.sign(w_j) * self.alpha) - dist[idx] = dist_j + dist = dist.at[idx].set(dist_j) return dist From 911d1289d6812f1fc0a198ce16585c4c63ea6089 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Tue, 25 Apr 2023 16:43:27 +0200 Subject: [PATCH 04/15] unitest cd solver --- skglm/skglm_jax/tests/test_anderson_cd.py | 33 +++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 skglm/skglm_jax/tests/test_anderson_cd.py diff --git a/skglm/skglm_jax/tests/test_anderson_cd.py b/skglm/skglm_jax/tests/test_anderson_cd.py new file mode 100644 index 000000000..43914585c --- /dev/null +++ b/skglm/skglm_jax/tests/test_anderson_cd.py @@ -0,0 +1,33 @@ +import pytest + +import numpy as np +from numpy.linalg import norm +from skglm.utils.data import make_correlated_data + +from skglm.skglm_jax.anderson_cd import AndersonCD +from skglm.skglm_jax.datafits import QuadraticJax +from skglm.skglm_jax.penalties import L1Jax + +from skglm.estimators import Lasso + + +def test_solver(): + n_samples, n_features = 100, 10 + random_state = 135 + + X, y, _ = make_correlated_data(n_samples, n_features, random_state=random_state) + + lmbd_max = norm(X.T @ y, ord=np.inf) / n_samples + lmbd = 1e-2 * lmbd_max + + datafit = QuadraticJax() + penalty = L1Jax(lmbd) + w = AndersonCD(max_iter=30, verbose=1).solve(X, y, datafit, penalty) + + estimator = Lasso(alpha=lmbd, fit_intercept=False).fit(X, y) + + np.testing.assert_allclose(w, estimator.coef_, atol=1e-7) + + +if __name__ == "__main__": + test_solver() From bc9b1fbf9922eb332ffc17c0ac2923f30914a7ca Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Tue, 25 Apr 2023 17:36:07 +0200 Subject: [PATCH 05/15] jit CD epoch --- skglm/skglm_jax/anderson_cd.py | 43 ++++++++++++++++------- skglm/skglm_jax/datafits.py | 8 ++--- skglm/skglm_jax/tests/test_anderson_cd.py | 4 +-- 3 files changed, 36 insertions(+), 19 deletions(-) diff --git a/skglm/skglm_jax/anderson_cd.py b/skglm/skglm_jax/anderson_cd.py index 5fe8ee516..1906edf39 100644 --- a/skglm/skglm_jax/anderson_cd.py +++ b/skglm/skglm_jax/anderson_cd.py @@ -1,3 +1,6 @@ +from functools import partial + +import jax import jax.numpy as jnp from skglm.skglm_jax.datafits import QuadraticJax from skglm.skglm_jax.penalties import L1Jax @@ -5,7 +8,7 @@ class AndersonCD: - def __init__(self, max_iter=100, verbose=0) -> None: + def __init__(self, max_iter=100, verbose=0): self.max_iter = max_iter self.verbose = verbose @@ -16,25 +19,18 @@ def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax): lipschitz = datafit.get_features_lipschitz_cst(X, y) w = jnp.zeros(n_features) + Xw = jnp.zeros(n_samples) all_features = jnp.arange(n_features) for it in range(self.max_iter): - for j in all_features: - - if lipschitz[j] == 0.: - continue - - step = 1 / lipschitz[j] - grad_j = datafit.gradient_1d(X, y, w, j) - next_w_j = penalty.prox_1d(w[j] - step * grad_j, step) - - w = w.at[j].set(next_w_j) + w, Xw = AndersonCD._cd_epoch(X, y, w, Xw, all_features, lipschitz, + datafit, penalty) if self.verbose: p_obj = datafit.value(X, y, w) + penalty.value(w) - grad_ws = datafit.gradient_ws(X, y, w, all_features) + grad_ws = datafit.gradient_ws(X, y, w, Xw, all_features) subdiff_dist = penalty.subdiff_dist(w, grad_ws, all_features) stop_crit = jnp.max(subdiff_dist) @@ -47,3 +43,26 @@ def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax): def _transfer_to_device(self, X, y): # TODO: other checks return jnp.asarray(X), jnp.asarray(y) + + @staticmethod + @partial(jax.jit, static_argnums=(-2, -1)) + def _cd_epoch(X, y, w, Xw, ws, lipschitz, datafit, penalty): + for j in ws: + + # Null columns of X would break this functions + # as their corresponding lipschitz is 0 + # TODO: implement condition using lax + # if lipschitz[j] == 0.: + # continue + + step = 1 / lipschitz[j] + + grad_j = datafit.gradient_1d(X, y, w, Xw, j) + next_w_j = penalty.prox_1d(w[j] - step * grad_j, step) + + delta_w_j = next_w_j - w[j] + + w = w.at[j].set(next_w_j) + Xw = Xw + delta_w_j * X[:, j] + + return w, Xw diff --git a/skglm/skglm_jax/datafits.py b/skglm/skglm_jax/datafits.py index bb48f205c..620e0d4b5 100644 --- a/skglm/skglm_jax/datafits.py +++ b/skglm/skglm_jax/datafits.py @@ -9,13 +9,13 @@ def value(self, X, y, w): n_samples = X.shape[0] return ((X @ w - y) ** 2).sum() / (2. * n_samples) - def gradient_1d(self, X, y, w, j): + def gradient_1d(self, X, y, w, Xw, j): n_samples = X.shape[0] - return X[:, j] @ (X @ w - y) / n_samples + return X[:, j] @ (Xw - y) / n_samples - def gradient_ws(self, X, y, w, ws): + def gradient_ws(self, X, y, w, Xw, ws): n_samples = X.shape[0] - Xw_minus_y = X @ w - y + Xw_minus_y = Xw - y grad_ws = jnp.zeros(len(ws)) for idx, j in enumerate(ws): diff --git a/skglm/skglm_jax/tests/test_anderson_cd.py b/skglm/skglm_jax/tests/test_anderson_cd.py index 43914585c..817bc91ee 100644 --- a/skglm/skglm_jax/tests/test_anderson_cd.py +++ b/skglm/skglm_jax/tests/test_anderson_cd.py @@ -1,5 +1,3 @@ -import pytest - import numpy as np from numpy.linalg import norm from skglm.utils.data import make_correlated_data @@ -26,7 +24,7 @@ def test_solver(): estimator = Lasso(alpha=lmbd, fit_intercept=False).fit(X, y) - np.testing.assert_allclose(w, estimator.coef_, atol=1e-7) + np.testing.assert_allclose(w, estimator.coef_, atol=1e-6) if __name__ == "__main__": From 58edde9d8feb5b3feae1bd46facc9e18244a90f4 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Tue, 25 Apr 2023 18:38:07 +0200 Subject: [PATCH 06/15] jit subdiff --- skglm/skglm_jax/datafits.py | 9 +-------- skglm/skglm_jax/penalties.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/skglm/skglm_jax/datafits.py b/skglm/skglm_jax/datafits.py index 620e0d4b5..743326668 100644 --- a/skglm/skglm_jax/datafits.py +++ b/skglm/skglm_jax/datafits.py @@ -1,4 +1,3 @@ -import jax.numpy as jnp from jax.numpy.linalg import norm as jnorm @@ -16,13 +15,7 @@ def gradient_1d(self, X, y, w, Xw, j): def gradient_ws(self, X, y, w, Xw, ws): n_samples = X.shape[0] Xw_minus_y = Xw - y - grad_ws = jnp.zeros(len(ws)) - - for idx, j in enumerate(ws): - grad_j = X[:, j] @ Xw_minus_y / n_samples - grad_ws = grad_ws.at[idx].set(grad_j) - - return grad_ws + return X[:, ws].T @ (Xw_minus_y / n_samples) def get_features_lipschitz_cst(self, X, y): n_samples = X.shape[0] diff --git a/skglm/skglm_jax/penalties.py b/skglm/skglm_jax/penalties.py index 6799d9838..e5f278eba 100644 --- a/skglm/skglm_jax/penalties.py +++ b/skglm/skglm_jax/penalties.py @@ -1,3 +1,6 @@ +from functools import partial + +import jax import jax.numpy as jnp @@ -14,6 +17,7 @@ def prox_1d(self, value, stepsize): shifted_value = jnp.abs(value) - stepsize * self.alpha return jnp.sign(value) * jnp.maximum(shifted_value, 0.) + @partial(jax.jit, static_argnums=(0,)) def subdiff_dist(self, w, grad, ws): dist = jnp.zeros(len(ws)) @@ -21,10 +25,12 @@ def subdiff_dist(self, w, grad, ws): w_j = w[j] grad_j = grad[j] - if w_j == 0.: - dist_j = max(abs(grad_j) - self.alpha, 0.) - else: - dist_j = abs(grad_j + jnp.sign(w_j) * self.alpha) + dist_j = jax.lax.cond( + w_j == 0., + lambda w_j, grad_j, alpha: jnp.maximum(jnp.abs(grad_j) - alpha, 0.), + lambda w_j, grad_j, alpha: jnp.abs(grad_j + jnp.sign(w_j) * alpha), + w_j, grad_j, self.alpha + ) dist = dist.at[idx].set(dist_j) From 70c9db85717dedf8a3d5a7a9267a5ad9e1f925c7 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 26 Apr 2023 10:13:49 +0200 Subject: [PATCH 07/15] update jit method --- skglm/skglm_jax/__init__.py | 2 +- skglm/skglm_jax/datafits.py | 12 +++++++++++- skglm/skglm_jax/penalties.py | 6 +++--- skglm/skglm_jax/utils.py | 5 +++++ 4 files changed, 20 insertions(+), 5 deletions(-) create mode 100644 skglm/skglm_jax/utils.py diff --git a/skglm/skglm_jax/__init__.py b/skglm/skglm_jax/__init__.py index eab52d2e3..294abd5c0 100644 --- a/skglm/skglm_jax/__init__.py +++ b/skglm/skglm_jax/__init__.py @@ -2,4 +2,4 @@ # as recommended, setting the 'XLA_FLAGS' to bypass it. # side-effect: (perhaps) slow compilation time. import os -os.environ['XLA_FLAGS'] = '--xla_gpu_force_compilation_parallelism=1' # noqa* +os.environ['XLA_FLAGS'] = '--xla_gpu_force_compilation_parallelism=1' # noqa diff --git a/skglm/skglm_jax/datafits.py b/skglm/skglm_jax/datafits.py index 743326668..8e69312b4 100644 --- a/skglm/skglm_jax/datafits.py +++ b/skglm/skglm_jax/datafits.py @@ -1,5 +1,8 @@ +import jax.numpy as jnp from jax.numpy.linalg import norm as jnorm +from skglm.skglm_jax.utils import jax_jit_method + class QuadraticJax: """1 / (2 n_samples) ||y - Xw||^2""" @@ -12,10 +15,17 @@ def gradient_1d(self, X, y, w, Xw, j): n_samples = X.shape[0] return X[:, j] @ (Xw - y) / n_samples + @jax_jit_method def gradient_ws(self, X, y, w, Xw, ws): n_samples = X.shape[0] Xw_minus_y = Xw - y - return X[:, ws].T @ (Xw_minus_y / n_samples) + + grad_ws = jnp.zeros(len(ws)) + for idx, j in enumerate(ws): + grad_j = X[:, j] @ Xw_minus_y / n_samples + grad_ws = grad_ws.at[idx].set(grad_j) + + return grad_ws def get_features_lipschitz_cst(self, X, y): n_samples = X.shape[0] diff --git a/skglm/skglm_jax/penalties.py b/skglm/skglm_jax/penalties.py index e5f278eba..e7407bdc6 100644 --- a/skglm/skglm_jax/penalties.py +++ b/skglm/skglm_jax/penalties.py @@ -1,8 +1,8 @@ -from functools import partial - import jax import jax.numpy as jnp +from skglm.skglm_jax.utils import jax_jit_method + class L1Jax: """alpha ||w||_1""" @@ -17,7 +17,7 @@ def prox_1d(self, value, stepsize): shifted_value = jnp.abs(value) - stepsize * self.alpha return jnp.sign(value) * jnp.maximum(shifted_value, 0.) - @partial(jax.jit, static_argnums=(0,)) + @jax_jit_method def subdiff_dist(self, w, grad, ws): dist = jnp.zeros(len(ws)) diff --git a/skglm/skglm_jax/utils.py b/skglm/skglm_jax/utils.py new file mode 100644 index 000000000..043e6d705 --- /dev/null +++ b/skglm/skglm_jax/utils.py @@ -0,0 +1,5 @@ +import jax +from functools import partial + + +jax_jit_method = partial(jax.jit, static_argnums=(0,)) From 66142bf2781eb3d35693a7b44f00cf1a7ec4d192 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 26 Apr 2023 10:45:01 +0200 Subject: [PATCH 08/15] cd_epoch as method --- skglm/skglm_jax/anderson_cd.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/skglm/skglm_jax/anderson_cd.py b/skglm/skglm_jax/anderson_cd.py index 1906edf39..91bec5456 100644 --- a/skglm/skglm_jax/anderson_cd.py +++ b/skglm/skglm_jax/anderson_cd.py @@ -24,8 +24,8 @@ def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax): for it in range(self.max_iter): - w, Xw = AndersonCD._cd_epoch(X, y, w, Xw, all_features, lipschitz, - datafit, penalty) + w, Xw = self._cd_epoch(X, y, w, Xw, all_features, lipschitz, + datafit, penalty) if self.verbose: p_obj = datafit.value(X, y, w) + penalty.value(w) @@ -42,11 +42,11 @@ def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax): def _transfer_to_device(self, X, y): # TODO: other checks + # - skip if they are already jax array return jnp.asarray(X), jnp.asarray(y) - @staticmethod - @partial(jax.jit, static_argnums=(-2, -1)) - def _cd_epoch(X, y, w, Xw, ws, lipschitz, datafit, penalty): + @partial(jax.jit, static_argnums=(0, -2, -1)) + def _cd_epoch(self, X, y, w, Xw, ws, lipschitz, datafit, penalty): for j in ws: # Null columns of X would break this functions From 57f26320f022564eec61b058d30a58b138dfe7ef Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 26 Apr 2023 11:40:53 +0200 Subject: [PATCH 09/15] add working sets --- skglm/skglm_jax/anderson_cd.py | 69 +++++++++++++++++++---- skglm/skglm_jax/penalties.py | 7 ++- skglm/skglm_jax/tests/test_anderson_cd.py | 4 +- 3 files changed, 64 insertions(+), 16 deletions(-) diff --git a/skglm/skglm_jax/anderson_cd.py b/skglm/skglm_jax/anderson_cd.py index 91bec5456..e0825d775 100644 --- a/skglm/skglm_jax/anderson_cd.py +++ b/skglm/skglm_jax/anderson_cd.py @@ -8,8 +8,13 @@ class AndersonCD: - def __init__(self, max_iter=100, verbose=0): + EPS_TOL = 0.3 + + def __init__(self, max_iter=100, max_epochs=100, tol=1e-6, p0=10, verbose=0): self.max_iter = max_iter + self.max_epochs = max_epochs + self.tol = tol + self.p0 = p0 self.verbose = verbose def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax): @@ -24,26 +29,61 @@ def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax): for it in range(self.max_iter): - w, Xw = self._cd_epoch(X, y, w, Xw, all_features, lipschitz, - datafit, penalty) + # check convergence + grad = datafit.gradient_ws(X, y, w, Xw, all_features) + scores = penalty.subdiff_dist_ws(w, grad, all_features) + stop_crit = jnp.max(scores) if self.verbose: p_obj = datafit.value(X, y, w) + penalty.value(w) - grad_ws = datafit.gradient_ws(X, y, w, Xw, all_features) - subdiff_dist = penalty.subdiff_dist(w, grad_ws, all_features) - stop_crit = jnp.max(subdiff_dist) - print( - f"Iter {it}: p_obj={p_obj:.8f} stop_crit={stop_crit:.4e}" + f"Iteration {it}: p_obj_in={p_obj:.8f} " + f"stop_crit_in={stop_crit:.4e}" ) + if stop_crit <= self.tol: + break + + # build ws + gsupp_size = penalty.generalized_support(w).sum() + ws_size = min( + max(2 * gsupp_size, self.p0), + n_features + ) + ws = jnp.argsort(scores)[-ws_size:] + tol_in = AndersonCD.EPS_TOL * stop_crit + + w, Xw = self._solve_sub_problem(X, y, w, Xw, ws, lipschitz, tol_in, + datafit, penalty) + return w - def _transfer_to_device(self, X, y): - # TODO: other checks - # - skip if they are already jax array - return jnp.asarray(X), jnp.asarray(y) + def _solve_sub_problem(self, X, y, w, Xw, ws, lipschitz, tol_in, + datafit, penalty): + + for epoch in range(self.max_epochs): + + w, Xw = self._cd_epoch(X, y, w, Xw, ws, lipschitz, + datafit, penalty) + + # check convergence + grad_ws = datafit.gradient_ws(X, y, w, Xw, ws) + scores_ws = penalty.subdiff_dist_ws(w, grad_ws, ws) + stop_crit_in = jnp.max(scores_ws) + + if max(self.verbose - 1, 0): + p_obj_in = datafit.value(X, y, w) + penalty.value(w) + + print( + f"Epoch {epoch}: p_obj_in={p_obj_in:.8f} " + f"stop_crit_in={stop_crit_in:.4e}" + ) + + if stop_crit_in <= tol_in: + break + + return w, Xw @partial(jax.jit, static_argnums=(0, -2, -1)) def _cd_epoch(self, X, y, w, Xw, ws, lipschitz, datafit, penalty): @@ -66,3 +106,8 @@ def _cd_epoch(self, X, y, w, Xw, ws, lipschitz, datafit, penalty): Xw = Xw + delta_w_j * X[:, j] return w, Xw + + def _transfer_to_device(self, X, y): + # TODO: other checks + # - skip if they are already jax array + return jnp.asarray(X), jnp.asarray(y) diff --git a/skglm/skglm_jax/penalties.py b/skglm/skglm_jax/penalties.py index e7407bdc6..e17eca0de 100644 --- a/skglm/skglm_jax/penalties.py +++ b/skglm/skglm_jax/penalties.py @@ -18,12 +18,12 @@ def prox_1d(self, value, stepsize): return jnp.sign(value) * jnp.maximum(shifted_value, 0.) @jax_jit_method - def subdiff_dist(self, w, grad, ws): + def subdiff_dist_ws(self, w, grad_ws, ws): dist = jnp.zeros(len(ws)) for idx, j in enumerate(ws): w_j = w[j] - grad_j = grad[j] + grad_j = grad_ws[idx] dist_j = jax.lax.cond( w_j == 0., @@ -35,3 +35,6 @@ def subdiff_dist(self, w, grad, ws): dist = dist.at[idx].set(dist_j) return dist + + def generalized_support(self, w): + return w != 0. diff --git a/skglm/skglm_jax/tests/test_anderson_cd.py b/skglm/skglm_jax/tests/test_anderson_cd.py index 817bc91ee..5c28a5e47 100644 --- a/skglm/skglm_jax/tests/test_anderson_cd.py +++ b/skglm/skglm_jax/tests/test_anderson_cd.py @@ -10,7 +10,7 @@ def test_solver(): - n_samples, n_features = 100, 10 + n_samples, n_features = 100, 200 random_state = 135 X, y, _ = make_correlated_data(n_samples, n_features, random_state=random_state) @@ -20,7 +20,7 @@ def test_solver(): datafit = QuadraticJax() penalty = L1Jax(lmbd) - w = AndersonCD(max_iter=30, verbose=1).solve(X, y, datafit, penalty) + w = AndersonCD(max_iter=30, verbose=1, p0=2).solve(X, y, datafit, penalty) estimator = Lasso(alpha=lmbd, fit_intercept=False).fit(X, y) From 459ef0fda86d140fd42fc95d638d9a474eaf22da Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 27 Apr 2023 10:39:41 +0200 Subject: [PATCH 10/15] rv verbose from unittest --- skglm/skglm_jax/tests/test_anderson_cd.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/skglm/skglm_jax/tests/test_anderson_cd.py b/skglm/skglm_jax/tests/test_anderson_cd.py index 5c28a5e47..4c59cfff6 100644 --- a/skglm/skglm_jax/tests/test_anderson_cd.py +++ b/skglm/skglm_jax/tests/test_anderson_cd.py @@ -10,22 +10,22 @@ def test_solver(): - n_samples, n_features = 100, 200 random_state = 135 + n_samples, n_features = 100, 20 X, y, _ = make_correlated_data(n_samples, n_features, random_state=random_state) lmbd_max = norm(X.T @ y, ord=np.inf) / n_samples - lmbd = 1e-2 * lmbd_max + lmbd = 1e-1 * lmbd_max datafit = QuadraticJax() penalty = L1Jax(lmbd) - w = AndersonCD(max_iter=30, verbose=1, p0=2).solve(X, y, datafit, penalty) + w = AndersonCD().solve(X, y, datafit, penalty) estimator = Lasso(alpha=lmbd, fit_intercept=False).fit(X, y) - np.testing.assert_allclose(w, estimator.coef_, atol=1e-6) + np.testing.assert_allclose(w, estimator.coef_, atol=1e-5) if __name__ == "__main__": - test_solver() + pass From 160a4355827b69d99ebcfcd32d9526dbe582bff3 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Mon, 1 May 2023 12:40:36 +0200 Subject: [PATCH 11/15] freeze grad & ws --- skglm/skglm_jax/anderson_cd.py | 45 ++++++++++++++++------- skglm/skglm_jax/datafits.py | 18 ++++++--- skglm/skglm_jax/penalties.py | 27 ++++++++++---- skglm/skglm_jax/tests/test_anderson_cd.py | 3 +- 4 files changed, 65 insertions(+), 28 deletions(-) diff --git a/skglm/skglm_jax/anderson_cd.py b/skglm/skglm_jax/anderson_cd.py index e0825d775..4a3904117 100644 --- a/skglm/skglm_jax/anderson_cd.py +++ b/skglm/skglm_jax/anderson_cd.py @@ -2,6 +2,7 @@ import jax import jax.numpy as jnp + from skglm.skglm_jax.datafits import QuadraticJax from skglm.skglm_jax.penalties import L1Jax @@ -25,7 +26,7 @@ def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax): w = jnp.zeros(n_features) Xw = jnp.zeros(n_samples) - all_features = jnp.arange(n_features) + all_features = jnp.full(n_features, fill_value=True, dtype=bool) for it in range(self.max_iter): @@ -51,7 +52,11 @@ def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax): max(2 * gsupp_size, self.p0), n_features ) - ws = jnp.argsort(scores)[-ws_size:] + + ws = jnp.full(n_features, fill_value=False, dtype=bool) + ws_features = jnp.argsort(scores)[-ws_size:] + ws = ws.at[ws_features].set(True) + tol_in = AndersonCD.EPS_TOL * stop_crit w, Xw = self._solve_sub_problem(X, y, w, Xw, ws, lipschitz, tol_in, @@ -87,23 +92,35 @@ def _solve_sub_problem(self, X, y, w, Xw, ws, lipschitz, tol_in, @partial(jax.jit, static_argnums=(0, -2, -1)) def _cd_epoch(self, X, y, w, Xw, ws, lipschitz, datafit, penalty): - for j in ws: + for j, in_ws in enumerate(ws): + + w, Xw = jax.lax.cond( + in_ws, + lambda X, y, w, Xw, j, lipschitz: self._cd_epoch_j(X, y, w, Xw, j, lipschitz, datafit, penalty), # noqa + lambda X, y, w, Xw, j, lipschitz: (w, Xw), + *(X, y, w, Xw, j, lipschitz) + ) + + return w, Xw + + @partial(jax.jit, static_argnums=(0, -2, -1)) + def _cd_epoch_j(self, X, y, w, Xw, j, lipschitz, datafit, penalty): - # Null columns of X would break this functions - # as their corresponding lipschitz is 0 - # TODO: implement condition using lax - # if lipschitz[j] == 0.: - # continue + # Null columns of X would break this functions + # as their corresponding lipschitz is 0 + # TODO: implement condition using lax + # if lipschitz[j] == 0.: + # continue - step = 1 / lipschitz[j] + step = 1 / lipschitz[j] - grad_j = datafit.gradient_1d(X, y, w, Xw, j) - next_w_j = penalty.prox_1d(w[j] - step * grad_j, step) + grad_j = datafit.gradient_1d(X, y, w, Xw, j) + next_w_j = penalty.prox_1d(w[j] - step * grad_j, step) - delta_w_j = next_w_j - w[j] + delta_w_j = next_w_j - w[j] - w = w.at[j].set(next_w_j) - Xw = Xw + delta_w_j * X[:, j] + w = w.at[j].set(next_w_j) + Xw = Xw + delta_w_j * X[:, j] return w, Xw diff --git a/skglm/skglm_jax/datafits.py b/skglm/skglm_jax/datafits.py index 8e69312b4..08bf9b418 100644 --- a/skglm/skglm_jax/datafits.py +++ b/skglm/skglm_jax/datafits.py @@ -1,3 +1,4 @@ +import jax import jax.numpy as jnp from jax.numpy.linalg import norm as jnorm @@ -17,13 +18,20 @@ def gradient_1d(self, X, y, w, Xw, j): @jax_jit_method def gradient_ws(self, X, y, w, Xw, ws): - n_samples = X.shape[0] + n_features = X.shape[1] Xw_minus_y = Xw - y - grad_ws = jnp.zeros(len(ws)) - for idx, j in enumerate(ws): - grad_j = X[:, j] @ Xw_minus_y / n_samples - grad_ws = grad_ws.at[idx].set(grad_j) + grad_ws = jnp.empty(n_features) + for j, in_ws in enumerate(ws): + + grad_j = jax.lax.cond( + in_ws, + lambda X, Xw_minus_y, j: X[:, j] @ Xw_minus_y / len(Xw_minus_y), + lambda X, Xw_minus_y, j: 0., + *(X, Xw_minus_y, j) + ) + + grad_ws = grad_ws.at[j].set(grad_j) return grad_ws diff --git a/skglm/skglm_jax/penalties.py b/skglm/skglm_jax/penalties.py index e17eca0de..464dc567f 100644 --- a/skglm/skglm_jax/penalties.py +++ b/skglm/skglm_jax/penalties.py @@ -19,22 +19,33 @@ def prox_1d(self, value, stepsize): @jax_jit_method def subdiff_dist_ws(self, w, grad_ws, ws): - dist = jnp.zeros(len(ws)) + n_features = w.shape[0] + dist = jnp.empty(n_features) - for idx, j in enumerate(ws): + for j, in_ws in enumerate(ws): w_j = w[j] - grad_j = grad_ws[idx] + grad_j = grad_ws[j] dist_j = jax.lax.cond( - w_j == 0., - lambda w_j, grad_j, alpha: jnp.maximum(jnp.abs(grad_j) - alpha, 0.), - lambda w_j, grad_j, alpha: jnp.abs(grad_j + jnp.sign(w_j) * alpha), - w_j, grad_j, self.alpha + in_ws, + self._compute_subdiff_dist_j, + lambda w_j, grad_j: 0., + *(w_j, grad_j) ) - dist = dist.at[idx].set(dist_j) + dist = dist.at[j].set(dist_j) return dist def generalized_support(self, w): return w != 0. + + @jax_jit_method + def _compute_subdiff_dist_j(self, w_j, grad_j): + dist_j = jax.lax.cond( + w_j == 0., + lambda w_j, grad_j, alpha: jnp.maximum(jnp.abs(grad_j) - alpha, 0.), + lambda w_j, grad_j, alpha: jnp.abs(grad_j + jnp.sign(w_j) * alpha), + *(w_j, grad_j, self.alpha) + ) + return dist_j diff --git a/skglm/skglm_jax/tests/test_anderson_cd.py b/skglm/skglm_jax/tests/test_anderson_cd.py index 4c59cfff6..e6c7125bf 100644 --- a/skglm/skglm_jax/tests/test_anderson_cd.py +++ b/skglm/skglm_jax/tests/test_anderson_cd.py @@ -20,7 +20,7 @@ def test_solver(): datafit = QuadraticJax() penalty = L1Jax(lmbd) - w = AndersonCD().solve(X, y, datafit, penalty) + w = AndersonCD(verbose=1).solve(X, y, datafit, penalty) estimator = Lasso(alpha=lmbd, fit_intercept=False).fit(X, y) @@ -28,4 +28,5 @@ def test_solver(): if __name__ == "__main__": + test_solver() pass From 8e734e862e74408b903078709823d35962bdd4dd Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Mon, 1 May 2023 14:03:22 +0200 Subject: [PATCH 12/15] anderson acceleration --- skglm/skglm_jax/anderson_cd.py | 11 ++++++- skglm/skglm_jax/tests/test_anderson_cd.py | 2 +- skglm/skglm_jax/utils.py | 35 +++++++++++++++++++++++ 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/skglm/skglm_jax/anderson_cd.py b/skglm/skglm_jax/anderson_cd.py index 4a3904117..c9dbb6381 100644 --- a/skglm/skglm_jax/anderson_cd.py +++ b/skglm/skglm_jax/anderson_cd.py @@ -5,17 +5,20 @@ from skglm.skglm_jax.datafits import QuadraticJax from skglm.skglm_jax.penalties import L1Jax +from skglm.skglm_jax.utils import JaxAA class AndersonCD: EPS_TOL = 0.3 - def __init__(self, max_iter=100, max_epochs=100, tol=1e-6, p0=10, verbose=0): + def __init__(self, max_iter=100, max_epochs=100, tol=1e-6, p0=10, + use_acc=True, verbose=0): self.max_iter = max_iter self.max_epochs = max_epochs self.tol = tol self.p0 = p0 + self.use_acc = use_acc self.verbose = verbose def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax): @@ -67,11 +70,17 @@ def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax): def _solve_sub_problem(self, X, y, w, Xw, ws, lipschitz, tol_in, datafit, penalty): + if self.use_acc: + accelerator = JaxAA(K=5) + for epoch in range(self.max_epochs): w, Xw = self._cd_epoch(X, y, w, Xw, ws, lipschitz, datafit, penalty) + if self.use_acc: + w, Xw = accelerator.extrapolate(w, Xw) + # check convergence grad_ws = datafit.gradient_ws(X, y, w, Xw, ws) scores_ws = penalty.subdiff_dist_ws(w, grad_ws, ws) diff --git a/skglm/skglm_jax/tests/test_anderson_cd.py b/skglm/skglm_jax/tests/test_anderson_cd.py index e6c7125bf..c024dce0f 100644 --- a/skglm/skglm_jax/tests/test_anderson_cd.py +++ b/skglm/skglm_jax/tests/test_anderson_cd.py @@ -20,7 +20,7 @@ def test_solver(): datafit = QuadraticJax() penalty = L1Jax(lmbd) - w = AndersonCD(verbose=1).solve(X, y, datafit, penalty) + w = AndersonCD(verbose=2).solve(X, y, datafit, penalty) estimator = Lasso(alpha=lmbd, fit_intercept=False).fit(X, y) diff --git a/skglm/skglm_jax/utils.py b/skglm/skglm_jax/utils.py index 043e6d705..8c560214c 100644 --- a/skglm/skglm_jax/utils.py +++ b/skglm/skglm_jax/utils.py @@ -1,5 +1,40 @@ import jax +import jax.numpy as jnp from functools import partial jax_jit_method = partial(jax.jit, static_argnums=(0,)) + + +class JaxAA: + + def __init__(self, K): + self.K, self.current_iter = K, 0 + self.arr_w_, self.arr_Xw_ = None, None + + def extrapolate(self, w, Xw): + if self.arr_w_ is None or self.arr_Xw_ is None: + self.arr_w_ = jnp.zeros((w.shape[0], self.K+1)) + self.arr_Xw_ = jnp.zeros((Xw.shape[0], self.K+1)) + + if self.current_iter <= self.K: + self.arr_w_ = self.arr_w_.at[:, self.current_iter].set(w) + self.arr_Xw_ = self.arr_Xw_.at[:, self.current_iter].set(Xw) + self.current_iter += 1 + return w, Xw + + # compute residuals + U = jnp.diff(self.arr_w_, axis=1) + + # compute extrapolation coefs + try: + inv_UTU_ones = jnp.linalg.solve(U.T @ U, jnp.ones(self.K)) + except Exception: + return w, Xw + finally: + self.current_iter = 0 + + # extrapolate + C = inv_UTU_ones / jnp.sum(inv_UTU_ones) + + return self.arr_w_[:, 1:] @ C, self.arr_Xw_[:, 1:] @ C From 3e420d9ee80bc8b2c6690acb17a004e6710ee442 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Mon, 1 May 2023 21:14:32 +0200 Subject: [PATCH 13/15] fista solver --- skglm/skglm_jax/datafits.py | 8 ++++ skglm/skglm_jax/fista.py | 79 ++++++++++++++++++++++++++++++++++++ skglm/skglm_jax/penalties.py | 3 ++ 3 files changed, 90 insertions(+) create mode 100644 skglm/skglm_jax/fista.py diff --git a/skglm/skglm_jax/datafits.py b/skglm/skglm_jax/datafits.py index 08bf9b418..948133df6 100644 --- a/skglm/skglm_jax/datafits.py +++ b/skglm/skglm_jax/datafits.py @@ -38,3 +38,11 @@ def gradient_ws(self, X, y, w, Xw, ws): def get_features_lipschitz_cst(self, X, y): n_samples = X.shape[0] return jnorm(X, ord=2, axis=0) ** 2 / n_samples + + def get_global_lipschitz_cst(self, X, y): + n_samples = X.shape[0] + return jnorm(X, ord=2) ** 2 / n_samples + + def gradient(self, X, y, w): + n_samples = X.shape[0] + return X.T @ (X @ w - y) / n_samples diff --git a/skglm/skglm_jax/fista.py b/skglm/skglm_jax/fista.py new file mode 100644 index 000000000..f39b26a2e --- /dev/null +++ b/skglm/skglm_jax/fista.py @@ -0,0 +1,79 @@ +import numpy as np + +import jax +import jax.numpy as jnp + +from skglm.skglm_jax.datafits import QuadraticJax +from skglm.skglm_jax.penalties import L1Jax + + +class Fista: + + def __init__(self, max_iter=200, use_auto_diff=True, verbose=0): + self.max_iter = max_iter + self.use_auto_diff = use_auto_diff + self.verbose = verbose + + def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax): + n_samples, n_features = X.shape + X_gpu, y_gpu = jnp.asarray(X), jnp.asarray(y) + + # compute step + lipschitz = datafit.get_global_lipschitz_cst(X_gpu, y_gpu) + if lipschitz == 0.: + return np.zeros(n_features) + + step = 1 / lipschitz + all_features = jnp.full(n_features, fill_value=True, dtype=bool) + + # get grad func of datafit + if self.use_auto_diff: + auto_grad = jax.jit(jax.grad(datafit.value, argnums=-1)) + + # init vars in device + w = jnp.zeros(n_features) + old_w = jnp.zeros(n_features) + mid_w = jnp.zeros(n_features) + grad = jnp.zeros(n_features) + + t_old, t_new = 1, 1 + + for it in range(self.max_iter): + + # compute grad + if self.use_auto_diff: + grad = auto_grad(X_gpu, y_gpu, mid_w) + else: + grad = datafit.gradient(X_gpu, y_gpu, mid_w) + + # forward / backward + val = mid_w - step * grad + w = penalty.prox(val, step) + + if self.verbose: + p_obj = datafit.value(X_gpu, y_gpu, w) + penalty.value(w) + + if self.use_auto_diff: + grad = auto_grad(X_gpu, y_gpu, w) + else: + grad = datafit.gradient(X_gpu, y_gpu, w) + + scores = penalty.subdiff_dist_ws(w, grad, all_features) + stop_crit = jnp.max(scores) + + print( + f"Iteration {it:4}: p_obj={p_obj:.8f}, opt crit={stop_crit:.4e}" + ) + + # extrapolate + mid_w = w + ((t_old - 1) / t_new) * (w - old_w) + + # update FISTA vars + t_old = t_new + t_new = 0.5 * (1 + jnp.sqrt(1. + 4. * t_old ** 2)) + old_w = jnp.copy(w) + + # transfer back to host + w_cpu = np.asarray(w, dtype=np.float64) + + return w_cpu diff --git a/skglm/skglm_jax/penalties.py b/skglm/skglm_jax/penalties.py index 464dc567f..b9a27c0fe 100644 --- a/skglm/skglm_jax/penalties.py +++ b/skglm/skglm_jax/penalties.py @@ -17,6 +17,9 @@ def prox_1d(self, value, stepsize): shifted_value = jnp.abs(value) - stepsize * self.alpha return jnp.sign(value) * jnp.maximum(shifted_value, 0.) + def prox(self, value, stepsize): + return self.prox_1d(value, stepsize) + @jax_jit_method def subdiff_dist_ws(self, w, grad_ws, ws): n_features = w.shape[0] From f974bd88219b3b24be9b80411211f68908d9f279 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Mon, 1 May 2023 21:14:59 +0200 Subject: [PATCH 14/15] fix bugs norm && type && transfer --- skglm/skglm_jax/__init__.py | 7 +++++++ skglm/skglm_jax/anderson_cd.py | 4 +++- skglm/skglm_jax/tests/test_anderson_cd.py | 19 +++++++++++++------ 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/skglm/skglm_jax/__init__.py b/skglm/skglm_jax/__init__.py index 294abd5c0..ca636c085 100644 --- a/skglm/skglm_jax/__init__.py +++ b/skglm/skglm_jax/__init__.py @@ -3,3 +3,10 @@ # side-effect: (perhaps) slow compilation time. import os os.environ['XLA_FLAGS'] = '--xla_gpu_force_compilation_parallelism=1' # noqa + +# set flag to resolve bug with `jax.linalg.norm` +# ref: https://github.com/google/jax/issues/8916#issuecomment-1101113497 +os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "False" # noqa + +import jax # noqa +jax.config.update("jax_enable_x64", True) # noqa diff --git a/skglm/skglm_jax/anderson_cd.py b/skglm/skglm_jax/anderson_cd.py index c9dbb6381..c32aa763b 100644 --- a/skglm/skglm_jax/anderson_cd.py +++ b/skglm/skglm_jax/anderson_cd.py @@ -1,6 +1,7 @@ from functools import partial import jax +import numpy as np import jax.numpy as jnp from skglm.skglm_jax.datafits import QuadraticJax @@ -65,7 +66,8 @@ def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax): w, Xw = self._solve_sub_problem(X, y, w, Xw, ws, lipschitz, tol_in, datafit, penalty) - return w + w_cpu = np.asarray(w) + return w_cpu def _solve_sub_problem(self, X, y, w, Xw, ws, lipschitz, tol_in, datafit, penalty): diff --git a/skglm/skglm_jax/tests/test_anderson_cd.py b/skglm/skglm_jax/tests/test_anderson_cd.py index c024dce0f..5f12a54d3 100644 --- a/skglm/skglm_jax/tests/test_anderson_cd.py +++ b/skglm/skglm_jax/tests/test_anderson_cd.py @@ -1,32 +1,39 @@ +import pytest + import numpy as np from numpy.linalg import norm from skglm.utils.data import make_correlated_data from skglm.skglm_jax.anderson_cd import AndersonCD +from skglm.skglm_jax.fista import Fista from skglm.skglm_jax.datafits import QuadraticJax from skglm.skglm_jax.penalties import L1Jax from skglm.estimators import Lasso -def test_solver(): +@pytest.mark.parametrize( + "solver", [AndersonCD(), + Fista(use_auto_diff=True), + Fista(use_auto_diff=False)]) +def test_solver(solver): random_state = 135 - n_samples, n_features = 100, 20 + n_samples, n_features = 50, 10 X, y, _ = make_correlated_data(n_samples, n_features, random_state=random_state) lmbd_max = norm(X.T @ y, ord=np.inf) / n_samples - lmbd = 1e-1 * lmbd_max + lmbd = 1e-2 * lmbd_max datafit = QuadraticJax() penalty = L1Jax(lmbd) - w = AndersonCD(verbose=2).solve(X, y, datafit, penalty) + w = solver.solve(X, y, datafit, penalty) estimator = Lasso(alpha=lmbd, fit_intercept=False).fit(X, y) - np.testing.assert_allclose(w, estimator.coef_, atol=1e-5) + np.testing.assert_allclose(w, estimator.coef_, atol=1e-4) if __name__ == "__main__": - test_solver() + test_solver(Fista()) pass From d972d840fdfc96253ac02f1c2ff9d226eb710177 Mon Sep 17 00:00:00 2001 From: Moufad Badr Date: Tue, 2 May 2023 14:45:18 +0200 Subject: [PATCH 15/15] fix errors XLA --- skglm/skglm_jax/README.md | 2 +- skglm/skglm_jax/__init__.py | 10 +++++----- skglm/skglm_jax/anderson_cd.py | 2 +- skglm/skglm_jax/tests/test_anderson_cd.py | 10 ++++++++-- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/skglm/skglm_jax/README.md b/skglm/skglm_jax/README.md index 976ea930d..52146bafc 100644 --- a/skglm/skglm_jax/README.md +++ b/skglm/skglm_jax/README.md @@ -4,7 +4,7 @@ 1. create then activate ``conda`` environnement ```shell # create -conda create -n skglm-jax python=3.7 +conda create -n skglm-jax python=3.10 # activate env conda activate skglm-jax diff --git a/skglm/skglm_jax/__init__.py b/skglm/skglm_jax/__init__.py index ca636c085..9f3937231 100644 --- a/skglm/skglm_jax/__init__.py +++ b/skglm/skglm_jax/__init__.py @@ -1,12 +1,12 @@ # if not set, raises an error related to CUDA linking API. # as recommended, setting the 'XLA_FLAGS' to bypass it. # side-effect: (perhaps) slow compilation time. -import os -os.environ['XLA_FLAGS'] = '--xla_gpu_force_compilation_parallelism=1' # noqa +# import os +# os.environ['XLA_FLAGS'] = '--xla_gpu_force_compilation_parallelism=1' # noqa # set flag to resolve bug with `jax.linalg.norm` # ref: https://github.com/google/jax/issues/8916#issuecomment-1101113497 -os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "False" # noqa +# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "False" # noqa -import jax # noqa -jax.config.update("jax_enable_x64", True) # noqa +import jax +jax.config.update("jax_enable_x64", True) diff --git a/skglm/skglm_jax/anderson_cd.py b/skglm/skglm_jax/anderson_cd.py index c32aa763b..a3b45f433 100644 --- a/skglm/skglm_jax/anderson_cd.py +++ b/skglm/skglm_jax/anderson_cd.py @@ -14,7 +14,7 @@ class AndersonCD: EPS_TOL = 0.3 def __init__(self, max_iter=100, max_epochs=100, tol=1e-6, p0=10, - use_acc=True, verbose=0): + use_acc=False, verbose=0): self.max_iter = max_iter self.max_epochs = max_epochs self.tol = tol diff --git a/skglm/skglm_jax/tests/test_anderson_cd.py b/skglm/skglm_jax/tests/test_anderson_cd.py index 5f12a54d3..f66834860 100644 --- a/skglm/skglm_jax/tests/test_anderson_cd.py +++ b/skglm/skglm_jax/tests/test_anderson_cd.py @@ -18,7 +18,7 @@ Fista(use_auto_diff=False)]) def test_solver(solver): random_state = 135 - n_samples, n_features = 50, 10 + n_samples, n_features = 10_000, 100 X, y, _ = make_correlated_data(n_samples, n_features, random_state=random_state) @@ -35,5 +35,11 @@ def test_solver(solver): if __name__ == "__main__": - test_solver(Fista()) + import time + + start = time.perf_counter() + test_solver(AndersonCD(verbose=2)) + end = time.perf_counter() + + print("Elapsed time:", end - start) pass