Skip to content

POC - Jax implementation of AndersonCD solver #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
22 changes: 22 additions & 0 deletions skglm/skglm_jax/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
## Installation


1. create then activate ``conda`` environnement
```shell
# create
conda create -n skglm-jax python=3.10

# activate env
conda activate skglm-jax
```

2. install ``skglm`` in editable mode
```shell
pip install skglm -e .
```

3. install dependencies
```shell
# jax
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
```
12 changes: 12 additions & 0 deletions skglm/skglm_jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# if not set, raises an error related to CUDA linking API.
# as recommended, setting the 'XLA_FLAGS' to bypass it.
# side-effect: (perhaps) slow compilation time.
# import os
# os.environ['XLA_FLAGS'] = '--xla_gpu_force_compilation_parallelism=1' # noqa

# set flag to resolve bug with `jax.linalg.norm`
# ref: https://github.com/google/jax/issues/8916#issuecomment-1101113497
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "False" # noqa

import jax
jax.config.update("jax_enable_x64", True)
141 changes: 141 additions & 0 deletions skglm/skglm_jax/anderson_cd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from functools import partial

import jax
import numpy as np
import jax.numpy as jnp

from skglm.skglm_jax.datafits import QuadraticJax
from skglm.skglm_jax.penalties import L1Jax
from skglm.skglm_jax.utils import JaxAA


class AndersonCD:

EPS_TOL = 0.3

def __init__(self, max_iter=100, max_epochs=100, tol=1e-6, p0=10,
use_acc=False, verbose=0):
self.max_iter = max_iter
self.max_epochs = max_epochs
self.tol = tol
self.p0 = p0
self.use_acc = use_acc
self.verbose = verbose

def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax):
X, y = self._transfer_to_device(X, y)

n_samples, n_features = X.shape
lipschitz = datafit.get_features_lipschitz_cst(X, y)

w = jnp.zeros(n_features)
Xw = jnp.zeros(n_samples)
all_features = jnp.full(n_features, fill_value=True, dtype=bool)

for it in range(self.max_iter):

# check convergence
grad = datafit.gradient_ws(X, y, w, Xw, all_features)
scores = penalty.subdiff_dist_ws(w, grad, all_features)
stop_crit = jnp.max(scores)

if self.verbose:
p_obj = datafit.value(X, y, w) + penalty.value(w)

print(
f"Iteration {it}: p_obj_in={p_obj:.8f} "
f"stop_crit_in={stop_crit:.4e}"
)

if stop_crit <= self.tol:
break

# build ws
gsupp_size = penalty.generalized_support(w).sum()
ws_size = min(
max(2 * gsupp_size, self.p0),
n_features
)

ws = jnp.full(n_features, fill_value=False, dtype=bool)
ws_features = jnp.argsort(scores)[-ws_size:]
ws = ws.at[ws_features].set(True)

tol_in = AndersonCD.EPS_TOL * stop_crit

w, Xw = self._solve_sub_problem(X, y, w, Xw, ws, lipschitz, tol_in,
datafit, penalty)

w_cpu = np.asarray(w)
return w_cpu

def _solve_sub_problem(self, X, y, w, Xw, ws, lipschitz, tol_in,
datafit, penalty):

if self.use_acc:
accelerator = JaxAA(K=5)

for epoch in range(self.max_epochs):

w, Xw = self._cd_epoch(X, y, w, Xw, ws, lipschitz,
datafit, penalty)

if self.use_acc:
w, Xw = accelerator.extrapolate(w, Xw)

# check convergence
grad_ws = datafit.gradient_ws(X, y, w, Xw, ws)
scores_ws = penalty.subdiff_dist_ws(w, grad_ws, ws)
stop_crit_in = jnp.max(scores_ws)

if max(self.verbose - 1, 0):
p_obj_in = datafit.value(X, y, w) + penalty.value(w)

print(
f"Epoch {epoch}: p_obj_in={p_obj_in:.8f} "
f"stop_crit_in={stop_crit_in:.4e}"
)

if stop_crit_in <= tol_in:
break

return w, Xw

@partial(jax.jit, static_argnums=(0, -2, -1))
def _cd_epoch(self, X, y, w, Xw, ws, lipschitz, datafit, penalty):
for j, in_ws in enumerate(ws):

w, Xw = jax.lax.cond(
in_ws,
lambda X, y, w, Xw, j, lipschitz: self._cd_epoch_j(X, y, w, Xw, j, lipschitz, datafit, penalty), # noqa
lambda X, y, w, Xw, j, lipschitz: (w, Xw),
*(X, y, w, Xw, j, lipschitz)
)

return w, Xw

@partial(jax.jit, static_argnums=(0, -2, -1))
def _cd_epoch_j(self, X, y, w, Xw, j, lipschitz, datafit, penalty):

# Null columns of X would break this functions
# as their corresponding lipschitz is 0
# TODO: implement condition using lax
# if lipschitz[j] == 0.:
# continue

step = 1 / lipschitz[j]

grad_j = datafit.gradient_1d(X, y, w, Xw, j)
next_w_j = penalty.prox_1d(w[j] - step * grad_j, step)

delta_w_j = next_w_j - w[j]

w = w.at[j].set(next_w_j)
Xw = Xw + delta_w_j * X[:, j]

return w, Xw

def _transfer_to_device(self, X, y):
# TODO: other checks
# - skip if they are already jax array
return jnp.asarray(X), jnp.asarray(y)
48 changes: 48 additions & 0 deletions skglm/skglm_jax/datafits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import jax
import jax.numpy as jnp
from jax.numpy.linalg import norm as jnorm

from skglm.skglm_jax.utils import jax_jit_method


class QuadraticJax:
"""1 / (2 n_samples) ||y - Xw||^2"""

def value(self, X, y, w):
n_samples = X.shape[0]
return ((X @ w - y) ** 2).sum() / (2. * n_samples)

def gradient_1d(self, X, y, w, Xw, j):
n_samples = X.shape[0]
return X[:, j] @ (Xw - y) / n_samples

@jax_jit_method
def gradient_ws(self, X, y, w, Xw, ws):
n_features = X.shape[1]
Xw_minus_y = Xw - y

grad_ws = jnp.empty(n_features)
for j, in_ws in enumerate(ws):

grad_j = jax.lax.cond(
in_ws,
lambda X, Xw_minus_y, j: X[:, j] @ Xw_minus_y / len(Xw_minus_y),
lambda X, Xw_minus_y, j: 0.,
*(X, Xw_minus_y, j)
)

grad_ws = grad_ws.at[j].set(grad_j)

return grad_ws

def get_features_lipschitz_cst(self, X, y):
n_samples = X.shape[0]
return jnorm(X, ord=2, axis=0) ** 2 / n_samples

def get_global_lipschitz_cst(self, X, y):
n_samples = X.shape[0]
return jnorm(X, ord=2) ** 2 / n_samples

def gradient(self, X, y, w):
n_samples = X.shape[0]
return X.T @ (X @ w - y) / n_samples
79 changes: 79 additions & 0 deletions skglm/skglm_jax/fista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np

import jax
import jax.numpy as jnp

from skglm.skglm_jax.datafits import QuadraticJax
from skglm.skglm_jax.penalties import L1Jax


class Fista:

def __init__(self, max_iter=200, use_auto_diff=True, verbose=0):
self.max_iter = max_iter
self.use_auto_diff = use_auto_diff
self.verbose = verbose

def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax):
n_samples, n_features = X.shape
X_gpu, y_gpu = jnp.asarray(X), jnp.asarray(y)

# compute step
lipschitz = datafit.get_global_lipschitz_cst(X_gpu, y_gpu)
if lipschitz == 0.:
return np.zeros(n_features)

step = 1 / lipschitz
all_features = jnp.full(n_features, fill_value=True, dtype=bool)

# get grad func of datafit
if self.use_auto_diff:
auto_grad = jax.jit(jax.grad(datafit.value, argnums=-1))

# init vars in device
w = jnp.zeros(n_features)
old_w = jnp.zeros(n_features)
mid_w = jnp.zeros(n_features)
grad = jnp.zeros(n_features)

t_old, t_new = 1, 1

for it in range(self.max_iter):

# compute grad
if self.use_auto_diff:
grad = auto_grad(X_gpu, y_gpu, mid_w)
else:
grad = datafit.gradient(X_gpu, y_gpu, mid_w)

# forward / backward
val = mid_w - step * grad
w = penalty.prox(val, step)

if self.verbose:
p_obj = datafit.value(X_gpu, y_gpu, w) + penalty.value(w)

if self.use_auto_diff:
grad = auto_grad(X_gpu, y_gpu, w)
else:
grad = datafit.gradient(X_gpu, y_gpu, w)

scores = penalty.subdiff_dist_ws(w, grad, all_features)
stop_crit = jnp.max(scores)

print(
f"Iteration {it:4}: p_obj={p_obj:.8f}, opt crit={stop_crit:.4e}"
)

# extrapolate
mid_w = w + ((t_old - 1) / t_new) * (w - old_w)

# update FISTA vars
t_old = t_new
t_new = 0.5 * (1 + jnp.sqrt(1. + 4. * t_old ** 2))
old_w = jnp.copy(w)

# transfer back to host
w_cpu = np.asarray(w, dtype=np.float64)

return w_cpu
54 changes: 54 additions & 0 deletions skglm/skglm_jax/penalties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import jax
import jax.numpy as jnp

from skglm.skglm_jax.utils import jax_jit_method


class L1Jax:
"""alpha ||w||_1"""

def __init__(self, alpha):
self.alpha = alpha

def value(self, w):
return (self.alpha * jnp.abs(w)).sum()

def prox_1d(self, value, stepsize):
shifted_value = jnp.abs(value) - stepsize * self.alpha
return jnp.sign(value) * jnp.maximum(shifted_value, 0.)

def prox(self, value, stepsize):
return self.prox_1d(value, stepsize)

@jax_jit_method
def subdiff_dist_ws(self, w, grad_ws, ws):
n_features = w.shape[0]
dist = jnp.empty(n_features)

for j, in_ws in enumerate(ws):
w_j = w[j]
grad_j = grad_ws[j]

dist_j = jax.lax.cond(
in_ws,
self._compute_subdiff_dist_j,
lambda w_j, grad_j: 0.,
*(w_j, grad_j)
)

dist = dist.at[j].set(dist_j)

return dist

def generalized_support(self, w):
return w != 0.

@jax_jit_method
def _compute_subdiff_dist_j(self, w_j, grad_j):
dist_j = jax.lax.cond(
w_j == 0.,
lambda w_j, grad_j, alpha: jnp.maximum(jnp.abs(grad_j) - alpha, 0.),
lambda w_j, grad_j, alpha: jnp.abs(grad_j + jnp.sign(w_j) * alpha),
*(w_j, grad_j, self.alpha)
)
return dist_j
Loading