pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.15k stars 235 forks source link

Test mcmc with chains #199

Closed fehiepsi closed 5 years ago

fehiepsi commented 5 years ago

With #198, we can use mcmc to with num_chains > 1. Let's test if it works for all of our examples.

Let's report issues here too

arunpatro commented 5 years ago

HI @fehiepsi, I have been using pyro for mvn posterior estimation using HMC/NUTS. I wanted to use numpyro.

The following code works fine with num_chains=1, but fails for num_chains=2. I have installed numpyro via pip install git+https://github.com/pyro-ppl/numpyro as the normal pip install numpyro doesn't have MultivariateNormal in the distribution: AttributeError: module 'numpyro.distributions' has no attribute 'MultivariateNormal'

CODE:

import jax
import jax
import jax.numpy as np
from jax import random
from jax.config import config; config.update("jax_platform_name", "cpu")
from jax.scipy.special import logsumexp
import numpy as onp
import pickle
import warnings; warnings.filterwarnings("ignore")

from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro.handlers import sample, seed, substitute, trace
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import mcmc
import time

rng = random.PRNGKey(0)

with open('data_unit.pickle' , 'rb') as f:
    data_unit = pickle.load(f)

def mvn_model(y):
    nu = sample('nu', dist.Uniform(0., 4.))
    sigma = sample('sigma', dist.LogNormal(loc=np.zeros(4), scale=np.ones(4)))
    L_omega = sample("L_omega", dist.LKJCholesky(dimension=4, concentration=nu))
    L_Omega = np.matmul(L_omega, np.diag(np.sqrt(sigma))) 
    mu = sample('mu', dist.MultivariateNormal(loc=np.zeros(4), scale_tril=L_Omega))
    obs = sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y)   
    return obs

rng, rng_ = random.split(rng)
init_params, potential_fn, constrain_fn = initialize_model(rng_, mvn_model, y=data_unit)

start_time = time.time()
posterior_samples = mcmc(num_warmup=10, 
                         num_samples=250, 
                         init_params=init_params,
                         num_chains=2,
                         potential_fn=potential_fn,
                         trajectory_length=10,
                         target_accept_prob=0.9,
                         constrain_fn=constrain_fn,
                         print_summary=False)
print(f'It took {time.time() - start_time}')

ERROR:

Traceback (most recent call last):
  File "numpyro-test.py", line 44, in <module>
    print_summary=False)
  File "/home/myntra/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/numpyro/mcmc.py", line 435, in mcmc
    init_params_i = tree_map(lambda x: x[i], init_params)
  File "/home/myntra/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/jax/tree_util.py", line 62, in tree_map
    new_children = [tree_map(f, child) for child in children]
  File "/home/myntra/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/jax/tree_util.py", line 62, in <listcomp>
    new_children = [tree_map(f, child) for child in children]
  File "/home/myntra/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/jax/tree_util.py", line 65, in tree_map
    return f(tree)
  File "/home/myntra/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/numpyro/mcmc.py", line 435, in <lambda>
    init_params_i = tree_map(lambda x: x[i], init_params)
  File "/home/myntra/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 2076, in _rewriting_take
    return lax.index_in_dim(arr, idx, axis, False)
  File "/home/myntra/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/jax/lax/lax.py", line 1152, in index_in_dim
    axis_size = operand.shape[axis]
IndexError: tuple index out of range
fehiepsi commented 5 years ago

Hi @arunpatro, because you use 2 chains, I think that you would need to make init_params have batch size 2. To do that, just feed a rng with batch size 2 to initialize_model:

rngs = random.split(rng, 2)
init_params, potential_fn, constrain_fn = initialize_model(rngs, mvn_model, y=data_unit)

Please let me know if that does not work out for you. A small (not important) note is that you don't need to set trajectory_length=10 because it is only used for algo='HMC'. In addition, because you are running your script in CPU, try it with num_warmup=10000, num_samples=10000, I think that the inference is still fast. :)

By the way, could you let us know about what you think about the current interface? Is it hard to use, counter intuitive, or something not be documented enough? Any feedbacks from you are very welcome! :)

fehiepsi commented 5 years ago

@neerajprad In case you have access to a 2-GPU machine, could you help me test the examples on multi GPUs)? I think that it should work natively but I don't have 2 GPUs to test. Thanks!

neerajprad commented 5 years ago

Sure, I'll post an update on this tomorrow.

neerajprad commented 5 years ago

@fehiepsi - I can confirm that this works fine with multiple GPUs too. Feel free to close this issue. At some point, it will be nice to be able to do this dynamically than having to change the xla args.

fehiepsi commented 5 years ago

Thanks, @neerajprad !