chuanyangjin / fast-DiT

Fast Diffusion Models with Transformers
Other
735 stars 94 forks source link

Maybe need accelerator.reduce? Loss scale mismatch DiT official code #16

Open ShenZhang-Shin opened 2 months ago

ShenZhang-Shin commented 2 months ago

DiT official code loss log image

while Fast DiT loss is much smaller than official loss image

I think maybe fast dit code miss gathering the loss across all GPUs. After using accelerator.reduce avg_loss = accelerator.reduce(avg_loss, reduction="sum") the loss matches the result of official code image

wangyanhui666 commented 1 month ago

I used a single node with 4 GPUs for training, and the loss is normal, around 0.15. Is this 'reduce' function used for multi-node multi-GPU situations?

Additionally, the model I trained has a very high FID; in the 256x256 setting, after training for 400k steps, the FID is 70. I'm not sure where the problem is.

ShenZhang-Shin commented 1 month ago

I use a single node with 8 GPUs dit-s-2 with 256x256 resolution? After 400k steps, my FID is 69.76, 1.3 larger than DiT paper's result maybe FP16, or VAE pre-extraction? @chuanyangjin

wangyanhui666 commented 1 month ago

I use a single node with 8 GPUs dit-s-2 with 256x256 resolution? After 400k steps, my FID is 69.76, 1.3 larger than DiT paper's result maybe FP16, or VAE pre-extraction? @chuanyangjin

I use dit-xl-2 with 256x256 resolution. After 400k step, my FID is 70..., I do not know why. Maybe cause by bf16 training?

wangyanhui666 commented 1 month ago

can i get your wechat to discuss?

I use a single node with 8 GPUs dit-s-2 with 256x256 resolution? After 400k steps, my FID is 69.76, 1.3 larger than DiT paper's result maybe FP16, or VAE pre-extraction? @chuanyangjin

ShenZhang-Shin commented 1 month ago

@wangyanhui666
Well, too high FID for dit-xl-2, you should check your code. Ok, you can send your wechat to zhangshen1915@gmail.com. I friend you

dlsrbgg33 commented 1 month ago

@ShenZhang-Shin Did you fix the performance gap issue?

ShenZhang-Shin commented 3 weeks ago

@dlsrbgg33 yes, I did I use fast-dit tricks and achieve the same performance of dit without cfg.