lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.03k stars 1.07k forks source link

EMA Bug #87

Closed CiaoHe closed 2 years ago

CiaoHe commented 2 years ago

Hi Phil,

This morning I tried to run the decoder training part. I decided to use DecoderTrainer but found one issue when ema update.

When after using decoder_trainer do sampling, the next train forward run will throw RunError:

Traceback (most recent call last):
  File "/home/caohe/DPMs/dalle2/train_decoder.py", line 321, in <module>    main()
  File "/home/caohe/DPMs/dalle2/train_decoder.py", line 318, in main
    train(decoder_trainer, train_dl, val_dl, train_config, device)
  File "/home/caohe/DPMs/dalle2/train_decoder.py", line 195, in train
    trainer.update(unet_number)
  File "/home/caohe/DPMs/dalle2/dalle2_pytorch/train.py", line 288, in update
    self.ema_unets[index].update()
  File "/home/caohe/DPMs/dalle2/dalle2_pytorch/train.py", line 119, in update
    self.update_moving_average(self.ema_model, self.online_model)
  File "/home/caohe/DPMs/dalle2/dalle2_pytorch/train.py", line 129, in update_moving_average
    ema_param.data = calculate_ema(self.beta, old_weight, up_weight)
  File "/home/caohe/DPMs/dalle2/dalle2_pytorch/train.py", line 125, in calculate_ema
    return old * beta + new * (1 - beta)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and CPU!

https://github.com/lucidrains/DALLE2-pytorch/blob/6f76652d118d3da2419bd12084abfff45772553b/dalle2_pytorch/train.py#L108-L118

And I checked the up_weight.device(online model) and old_weight.device(ema model), found online model is on cuda:0 but ema model is on cpu. It's really weird, I debugged for a long time and I think it might be caused by the DecoderTrainer.sample() process. When swapping across ema and online model, there exists some problem related to the device. https://github.com/lucidrains/DALLE2-pytorch/blob/6021945fc8e1ec27bbebfa1e181e892a7c4d05fb/dalle2_pytorch/train.py#L298-L308

The way I fixed it just add self.ema_model = self.ema_model.to(next(self.online_model.parameters()).device) before use self.update_moving_average(self.ema_model, self.online_model) (pretty naive haha)

Hope to hear your solution

Enjoy!

lucidrains commented 2 years ago

@CiaoHe ohh yes, you are correct, thank you! i think this should fix it https://github.com/lucidrains/DALLE2-pytorch/commit/924455d97d0e0230ba6a34c5d3af792d272af481

CiaoHe commented 2 years ago

btw, how do you usually debug/test when adding some new functions or starting a new repo? I found my efficiency is quite low (Either run in command and wait for ERROR, or copy codes into jupyter-notebook and test again and again...)

lucidrains commented 2 years ago

@CiaoHe i've come full circle and just use a simple test.py in the root directory + print lol

CiaoHe commented 2 years ago

@CiaoHe i've come full circle and just use a simple test.py in the root directory + print lol

@lucidrains lol. But when moving to cluster do train, things gonna be out of control sometimes (I hate bugs)

lucidrains commented 2 years ago

🪰 🪱 🐞