CompVis / latent-diffusion

High-Resolution Image Synthesis with Latent Diffusion Models
MIT License
11.6k stars 1.51k forks source link

the stability of training(a collapse loss) #52

Open sunsq-blue opened 2 years ago

sunsq-blue commented 2 years ago

Thanks for the great work. I try to train the ldm model on ImageNet with 8 V100, but get a bad result.I found that loss was normal at first, but soon collapsed:

image image

and the sampled image are all noise at 5000 steps: image

How can I solve this problem, thank you very much

sunsq-blue commented 2 years ago

update: The configuration file has hardly been modified。

model: base_learning_rate: 1.0e-06 target: ldm.models.diffusion.ddpm.LatentDiffusion params: linear_start: 0.0015 linear_end: 0.0195 num_timesteps_cond: 1 log_every_t: 200 timesteps: 1000 first_stage_key: image cond_stage_key: class_label image_size: 32 channels: 4 cond_stage_trainable: true conditioning_key: crossattn monitor: val/loss_simple_ema unet_config: target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: image_size: 32 in_channels: 4 out_channels: 4 model_channels: 256 attention_resolutions:

note: this isn\t actually the resolution but

    # the downsampling factor, i.e. this corresnponds to
    # attention on spatial resolution 8,16,32, as the
    # spatial reolution of the latents is 32 for f8
    - 4
    - 2
    - 1
    num_res_blocks: 2
    channel_mult:
    - 1
    - 2
    - 4
    num_head_channels: 32
    use_spatial_transformer: true
    transformer_depth: 1
    context_dim: 512
first_stage_config:
  target: ldm.models.autoencoder.VQModelInterface
  params:
    embed_dim: 4
    n_embed: 16384
    ckpt_path: models/first_stage_models/vq-f8/model.ckpt
    ddconfig:
      double_z: false
      z_channels: 4
      resolution: 256
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult:
      - 1
      - 2
      - 2
      - 4
      num_res_blocks: 2
      attn_resolutions:
      - 32
      dropout: 0.0
    lossconfig:
      target: torch.nn.Identity
cond_stage_config:
  target: ldm.modules.encoders.modules.ClassEmbedder
  params:
    embed_dim: 512
    key: class_label

data: target: main.DataModuleFromConfig params: batch_size: 64 num_workers: 12 wrap: false train: target: ldm.data.imagenet.ImageNetTrain params: data_root: /data/imagenet/ImageNet/data config: size: 256 validation: target: ldm.data.imagenet.ImageNetValidation params: data_root: /data/imagenet/ImageNet/data config: size: 256

lightning: callbacks: image_logger: target: main.ImageLogger params: batch_frequency: 5000 max_images: 8 increase_log_steps: False

trainer: benchmark: True

sunsq-blue commented 2 years ago

update:
When I train on single GPU , the loss collapse seem not to happen again. And another solution may change the AdamW to Adam. I hope others will share their solutions,thanks.

CrossLee1 commented 2 years ago

@sunsq-blue I also met the same problem with you. Did you succeed to make it work using the above solutions? Thanks

Ir1d commented 2 years ago

@CrossLee1 @sunsq-blue I figure out that you shouldn't scale the lr too much. You should use --scale_lr False and it works well. In lightning the effective batch size is batch_size n_gpu n_device. In code they scale the lr to be bs * n_gpu

wtliao commented 1 year ago

@Ir1d Thanks a lot! I was suffering the same problem for several days. Finally solved with your suggestion!