lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
MIT License
5.55k stars 643 forks source link

getting KL divergence to work #92

Open CDitzel opened 3 years ago

CDitzel commented 3 years ago

in the train_vae script the kl_loss is set to zero via the weight parameter and also in my elaborate runs of experiments, I found that including the KL term does more harm than it helps. @karpathy also mentioned trouble in getting it to work properly.

did anyone achieve any progress on this matter?

Also, this

https://github.com/lucidrains/DALLE-pytorch/blob/7658e60f3a15c74ce96ac5cb661e7dd7101b50b6/dalle_pytorch/dalle_pytorch.py#L196

rather use the soft_one_hot values than the raw logits?

Also, I find it a little confusing that we are actually annealing the temperature of gumbel-softmax, thus steering the it towards one_hot sampling when at the same time we are trying to encourage the distribution to be close to a uniform prior. Isnt this a contradiction?

LinLanbo commented 3 years ago

I argee with you. So I temporarily set kl_weight to zero. Otherwise the recon_loss cannot be reduced. In this version, kl_loss is contradict to recon_loss.

lucidrains commented 3 years ago

maybe someone can email the paper authors to see if this loss was used at all?

CDitzel commented 3 years ago

I must have been used as they mention an increasing weight parameter in the paper.

Still, I am trying, but I cant seem to figure out his e mail adress. On the paper it says

Aditya Ramesh <_@adityaramesh.com

so I tried Aditya_Ramesh@adityaramesh.com, Aditya.Ramesh@adityaramesh.com

but they dont exist...

samringer commented 2 years ago

Has anyone had any more insights/updates on this? I'm running into the exact same issue (on an independent DALL-E repro) and bashing my head against the wall trying to understand the behaviour!