XPixelGroup / BasicSR

Open Source Image and Video Restoration Toolbox for Super-resolution, Denoise, Deblurring, etc. Currently, it includes EDSR, RCAN, SRResNet, SRGAN, ESRGAN, EDVR, BasicVSR, SwinIR, ECBSR, etc. Also support StyleGAN2, DFDNet.
https://basicsr.readthedocs.io/en/latest/
Apache License 2.0
6.89k stars 1.2k forks source link

Gan's training loss is abnormal and the model crashes! ! ! ! ! ! ! ! ! ! #333

Open yanghedada opened 4 years ago

yanghedada commented 4 years ago

The loss of my gan looks very strange. I used a lightweight esrgan, so I also reduced the gan. The weight of the two models is about 2-3Mb. I use the lr strategy of TTUR, and nothing else has changed. I have trained the gan method many times, and the loss of gan has risen.

Default esrgan configuration

image

Lightweight esrgan model crashes

image

Adjusted lr and perceptual loss weight = 0.05

image

Configuration of the 3rd picture

# general settings
name: 052_ESRGAN_x4_Mix_HQ_hr_kernel_A10nose_ttur_smooth_sn_ms_192_inter
model_type: MSESRGANModel
scale: 4
num_gpu: 1  # set num_gpu: 0 for cpu mode
manual_seed: 0

# dataset and data loader settings
datasets:
  train:
    name: DIV2K
    type: PairedImageDataset
    aug: noise
    #aug: ~
    noise_data: /data/Data/Camera_noise
    dataroot_gt: /data/Data/Mix_HQ_hr
    dataroot_lq: /data/Data/Mix_HQ_A10_lr
    # (for lmdb)
    # dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb
    # dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
    filename_tmpl: '{}'
    io_backend:
      type: disk
      # (for lmdb)
      # type: lmdb

    gt_size: 192
    use_flip: true
    use_rot: true

    # data loader
    use_shuffle: true
    num_worker_per_gpu: 8
    batch_size_per_gpu: 40
    dataset_enlarge_ratio: 100
    prefetch_mode: ~

  val:
    name: Set14
    type: PairedImageDataset
    dataroot_gt: /home/Data/Mix_hr_lr_landmark/hr
    dataroot_lq: /home/Data/Mix_hr_lr_landmark/lr
    io_backend:
      type: disk

# network structures
network_g:
  type: RRDBNet
  num_in_ch: 3
  num_out_ch: 3
  num_feat: 64
  num_block: 4

network_d:
  type: Discriminator_VGG_128
  num_in_ch: 3
  num_feat: 32
# path
path:
  pretrain_network_g: ~
  strict_load_g: false
  resume_state: ~

  # pretrain_network_d: experiments/052_ESRGAN_x4_f64b23_DIV2K_400k_B16G1_051pretrain_wandb_mix/models/net_d_66000.pth
  # strict_load_d: true
  # pretrain_network_g: experiments/052_ESRGAN_x4_f64b23_DIV2K_400k_B16G1_051pretrain_wandb_mix/models/net_g_66000.pth
  # strict_load_g: true
  # resume_state: experiments/052_ESRGAN_x4_f64b23_DIV2K_400k_B16G1_051pretrain_wandb_mix/training_states/66000.state

# training settings
#
train:
  optim_g:
    type: Adam
    lr: !!float 5e-5
    weight_decay: 0
    betas: [0.9, 0.99]
  optim_d:
    type: Adam
    lr: !!float 1e-4
    weight_decay: 0
    betas: [0.9, 0.99]

  scheduler:
    type: MultiStepLR
    milestones: [40000,  60000, 100000]
    gamma: 0.5

  total_iter: 140000
  warmup_iter: -1  # no warm up

  # losses
  pixel_opt:
    type: L1Loss
    loss_weight: !!float 1e-2
    reduction: mean
  perceptual_opt:
    type: PerceptualLoss
    layer_weights:
      'conv5_4': 1  # before relu
    vgg_type: vgg19
    use_input_norm: true
    #perceptual_weight: 1.0
    perceptual_weight: !!float 5e-2
    style_weight: 0
    norm_img: false
    criterion: l1
  gan_opt:
    type: GANLoss
    gan_type: vanilla
    real_label_val: 1.0
    fake_label_val: 0.0
    loss_weight: !!float 5e-3

  net_d_iters: 1
  net_d_init_iters: 0

# validation settings
val:
  #val_freq: !!float 5e3
  val_freq: 400
  save_img: true

  metrics:
    psnr: # metric name, can be arbitrary
      type: calculate_psnr
      crop_border: 4
      test_y_channel: false

# logging settings
logger:
  print_freq: 60
  #save_checkpoint_freq: !!float 5e3
  save_checkpoint_freq: 3000
  use_tb_logger: true
  wandb:
    project: ~
    resume_id: ~

# dist training settings
dist_params:
  backend: nccl
  port: 29500
KershawTien commented 3 days ago

请问您在使用ganloss时有遇到以下的问题吗?如果有的话请问是怎么解决的呢? Traceback (most recent call last): File "basicsr/train.py", line 216, in train_pipeline(root_path) File "basicsr/train.py", line 169, in train_pipeline model.optimize_parameters(current_iter) File "/hdd/u202320081001026/DAT-main/basicsr/models/sr_model.py", line 105, in optimize_parameters l_percep, l_style = self.cri_perceptual(self.output, self.gt) File "/home/u202320081001026/anaconda3/envs/DAT/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/u202320081001026/anaconda3/envs/DAT/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, **kwargs) File "/hdd/u202320081001026/DAT-main/basicsr/losses/losses.py", line 353, in forward target_label = self.get_target_label(input, target_is_real) File "/hdd/u202320081001026/DAT-main/basicsr/losses/losses.py", line 334, in get_target_label target_val = (self.real_label_val if target_is_real else self.fake_label_val) RuntimeError: Boolean value of Tensor with more than one value is ambiguous