patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
333 stars 14 forks source link

Improving LM implementation #92

Open Joshuaalbert opened 2 weeks ago

Joshuaalbert commented 2 weeks ago

Hi @patrick-kidger, as promised I wanted to help improve some of the optimisation methods in optimistix. I'd like to start with the LM implementation.

Trust region acceptance

The implemented approach has two thresholds used for determining any improvement, and sufficient improvement to warrant taking larger steps. The damping parameter is then taken to be 1/step_size in the damped newton iteration.

There are several points here:

  1. The point of the trust region approach is to determine when the function is locally acting quadratic, and to take more Newton-like steps in this case. That said, you cover the case where the actual reduction is less than predicted, but you miss the case where it's more than predicted. Intuitively you would think that if it improves more than predicted then you should take even bigger steps, but this is wrong. The function is not behaving quadratically.

Therefore, you should have a third cutoff sensing when actual_reduction/pred_reduction is sufficiently greater than one (1.1 is usually fine). In this case, accept but do not make the step more newton. Basically, only make the iterations more newton if the gain is within a region around 1.

  1. The units of dampening parameter in damped Newton step are those of the hessian-like operator. For minimisation it's actual hessian with [f]/[x]^2. For normalised least-squares it's J^T.J it's [f]^2/[x]^2, which is consistent because we normalised the equations. Anyways, choosing lambda=1/step_size is not dimensionally correct. Much better is to let lambda = |grad(f)| / mu (or lambda = |J^T.F| / mu for LM). Note, the units are now correct when mu has units [x]. The intuition behind this is, in the asymptotic steepest descent case x -> x - mu * grad(f) / |grad(f)|, i.e. a step-size times the gradient unit-vector.

Therefore, you can improve the damping in two ways.

i. setting lambda = |grad(f)|/step_size for minimisation, and lambda = |J^T.F| / step_size for LM.

ii. Choosing the initial value of step_size can be done by line search for a value of mu that leads to a reduction in the objective. This only needs to be done once, and thereafter step_size is modified following the normal approach. A good approach is to start from mu = |grad(f)| and half until x - mu * grad(f) / |grad(f)| leads to an objective improvement. You don't need to satisfy any other particular conditions to accept the value of mu.

Reusing J/JVP

Multi-step "approximate" LM is easily implemented by first linearising the JVPop around the current parameter and then performing one exact LM step, followed by a number of approximate steps using same JVPop. In the dense J case this is really valuable as you only form the J matrix once per 1 + num_approx_steps. It's also still helpful in the sparse case, wherein using jax.linearize is helpful. It is shown in literature to significantly reduce the amount of computation and only require a few more iterations to converge. There are simple criteria to determine when J should be recomputed, however JAX precludes these dynamic decisions. Simplest is a fixed number of approximate steps per exact step.

I didn't have time to attach literature, but hopefully this gets the ball rolling. I also suggest that a suite of simple but difficult benchmarks be written first to assess an improvement to the algorithm.

johannahaffner commented 2 weeks ago

Hi Joshua, Hi Patrick,

I follow discussions on this repository and also use the optimistix version of LM quite a bit, so I'm interested in this discussion.

For 2.: If I understand correctly, then $\mu\$ is an auxiliary parameter used to find the step size. Where I do not follow is

This only needs to be done once, and thereafter step_size is modified following the normal approach.

Why would a line search at some initial point be helpful for modifying the step size throughout the solve? Would't this merely scale the step-size by some value if done once, and then you'd go back to multiplying with high- and low cutoffs and essentially end up with more-or-less identical $\lambda$ after a few iterations?

Maybe I am missing something :) I'd be curious how this improves the convergence rates, and if making the step size more dimensionally correct "matters" to the solver, which only sees floats.

And to your other point - re-using Jacobians would make step-size computations less accurate, right? To what extent can this be counteracted by making the step size selection criteria more stringent?

patrick-kidger commented 2 weeks ago

@Joshuaalbert -- thanks for getting involved!

To address your various points:

  1. So I believe we actually experimented with this exact idea during development. (Certainly it's a standard one that one sees floating around.) I think on balance we elected not to include this -- IIRC it didn't help overall in the average case. But I think if this is important to you then I'd definitely be happy to take that as a PR, just with a default cutoff of jnp.inf so as to preserve backward compatibility.

  2. To make explicit what you are suggesting, you are considering multiplying this quantity:

    https://github.com/patrick-kidger/optimistix/blob/ef86ef5d3dcd1d4be61611137157d19bf7944118/optimistix/_solver/levenberg_marquardt.py#L57

    by |grad(f)| (or |J^T f| in the least-squares case).

    I think something like that is reasonable, although note that to be dimensionally correct I think things may need to change further, as step_size is unitless. (It is an arbitrary scalar whose value is entirely up to the search to choose.) What would you suggest in light of that?

    (I probably wouldn't introduce the initial search -- as Johanna comments this is unlikely to affect things dramatically and it would impede efficient compilation.)

  3. On reusing the Jacobian evaluations. We actually already do this in the Gauss-Newton case: notice how we form it here in query rather than in step:

    https://github.com/patrick-kidger/optimistix/blob/ef86ef5d3dcd1d4be61611137157d19bf7944118/optimistix/_solver/gauss_newton.py#L127

    Indeed for our LM implementation then we elected to do the exact form in which this is computed on every step. If you wanted to implement an approximate form in which this is re-evaluated based on some criteria (even a dynamic one), then you could actually already implement this yourself without any changes to Optimistix! Implement your own descent and/or search. Depending on how this goes I'd be happy to consider upstreaming this if that was something you wanted.

Joshuaalbert commented 2 weeks ago

Hi both, thanks for the replies. @johannahaffner the reason the step-size selection only needs to be performed once is because all it's doing is find a suitable initial dampening factor. This is a dampening parameter that would lead to successful steepest gradient descent iteration in the asymptotic case. By asymptotic case I mean this, as lambda -> inf we have (A + lambda*I)^-1 v -> lambda^-1 v. For those coming from SGD, lambda=1/learning_rate in the asymptotic case, so choosing an initial step-size is kind of like choosing a good initial learning rate. The initial selection does two important things: 1) removes the necessity of the user to choose a suitable initial damping parameter, and 2) automatically ensures that the first iteration shows improvement (in most problems). This wastes fewer iterations where the solver needs to find a good initial step-size.

Hi, @patrick-kidger we could make a list of possible improvements and then do them in one fell PR.

  1. On adding an upper trust-region limit, it prevents the step-size increasing too quickly and then needing to wait several iterations while it shrinks it again.
  2. Yes that's what it would be equivalent to. step_size isn't unitless. We see that in the asymptotic case, the steepest descent update is x -> x - grad / lm_param = x - step_size * grad, so step_size has units [x]^2/[f]^2. Note, this is because grad=J^T.F for LM. When you take the parametrisation that I suggest, the asymptotic update is x -> x - mu * grad / |grad|, and mu has units [x]. Of course, this only makes sense when all x_i components have the same units, but that's just a good reason to use homogeneous spaces to parametrise parameters. This is also why different problems require different initial damping parameters, because it's dependent on the scale of the problem. Choosing lm_param = |grad|/mu has the nice property that as the stationary point is approached, |grad| shrinks which automatically makes the system more Newton, and less sensitive to mu.

Worth mentioning, in my code I use lambda(mu) = mu * |grad| rather than dividing so mu has units [x]^-1 in that case.

  1. I suspected this was already something you're doing. I'm not sure exactly how I would implement dynamic re-evaluation of J, in the case where we don't materialise the Jacobian. In my implementation I simply use a static inner for-loop for the approximate steps using the same linearised JVPop.

Another thing I've wanted to do for a while is collect a large number of least-squares problems, of varying difficulty, and then do a grid search over hyper-parameters to choose defaults that lead to fastest convergence on average. WDYT?

johannahaffner commented 2 weeks ago

thanks for clarifying, @Joshuaalbert :)

patrick-kidger commented 1 week ago

we could make a list of possible improvements and then do them in one fell PR

Multiple smaller PRs are much easier to review :)

step_size isn't unitless

I'd like it to be, though! I think the argument you're making here is really that we should make similar changes elsewhere -- to improve unitlessness -- beyond just the case of LM.

I'm not sure exactly how I would implement dynamic re-evaluation of J, in the case where we don't materialise the Jacobian.

So the loop of a solver is over individual function evaluations. At each step we then decide what to do with that information. C.f. the earlier discussion here:

https://github.com/patrick-kidger/optimistix/issues/89#issuecomment-2447669714

Using this as an example, this already includes a dynamic choice about when to compute the gradient -- in this case, that we are finishing one line search and starting another. You could adjust this logic to be match whatever condition you most prefer.

Another thing I've wanted to do for a while is collect a large number of least-squares problems, of varying difficulty, and then do a grid search over hyper-parameters to choose defaults that lead to fastest convergence on average. WDYT?

I think having some JAX-compatible benchmarks sounds pretty useful to me!

Joshuaalbert commented 1 week ago

I'd like it to be, though! I think the argument you're making here is really that we should make similar changes elsewhere -- to improve unitlessness -- beyond just the case of LM.

Just to be clear this is not possible without some arbitrary imposition of a "default scale". Even using something like damping=step_size*diag(diag(hessian)), which would make step_size dimensionless, imposes an a choice of scale. You still need to find a step_size that would work for that choice. Note, I find damping=|grad|/step_size performs much better than damping=step_size*diag(diag(hessian)) as it automatically shrinks as you get closer to a stationary point which makes the algorithm less sensitive to step_size.

All solvers require some knowledge of the search domain. Dimensionlessness is a property of the model, not the solver. So, if you want dimensionless step_size, then you should parametrise your models from the unit-cube using quantiles to impose your prior knowledge. This is actually how my probabilstic programming framework, jaxns, works. The units of the support of U[0,1] are the units of probability. E.g. if you have the distance to the Sun as a parameter, this could be a LogNormal variable, with a prior knowledge mu and sigma of the log-distance.

# your unconstrained variable in [-inf, inf]
unconstrained_param = ...
# any measure preserving map to [0, 1] (CDF if any prob. dist. is fine).
U = tfpd.Normal(0, 1).cdf(unconstrained_param)
# Now apply the quantile of some dist that encapsultes your prior knowledge about the variable.
param = tfpd.LogNormal(mu, sigma).quantile(U)

The solver operates on the unconstrained space, which is dimensionless. Note, I have a fast bijection [-inf, inf) -> [0,1) here which doesn't require any transcendentals.

However, it's not a problem that any solver is inherently dimensionful. As long as you choose the right problem-specific scale. This is why an automatic determination of an initial step-size in Gauss-Newton would be so helpful to general users. I'm not sure about doing this for every solver, as not all solvers operate in the same way, so you'd need to determine what the effect of scale is on each solver and treat each specifically. Certainly, all variable metric methods that employ a line search already endeavour to find the correct scales by using some form of search. These can all be made more robust by dimensionless parametrisation, and also by ensuring the line search variable is in units of the parameter.

Sorry for long reply. I love this stuff.

johannahaffner commented 1 week ago

Hi both :)

I think the issue of scaling is specific to regularisation.

For instance, in NewtonDescent as used by Gauss-Newton and BFGS,

https://github.com/patrick-kidger/optimistix/blob/ef86ef5d3dcd1d4be61611137157d19bf7944118/optimistix/_solver/gauss_newton.py#L133

step_size is obviously unitless.

ClassicalTrustRegion side-steps the problem of dimensionality by taking the ratio f_diff / predicted reduction and updating the step size parameter heuristically.

The use of the step-size as a regularisation parameter in the computation of y_diff is then what raises the question of dimensionality.

Joshuaalbert commented 1 week ago

@johannahaffner I see what you're saying there, however dimensionality and scale are two different things that are not exactly the same as regularisation (which is about adding some extra info to make ill-posed systems better-posed). There are two things at play: 1) dimensional analysis, which looks at the units of the function and units of parameters, and tracks how dimensions ripple through the analysis. This allows one to say things like, in any fixed point iteration, like x -> f(x, theta) the units of f are the same as x. E.g. in classic SGD the "learning rate" has units of the inverse hessian. The simplest result of this is that often it makes things more intuitive to reason in terms of units of parameters 2) is that most algorithms are scale dependent, e.g. linear reparametrisations of the parameters lead to different algorithm behaviour. Gauss-Newton and BFGS (with inexact linesearch) for example. Ideally, we'd like this to not be the case, because then it matters if I chose to use centrimetres or millimeters for length, etc. The best sorts of optimisation algorithms are invariant to monotonic tranformations of the objective function, and linear transformations of the parameter space. But they are rare.

johannahaffner commented 1 week ago

I take your point that dimensionality and scale are different things!

As long as we subtract $\lambda \mathbb{I}$, I would view this as a regularisation, even though (Tikhonov) regularisation parameters are likely smaller.

The point is that by subtracting lm_param, we introduce the step size into the solve for the next step in a way we do not with other descents. So this touches on the effect of this parameter in identifying the next step.

most algorithms are scale dependent, [...] Gauss-Newton and BFGS (with inexact linesearch) for example.

Don't you mean to say that Gauss-Newton and BFGS are scale-invariant? Introducing lm_param, which makes the steps more gradient-descent-like, actually introduces scale-variance. Scaling the computed trust-region radius by a scalar value such as |J^T f| would either lessen or increase that, depending on whether the norm of the Jacobian is less than or greater than one.

Is your goal to figure out what to subtract from the Hessian approximation so as to preserve scale-invariance?

I'm wondering if a little scale variance is not what we want here - and to what extent the robustness of Levenberg-Marquardt depends on being able to interpolate between two different optimisation regimes with different strengths.