pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.13k stars 233 forks source link

Rhat and NEff are NaNs with NUTS #1340

Closed rusty-robin closed 2 years ago

rusty-robin commented 2 years ago

Hi,

I am testing some models with NumPyro and trying out various combinations to evaluate performance of MCMC algorithms. I observed that running a linear regression model with NumPyro causes the Rhat and Neff values to be NaN for some parameters.

I think this happens when the distribution parameters are large (e.g. for studentT, its of the order 10^31). If I change those then I dont see NaN values. However, are the NaNs expected in this case or should there be better handling for such extreme values?

Let me know what you guys think.

Code:

import jax
from jax import random
import jax.numpy as jnp
import numpyro
import numpy as np
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, NUTS
from numpyro.diagnostics import effective_sample_size, split_gelman_rubin
numpyro.set_platform("cpu")
import os
import sys

data = dict()
data['x'] =np.array([0.0970,2.1020,0.5840,1.0394,3.4375,1.3102,1.2863,4.5382,1.2539,2.9319,4.7777,4.5937,4.0403,0.7749,1.8342,1.5008,3.9557,1.6095,0.5602,4.7997,1.6436,4.5236,1.9404,1.1508,3.1447,3.3551,0.266\
9,3.2483,2.5293,0.1596,2.8992,2.5016,4.7589,2.5648,1.9871,0.3201,4.4863,0.5675,0.8897,0.5689,0.5621,0.5701,0.7998,0.3166,4.8490,2.7849,2.6858,4.2393,1.4939,1.7230,0.4202,0.3591,2.5912,3.9753,3.5417,0.3198,2.7\
192,1.9892,1.9525,4.7293,1.7719,3.7889,2.3613,4.6994,2.0155,3.8840,0.8289,2.1164,1.3342,1.4850,0.5512,0.3919,2.4001,3.1949,4.9759,1.4158,0.6088,2.1599,0.3643,0.6800,2.4031,1.7706,3.7669,4.1080,4.9707,1.9139,4\
.7615,3.3075,3.6872,4.0946,0.3447,2.7292,0.2768,3.7641,3.3987,0.1833,4.9353,1.9730,2.2034,2.8598])
data['y'] =np.array([4.9404,45.0399,14.6792,23.7890,71.7504,29.2036,28.7251,93.7642,28.0774,61.6389,98.5535,94.8738,83.8054,18.4976,39.6836,33.0162,82.1149,35.1902,14.2033,98.9948,35.8724,93.4728,41.8081,26.0\
160,65.8950,70.1028,8.3385,67.9659,53.5864,6.1922,60.9839,53.0324,98.1783,54.2960,42.7421,9.4019,92.7252,14.3490,20.7931,14.3781,14.2424,14.4019,18.9951,9.3322,99.9797,58.6983,56.7166,87.7854,32.8781,37.4594,\
11.4034,10.1829,54.8241,82.5061,73.8340,9.3964,57.3845,42.7835,42.0501,97.5864,38.4374,78.7779,50.2264,96.9873,43.3109,80.6796,19.5773,45.3275,29.6839,32.6994,14.0248,10.8371,51.0026,66.8986,102.5188,31.3158,\15.1756,46.1974,10.2855,16.6008,51.0629,38.4128,78.3385,85.1597,102.4138,41.2770,98.2303,69.1497,76.7434,84.8919,9.8933,57.5845,8.5367,78.2816,70.9734,6.6665,101.7066,42.4600,47.0682,60.1959])

def model():
    b_param_0=numpyro.sample("b_param_0", dist.StudentT(2.5454194259963054E31,-.205185055113823E30,4.117323288668916E31))
    w=numpyro.sample("w", dist.Normal(10,1))
    b=numpyro.sample("b", dist.Normal(b_param_0,1))
    sigma=numpyro.sample("sigma", dist.Gamma(1,2))
    with numpyro.plate("size", np.size(data['y'])):
        numpyro.sample("obs67", dist.Normal(w*data['x']+b,sigma), obs=data['y'])

mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0))
params = mcmc.get_samples()
mcmc.print_summary()

Output:

            mean       std    median      5.0%     95.0%    n_eff     r_hat
b          -0.68      0.00     -0.68     -0.68     -0.68      0.50     1.00
b_param_0   1.60      0.00      1.60      1.60      1.60      0.50     1.00
sigma       2.45      0.00      2.45      2.45      2.45      nan      nan
w           1.93      0.00      1.93      1.93      1.93      nan      nan    
Number of divergences: 1000 

Environment:

Python 3.7
Numpyro 0.9.0
Ubuntu 18.04
fehiepsi commented 2 years ago

Hi @rusty-robin this typically happens when you have constant samples. Looking like with the default strategy, we start MCMC at the tail of student t prior, hence the geometry there will be tricky to guide the HMC sampler. You can try different initial strategies like init_to_sample but even so, working with large value is not robust. You might want to call numpyro.enable_x64() at the beginning of your program to help the inference a bit.

rusty-robin commented 2 years ago

Thanks for the response @fehiepsi

I see. What do you mean by constant samples here? The data points: x and y?

It seems there should better error messages here? Otherwise, there is no way to guide the user. Ideally, the statistics should also not be NaNs i feel.

I will be happy to contribute if you have any suggestions on addressing these points.

fehiepsi commented 2 years ago

I just looked at your b sigma and w std values.

It seems there should better error messages here? Otherwise, there is no way to guide the user. Ideally, the statistics should also not be NaNs i feel.

I'm not sure if we need to raise error messages here. Sometimes we just got constant samples (e.g. when drawing discrete latent values). Those r_hat, n_eff diagnostics can produce nan - but I'm not sure if getting constant samples are good or bad. How about adding a not to docstring of print_summary to mention that nan/inf might indicate that the variance of your samples is very small.

fehiepsi commented 2 years ago

Closed because the behavior is expected. Please feel free to open a PR that enhances the documentation to mention it.