FirasGit / medicaldiffusion

Medical Diffusion: This repository contains the code to our paper Medical Diffusion: Denoising Diffusion Probabilistic Models for 3D Medical Image Synthesis
328 stars 58 forks source link

Loss plateau and potential mode collapse in VAGAN training on MRnet dataset #15

Closed WhenMelancholy closed 4 months ago

WhenMelancholy commented 1 year ago

When training VAGAN on the MRnet dataset, the loss stopped decreasing after a certain period of time and started to increase. We trained the model using the following parameters:

CUDA_VISIBLE_DEVICES=3 PL_TORCH_DISTRIBUTED_BACKEND=gloo PYTHONPATH=.:$PYTHONPATH python \
    train/train_vqgan.py \
    dataset=mrnet \
    dataset.root_dir="~/medicaldiffusion/data/MRNet-v1.0/" \
    model=vq_gan_3d \
    model.gpus=1 \
    model.default_root_dir="~/medicaldiffusion/when/checkpoints/vq_gan2" \
    model.default_root_dir_postfix="mrnet" \
    model.precision=32 \
    model.embedding_dim=8 \
    model.n_hiddens=16 \
    model.downsample=[4,4,4] \
    model.num_workers=32 \
    model.gradient_clip_val=1.0 \
    model.lr=3e-4 \
    model.discriminator_iter_start=10000 \
    model.perceptual_weight=4 \
    model.image_gan_weight=1 \
    model.video_gan_weight=1 \
    model.gan_feat_weight=4 \
    model.batch_size=2 \
    model.n_codes=16384 \
    model.accumulate_grad_batches=1 

The excerpt of abnormal loss changes during training is as follows:

......
Epoch 0:   1%|          | 4/565 [00:04<10:59,  1.18s/it, loss=2.48, v_num=0, train/perceptual_loss_step=2.650, train/recon_loss_step=2.390, train/aeloss_step=0.000, train/commitment_loss_step=0.00397, train/perplexity_step=8.32e+3, train/discloss_step=0.000]
......
Epoch 3:  44%|████▎     | 247/565 [02:13<02:52,  1.84it/s, loss=1.11, v_num=0, train/perceptual_loss_step=1.790, train/recon_loss_step=0.387, train/aeloss_step=0.000, train/commitment_loss_step=0.00353, train/perplexity_step=4.45e+3, train/discloss_step=0.000, val/recon_loss=0.478, val/perceptual_loss=1.600, val/perplexity=5.95e+3, val/commitment_loss=0.00394, train/perceptual_loss_epoch=1.690, train/recon_loss_epoch=0.481, train/aeloss_epoch=0.000, train/commitment_loss_epoch=0.00367, train/perplexity_epoch=5.88e+3, train/discloss_epoch=0.000]
......
Epoch 11:   7%|▋         | 41/565 [00:24<05:17,  1.65it/s, loss=4.18, v_num=0, train/perceptual_loss_step=1.160, train/recon_loss_step=0.319, train/aeloss_step=0.326, train/commitment_loss_step=0.00809, train/perplexity_step=5.84e+3, train/discloss_step=1.930, val/recon_loss=0.378, val/perceptual_loss=1.360, val/perplexity=6.32e+3, val/commitment_loss=0.00797, train/perceptual_loss_epoch=1.340, train/recon_loss_epoch=0.382, train/aeloss_epoch=0.000, train/commitment_loss_epoch=0.00779, train/perplexity_epoch=6.86e+3, train/discloss_epoch=0.000]
......
Epoch 14:  49%|████▉     | 278/565 [02:30<02:35,  1.85it/s, loss=5.4, v_num=0, train/perceptual_loss_step=1.440, train/recon_loss_step=0.397, train/aeloss_step=-.0933, train/commitment_loss_step=0.0129, train/perplexity_step=6.43e+3, train/discloss_step=1.940, val/recon_loss=0.422, val/perceptual_loss=1.600, val/perplexity=6.38e+3, val/commitment_loss=0.0126, train/perceptual_loss_epoch=1.640, train/recon_loss_epoch=0.429, train/aeloss_epoch=0.679, train/commitment_loss_epoch=0.012, train/perplexity_epoch=6.49e+3, train/discloss_epoch=1.660]
......

Is this caused by mode collapse in GAN? Or is it due to the training configuration? Are there any good methods to fix this? I would greatly appreciate any suggestions.

benearnthof commented 1 year ago

@WhenMelancholy What worked for me was starting the discriminator after only 50000+ steps and also decreasing the GAN loss weights by a factor of 4-5. This will cause the discriminator to train a lot slower (so sample quality will quickly decrease for a couple thousand iterations after the discriminator starts training) but should recover after about 5-6k iterations and lead to a further improvement in sample quality. Also precision of 32 was needed for me in all cases, but you have that in your config already.

hieuaka47 commented 5 months ago

@WhenMelancholy What worked for me was starting the discriminator after only 50000+ steps and also decreasing the GAN loss weights by a factor of 4-5. This will cause the discriminator to train a lot slower (so sample quality will quickly decrease for a couple thousand iterations after the discriminator starts training) but should recover after about 5-6k iterations and lead to a further improvement in sample quality. Also precision of 32 was needed for me in all cases, but you have that in your config already.

Hello, could you tell me what the GAN loss weights include?

WhenMelancholy commented 4 months ago

@WhenMelancholy What worked for me was starting the discriminator after only 50000+ steps and also decreasing the GAN loss weights by a factor of 4-5. This will cause the discriminator to train a lot slower (so sample quality will quickly decrease for a couple thousand iterations after the discriminator starts training) but should recover after about 5-6k iterations and lead to a further improvement in sample quality. Also precision of 32 was needed for me in all cases, but you have that in your config already.

Hello, could you tell me what the GAN loss weights include?

Sorry but I can not access the environment I meet this problem before >_< I will close the issue.