openai / DALL-E

PyTorch package for the discrete VAE used for DALL·E.
Other
10.8k stars 1.94k forks source link

Discrete Bottleneck Method #1

Closed cfoster0 closed 3 years ago

cfoster0 commented 3 years ago

Thank you for letting us take an advance peek into the model! I'm sure folks will have fun doing some interesting art and research with this.

From the blog post, there's reference made to the Gumbel-Softmax trick, the Concrete distribution, and previous VQ-VAEs, but it is left ambiguous which of these are used in the discrete VAE and how. Looking through the model in this repo, I do not see anything that fully clarifies this: there's an encoder that predicts distributions over the vocab z_logits, which are argmax-ed to get codes z, and then run through a decoder. Not sure if this clarifies, though, how the logits predicted by the encoder are translated into codes in training.

How does this work in training? Are the codes hard-sampled (argmax) with noise added to the logits, soft-sampled (softmax) with noise added to the logits, or something else?

adityaramesh commented 3 years ago

When training the discrete VAE, we apply the gumbel-softmax relaxation during both the forward and backward passes. I.e., rather than taking the argmax as is done in the notebook, we add gumbel noise and divide by the relaxation temperature. Both the temperature and the step size for optimization are annealed according to schedules that we describe in Appendix A of our paper, which should be online in a few hours.

When training the transformer, we take the argmax of the logits, rather than sampling from them. We could optionally train the transformer model by sampling from the logits and using the softmax of the logits as the targets for the cross-entropy loss. We found this to be useful in the overfitting regime, but not in the underfitting regime.