probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
667 stars 76 forks source link

EM yields NaN after some iterations #316

Open atlaie opened 1 year ago

atlaie commented 1 year ago

Hi,

I'm currently using Dynamax to implement a Switching Linear Regression (LinearRegressionHMM in Dynamax) and I wanted to know a bit more about the fit_em function. Particularly, I'd like to train the SLR on some data that looks like this:

Screenshot 2023-05-16 at 20 21 25

i.e., in part of its domain, it's constant. Here's a link with the .npz file containing this data (called "target" in that npz file), as well as the predictors ("predictors" in the npz) that I'm trying to use to call fit_em.

Also, I've tried adding some small noise so that the variance is non-null, but still, after some steps, the EM algorithm blows up:

LogLikelihood

Notice how after the 4th iteration (out of 100) it's just full of NaNs. When I try with SGD (fit_sgd) it works and converges to reasonable predictions.

Let me know if any further details are needed,

Thanks!