lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
MIT License
5.57k stars 642 forks source link

KL divergent term in DiscreteVAE #250

Open richcmwang opened 3 years ago

richcmwang commented 3 years ago

Hi Phil @lucidrains, I notice a KL divergent term (default set to 0) in the DiscreteVAE. The paper often quoted (Neural discrete representation learning) have extra two stopgradient terms. Can you point me to the reference to the definition in the code?

afiaka87 commented 3 years ago

I believe we've had some earlier discussion on this topic @richcmwang.

https://github.com/lucidrains/DALLE-pytorch/issues/74#issuecomment-794331986

A podcast called "DeepLearningDeepDive" was able to get one of the main researchers for the DALL-E paper in for an interview - they go over the entire paper.

The researcher is questioned on this very topic in this video

Youtube Video: Watch the video https://www.youtube.com/watch?v=PtdpWC7Sr98&t=2544s

Podcast: https://podcasts.apple.com/us/podcast/deep-learning-deep-dive/id1555309024 https://open.spotify.com/show/1zqRuymMjxXGKYMmEeTarz

Some have suggested that a researcher was somewhat coy around the equation in "Figure 1" in the paper. Having rewatched the video - I'm fairly certain it's not terribly important. The researcher claims:

First discussion (heavily paraphrased)

Discussing Differences Between VQGAN paper and DALLE's VAE

Interviewer: "in this case - the encoder for the vae is actually predicting a true distribution over the tokens. this regularization term D_kl that pushes the distribution towards a uniform prior actually matters. this Beta parameter gives you this tradeoff between the reconstruction objective and this matches the prior objective. We didn't expect the beta to be set so high to 6.6 after a short annealing from 0 very early in training."

Aditya Ramesh: "It's a trick from Sam Bowman's old paper Learning Continuous Representations. It's quite likely if it is set to a constant it may not matter. "

Second

Interviewer: "The description in section 2.2 doesn't seem to match the ELBO equation in figure 1. Being rigorous - when learning the transformer, you should also be including this "KL" term that's trying to get it to go towards the prior. And that seems like it's not really done in any explicit way. Perhaps it doesn't matter? This seems more straightforward. But the formalism for matching the ELBO didn't quite click with me." Aditya Ramesh: I guess it's because we're maximizing it in two stages right? In the first stage we train the encoder/decoder... Yeah I see what you mean - the transformer should get the gradient from the "KL" term. Aditya Ramesh: Actually isn't the cross entropy loss for the transformer "VKL" term in the ELBO? Interviewer: Yes but it's not quite the same because you're using argmax sampling on the image tokens rather than using the full predicted distributions. Aditya Ramesh: Yes, that's true. In the underfitting regime, the discrete vae gives you soft targets which can be used for the Transformer's cross-entropy loss which is useful for regularization. I didn't do that in this case because we had 250 million image-text pairs which placed us well in the underfitting regime we wanted and, in experiments it didn't fit well with the soft codes.

(aside: the whole video is incredibly useful and highly recommend watching it. the OpenAI team made many decisions based on their specific goals which aren't necessarily hard requirements to implement.)

Other notable quotes (again - paraphrasing)

"the other thing that struck out ot me in this section is that the visual codebook is a CNN (resnet) - using some type of vision transformer for the visual codebook instead could help."

"I didn't try batchnorm - if I can avoid using I tend not to. But it may help. With generative models though - the model may overfit to the statistics of the batch instead of the features of each image. some tricks involve using batch norm for the first few epochs but this seems unprincipled."

@janEbert @mehdidc @robvanvolt @rom1504

" We found proper composition of visual elements still occurs with the smaller dataset 'Conceptual Captions'"

Discussing the attention architecture: ("row", "column", "row", "row")

@lucidrains

"That's copied directly from sparse transformers. Scott found adding convolutional attention mask as the very last layer helped the loss some small amount compared to just using the ("row", "column", "row", "row") pattern."

Sparse Attention for Hardware Constraints? Or for Loss?

Interviewer: "Column attention may be a bit slower due to a transpose requirement. Is this why you used more row attention than column attention?"

"It's actually because it helped with the loss. "You would expect there to be a symmetry in both x and y directions. "It seems that column attention only attends to the previous column whereas row attention can attend to more of the image." "We used sparse not just for performance - but also because we get lower loss as noted in Sparse transformers."

Later in the video they recap the subject and intuit that the row and column attention are indeed helping more than just dense attention would because it helps the transformer take advantage of the 2 dimensional nature of images in order to learn a more efficient representation.

richcmwang commented 3 years ago

@afiaka87 Thanks for all the information! The in-depth video discussion is really interesting. I was training VQGAN and find it difficult to training compared to DiscreteVAE and start looking into the loss function. Somehow I thought DiscreteVAE used the loss in the defined in the Neural discrete representation learning paper. Now I realized that the ELBO equation (1) is the objective function.

Also wondering about people's experience in training VQGAN. I find DiscreteVAE gives fairly good reconstruction quickly (within 1 epoch), but VQGAN is completely blurred after training double the time. It makes me puzzle whether it takes a much longer time to train or other types of tuning such as hyperparameters or weighting are the issues.

afiaka87 commented 3 years ago

@afiaka87 Thanks for all the information! The in-depth video discussion is really interesting. I was training VQGAN and find it difficult to training compared to DiscreteVAE and start looking into the loss function. Somehow I thought DiscreteVAE used the loss in the defined in the Neural discrete representation learning paper. Now I realized that the ELBO equation (1) is the objective function.

Also wondering about people's experience in training VQGAN. I find DiscreteVAE gives fairly good reconstruction quickly (within 1 epoch), but VQGAN is completely blurred after training double the time. It makes me puzzle whether it takes a much longer time to train or other types of tuning such as hyperparameters or weighting are the issues.

I know @bob80333 and @gabriel_syme have done a bit of training with the VQGAN. Perhaps they can chime in?

My first intuition is that training may indeed be a bit slower due to the adversarial net. But that the eventual results will still be quite good relative to the number of parameters needed.

bob80333 commented 3 years ago

according to the authors, they needed something like 5+ epochs on imagenet to get good results, which would translate to >600k iterations with batch size 8. I would highly suggest fine-tuning their pretrained models if possible to save GPU time.

https://github.com/CompVis/taming-transformers/issues/31#issuecomment-809382665

richcmwang commented 3 years ago

@afiaka87 @bob80333 Thanks for pointing this out. Yes, I probably need to train a very long time from the scratch before seeing reasonable reconstruction. I load the pretrained checkpoint and realized that the model already gives sharp reconstruction right from the beginning on my custom dataset. Saves lots of trouble and time!