LTH14 / mar

PyTorch implementation of MAR+DiffLoss https://arxiv.org/abs/2406.11838
MIT License
766 stars 40 forks source link

VAE decoded as NaN in early stages of training #25

Closed xiazhi1 closed 2 weeks ago

xiazhi1 commented 2 weeks ago

Hi, @LTH14 , Thanks for awesome work ! Recently I tried to train mar-b on the celebaHQ dataset, but I found that the vae decoded to NaN when evaluating the checkpoints saved during the first 200 epochs of training. When I tested the training results after 400 epochs, VAE was able to decode slightly reasonable face images. I observed that the loss had dropped to a lower level after 100 epochs of training. There is no obvious loss drop from 200 epochs to 400 epochs, but the fid of the vae decoding result drops significantly. It is worth mentioning that I found that the loss of mar-b on celeba-hq oscillates violently. Is this normal? Any suggestions are welcome, looking forward to your reply!

LTH14 commented 2 weeks ago

Thanks for your interest. The VAE in this repo is trained only on ImageNet, so it might not be able to tokenize face images. You could consider using LDM's VAE (or stable diffusion's VAE) which are trained on large-scale datasets.

We also face a similar NAN issue during the early stage of training. This is because the ema we use is 0.9999, and thus the model needs around 100k training iterations to achieve a reasonable performance.

Such a loss curve is common in diffusion training (see Fig 13 of the DiT paper). Although the loss drop is minor, the performance actually improves a lot in those slow drops. If the loss oscillates a lot, you should consider a smaller learning rate -- this is not observed in our ImageNet training.

xiazhi1 commented 2 weeks ago

Thanks, this is very helpful to me