Open jhennawi opened 2 years ago
I tracked down the issue to our logpdf using a Cholesky decomposition, which for some reason doesn't succeed in this case. You code uses a more general (and expensive) linear solver. I'm not a numerical stability expert so I'm not sure what's the best way to proceed. @hawkinsp might be able to hint at some improvements here.
To be fair, the old scipy.stats.multivariate_normal.logpdf also fails in this case. It does however fault with a singular matrix warning, but if you set allow_singular=True it continues but gives bad results. I think the issue here is accumulation of roundoff error when you perform cholesky decomposition for large matrices. It is clear from slogdet and how the covar was constructed that it is not singular. A better implementation here would switch to the more robust linear solver implementation for possibly (numerically) singular covariance matrices.
Description
The following code example works perfectly for ndim = 10, but Jax returns nan for ndim=100. It stars failing at ndim=15. I am working in double precision by setting the environment variable JAX_ENABLE_X64=True
What jax/jaxlib version are you using?
Jax version 0.3.13
Which accelerator(s) are you using?
Additional System Info
Python 3.9.7, MacOS M1 ARM,