Closed fehiepsi closed 5 years ago
HI @fehiepsi, I have been using pyro for mvn posterior estimation using HMC/NUTS. I wanted to use numpyro.
The following code works fine with num_chains=1
, but fails for num_chains=2
. I have installed numpyro via pip install git+https://github.com/pyro-ppl/numpyro
as the normal pip install numpyro
doesn't have MultivariateNormal in the distribution: AttributeError: module 'numpyro.distributions' has no attribute 'MultivariateNormal'
CODE:
import jax
import jax
import jax.numpy as np
from jax import random
from jax.config import config; config.update("jax_platform_name", "cpu")
from jax.scipy.special import logsumexp
import numpy as onp
import pickle
import warnings; warnings.filterwarnings("ignore")
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro.handlers import sample, seed, substitute, trace
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import mcmc
import time
rng = random.PRNGKey(0)
with open('data_unit.pickle' , 'rb') as f:
data_unit = pickle.load(f)
def mvn_model(y):
nu = sample('nu', dist.Uniform(0., 4.))
sigma = sample('sigma', dist.LogNormal(loc=np.zeros(4), scale=np.ones(4)))
L_omega = sample("L_omega", dist.LKJCholesky(dimension=4, concentration=nu))
L_Omega = np.matmul(L_omega, np.diag(np.sqrt(sigma)))
mu = sample('mu', dist.MultivariateNormal(loc=np.zeros(4), scale_tril=L_Omega))
obs = sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y)
return obs
rng, rng_ = random.split(rng)
init_params, potential_fn, constrain_fn = initialize_model(rng_, mvn_model, y=data_unit)
start_time = time.time()
posterior_samples = mcmc(num_warmup=10,
num_samples=250,
init_params=init_params,
num_chains=2,
potential_fn=potential_fn,
trajectory_length=10,
target_accept_prob=0.9,
constrain_fn=constrain_fn,
print_summary=False)
print(f'It took {time.time() - start_time}')
ERROR:
Traceback (most recent call last):
File "numpyro-test.py", line 44, in <module>
print_summary=False)
File "/home/myntra/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/numpyro/mcmc.py", line 435, in mcmc
init_params_i = tree_map(lambda x: x[i], init_params)
File "/home/myntra/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/jax/tree_util.py", line 62, in tree_map
new_children = [tree_map(f, child) for child in children]
File "/home/myntra/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/jax/tree_util.py", line 62, in <listcomp>
new_children = [tree_map(f, child) for child in children]
File "/home/myntra/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/jax/tree_util.py", line 65, in tree_map
return f(tree)
File "/home/myntra/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/numpyro/mcmc.py", line 435, in <lambda>
init_params_i = tree_map(lambda x: x[i], init_params)
File "/home/myntra/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 2076, in _rewriting_take
return lax.index_in_dim(arr, idx, axis, False)
File "/home/myntra/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/jax/lax/lax.py", line 1152, in index_in_dim
axis_size = operand.shape[axis]
IndexError: tuple index out of range
Hi @arunpatro, because you use 2 chains, I think that you would need to make init_params
have batch size 2. To do that, just feed a rng with batch size 2 to initialize_model
:
rngs = random.split(rng, 2)
init_params, potential_fn, constrain_fn = initialize_model(rngs, mvn_model, y=data_unit)
Please let me know if that does not work out for you. A small (not important) note is that you don't need to set trajectory_length=10
because it is only used for algo='HMC'
. In addition, because you are running your script in CPU, try it with num_warmup=10000
, num_samples=10000
, I think that the inference is still fast. :)
By the way, could you let us know about what you think about the current interface? Is it hard to use, counter intuitive, or something not be documented enough? Any feedbacks from you are very welcome! :)
@neerajprad In case you have access to a 2-GPU machine, could you help me test the examples on multi GPUs)? I think that it should work natively but I don't have 2 GPUs to test. Thanks!
Sure, I'll post an update on this tomorrow.
@fehiepsi - I can confirm that this works fine with multiple GPUs too. Feel free to close this issue. At some point, it will be nice to be able to do this dynamically than having to change the xla args.
Thanks, @neerajprad !
With #198, we can use
mcmc
to withnum_chains > 1
. Let's test if it works for all of our examples.num_chains
args to current examples and test it in Travisnum_chains
args to time series forecasting notebook~ just make a separate issue #216 for this workLet's report issues here too
standard_gamma
to support batched key and non-batchalpha
. The current version only works with batched key and batched alpha.