lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.03k stars 1.07k forks source link

Potential EMA Model Issues #131

Closed nousr closed 2 years ago

nousr commented 2 years ago

So, I've noticed a potential bug related to the exponential moving averaged priors (I haven't tested the other models).

Essentially, the EMA model seemingly refuses to "learn". I have to believe that the weights aren't being properly updated with the online_model, for whatever reason.

image

As you can see, as soon as the EMA kicks in at step 1000, the loss plateaus. I've never worked with an EMA model before, so I guess this could be expected behavior?

This issue also presents itself when running evaluation on the EMA model (the (text<->predicted_image) similarity should be at least 0.1 even after a few hundred steps)

image

lucidrains commented 2 years ago

it is normal i believe, you can always try lowering the ema_beta on the Trainer to 0.995 or even 0.99 to have it update faster

nousr commented 2 years ago

gotcha, glad I asked then! just have to be patient then 😄

nousr commented 2 years ago

@lucidrains hey! when you get a chance, do you think could you explain the intuition for this line?

https://github.com/lucidrains/DALLE2-pytorch/blob/1cc288af39f171b3e7f77fdb4252682af05e17e9/dalle2_pytorch/trainer.py#L191

specifically, why the // self.update_every? I guess I don't see the purpose of moving self.update_after "up" by a factor of update_every?

nousr commented 2 years ago

Re-opening this issue as @Veldrovive has confirmed that the EMA is not giving expected results for the decoder. A discussion with Katherine has confirmed my suspicions as well.

Will try to do some more debugging after my class final today.

lucidrains commented 2 years ago

@nousr oh crap, yes, you are right! i thought that i wasn't incrementing the EMA steps in accordance with the global training step, but i was

https://github.com/lucidrains/DALLE2-pytorch/commit/9cc475f6e7990b3d978128902ee0ea90614451f6 should fix the update every issue