pyro-ppl / brmp

Bayesian Regression Models in Pyro
Apache License 2.0
70 stars 8 forks source link

Initialize numpyro NUTS with SVI #80

Open stanbiryukov opened 4 years ago

stanbiryukov commented 4 years ago

Hi there, great initial work wrapping numpyro and pyro into a more user friendly interface! I'm having an issue with a few simple models where the numpyro backend gives me the following error: Cannot find valid initial parameters. Please check your model again. It seems to occur when there are many predictors in the formula.

Pyro sampling and SVI both work fine for this model with the default Cauchy beta priors. Any thoughts on better initializing numpyro NUTS with SVI or perhaps using maximum a posteriori estimates? It's tough to figure out what exactly is causing MCMC to immediately fail but I'm assuming it's the initial starting values. Full traceback:

RuntimeError                              Traceback (most recent call last)
<ipython-input-185-9da937b2eeba> in <module>
----> 1 fit = model.fit(backend=numpyro, seed=8877, iter=1000, warmup=500)

/opt/conda/lib/python3.6/site-packages/brmp/__init__.py in fit(self, algo, **kwargs)
    173         """
    174         assert algo in ['prior', 'nuts', 'svi']
--> 175         return getattr(self, algo)(**kwargs)
    176 
    177     def nuts(self, iter=10, warmup=None, num_chains=1, seed=None, backend=numpyro_backend):

/opt/conda/lib/python3.6/site-packages/brmp/__init__.py in nuts(self, iter, warmup, num_chains, seed, backend)
    200         """
    201         warmup = iter // 2 if warmup is None else warmup
--> 202         return self.run_algo('nuts', backend, iter, warmup, num_chains, seed)
    203 
    204     def svi(self, iter=10, num_samples=10, seed=None, backend=pyro_backend, **kwargs):

/opt/conda/lib/python3.6/site-packages/brmp/__init__.py in run_algo(self, name, backend, df, *args, **kwargs)
    154         data = self.model.encode(df) if df is not None else self.data
    155         assets_wrapper = self.model.gen(backend)
--> 156         return assets_wrapper.run_algo(name, data_from_numpy(backend, data), *args, **kwargs)
    157 
    158     def fit(self, algo='nuts', **kwargs):

/opt/conda/lib/python3.6/site-packages/brmp/__init__.py in run_algo(self, name, data, *args, **kwargs)
     75 
     76     def run_algo(self, name, data, *args, **kwargs):
---> 77         samples = getattr(self.backend, name)(data, self.assets, *args, **kwargs)
     78         return Fit(self.model.formula, self.model.metadata,
     79                    self.model.contrasts, data,

/opt/conda/lib/python3.6/site-packages/brmp/numpyro_backend.py in nuts(data, assets, iter, warmup, num_chains, seed)
     86     # `num_chains` > 1 to achieve parallel chains.
     87     mcmc = MCMC(kernel, warmup, iter, num_chains=num_chains)
---> 88     mcmc.run(rng, **data)
     89     samples = mcmc.get_samples(group_by_chain=True)
     90 

/opt/conda/lib/python3.6/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, collect_warmup, init_params, *args, **kwargs)
    639         if self.num_chains == 1:
    640             states_flat = self._single_chain_mcmc((rng_key, init_params), collect_fields, collect_warmup,
--> 641                                                   args, kwargs)
    642             states = tree_map(lambda x: x[np.newaxis, ...], states_flat)
    643         else:

/opt/conda/lib/python3.6/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, collect_fields, collect_warmup, args, kwargs)
    582         rng_key, init_params = init
    583         init_state, constrain_fn = self.sampler.init(rng_key, self.num_warmup, init_params,
--> 584                                                      model_args=args, model_kwargs=kwargs)
    585         if self.constrain_fn is None:
    586             constrain_fn = identity if constrain_fn is None else constrain_fn

/opt/conda/lib/python3.6/site-packages/numpyro/infer/mcmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    409                 rng_key, rng_key_init_model = np.swapaxes(vmap(random.split)(rng_key), 0, 1)
    410             init_params_, self.potential_fn, constrain_fn = initialize_model(
--> 411                 rng_key_init_model, self.model, *model_args, init_strategy=self.init_strategy, **model_kwargs)
    412             if init_params is None:
    413                 init_params = init_params_

/opt/conda/lib/python3.6/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, *model_args, **model_kwargs)
    413     if not_jax_tracer(is_valid):
    414         if device_get(~np.all(is_valid)):
--> 415             raise RuntimeError("Cannot find valid initial parameters. Please check your model again.")
    416     return init_params, potential_fn, constrain_fun
    417 

RuntimeError: Cannot find valid initial parameters. Please check your model again.
null-a commented 4 years ago

Thanks for reporting this. I guess you're correct, and that we maybe we need to think more about initialization strategies at some point. If possible, could you share a simple example that reproduces the problem, so we might check this isn't some other bug. Thanks.

stanbiryukov commented 4 years ago

I realized part of the problem here is that brms automatically centers data whereas here we need to either specify more accurate priors or standardize the data before fitting with numpyro. Nevertheless, the pyro backends do work with non-centered columns in the example below. Agreed that it will be good to think through some initialization strategies.

import brmp
from brmp import brm
from brmp.numpyro_backend import backend as numpyro
from brmp.pyro_backend import backend as pyro
import pandas as pd

df = pd.read_csv('https://stats.idre.ucla.edu/stat/data/hdp.csv')
df = df.apply(lambda x: pd.factorize(x)[0] if np.issubdtype(x.dtype, np.number) is False else x) # factorize some columns
df['remission'] = df['remission'].astype(np.int)
df['DID'] = df['DID'].astype('category')
model = brm('remission ~ IL6 + CRP + CancerStage + LengthofStay + Experience + FamilyHx + SmokingHx + Sex + CancerStage + LengthofStay + WBC + BMI + (1 | DID)', df = df, family = brmp.family.Binomial(num_trials=1))
fit = model.fit(backend=numpyro, seed=8877, iter=1000, warmup=500) # fails
fit = model.fit(backend=pyro, seed=8877, iter=1000, warmup=500) # works
fit = model.fit(algo='svi', seed=8877, iter=10000, num_samples=1000) # works