Closed CiaoHe closed 2 years ago
Config
Unets:
Unet1:
dim: 128
image_embed_dim: 512
text_embed_dim: 512
cond_dim: 128
channels: 3
dim_mults: [1, 2, 3, 4]
cond_on_text_encodings: True
Unet2:
dim: 64
image_embed_dim: 512
text_embed_dim: 512
cond_dim: 128
channels: 3
dim_mults: [1, 2, 3, 4]
cond_on_text_encodings: False
Decoder:
image_sizes: [64, 256]
timesteps: 1000
image_cond_drop_prob: 0.1
text_cond_drop_prob: 0.5
condition_on_text_encodings: True
# loss
loss_type: 'l2'
learned_variance: True
# diff
beta_schedule: 'cosine'
predict_x_start: False
predict_x_start_for_latent_diffusion: False # not to latent now
lowres_downsample_first: True
# blur
blur_sigma: 0.1
blur_kernel_size: 3
clip_denoised: True
clip_x_start: True
@CiaoHe how does it look without AMP? i also built in an unconditional feature for the decoder, so one can train it without cross attention conditioning (im guessing it is probably the attention blocks blowing up, so i added some additional normalization that should help)
@CiaoHe also realized the default learning rate i had (3e-4) is a bit too high compared to Ho's original ddpm (2e-5)
worse comes to worse, i can bring in the cosine sim attention from SwinV2 for the cross attention blocks, if we diagnose that to be the issue
@CiaoHe how does it look without AMP? i also built in an unconditional feature for the decoder, so one can train it without cross attention conditioning (im guessing it is probably the attention blocks blowing up, so i added some additional normalization that should help)
here's w/o AMP.
@lucidrains let me try the current version
@lucidrains wow cool built-in grad-accum support.
Now, everything works fine (max-batch-size
:32, total-batch-size
:265, amp
:True)
@CiaoHe hurray 🎊 🎉
I missed some msgs you replied in the github box. It will be a massive loss if u end up open source work.(ಥ_ಥ) But anyway hope you get a excellent experience!
获取Outlook for Androidhttps://aka.ms/AAb9ysg
From: Phil Wang @.> Sent: Sunday, May 15, 2022 12:35:47 PM To: lucidrains/DALLE2-pytorch @.> Cc: He Cao @.>; Mention @.> Subject: Re: [lucidrains/DALLE2-pytorch] Decoder AMP training loss blow up (Issue #97)
@CiaoHehttps://github.com/CiaoHe hurray 🎊 🎉
— Reply to this email directly, view it on GitHubhttps://github.com/lucidrains/DALLE2-pytorch/issues/97#issuecomment-1126857560, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AJJ3OE6FACGJRG7MS3NUQETVKB5KHANCNFSM5V54FYFA. You are receiving this because you were mentioned.Message ID: @.***>
Just pin my training issue:
when set
amp = True
, the decoder's loss(unet1 & unet2) blowed up to 1K+. [loss <0.01 when amp declined] Just curious, does the half-precision training mode harm the DDPM training?as a ref: https://wandb.ai/hcaoaf/dalle2-decoder/runs/3sb0pumh?workspace=user-hcaoaf