lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.09k stars 1.09k forks source link

Decoder AMP training loss blow up #97

Closed CiaoHe closed 2 years ago

CiaoHe commented 2 years ago

Just pin my training issue:

when set amp = True, the decoder's loss(unet1 & unet2) blowed up to 1K+. [loss <0.01 when amp declined] Just curious, does the half-precision training mode harm the DDPM training?

image

as a ref: https://wandb.ai/hcaoaf/dalle2-decoder/runs/3sb0pumh?workspace=user-hcaoaf

CiaoHe commented 2 years ago

Config


Unets:
  Unet1:
    dim: 128
    image_embed_dim: 512
    text_embed_dim: 512
    cond_dim: 128
    channels: 3
    dim_mults: [1, 2, 3, 4] 
    cond_on_text_encodings: True

  Unet2:
    dim: 64
    image_embed_dim: 512
    text_embed_dim: 512
    cond_dim: 128
    channels: 3
    dim_mults: [1, 2, 3, 4]
    cond_on_text_encodings: False

Decoder:
  image_sizes: [64, 256]
  timesteps: 1000
  image_cond_drop_prob: 0.1
  text_cond_drop_prob: 0.5
  condition_on_text_encodings: True
  # loss
  loss_type: 'l2'
  learned_variance: True
  # diff
  beta_schedule: 'cosine'
  predict_x_start: False
  predict_x_start_for_latent_diffusion: False # not to latent now
  lowres_downsample_first: True
  # blur
  blur_sigma: 0.1
  blur_kernel_size: 3
  clip_denoised: True
  clip_x_start: True
lucidrains commented 2 years ago

@CiaoHe how does it look without AMP? i also built in an unconditional feature for the decoder, so one can train it without cross attention conditioning (im guessing it is probably the attention blocks blowing up, so i added some additional normalization that should help)

lucidrains commented 2 years ago

@CiaoHe also realized the default learning rate i had (3e-4) is a bit too high compared to Ho's original ddpm (2e-5)

lucidrains commented 2 years ago

worse comes to worse, i can bring in the cosine sim attention from SwinV2 for the cross attention blocks, if we diagnose that to be the issue

CiaoHe commented 2 years ago

@CiaoHe how does it look without AMP? i also built in an unconditional feature for the decoder, so one can train it without cross attention conditioning (im guessing it is probably the attention blocks blowing up, so i added some additional normalization that should help)

here's w/o AMP.

image

@lucidrains let me try the current version

CiaoHe commented 2 years ago

@lucidrains wow cool built-in grad-accum support. Now, everything works fine (max-batch-size:32, total-batch-size:265, amp:True)

image
lucidrains commented 2 years ago

@CiaoHe hurray 🎊 🎉

CiaoHe commented 1 year ago

I missed some msgs you replied in the github box. It will be a massive loss if u end up open source work.(ಥ_ಥ) But anyway hope you get a excellent experience!

获取Outlook for Androidhttps://aka.ms/AAb9ysg


From: Phil Wang @.> Sent: Sunday, May 15, 2022 12:35:47 PM To: lucidrains/DALLE2-pytorch @.> Cc: He Cao @.>; Mention @.> Subject: Re: [lucidrains/DALLE2-pytorch] Decoder AMP training loss blow up (Issue #97)

@CiaoHehttps://github.com/CiaoHe hurray 🎊 🎉

— Reply to this email directly, view it on GitHubhttps://github.com/lucidrains/DALLE2-pytorch/issues/97#issuecomment-1126857560, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AJJ3OE6FACGJRG7MS3NUQETVKB5KHANCNFSM5V54FYFA. You are receiving this because you were mentioned.Message ID: @.***>