the-database / traiNNer-redux

Deep learning training framework for image super resolution and restoration.
Apache License 2.0
23 stars 3 forks source link

feature request: add support to calculate_dists for grayscale images #57

Closed HYChou0515 closed 2 months ago

HYChou0515 commented 2 months ago

Hi, I am trying to train HAT on grayscale images dataset. Here’s my HAT.yml, modified from the official HAT.yml.

my HAT.yml ```yaml #################### # General Settings #################### name: 2x_HAT_L_gray scale: 2 # 1, 2, 3, 4, 8 use_amp: true # Speed up training and reduce VRAM usage. amp_bf16: true # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. fast_matmul: false # Trade precision for performance. num_gpu: auto # manual_seed: 1024 # Deterministic mode, slows down training. ################################# # Dataset and Dataloader Settings ################################# datasets: # Settings for the training dataset. train: name: Train Dataset type: PairedImageDataset color: "y" # Path to the HR (high res) images in your training dataset. dataroot_gt: datasets/train/grayscale1/hr # Path to the LR (low res) images in your training dataset. dataroot_lq: datasets/train/grayscale1/lr # meta_info: data/meta_info/dataset1.txt io_backend: type: disk gt_size: 64 # During training, it will crop a square of this size from your HR images. Larger is usually better but uses more VRAM. use_hflip: true # Randomly flip the images horizontally. use_rot: true # Randomly rotate the images. num_worker_per_gpu: 8 batch_size_per_gpu: 6 # Increasing stabilizes training but going too high can cause issues dataset_enlarge_ratio: auto # Increase if the dataset is less than 1000 images to avoid slowdowns. Auto will automatically enlarge small datasets only. prefetch_mode: ~ # Settings for your validation dataset (optional). These settings will # be ignored if val_enabled is false in the Validation section below. val: name: Val Dataset type: PairedImageDataset color: "y" dataroot_gt: datasets/val/grayscale1/hr dataroot_lq: datasets/val/grayscale1/lr io_backend: type: disk ##################### # Network Settings ##################### # Generator model settings network_g: type: HAT # HAT_L, HAT_M, HAT_S in_chans: 1 img_size: 64 window_size: 16 compress_ratio: 3 squeeze_factor: 30 conv_scale: 0.01 overlap_ratio: 0.5 img_range: 1.0 depths: [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6] embed_dim: 180 num_heads: [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6] mlp_ratio: 2 upsampler: "pixelshuffle" resi_connection: "1conv" # Discriminator model settings network_d: type: UNetDiscriminatorSN num_in_ch: 1 num_feat: 64 skip_connection: true ############################ # Pretrain and Resume Paths ############################ path: #pretrain_network_g: experiments/pretrained_models/pretrain.pth param_key_g: ~ strict_load_g: true # Disable strict loading to partially load a pretrain model with a different scale resume_state: ~ ##################### # Training Settings ##################### train: ema_decay: 0.999 # Optimizer for generator model optim_g: type: AdamW lr: !!float 1e-4 weight_decay: 0 betas: [0.9, 0.99] # Optimizer for discriminator model optim_d: type: AdamW lr: !!float 1e-4 weight_decay: 0 betas: [0.9, 0.99] scheduler: type: MultiStepLR milestones: [200000, 400000] gamma: 0.5 total_iter: 500000 # Total number of iterations. warmup_iter: -1 # Gradually ramp up learning rates until this iteration, to stabilize early training. Use -1 to disable. # Losses - for any loss set the loss_weight to 0 to disable it. # MSSIM loss mssim_opt: type: MSSIMLoss loss_weight: 1.0 # Perceptual loss perceptual_opt: type: PerceptualLoss layer_weights: # before relu "conv1_2": 0.1 "conv2_2": 0.1 "conv3_4": 1 "conv4_4": 1 "conv5_4": 1 vgg_type: vgg19 style_weight: 0 range_norm: false criterion: charbonnier perceptual_weight: 0.03 # DISTS loss # dists_opt: # type: DISTSLoss # DISTSLoss, ADISTSLoss # use_input_norm: true # loss_weight: 0.25 # HSLuv loss (hue, saturation, lightness). Use one of HSLuv loss or ColorLoss + LumaLoss, not both. # hsluv_opt: # type: HSLuvLoss # criterion: charbonnier # loss_weight: 1.0 # GAN loss gan_opt: type: GANLoss gan_type: vanilla real_label_val: 1.0 fake_label_val: 0.0 loss_weight: 0.1 # Color loss (ITU-R BT.601) # color_opt: # type: ColorLoss # criterion: charbonnier # loss_weight: 1.0 # Luma loss (RGB to CIELAB L*) # luma_opt: # type: LumaLoss # criterion: charbonnier # loss_weight: 1.0 # Mix of Augmentations (MoA) use_moa: false # Whether to enable mixture of augmentations, which augments the dataset on the fly to create more variety and help the model generalize. moa_augs: ['none', 'mixup', 'cutmix', 'resizemix', 'cutblur'] # The list of augmentations to choose from, only one is selected per iteration. moa_probs: [0.4, 0.084, 0.084, 0.084, 0.348] # The probability each augmentation in moa_augs will be applied. Total should add up to 1. moa_debug: false # Save images before and after augment to debug/moa folder inside of the root training directory. moa_debug_limit: 100 # The max number of iterations to save augmentation images for. ############# # Validation ############# val: val_enabled: true # Whether to enable validations. If disabled, all validation settings below are ignored. val_freq: 1 # How often to run validations, in iterations. save_img: true # Whether to save the validation images during validation, in the experiments//visualization folder. metrics_enabled: true # Whether to run metrics calculations during validation. metrics: psnr: type: calculate_psnr crop_border: 4 test_y_channel: false ssim: type: calculate_ssim crop_border: 4 # Whether to crop border during validation. test_y_channel: false # Whether to convert to Y(CbCr) for validation. dists: type: calculate_dists better: lower ########## # Logging ########## logger: print_freq: 100 save_checkpoint_freq: 1000 save_checkpoint_format: safetensors use_tb_logger: true wandb: project: ~ resume_id: ~ dist_params: backend: nccl port: 29500 ```

But when doing validation (more specifically, calculate_dists), there’s an error.

Exception has occurred: IndexError
tuple index out of range
  File "/home/hychou/kaggle/srgan/traiNNer-redux/traiNNer/utils/img_util.py", line 33, in img2tensor
    if img.shape[2] == 3 and bgr2rgb:
       ~~~~~~~~~^^^
  File "/home/hychou/kaggle/srgan/traiNNer-redux/traiNNer/utils/img_util.py", line 72, in <listcomp>
    return [img2tensor(img, color, bgr2rgb, float32) for img in imgs]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hychou/kaggle/srgan/traiNNer-redux/traiNNer/utils/img_util.py", line 72, in imgs2tensors
    return [img2tensor(img, color, bgr2rgb, float32) for img in imgs]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hychou/kaggle/srgan/traiNNer-redux/traiNNer/metrics/dists.py", line 19, in calculate_dists
    img_t, img2_t = imgs2tensors([img, img2], color=True, bgr2rgb=True, float32=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hychou/kaggle/srgan/traiNNer-redux/traiNNer/metrics/__init__.py", line 24, in calculate_metric
    metric = METRIC_REGISTRY.get(metric_type)(**data, **opt, device=device)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hychou/kaggle/srgan/traiNNer-redux/traiNNer/models/sr_model.py", line 570, in nondist_validation
    self.metric_results[name] += calculate_metric(
                                 ^^^^^^^^^^^^^^^^^
  File "/home/hychou/kaggle/srgan/traiNNer-redux/traiNNer/models/base_model.py", line 95, in validation
    self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
  File "/home/hychou/kaggle/srgan/traiNNer-redux/train.py", line 358, in train_pipeline
    model.validation(
  File "/home/hychou/kaggle/srgan/traiNNer-redux/train.py", line 397, in <module>
    train_pipeline(root_path)
IndexError: tuple index out of range

This is because the img of img2tensor is of shape WxH, not WxHx1, so there’s an index error.

https://github.com/the-database/traiNNer-redux/blob/8593b055d529d6ddf6ca2934d92a17c842eea6c4/traiNNer/utils/img_util.py#L33

Also, calculate_dists is always using color=True so I can’t set it False from HAT.yml

https://github.com/the-database/traiNNer-redux/blob/8593b055d529d6ddf6ca2934d92a17c842eea6c4/traiNNer/metrics/dists.py#L19

I suggest we can add color as an argument of calculate_dists and can be set in HAT.yml like this.

#############
# Validation
#############
val:
  val_enabled: true  # Whether to enable validations. If disabled, all validation settings below are ignored.
  val_freq: 1  # How often to run validations, in iterations.
  save_img: true  # Whether to save the validation images during validation, in the experiments/<name>/visualization folder.

  metrics_enabled: true  # Whether to run metrics calculations during validation.
  metrics:
    dists:
      type: calculate_dists
      better: lower
      color: false  # <- add this

Please see my PR here, thanks.

https://github.com/the-database/traiNNer-redux/pull/56

the-database commented 2 months ago

Closed by #56, thanks!