flatironinstitute / nemos

NEural MOdelS, a statistical modeling framework for neuroscience.
https://nemos.readthedocs.io/
MIT License
75 stars 7 forks source link

Solver returned at least one NaN parameter #83

Closed samdeoxys1 closed 5 months ago

samdeoxys1 commented 7 months ago

problem_data.p.zip

I found certain situations (data attached) that would lead to:

image

code: `import pickle import nemos as nmo import numpy as np

problem_data=pickle.load(open('problem_data.p','rb')) X = problem_data['X'] spikes = problem_data['spikes'] w_true = problem_data['w_true'] model = nmo.glm.GLM() model.set_params( regularizerregularizer_strength=0.,#0.01,
regularizer=nmo.regularizer.Lasso(), observation_model
inverse_link_function=jax.numpy.exp ) model.fit(X, spikes) `

Yet the statsmodels GLM would give a result. import statsmodels.api as sm xx=sm.add_constant(X[:,0,:]) model = sm.GLM(spikes[:,0],xx,family=sm.families.Poisson()) res = model.fit() res.params

(I suspect somewhere there's overflow, since my design matrix has collinearity and the rates are quite high.)

BalzaniEdoardo commented 7 months ago

Have you checked the conditioning of X.T @ X? that should be very large if you have collinearity. you can use numpy.linalg.cond. It could be due to the design, as you said. Do you think the error message is confusing?

We may think of checking the conditioning in case this happens to have a more specific error message.

ahwillia commented 7 months ago

Is this not fixed by ridge regularization?

On Mon, Jan 22, 2024, 9:52 AM Edoardo Balzani @.***> wrote:

Have you checked the conditioning of X.T @ X? that should be very large if you have collinearity. you can use numpy.linalg.cond. It could be due to the design, as you said. Do you think the error message is confusing?

We may think of checking the conditioning in case this happens to have a more specific error message.

— Reply to this email directly, view it on GitHub https://github.com/flatironinstitute/nemos/issues/83#issuecomment-1904169564, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAE3NUN6J7N3SBUMZIDSWEDYPZ4LPAVCNFSM6AAAAABCCTJH6WVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSMBUGE3DSNJWGQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>

BalzaniEdoardo commented 7 months ago

it should definitely help, in my experience if the collinearity is too extreme it might not be enough

BalzaniEdoardo commented 7 months ago

Elastic-net (both lasso and Ridge) sometimes works better for collinear data. Some things worth trying: 1) I would also try to change the precision to float64 which has to be done explicitly in jax. 2) Comparing what happens in scipy.minimize (that's what sklearn calls) on the same data, to our implementation jaxopt based. 3) Check if the initial condition are too off, making the optimization initial steps too large, which can cause the log-likelihood to return -inf. This happens more often with float32 precision

BalzaniEdoardo commented 7 months ago

It's a GradientDescent issue, if you start with a second order method, the fit works. And if you then use the initial parameter from the solution (much better initial condition), the ProxGradient works. The failure comes from the first initial Gradient updates being too large and probably making the likelihood inf. very annoying. we should think of a better way to initialize our parameters.

model0 = nmo.glm.GLM()
model0.set_params(
regularizer=nmo.regularizer.UnRegularized(solver_name="LBFGS"),
observation_model__inverse_link_function=jax.numpy.exp
)
model0.fit(X, np.asarray(spikes, float))

model = nmo.glm.GLM()
model.set_params(
regularizer=nmo.regularizer.Lasso(regularizer_strength=0.1),
observation_model__inverse_link_function=jax.numpy.exp
)
model.fit(X, np.asarray(spikes, float), init_params=(model0.coef_, model0.intercept_))
samdeoxys1 commented 7 months ago

When the rates are high, I again encounter the nan issues, even if I have high lasso penalty. The issue goes away if I use the LBFGS or if I set the lasso penalty to be extremely high. But that would zero out everything. By contrast, statsmodels would recover the correct weight given a small penalty.

Here is the code for generating the data. Increasing w_speed would increase the rate of the simulated spikes.

def gen_data(w_speed=0.08):
    n_pos = 5
    pos_l = np.arange(n_pos)
    speed_per_pos = np.zeros_like(pos_l)
    speed_per_pos = -4*(pos_l - n_pos//2) **2 + 16

    n_trial = 500
    speed_per_pos_std = 2 * speed_per_pos
    speed_per_pos_per_trial = speed_per_pos + np.random.normal(size=(n_trial,n_pos)) * speed_per_pos_std
    speed_per_pos_per_trial[speed_per_pos_per_trial<0] = 0

    pos_onehot_t = np.tile(np.eye(n_pos),(n_trial,1))
    pos_t = np.tile(pos_l,n_trial)
    speed_t = speed_per_pos_per_trial.reshape(-1,1)

    X = np.concatenate([pos_onehot_t,speed_t],axis=1)
    X = X[:,None,:]

    w_true = np.concatenate([[0] * n_pos,[w_speed]])[None,:]
    b_true = 0

    rate = jax.numpy.exp(jax.numpy.einsum("ik,tik->ti", w_true, X)+ b_true) 

    spikes = np.random.poisson(rate)

    return X,rate, spikes, w_true

problem_data_lasso.p.zip

Attached is one sampled dataset that gave me troubles.

samdeoxys1 commented 7 months ago

After some digging, I think the problem comes from the exponential term in the gradient of the log likelihood. The loglikelihood $L = y(\mathbf x^T \beta) - \exp(\mathbf xi^T \beta)-log(y!)$, and the gradient $\nabla{\beta}L=y\mathbf x^T-\exp(\mathbf x^T \beta)\mathbf x^T$. When the step size of the proximal gradient descent is not small enough, the term inside the exponential can get quite big and return nan. Once I set the step size to really small the nan problem went away.

Perhaps that should be emphasized in the error, or step size should first be automatically reduced if nan problem shows up?

BalzaniEdoardo commented 5 months ago

in the error message, when one encounters nans, we tell the user to try a smaller step.

The jaxopt ProximalGradient optimizer has a "fista_line_search" step in which the step is reduced exponentially until the loss decreases, this is applied at every iteration.

https://github.com/google/jaxopt/blob/main/jaxopt/_src/proximal_gradient.py

The issue seems to be the very first initialization, which happens before the while loop that does the step size reduction (line 76). if at the very first iteration the "init_x", which are the initial parameters after one step, return nans then the line search condition won't be matched, and the algorithm terminateS.