Anima-Lab / MaskDiT

Code for Fast Training of Diffusion Models with Masked Transformers
MIT License
377 stars 14 forks source link

Training log and loss curve #4

Open Jarvis73 opened 1 year ago

Jarvis73 commented 1 year ago

Nice work! I've been training MaskDiT on my own dataset recently, but I'm unsure if the loss is decreasing normally because I'm not seeing satisfactory results in the generated images. Can you please provide the log records of training on ImageNet 256x256? Or the loss curve.

Thanks very much!

devzhk commented 1 year ago

Hi,

Thanks for your interest in our work. I don't know how much it will help, but here is the training loss curve. How much different is your dataset from ImageNet256? training_loss_mask_training

Jarvis73 commented 1 year ago

Hi, devzhk.

I am training MaskDiT on a self-collected text2image dataset (cross-attention is used to inject text condition). My loss is similar to the curve you provided. However, after training for about 70 epochs, the loss starts to have NaN values and it becomes more and more frequent as the training progresses. I used PyTorch's mixed precision. Have you encountered this situation before?

devzhk commented 1 year ago

Hi,

Maybe try to turn off mixed precision for debugging? If disabling mixed precision works, then try to diable AMP for each individual module to see which one is causing the problem.