Open Jarvis73 opened 1 year ago
Hi,
Thanks for your interest in our work. I don't know how much it will help, but here is the training loss curve. How much different is your dataset from ImageNet256?
Hi, devzhk.
I am training MaskDiT on a self-collected text2image dataset (cross-attention is used to inject text condition). My loss is similar to the curve you provided. However, after training for about 70 epochs, the loss starts to have NaN values and it becomes more and more frequent as the training progresses. I used PyTorch's mixed precision. Have you encountered this situation before?
Hi,
Maybe try to turn off mixed precision for debugging? If disabling mixed precision works, then try to diable AMP for each individual module to see which one is causing the problem.
Nice work! I've been training MaskDiT on my own dataset recently, but I'm unsure if the loss is decreasing normally because I'm not seeing satisfactory results in the generated images. Can you please provide the log records of training on ImageNet 256x256? Or the loss curve.
Thanks very much!