flatironinstitute / nemos

NEural MOdelS, a statistical modeling framework for neuroscience.
https://nemos.readthedocs.io/
MIT License
67 stars 7 forks source link

Changing solver default configs for first order methods #160

Open BalzaniEdoardo opened 1 month ago

BalzaniEdoardo commented 1 month ago

All proximal method have stepsize adjustment steps that may cause instability in the parameter learning if the log-likelihood is sharply peaked. These parameters are acceleration=True and stepsize<=0.

I am posting some code (modified from Sam's notebook) that generates a simple example for which the first order methods (proximal gradient for lasso and gradient descent for ridge and ML) behave unstably.

This example is very easy (the log-likelihood is very sharp), there is a clear maximum. I run different config of the proximal gradient with the example below and the result is the following:

fitting Lasso with {'stepsize': 0.001, 'acceleration': False, 'maxiter': 50000}
true w: [0 0 0 0 0 1]
recovered w: [-0.          0.         -0.         -0.          0.93998479]

fitting Lasso with {'stepsize': 0.001, 'acceleration': True, 'maxiter': 50000}
true w: [0 0 0 0 0 1]
recovered w: [ 0.00000000e+00 -1.41409541e+87  0.00000000e+00  0.00000000e+00 -1.13790427e+88]

fitting Lasso with {'acceleration': True, 'maxiter': 50000}
failed Lasso fit with exception Solver returned at least one NaN parameter, 
so solution is invalid! Try tuning optimization hyperparameters, specifically try decreasing the learning rate.

fitting Lasso with {'acceleration': False, 'maxiter': 50000}
failed Lasso fit with exception Solver returned at least one NaN parameter, 
so solution is invalid! Try tuning optimization hyperparameters, specifically try decreasing the learning rate.

This seems to suggest that, any time one set acceleration=True or uses a stepsize=0, which is the default, this may result in unstable parameter learning: either returning nans, or the coefficient that are extremely large in norm.

Below the example code

import jax
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle
from sklearn import model_selection
import pandas as pd

import nemos as nmo
from nemos import simulation

# enable float64 precision (optional)
jax.config.update("jax_enable_x64", True)

np.random.seed(111)

n_pos = 5
pos_l = np.arange(n_pos)
speed_per_pos = np.zeros_like(pos_l)
speed_per_pos = -4*(pos_l - n_pos//2) ** 2 + 16

n_trial = 100
speed_per_pos_per_trial = speed_per_pos + np.random.normal(size=(n_trial,n_pos)) * 1 * speed_per_pos
speed_per_pos_per_trial -= speed_per_pos.mean()
speed_per_pos_per_trial /= speed_per_pos.std()
speed_per_pos_per_trial[speed_per_pos_per_trial<0] = 0

np.random.seed(111)
pos_onehot_t = np.tile(np.eye(n_pos),(n_trial, 1))
pos_t = np.tile(pos_l,n_trial)
speed_t = speed_per_pos_per_trial.reshape(-1,1)

X = np.concatenate([pos_onehot_t,speed_t],axis=1)
w_true = np.concatenate([[0] * n_pos, [1]])
rate = jax.numpy.exp(jax.numpy.einsum("k,tk->t", w_true, X) + 1.)
spikes = np.random.poisson(rate)

for slv_kwargs in [dict(stepsize=0.001, acceleration=False, maxiter=50000), 
                   dict(stepsize=0.001, acceleration=True, maxiter=50000),
                   dict(acceleration=True, maxiter=50000),
                   dict(acceleration=False,  maxiter=50000)]:
    try:
        model = nmo.glm.GLM(regularizer=nmo.regularizer.Lasso(regularizer_strength=15., solver_kwargs=slv_kwargs))
        print(f"\nfitting Lasso with {slv_kwargs}")
        model.fit(X[:, 1:], spikes)
        print(f"true w: {w_true}")
        print(f"recovered w: {model.coef_}")
    except Exception as e:
        print(f"failed Lasso fit with exception {e}")
BalzaniEdoardo commented 1 month ago

mode details on the acceleration.

According to the docstrings of jaxopt ProximalGradient, acceleration should activate the line search (FISTA), but looking at the code, this is not what happens.

This is called with acceleration=False too, in the ProximalGradient._iter method. The acceleration parameter toggles if the _update or _update_accel is called. The update_accel computes an auxiliary velocity variable that is used in the computation of the new parameters instead of the parameters update itself.

The way this auxiliary var is computed is the following

    diff_x = tree_sub(next_x, x)  # next_x is the new parameter, computed in _iter, which perform a FISTA, 
                                  # based on the auxiliary param at the previous step
    next_y = tree_add_scalar_mul(next_x, (t - 1) / next_t, diff_x) # new auxiliary paramerer