Open BalzaniEdoardo opened 1 month ago
mode details on the acceleration.
According to the docstrings of jaxopt ProximalGradient, acceleration should activate the line search (FISTA), but looking at the code, this is not what happens.
This is called with acceleration=False too, in the ProximalGradient._iter method. The acceleration parameter toggles if the _update
or _update_accel
is called. The update_accel
computes an auxiliary velocity variable that is used in the computation of the new parameters instead of the parameters update itself.
The way this auxiliary var is computed is the following
diff_x = tree_sub(next_x, x) # next_x is the new parameter, computed in _iter, which perform a FISTA,
# based on the auxiliary param at the previous step
next_y = tree_add_scalar_mul(next_x, (t - 1) / next_t, diff_x) # new auxiliary paramerer
All proximal method have stepsize adjustment steps that may cause instability in the parameter learning if the log-likelihood is sharply peaked. These parameters are
acceleration=True
andstepsize<=0
.I am posting some code (modified from Sam's notebook) that generates a simple example for which the first order methods (proximal gradient for lasso and gradient descent for ridge and ML) behave unstably.
This example is very easy (the log-likelihood is very sharp), there is a clear maximum. I run different config of the proximal gradient with the example below and the result is the following:
This seems to suggest that, any time one set
acceleration=True
or uses astepsize=0
, which is the default, this may result in unstable parameter learning: either returning nans, or the coefficient that are extremely large in norm.Below the example code