openai / vdvae

Repository for the paper "Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images"
MIT License
434 stars 84 forks source link

plateau during training #8

Open georgosgeorgos opened 3 years ago

georgosgeorgos commented 3 years ago

Hi Rewon. Cool work!

I tested the model with your checkpoint and everything works perfectly. Now I am training VDVAE on CIFAR10 from scratch using one GPU (reducing the batch size 32 -->16 and the lr 2e-4 --> 1e-4). The model starts training without problems and then gets stuck in a plateau around ~4.7 nats/dim for a long time. I found a similar plateau with other configurations (smaller lr, smaller model).

Did you experience this plateau during training?

Thanks!

christopher-beckham commented 3 years ago

I've been toying with another dataset with this method (Omniglot) and it does seem very sensitive to hyperparameters. I am not sure if what you're describing is the same as what I'm getting, but for me most experiments just end up having a reconstruction loss plateauing around ~0.16 (n.b: using Gaussian/L2 error here, not logistic mixture). It seems like one has to maybe carefully balance the weightings of the reconstruction loss and the KL loss to get things to work. I'm actually surprised that the unweighted ELBO seemingly just works for the experiments in this paper.

georgosgeorgos commented 3 years ago

I guess that, because VDVAE has almost no normalization, some hyperparameters (like batch_size) are more relevant than what we expect. I managed to solve the plateau (posterior collapse?) by reducing by a factor of 10 the learning rate and reweighting the ELBO. Now I have exploding gradient norms.

Out of curiosity: I guess the model you are training on Omniglot is much smaller than the one for CIFAR10. Did you try to use Bernoulli likelihood and train in nats per image?

christopher-beckham commented 3 years ago

Yeah I think we're talking about the same thing now, posterior collapse.

Yes, it seems like you have to upweight lambda or downweight beta, and it's clear the posterior collapse actually happens most of the time because if you look at the kl_loss it is essentially zero. It makes me wonder whether a more robust way of actually keeping this from going completely to zero is to use a two-sided penalty, i.e. (kl_loss - eps)**2, where eps could be something like 0.05.

Nope I haven't tried a Bernoulli likelihood, though it would make sense. IIRC it has roughly the same number of examples, but there are way more classes.

If it helps any, when I first experimented with this code back when the paper was in review, it seemed like a lot of that instability was actually coming from the DMoL layer (see my post here: https://openreview.net/forum?id=RLRXCV6DbEJ&noteId=cra1CWLY3U_). It is also still not clear to me what the difference is between this and a Gaussian distribution (i.e. L2 error) when it comes to the metrics we care about (such as the ELBO).

rewonc commented 3 years ago

Hi @georgosgeorgos,

The model is certainly sensitive to batch size/LR, and typically (as in any ML model) this should be chosen from a hyperparameter search over several options. Since you're decreasing the batch size by a factor of 4, it's possible that a lower LR than 1/2 the original is required (as you seem to have found).

In my experiments I did not find that adjusting the KL was actually useful--the default unweighting typically led to higher performance--so I'm surprised that you're finding it's necessary for good performance. My assumption (without seeing your specific experiments) would be that the unweighted loss with better optimizer hyperparameters would do better.

christopher-beckham commented 3 years ago

Another note, which may be good for folks to know (even if you used the code as-is, without changing anything):

I found that training broke entirely if I replaced the output distribution from DMoL (the default) to Gaussian (i.e. mean squared error). In this case the reconstruction error just simply plateaued and I had to disable gradient clipping entirely, as well as introduce instance norm to the Block class. From what I understand, @rewonc probably used DMoL to allow for an 'apples to apples' comparison to other methods in the literature, though in my experience it's full of frustration in the form of high gradient norms, NaNs, etc. It seems to me like using a Gaussian output is much simpler, and it probably will give you images that look just as good, but maybe the NLL won't be as good as what you get with DMoL.

rewonc commented 3 years ago

Yes, I think that's reasonable.

Another thing to note is that the gradient skipping threshold depends on LR/batchsize -- @georgosgeorgos it could be that your config has much higher gradient norms, and thus they might be getting skipped (which would result in training stopping). You should be able to see in your logs the number of skipped updates. Ideally you should set it to a value that sees very few skips.