daib13 / TwoStageVAE

230 stars 33 forks source link

Optimizing gamma to zero and mode collapse #8

Open chanshing opened 4 years ago

chanshing commented 4 years ago

Thank you for your very nice work. It was a very good read. If you don't mind, I have a few questions:

  1. You argued for the importance of optimizing gamma, and showed that as gamma goes to zero, the VAE reconstructs the same x for any z~q(z|x). But do we really want to get this scenario? Isn't this mode collapse?
  2. If I understood correctly, the above happens because the injected noise is rapidly scaled down by the encoder variance, which goes to zero as gamma goes to zero. I think this means that q(z|x) is converging to a delta on some of its dimensions (specifically, on the nonsuperfluous ones). Then how is this nonzero measure?
  3. Some works use pixel-wise gammas, instead of a scalar one. Does your work easily generalize to this case?
  4. Given the interesting insights from your paper, I now wonder what should we optimize for. What metric should we track e.g. for early stopping, model selection, etc? a) The VAE loss b) Expectation (under z~q(z|x)) of reconstruction loss (since you argue that perfect reconstruction happens at the optima) c) Deterministic reconstruction loss (i.e. using mean of q(z|x)) d) Wait until gamma gets under certain small threshold
daib13 commented 4 years ago

Hi @chanshing

  1. About mode collapse. \gamma goes to zeros does not correspond to the mode collapse issue. Let the latent space be R^\kappa, the data manifold be \chi and the generation function be f. Mode collapse means the generated manifold \chi^\prime = {f(z) | z \in R^\kappa} is a subset of \chi. In the \gamma -> 0 scenario, for every x\in\chi, there exists a z\in R^\kappa such that f(z)=x. As long as the network capacity is enough, VAE will not have the mode collapse issue. However, it will have a different issue when \chi is not diffeomorphic to a Euclidean space. That is \chi becomes a subset of \chi^\prime = {f(z) | z \in R^\kappa}. Our paper didn't discuss this case but we believe this is one of the key reasons why VAE cannot generate samples as good as GAN models.

  2. About the nonzero measure. Yes your understanding is correct. Note that in our paper corresponding to this repository, the objective function is integrated over the whole manifold. So {\mu(x) | x\in\chi} occupies r latent dimensions where r is the manifold dimension of \chi and the noise will fill up the rest \kappa - r dimensions. So q(z|x) will occupy the whole R^\kappa latent space.

  3. About pixel-wise gammas. If \chi is a noiseless manifold as assumed in our paper, there is no need to use pixel-wise \gamma at all since it is easy to prove that all the \gammas will converge to 0. However, if \chi is contaminated by some noise, using pixel wise \gamma could be helpful. In our JMLR paper [1], we proved that VAE is a nonlinear extension of the robust PCA model, which can decompose the contaminated data into a low-rank component and a sparse noise component. Of course there are many other works using pixel wise \gammas in different scenarios, it is difficult to give a general comment on these works.

  4. About what should be optimized for. I think one of the key points in our paper is that there is no single metric that we can track to obtain good generation performance. There are two equally important things: 1) detect the manifold in the ambient space and 2) learn the distribution within the manifold. You mentioned four candiates to track in your question. The VAE loss is important for both purpose 1) and 2). (b) and (c) are the same thing (refer to theorem 5 in the paper). They serve for the first purpose. (d) also serves for the first purpose. For the first VAE, it will push the VAE loss to negative infinite, \gamma to 0 and the reconstruction error to 0. All these things happen together. However, even though these are achieved, it does not mean VAE can generate good samples. We have to use another VAE for the second purpose.

[1] Dai B, Wang Y, Aston J, et al. Connections with robust PCA and the role of emergent sparsity in variational autoencoder models[J]. The Journal of Machine Learning Research, 2018, 19(1): 1573-1614.

chanshing commented 4 years ago

@daib13 Thank you so much for the thorough response.

Regarding your last sentence in point 2: did you mean to say q(z) (not q(z|x)) will occupy whole R^\kappa?... since q(z|x) will be delta on the nonsuperfluous dimensions...

If you don't mind, I would like to share a few observations from my experiments. I trained a VAE on some custom dataset with learnable gamma, but thresholded it at 1e-4 (otherwise I get nans due to precision errors). The following is gamma during training: gamma

My latent dimension kappa is 200. The following is the variance of each during training: z_logvar

We observe that when gamma is decreasing, all my variances decrease with it. This may suggest that I am not in the setting kappa < r (I don't have enough latent dimensions). However, when gamma finally stalls at the threshold of 1e-4, I see that some of the variances start to go up. My intuition is that at this regime, given that we have fixed gamma, the model is finally allowed to find a parsimonious representation within this gamma. In other words, we are accepting a noise level of 1e-4, and therefore the model is allowed to discard some information.

I would love to hear your insights on this.

daib13 commented 4 years ago

@chanshing

  1. About point 2 in my last response. Yes I mean q(z) will occupy the whole R^\kappa.

  2. About your experiment results. These results are interesting and I agree with your intuition. There are two phases during training. In the first phase, the reconstruction term (including the d*log\gamma term) dominates because pushing \gamma to a marginally smaller value will make dlog\gamma dramatically smaller. This comes at a cost that \sigma_z will also become very small, introducing a term -log\sigma_z going to infinite. But note that the dimension of x is much larger than the dimension of z. So the d*log\gamma term will overwhelm the log\sigma term. In the second stage when \gamma is thresholded, we can remove the d\log\gamma term from the loss function assuming that \gamma always equal to 1e-4. Then -log\sigma_z becomes really important. If the model can make \sigma_z slightly larger at the cost of a slightly worse reconstruction, it can make the loss further smaller. In this phase, the model is actually trying to find a parsimonious representation within this gamma as you said.

chanshing commented 4 years ago

Thank you so much for your insights @daib13... really appreciated it!