-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improving LM implementation #92
Comments
Hi Joshua, Hi Patrick, I follow discussions on this repository and also use the optimistix version of LM quite a bit, so I'm interested in this discussion. For 2.: If I understand correctly, then
Why would a line search at some initial point be helpful for modifying the step size throughout the solve? Would't this merely scale the step-size by some value if done once, and then you'd go back to multiplying with high- and low cutoffs and essentially end up with more-or-less identical Maybe I am missing something :) I'd be curious how this improves the convergence rates, and if making the step size more dimensionally correct "matters" to the solver, which only sees floats. And to your other point - re-using Jacobians would make step-size computations less accurate, right? To what extent can this be counteracted by making the step size selection criteria more stringent? |
@Joshuaalbert -- thanks for getting involved! To address your various points:
|
Hi both, thanks for the replies. @johannahaffner the reason the step-size selection only needs to be performed once is because all it's doing is find a suitable initial dampening factor. This is a dampening parameter that would lead to successful steepest gradient descent iteration in the asymptotic case. By asymptotic case I mean this, as Hi, @patrick-kidger we could make a list of possible improvements and then do them in one fell PR.
Worth mentioning, in my code I use
Another thing I've wanted to do for a while is collect a large number of least-squares problems, of varying difficulty, and then do a grid search over hyper-parameters to choose defaults that lead to fastest convergence on average. WDYT? |
thanks for clarifying, @Joshuaalbert :) |
Multiple smaller PRs are much easier to review :)
I'd like it to be, though! I think the argument you're making here is really that we should make similar changes elsewhere -- to improve unitlessness -- beyond just the case of LM.
So the loop of a solver is over individual function evaluations. At each step we then decide what to do with that information. C.f. the earlier discussion here: Using this as an example, this already includes a dynamic choice about when to compute the gradient -- in this case, that we are finishing one line search and starting another. You could adjust this logic to be match whatever condition you most prefer.
I think having some JAX-compatible benchmarks sounds pretty useful to me! |
Just to be clear this is not possible without some arbitrary imposition of a "default scale". Even using something like All solvers require some knowledge of the search domain. Dimensionlessness is a property of the model, not the solver. So, if you want dimensionless # your unconstrained variable in [-inf, inf]
unconstrained_param = ...
# any measure preserving map to [0, 1] (CDF if any prob. dist. is fine).
U = tfpd.Normal(0, 1).cdf(unconstrained_param)
# Now apply the quantile of some dist that encapsultes your prior knowledge about the variable.
param = tfpd.LogNormal(mu, sigma).quantile(U) The solver operates on the unconstrained space, which is dimensionless. Note, I have a fast bijection However, it's not a problem that any solver is inherently dimensionful. As long as you choose the right problem-specific scale. This is why an automatic determination of an initial step-size in Gauss-Newton would be so helpful to general users. I'm not sure about doing this for every solver, as not all solvers operate in the same way, so you'd need to determine what the effect of scale is on each solver and treat each specifically. Certainly, all variable metric methods that employ a line search already endeavour to find the correct scales by using some form of search. These can all be made more robust by dimensionless parametrisation, and also by ensuring the line search variable is in units of the parameter. Sorry for long reply. I love this stuff. |
Hi both :) I think the issue of scaling is specific to regularisation. For instance, in optimistix/optimistix/_solver/gauss_newton.py Line 133 in ef86ef5
The use of the step-size as a regularisation parameter in the computation of |
@johannahaffner I see what you're saying there, however dimensionality and scale are two different things that are not exactly the same as regularisation (which is about adding some extra info to make ill-posed systems better-posed). There are two things at play: 1) dimensional analysis, which looks at the units of the function and units of parameters, and tracks how dimensions ripple through the analysis. This allows one to say things like, in any fixed point iteration, like |
I take your point that dimensionality and scale are different things! As long as we subtract The point is that by subtracting
Don't you mean to say that Gauss-Newton and BFGS are scale-invariant? Is your goal to figure out what to subtract from the Hessian approximation so as to preserve scale-invariance? I'm wondering if a little scale variance is not what we want here - and to what extent the robustness of Levenberg-Marquardt depends on being able to interpolate between two different optimisation regimes with different strengths. |
Hi @Joshuaalbert, I'm going to write a parser for optimisation problems specified in the SIF format, to make CUTEst benchmark problems usable in JAX. I will implement this as-needed for downstream use cases, but if you're interested, the relevant code will live here: https://github.com/johannahaffner/sif2jax |
Hi @patrick-kidger, as promised I wanted to help improve some of the optimisation methods in optimistix. I'd like to start with the LM implementation.
Trust region acceptance
The implemented approach has two thresholds used for determining any improvement, and sufficient improvement to warrant taking larger steps. The damping parameter is then taken to be
1/step_size
in the damped newton iteration.There are several points here:
Therefore, you should have a third cutoff sensing when
actual_reduction/pred_reduction
is sufficiently greater than one (1.1 is usually fine). In this case, accept but do not make the step more newton. Basically, only make the iterations more newton if the gain is within a region around 1.[f]/[x]^2
. For normalised least-squares it'sJ^T.J
it's[f]^2/[x]^2
, which is consistent because we normalised the equations. Anyways, choosinglambda=1/step_size
is not dimensionally correct. Much better is to letlambda = |grad(f)| / mu
(orlambda = |J^T.F| / mu
for LM). Note, the units are now correct whenmu
has units[x]
. The intuition behind this is, in the asymptotic steepest descent casex -> x - mu * grad(f) / |grad(f)|
, i.e. a step-size times the gradient unit-vector.Therefore, you can improve the damping in two ways.
i. setting
lambda = |grad(f)|/step_size
for minimisation, andlambda = |J^T.F| / step_size
for LM.ii. Choosing the initial value of
step_size
can be done by line search for a value ofmu
that leads to a reduction in the objective. This only needs to be done once, and thereafterstep_size
is modified following the normal approach. A good approach is to start frommu = |grad(f)|
and half untilx - mu * grad(f) / |grad(f)|
leads to an objective improvement. You don't need to satisfy any other particular conditions to accept the value ofmu
.Reusing J/JVP
Multi-step "approximate" LM is easily implemented by first linearising the JVPop around the current parameter and then performing one exact LM step, followed by a number of approximate steps using same JVPop. In the dense J case this is really valuable as you only form the
J
matrix once per1 + num_approx_steps
. It's also still helpful in the sparse case, wherein usingjax.linearize
is helpful. It is shown in literature to significantly reduce the amount of computation and only require a few more iterations to converge. There are simple criteria to determine whenJ
should be recomputed, however JAX precludes these dynamic decisions. Simplest is a fixed number of approximate steps per exact step.I didn't have time to attach literature, but hopefully this gets the ball rolling. I also suggest that a suite of simple but difficult benchmarks be written first to assess an improvement to the algorithm.
The text was updated successfully, but these errors were encountered: