-
Notifications
You must be signed in to change notification settings - Fork 37
ENH flexible gram solver with penalty and using datafit #16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH flexible gram solver with penalty and using datafit #16
Conversation
@mathurinm ready for a quick review ;) |
skglm/solvers/cd_solver.py
Outdated
if (isinstance(datafit, (Quadratic, Quadratic_32)) and n_samples > n_features | ||
and n_features < 10_000) or solver in ("cd_gram", "fista"): | ||
# Gram matrix must fit in memory hence the restriction n_features < 1e5 | ||
if not isinstance(datafit, (Quadratic, Quadratic_32)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this bit is unreachable because the check is already performed L155
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've placed it because the first condition is "isinstance....
OR solver in ...
". If the user manually inputs "cd_gram", I think we enter the if statement and I want to catch a wrong datafit, hence L158. Overkill maybe? Should we even expose solver
? I think it is convenient for benchmarks.
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I understood, thanks.
Maybe we can indent the first if, breaking line before or solver
to make it more visible ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to make it more obvious. WDYT?
|
||
coefs : array, shape (n_features, n_alphas) | ||
Coefficients along the path. | ||
obj_out : array, shape (n_iter,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we really return this? or the optimality condition violation instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do return this. See L371.
skglm/solvers/cd_utils.py
Outdated
|
||
|
||
@njit | ||
def prox_vec(penalty, z, stepsize, n_features): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
arf, I though we had penalty.prox
make this function private, remove n_features (access as z.shape[1])
we need a reflection on solvers, but probably all penalties will need to implement it. We can do so in basepenalty, but I fear looping over all coordinates will be slower than performing it in one step as ST_vec
does
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In [16]: %%time
...: out = _prox_vec(pen, z, 0.01)
CPU times: user 28 µs, sys: 1e+03 ns, total: 29 µs
Wall time: 34.1 µs
In [17]: %%time
...: out2 = ST_vec(z, 0.01)
CPU times: user 23 µs, sys: 0 ns, total: 23 µs
Wall time: 25.7 µs
not a big difference, from my experiments. I tried with different thresholds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with @QB3 we had an issue a while ago on flashcd with finance where this caused a big overhead. Just to keep it in mind
Co-authored-by: mathurinm <[email protected]>
Co-authored-by: mathurinm <[email protected]>
…kglm into gram_penalty_nogroup
…enalty_nogroup
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM.
Tests are missing for the solvers though, I can write some if needed.
@@ -52,6 +56,9 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, | |||
return_n_iter : bool, optional | |||
If True, number of iterations along the path are returned. | |||
|
|||
solver : ('cd_ws'|'cd_gram'|'fista'), optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FISTA is not a CD solver, it's confusing to expose it to the user like this.
@mathurinm WDYT?
@njit | ||
def _cd_epoch_gram(XtX, grad, w, datafit, penalty, n_samples, n_features): | ||
lc = datafit.lipschitz | ||
for j in range(n_features): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we have complete access to grad
at each iteration, it would be interesting to use a greedy selection rule here: do not pick j
cyclically, but instead take j = np.argmax(np.abs(grad))
One "epoch" in this setting would only be the update of n_features
coordinates.
closing in favor of #59 |
This is a smaller version of #4 : only without groups, but reusing more code and supporting any penalty