pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.19k stars 239 forks source link

implicit sampling statements #600

Closed ross-h1 closed 4 years ago

ross-h1 commented 4 years ago

Hi guys, I'm looking at implementing this in Numpyro, which seems an excellent way of removing indeterminate nature of Probablistic PCA:

https://github.com/RSNirwan/HouseholderBPCA/tree/master/py_stan_code

The original code is written in Stan.

The component variances (sigma) are 'implicitly' sampled in the Stan code as a positive ordered vector, and the model logprob is incremented directly, wondering how something similar might be implemented in Numpyro?

Thank you in advance... Ross.

fehiepsi commented 4 years ago

Hi @ross-h1 , about positive ordered vector, I think it is a ComposeTransform of OrderedTransform and ExpTransform. To declare such positive_ordered_vector constraint, you can do the same as ordered_vector constraint.

What do you mean by "implicit"? Do you mean improper prior or something? NumPyro MCMC supports improper prior through param primitive:

numpyro.param("sigma", init_value, constraint=positive_ordered_vector)
ross-h1 commented 4 years ago

Yep is improper and not explicitly sampled. Ok so declare this as a parameter with some initial value, and I can increment model Log Probability via Numpyro.factor?

Thank you, Ross

On 20 May 2020, at 07:13, Du Phan notifications@github.com wrote:

 Hi @ross-h1 , about positive ordered vector, I think it is a ComposeTransform of OrderedTransform and ExpTransform. To declare such positive_ordered_vector constraint, you can do the same as ordered_vector constraint.

What do you mean by "implicit"? Do you mean improper prior or something? NumPyro MCMC supports improper prior through param primitive:

numpyro.param("sigma", init_value, constraint=positive_ordered_vector) — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.

fehiepsi commented 4 years ago

Yes, you can add custom log_prob using factor. For the log_prob of unconstrained sigma, it will be added automatically through the transform from unconstrained domain to positive ordered domain, so no need to worry about it. :)

Do you want to make a PR for positive_ordered_vector constraint? It would be nice to turn our discussion into something useful for other users. :)

vanAmsterdam commented 4 years ago

is something going to change to the way param states are handled by MCMC? when installing from source I get different behavior than when installing through pip

aren't model1 and model2 supposed to be equivalent? (https://github.com/pyro-ppl/numpyro/blob/d52c7f0a8adde52408a45bf8fb145fbf28e676b8/notebooks/bayesian_regression.ipynb)

now when I use model2 I get the message from here: https://github.com/pyro-ppl/numpyro/blob/67f3c4d014c2a87ee150041cbc70a360ba76bfd9/numpyro/infer/util.py#L228

from jax import numpy as np, random
import numpyro
from numpyro import distributions as dist, sample, param
from numpyro.distributions import constraints
from numpyro.infer.mcmc import NUTS, MCMC

num_chains = 2
numpyro.set_host_device_count(num_chains)

simkey  = random.PRNGKey(1)
nsim    = 100
ptrue   = 0.3
Y       = random.bernoulli(simkey, ptrue, shape=(nsim,))

def model1(Y):
    theta = sample('theta', dist.Beta(1,1))
    with numpyro.plate('obs', Y.shape[0]):
        sample('Y', dist.Bernoulli(probs=theta), obs=Y)
def model2(Y):
    theta = param('theta', np.array(0.5), constraint=constraints.unit_interval)
    sample('theta_obs', dist.Beta(1,1), obs=theta)
    with numpyro.plate('obs', Y.shape[0]):
        sample('Y', dist.Bernoulli(probs=theta), obs=Y)

mcmc_key = random.PRNGKey(12345)
kernel = NUTS(model1)
mcmc   = MCMC(kernel, num_warmup=250, num_samples=750, num_chains=num_chains)
mcmc.run(mcmc_key, Y)
mcmc.print_summary()

kernel = NUTS(model2)
mcmc   = MCMC(kernel, num_warmup=250, num_samples=750, num_chains=num_chains)
mcmc.run(mcmc_key, Y)
mcmc.print_summary()

jax.version = 0.1.68 jaxlib.version = 0.1.47 numpyro.version = 0.2.4

fehiepsi commented 4 years ago

@vanAmsterdam Yes, param nodes will be treated as constant. We found that it is better to distinguish sample and param. To get the old behavior, you can use

theta = sample('theta', dist.Uniform(0, 1).mask(False))

What do you think about this change? We love to hear more about your opinion.

vanAmsterdam commented 4 years ago

In terms of consistency of writing code I would say yes. I personally was a bit puzzled by this param behavior in the context of MCMC before, although it is more in line with the stan way of declaring parameters first and then sampling priors later.

Will it be possible to add constraints to the sample statements? I need an ordered vector parameter (for the OrderedLogistic distribution). How would that work without using the param statement?

BTW do I get this right that the dist.Uniform(0,1) in your example provides the restriction on the domain and then the .mask(False) makes the MCMC sampler ignore this prior distribution? If that is the case then for readability I guess I would prefer an interface where you could write sample(‘theta’, constraints=unit_interval), but that comes pretty close to merging the functionality of param and sample...

fehiepsi commented 4 years ago

Whoa, good point!!! I thought that every constraint has a corresponding distribution but you are right about ordered logistic. We should fix this soon. I like the idea of ImproperUniform in Pyro: https://github.com/pyro-ppl/pyro/pull/2495/files#diff-6ed267332a6a0424d3099a26160813c2 , so the code will be something like sample(..., dist.ImproperUniform(support)). How about that?

@fritzo @neerajprad I think this is a great usage case of ImproperUniform. WDYT? For sample method, we can just simply draw random samples in (-2, 2) and transform it. We might add a warning that this is only useful for MCMC, not SVI.

vanAmsterdam commented 4 years ago

ImproperUniform(support) would make the code more explicit compared to dist.Uniform().mask(False) in my opinion, so that's a plus. Would it be able to incorporate every constraint though? Aren't constraints (like 'ordered_vector') more general?

I guess a solution that allows the user to combine any constraint with any distribution (like in stan) would give the highest flexibility

fehiepsi commented 4 years ago

Yes, it should work for any bijective constraint including ordered_vector.

a solution that allows the user to combine any constraint with any distribution (like in stan) would give the highest flexibility

I think using numpyro.factor or an observed statement as you did should work. I agree that a solution would be more flexible, but a bit worried that we need some interface changes. Alternatively, we can have something like SomeNameDistribution(prior, support) but I am not sure if it is worth introducing... I can recall that some users even altered the support of prior to get the same purpose. Personally, I think we should have more words/examples/docs about using factor or obs statements. But let's be open to other alternatives.

ross-h1 commented 4 years ago

I think some examples with factor/ param in the context of sampling might be really helpful.

A tutorial going from a model set up for MCMC to running the same model in VI in Numpyro would also be great...

On 30 May 2020, at 16:08, Du Phan notifications@github.com wrote:

 Yes, it should work for any bijective constraint including ordered_vector.

a solution that allows the user to combine any constraint with any distribution (like in stan) would give the highest flexibility

I think using numpyro.factor or an observed statement as you did should work. I agree that a solution would be more flexible, but a bit worried that we need some interface changes. Alternatively, we can have something like SomeNameDistribution(prior, support) but I am not sure if it is worth introducing... I can recall that some users even altered the support of prior to get the same purpose. Personally, I think we should have more words/examples/docs about using factor or obs statements. But let's be open to other alternatives.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.

fehiepsi commented 4 years ago

Hi @ross-h1 , I think you can find the usage of factor in HMM example. FYI, we will move autoguide from contrib to the main infer module soon. For now, you can see how to use MCMC and SVI to get posteriors in neutra example. But it would be great to have more user-contributed examples to illustrate those functionalities. Would you mind creating feature requests for each functionality? I'll tag them properly so we can attract more contributors. :)

fehiepsi commented 4 years ago

Closed because all the questions have been addressed. I think we can have further discussions in forum or in a new topic. Thanks for all the feedback so far, @ross-h1 and @vanAmsterdam!

tavin commented 2 years ago

As a newcomer it seems rather odd that I have to pass batch_shape=() and event_shape=() to ImproperUniform() when I don't have to do this for any other distribution. Is there a reason why those args don't have defaults of ()?

fehiepsi commented 2 years ago

Hi @tavin there are many distributions that have non-trivial event_shape. We discussed previously if we wanted to have batch_shape=() by default then decided that making those arguments explicitly seems to be better

tavin commented 2 years ago

Thanks @fehiepsi. It's not so hard to deal with event_shape because it has to be dealt with implicitly anyway when using other distributions, e.g. when passing ndarrays of locations and scales. But I still don't understand batch_shape in the numpyro world. Of course I have a basic idea of what it must mean from general experience. But I wouldn't know how or why to assign it any value other than (). I can run inference on all kinds of models without knowing that, then I throw in an ImproperUniform, and just wonder if I'm doing the right thing with batch_shape.

fehiepsi commented 2 years ago

Yeah most of the time you can leave batch_shape to (). For other distributions, batch shape is interpreted from the parameter inputs. There is no input parameter for ImproperUniform . Probably it is less confusing to use prototype_sample as input, rather than batch shape and event shape. In NumPyro, shape of a distribution sample is batch_shape+event_shape. If you use ImproperUniform with real support, event shape will be () - so if that random variable has shape (4,) the batch shape needs to be (4,)

tavin commented 2 years ago

@fehiepsi I appreciate your explanation but I'm a bit more confused now. And I can't find the docs on prototype_sample.

Let's say I want to replace the following with ImproperUniform (because I am going to do something weird with log_prob):

Normal(jnp.zeros(shape))

Which of these do I use?

ImproperUniform(constraints.real, (), shape)
ImproperUniform(constraints.real, shape, ())
fehiepsi commented 2 years ago

The second one is the one you need to use because Normal(jnp.zeros(shape)) has batch shape shape and event shape (). In all distribution, batch shape is the shape of log probs.

Re prototype_sample: that is just a suggested api to replace the pairs batch shape and event shape.

tavin commented 2 years ago

@fehiepsi I'm looking forward to seeing how the api evolves, with prototype_sample etc., and I appreciate the correction on the existing method signatures -- as I had it backwards. Indeed the Usage notes concentrate on using an empty batch_shape. Even the names batch_shape and event_shape suggest you should focus on event_shape and leave batch_shape to the training framework -- at least for me.

So I'd like to share 2 observations. One: my confusion would have been avoided if ImproperUniform had only a single required shape parameter, with () being the default for the other. Two: I tried swapping batch_shape and event_shape in my model and it made no difference. By that I mean the final output mcmc.get_samples() was exactly the same. I hope this helps either as feedback for the future api or for the next person who is unsure how to specify these parameters.

fehiepsi commented 2 years ago

Good idea on using shape instead of the pair batch_shape and event_shape. Could you make a separate feature request for it? I think we'll need to go through several deprecation steps if we decide to make such change.