pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.32k stars 246 forks source link

[FR] Implement a surrogate loss for discrete, e.g. categorical, distributions #896

Closed FlorianWilhelm closed 3 years ago

FlorianWilhelm commented 3 years ago

As discussed in this thread right now using SVI in combination with categorical distributions, e.g. used in gaussian mixture models, leads to a wrong calculation of ELBO gradients, and thus unexpected results might be the consequence.

In order to fix this, one approach as explained in the NumPyro's SVI docs is using a surrogate loss.

fritzo commented 3 years ago

@eb8680 can we use your DiCE implementation from pyro.contrib.funsor?

fehiepsi commented 3 years ago

Just to remind for the future: when this PR is resolved, we would like to change the following paragraph in README: There is also a basic Variational Inference implementation for reparameterized distributions...

eb8680 commented 3 years ago

can we use your DiCE implementation from pyro.contrib.funsor?

Something like this seems like the right way to go, since a Dice/score function-based ELBO implementation without any Rao-Blackwellization, even the coarse kind we do based on plates in Pyro, would be pretty useless for practical purposes. For now we should just raise a NotImplementedError in NumPyro when a discrete variable is encountered in SVI and point people to DiscreteHMCGibbs + HMCECS, as @fehiepsi did in the forum thread above. I'm not sure it's a good idea to implement something super simple and tell people to use it unless it's at least equivalent to pyro.infer.Trace_ELBO in variance reduction.

FlorianWilhelm commented 3 years ago

@eb8680 From a user's perspective if something super simple works, but maybe not as good as pyro.infer.Trace_ELBO in variance reduction, I would appreciate it and one could emit a UserWarning with it. Just raising a NotImplementedError might be a little too restrictive for my taste as even the current approach of not handling Categoricals at all might work in some models but results might be much worse of course. At least that's my understanding from the threat I posted but correct me if you think that the current implementation will almost always lead to completely wrong inferences.

eb8680 commented 3 years ago

@FlorianWilhelm I'm not sure what you mean when you say the current approach might work in some cases - right now there is no ELBO gradient estimate computed at all for non-reparametrized variables, not even a biased one, so the posterior obtained for those variables using NumPyro's SVI implementation would actually be completely independent of the data.

The purpose of the extra complexity in pyro.infer.Trace_ELBO is to ensure that score-function gradient estimates for variational parameters of latent variables local to a particular datapoint (like the component ID in a mixture model) depend only on ELBO terms associated with that datapoint. As you can imagine, this is an absolute minimum requirement; otherwise the variance of these gradient estimators would grow with the number of datapoints, making inference impossible in all but the most trivial cases.

Note that this same minimum level of variance reduction for reparametrization gradients is trivially obtained through automatic differentiation, which is why the current ELBO implementation in NumPyro can be so simple.

We're working on porting some of Pyro's advanced variational inference features to NumPyro, but in the meantime your best options are either to try the new MCMC algorithms for discrete variables in NumPyro or rewrite your model to use Pyro.

FlorianWilhelm commented 3 years ago

Thanks for your reply @eb8680. Just to be 100% sure I got it, assume I have in my model 𝜙𝑘 ∼ Dirichlet(𝜶) and then 𝜑 ∼ Categorical(𝝓) and my observed 𝑟 strongly depends on 𝜑, I understood that there is no ELBO gradient estimate for 𝜑 but since I parametrized 𝜙𝑘 would I get the right results for it or is the chain of gradients just broken at the point 𝜑? In case of the latter, that would mean that even 𝜙𝑘 would be wrongly inferred as 𝜑 would wrongly depend on all data points.

Regarding the porting of Pyro's advanced variational inference features to NumPyro, thanks a lot for that! Do you have a rough time estimate for that just to help me decide if it's worth the work to port my whole model to Pyro when eventually I would switch back to NumPyro for speed reasons once the feature is implemented.

Thanks for your great work.

eb8680 commented 3 years ago

there is no ELBO gradient estimate for 𝜑 but since I parametrized 𝜙𝑘 would I get the right results for it or is the chain of gradients just broken at the point 𝜑?

The approximate posterior for 𝜙𝑘 obtained from NumPyro's current ELBO implementation would be completely independent of the data because it only affects the likelihood terms in the ELBO through the non-reparametrizable variable 𝜑.

Do you have a rough time estimate for that just to help me decide if it's worth the work to port my whole model to Pyro when eventually I would switch back to NumPyro for speed reasons once the feature is implemented.

That depends on the particular combination of features you need, but it may be a couple months, so I don't recommend waiting. I also suspect that you will have a lot of trouble getting SVI to work for your problem if enumeration isn't applicable (as seemed to be the case from your forum post), and that you would be much better off with the new MCMC algorithms for discrete variables and subsampling in NumPyro anyway.

FlorianWilhelm commented 3 years ago

The approximate posterior for 𝜙𝑘 obtained from NumPyro's current ELBO implementation would be completely independent of the data because it only affects the likelihood terms in the ELBO through the non-reparametrizable variable 𝜑.

Okay, that's tough. Then I would also go for just raising a NotImplementedError.

Thanks for your reply regarding my options, so I have:

  1. Porting my code to Pyro and using TraceELBO,
  2. Using MCMC algorithms for discrete variables and subsampling in NumPyro.

As you seem to recommend Option 2 and also @fehiepsi mentioned it a lot, I would look at it first. But as far as I understood @fehiepsi DiscreteHMCECS is not yet implemented in NumPyro and I would have to use the plate directive for subsampling. These two restrictions also make Option 1 interesting as I need to resample the data set constantly as it's common in Collaborative Filtering Use-Cases.

FlorianWilhelm commented 3 years ago

Thanks, @fehiepsi!

FlorianWilhelm commented 3 years ago

@fehiepsi: In the end, I decided to go for TraceEnum_ELBO in Pyro and you find my code here now publically available. Since this issue is now closed, I guess I could also switch back to NumPyro and use TraceGraph, correct? Compared to Pyro, exhaustive enumeration like with TraceEnum_ELBO would still be missing in Numpyro and thus a migration could potentially be much slower. Is this correct?

fehiepsi commented 3 years ago

Hi @FlorianWilhelm, you can track the progress for TraceEnum_ELBO in #741. I'll take a look later in the week.