pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.49k stars 982 forks source link

[Feature Request] Analogue to TraceELBO class, but with MMD instead of KL #1780

Open varenick opened 5 years ago

varenick commented 5 years ago

Feature:

A new MMDTraceELBO class, that will implement a Maximum Mean Discrepancy between samples from guide an from model instead of KL-divergence as in TraceELBO class.

Motivation:

Elbo is a sum of an expected loglikelihood and a minus KL-divergence between the posterior distribution and the prior. In order to compute a KL-term, we have to either have an ability to compute log-probabilities of both prior and posterior distributions at posterior samples, or train a classifier to distinguish between prior and posterior samples. The second alternative have not been implemented in pyro yet, however, using a classifier for computing density-ratios leads to a minimax-game objective and seems quite unreliable.

In Wasserstein Auto-Encoder paper https://arxiv.org/abs/1711.01558 authors propose two alternatives to distinguish between prior and posterior distributions: the first one is training a classifier, as discussed above, and the second one is using a Maximum Mean Discrepancy (MMD) instead of KL.

Advantages of MMD:

  1. Requires only samples from prior and posterior distributions, does not require explicit log-probabilities;
  2. Does not produce a minimax-game objective.

The main disadvantage of using MMD instead of KL is that the former does not provide us a valid variational lower bound for evidence. However, it leads us to an approximation for an optimal transport cost between training dataset and model distribution.

If this looks acceptable, I would like to try to implement this.

eb8680 commented 5 years ago

@varenick sure, PRs are welcome! You should be able to use some of the existing kernels in pyro.contrib.gp. How are you thinking of computing the MMD for models with multiple variables? Are you planning to use one additive kernel per latent variable and combine them with a sum?

karalets commented 5 years ago

Also check out this line of literature https://arxiv.org/abs/1608.04471 for Stein estimators.

On Sat, Mar 2, 2019 at 7:07 AM Eugene Golikov notifications@github.com wrote:

Feature:

A new MMDTraceELBO class, that will implement a Mean Measure Discrepancy between samples from guide an from model instead of KL-divergence as in TraceELBO class. Motivation:

Elbo is a sum of an expected loglikelihood and a minus KL-divergence between the posterior distribution and the prior. In order to compute a KL-term, we have to either have an ability to compute log-probabilities of both prior and posterior distributions at posterior samples, or train a classifier to distinguish between prior and posterior samples. The second alternative have not been implemented in pyro yet, however, using a classifier for computing density-ratios leads to a minimax-game objective and seems quite unreliable.

In Wasserstein Auto-Encoder paper https://arxiv.org/abs/1711.01558 authors propose two alternatives to distinguish between prior and posterior distributions: the first one is training a classifier, as discussed above, and the second one is using a Maximum Mean Discrepancy (MMD) instead of KL.

Advantages of MMD:

  1. Requires only samples from prior and posterior distributions, does not require explicit log-probabilities;
  2. Does not produce a minimax-game objective.

The main disadvantage of using MMD instead of KL is that the former does not provide us a valid variational lower bound for evidence. However, it leads us to an approximation for an optimal transport cost between training dataset and model distribution.

If this looks acceptable, I would like to try to implement this.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/pyro-ppl/pyro/issues/1780, or mute the thread https://github.com/notifications/unsubscribe-auth/ABVhL9pLNAJc7su1a4flsy4J0ieNz7Hkks5vSpPGgaJpZM4babhi .

varenick commented 5 years ago

@eb8680 Thanks for a tip with existing kernels; I didn't know about them.

I was thinking of using (generally) different kernel k_i(*,*) per latent variable z_i and combine them with a (weighted) sum. Since all latent variables live in different spaces, the resulting kernel k(*,*) for a joint latent variable z = (z_1, ..., z_n) breaks down into sum of kernels: k(z,*) = Sum_{i=1}^n k_i(z_i,*). Since c k(*,*) is a kernel as long as k(*,*) is a kernel, the weighted sum of kernels for every latent variables is also valid: k(z,*) = Sum_{i=1}^n c_i k_i(z_i,*).

varenick commented 5 years ago

@eb8680 I've recently made a working prototype, planning to make a PR soon. I have a small problem: I don't know how to name the corresponding class.

Candidates:

  1. MMD_ELBO. Intuititive, but incorrect: it is not a valid variational lower bound for evidence.
  2. MMD_PseudoELBO. Better, but PseudoELBO is not a commonly-used term.
  3. MMD_VAE_Loss. Refers to Ermon Group blogpost based on InfoVAE paper. Not actually good, since it explicitly mentions VAE model.
  4. MMD_Based_ELBO_Approximation. Formally correct, but mentions ELBO explicitly, and looks too long.
  5. MMD_Based_Evidence_Variational_Approximation. Formally correct, but way too long.

Could you please suggest the name? I've only see such an objective in the context of VAE: see Ermon Group blogpost, where it is called MMD-VAE, and InfoVAE paper, where it is called InfoVAE.

fritzo commented 5 years ago

How about Trace_MMD?

eb8680 commented 5 years ago

@varenick great! Looking forward to seeing your PR. I agree with @fritzo's suggestion of Trace_MMD.

varenick commented 5 years ago

@fritzo @eb8680 Hmm, Trace_MMD suggests that there is only an MMD term in the objective, however there are two terms: expected log-likelihood of the observed data, and MMD between marginal variational posterior and prior distributions. May be, Trace_MMD_Variational_Loss or Trace_MMD_Variational_Objective?

varenick commented 5 years ago

@fritzo @eb8680 I've tried to push my branch into remote repo, but git returns error 403: permission denied

eb8680 commented 5 years ago

You'll need to fork Pyro on GitHub and push your branch to that fork.

On Wed, Apr 10, 2019, 6:56 AM Eugene Golikov notifications@github.com wrote:

@fritzo https://github.com/fritzo @eb8680 https://github.com/eb8680 I've tried to push my branch into remote repo, but git returns error 403: permission denied

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/pyro-ppl/pyro/issues/1780#issuecomment-481701753, or mute the thread https://github.com/notifications/unsubscribe-auth/AB8CwF0posqjI_dyPArE0vJctWKv7ZWbks5vfe2pgaJpZM4babhi .

wthrif commented 5 years ago

Hey @varenick would you be willing to share and example VAE code with your Trace_MMD class? I'm getting a shape error that I don't understand when I try to run it: ValueError: Shape mismatch inside plate('num_particles_vectorized') at site obs dim -2, 32 vs 1024 Full disclosure, I'm new to pyro and don't know what I'm doing. I set num_particles to be the same as my batch size, 32, and used an rbf kernel, with dimensions set as the same as my latent space (2). The 1024 number is the num_particles*batch size Otherwise I'm just replacing the typical Trace_ELBO with Trace_MMD in a regular VAE code that already works. The code runs if I set num_particles to 1. Although it doesn't converge to a useful latent space as the same infoVAE does made in pytorch.

eb8680 commented 5 years ago

Although it doesn't converge to a useful latent space as the same infoVAE does made in pytorch.

@wthrif this is to be expected, since the Trace_MMD loss is not the same as the infoVAE loss. See #1818 for discussion. Unfortunately, implementing general-purpose inference algorithms is difficult and I'm not sure Trace_MMD as it currently exists is very useful. Without additional examples I'm inclined to remove it (at least temporarily) before the upcoming 0.4 release since we don't have spare cycles to maintain or improve it.

@varenick also wrote a nice example notebook in #1818 with a version of Trace_MMD that specifically replicates the infoVAE loss, but which is not correct for arbitrary models. If you're up for it (and @varenick doesn't mind), you could take that notebook, turn it into an example script similar to other examples in the examples/ folder, and submit it as a PR - I think @varenick did a great job with that notebook and lots of users would appreciate it as an example of a more complicated custom loss function.

wthrif commented 5 years ago

Thanks for the link @eb8680 I'll work on implementing it.