lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.03k stars 1.07k forks source link

Loss stuck issue #72

Closed goldiusleonard closed 2 years ago

goldiusleonard commented 2 years ago

I have trained the dalle 2 model, but the loss is still stuck on 0.1 - 0.2. I use AdamW optimizer with ExponentialLR scheduler. Is there any solution for this problem? I have attached my training code below. Please review itt. Thanks!

https://github.com/goldiusleonard/Dalle2_pytorch_project

lucidrains commented 2 years ago

@goldiusleonard hey Goldius, the repository is still about 3-4 weeks away from being trainable by the general public. there's a number of engineering work to be done, considering it involves 3+ model (multiple unets in the cascade in the decoder)

however, if you want to continue tinkering for education, i highly recommend trying out the training wrapper i designed at https://github.com/lucidrains/dalle2-pytorch#training-wrapper-wip

Amiineh commented 2 years ago

@lucidrains I'm using the wrapper provided in train_diffusion_prior.py for training. However, I get stuck on scaler.scale(loss).backward(). I tried getting rid of the scaler, but it still doesn't proceed. The output looks like this:

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1992.54it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s]
  0%|                                                                                                                                                  | 0/42 [00:00<?, ?it/s]
  0%|                                                                                                                                                  | 0/42 [00:00<?, ?it/s]

The run doesn't return any errors, but gets stuck there. Any idea why this could be happening?