probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

NAN returned in large data and large iteration number #355

Open Yuxin-Ren-SZ opened 5 months ago

Yuxin-Ren-SZ commented 5 months ago

Hi, I'm using LinearGaussianSSM with large time seires data: 73 emission dimension, 1000 length

I noticed that both fit_em() and posterior_predictive() will return me nan parameters or nan posterior predictions if I ran too many iterations, around 150 iterations. And if reduced the emission dimension or state dimension it can hold longer.

I felt like this is probably an overflow or underflow problem. I wonder if there is anything we can do in dynamax or jax to prevent it from returning nan.


A separate question: in dynamax.linear_gaussian_ssm.inference -> lgssm_posterior_sample() -> _step()

return state, state

And this is only used once and the first state, which should be exactly the same as second one, is discarded. I wonder if this is for later development or it's just an omit.

Yuxin-Ren-SZ commented 4 months ago

After reading #290 which I believe is the same problem here but in HMM, I believe this is caused by same problem here. I will try to locate the origin and fix it.

In the same time, I think it is better to check other models' smoother functions.