jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.66k stars 2.83k forks source link

BUG: jax.scipy.stats.multivariate_normal.logpdf returns nan for high dimensionality matrices #11763

Open jhennawi opened 2 years ago

jhennawi commented 2 years ago

Description

The following code example works perfectly for ndim = 10, but Jax returns nan for ndim=100. It stars failing at ndim=15. I am working in double precision by setting the environment variable JAX_ENABLE_X64=True


import jax
import jax.numpy as jnp
import jax.random as random

def logprob_analytic(x_samp, mu, covar):

    nspec = x_samp.shape[0]
    diff = x_samp - mu
    x = jnp.linalg.solve(covar, diff)
    _, lndetC = jnp.linalg.slogdet(covar)
    lnL = -(jnp.dot(diff, x) + lndetC + nspec * jnp.log(2.0 * jnp.pi)) / 2.0
    return lnL

ndim = 10
seed = 42
key = random.PRNGKey(seed)
key, subkey = random.split(key)
# Create a random upper triangular matrix with random entries off the diagonal
L = random.normal(subkey, (ndim, ndim))
L = L.at[jnp.diag_indices_from(L)].set(jnp.sqrt(0.1 * jnp.exp(L[jnp.diag_indices_from(L)])))
L = L.at[jnp.triu_indices_from(L,k=1)].set(0.0)
# Cholesky decomposition of covariance
covar = jnp.matmul(L, L.T)
# Mean is zero, generate one sample
mu = jnp.zeros(ndim)
nsamp = 1
key, subkey = random.split(key)
x_samp = random.multivariate_normal(subkey, mu, covar, shape=(nsamp,), method='svd').squeeze()

# Evaluate the logprob in two ways
lnP_ana = logprob_analytic(x_samp, mu, covar)
lnP_jax = jax.scipy.stats.multivariate_normal.logpdf(x_samp, mu, covar)

assert jnp.isclose(lnP_ana, lnP_jax)

What jax/jaxlib version are you using?

Jax version 0.3.13

Which accelerator(s) are you using?

Additional System Info

Python 3.9.7, MacOS M1 ARM,

apaszke commented 2 years ago

I tracked down the issue to our logpdf using a Cholesky decomposition, which for some reason doesn't succeed in this case. You code uses a more general (and expensive) linear solver. I'm not a numerical stability expert so I'm not sure what's the best way to proceed. @hawkinsp might be able to hint at some improvements here.

jhennawi commented 2 years ago

To be fair, the old scipy.stats.multivariate_normal.logpdf also fails in this case. It does however fault with a singular matrix warning, but if you set allow_singular=True it continues but gives bad results. I think the issue here is accumulation of roundoff error when you perform cholesky decomposition for large matrices. It is clear from slogdet and how the covar was constructed that it is not singular. A better implementation here would switch to the more robust linear solver implementation for possibly (numerically) singular covariance matrices.