lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.03k stars 1.07k forks source link

Decoder loss jumps to nan #138

Closed Veldrovive closed 2 years ago

Veldrovive commented 2 years ago

With small models there seems to be a large chance of the decoder loss becoming nan early in training. Using autograd.detect_anomaly() torch outputs the error Function 'ExpBackward0' returned nan values in its 0th output. for line 325 of dalle2_pytorch.py which is 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)).

YUHANG-Ma commented 2 years ago

Thanks. I am trying to edit the code. My dataset contains about 200w pics and their embeddings. Is this a small model?

Veldrovive commented 2 years ago

It mostly depends on the input dimensions of your unets.

Veldrovive commented 2 years ago

Trying to track down the cause of this has convinced me that I have no idea what is going on. With small models, around 16 to 64 input dimension, at some point the variances of logvar1 and logvar2 blow up in the space of a couple iterations, then it seems that the value of logvar1 gets clipped and the next iteration the unet outputs nan.

lucidrains commented 2 years ago

@Veldrovive i really don't think it is worth using the learned variance. i don't believe anyone uses except Nichol. the learned variance was even mentioned in passing here https://arxiv.org/abs/2206.00364 , but i haven't read a single paper say anything positive about it

lucidrains commented 2 years ago

what we can try is constraining the output of the network for the interpolation fraction value to be from 0 to 1 with a sigmoid. they claim in the paper this isn't needed, but maybe it adds to instability in scenarios they did not test

lucidrains commented 2 years ago

I think it is working for @YUHANG-Ma now

Veldrovive commented 2 years ago

I think I might have found the deeper cause of this. I'm haven't verified 100% yet, but when experimenting with deepspeed I only see this issue when fp16 is enabled. I have seen people complaining about LayerNorm causing this issue in some cases when using half precision so that could be a possible cause. detect_anomaly show the first NaN pops up in one of the convolutional layers though.