zsyOAOA / ResShift

ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting (NeurIPS@2023 Spotlight, TPAMI@2024)
Other
870 stars 47 forks source link

Training sf:1 (debluring) #22

Closed cyprian closed 10 months ago

cyprian commented 11 months ago

Thank you for providing your code. I already tested the super resolution, and it great. Is it possible to adopt the config to do debluring on lq:256 x gt:256 so without any super resolution? What would I have to change.

zsyOAOA commented 11 months ago

You only need to rewrite the class of dataset. Please refer to this bicubic dataset.

cyprian commented 11 months ago

I have changed the Dataset implemetation. However I am getting the following issue in Unet mismatch when I run it.

Traceback (most recent call last):
  File "/content/ResShift/main.py", line 53, in <module>
    trainer.train()
  File "/content/ResShift/trainer.py", line 275, in train
    self.training_step(data)
  File "/content/ResShift/trainer.py", line 638, in training_step
    losses, z_t, z0_pred = compute_losses()
  File "/content/ResShift/models/respace.py", line 47, in training_losses
    return super().training_losses(self._wrap_model(model), *args, **kwargs)
  File "/content/ResShift/models/gaussian_diffusion.py", line 537, in training_losses
    model_output = model(self._scale_input(z_t, t), t, **model_kwargs)
  File "/content/ResShift/models/respace.py", line 63, in __call__
    return self.model(x, new_ts, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/ResShift/models/unet.py", line 846, in forward
    x = th.cat([x, lq], dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 16 but got size 32 for tensor number 1 in the list.

I am trying test it on input sizes GT= 64x64 and LQ= 64x64

Here is my implementation of the Dataset:

class PairedData(Dataset):
    def __init__(
            self,
            sf,
            dir_path=None,
            dir_path_lq=None,
            txt_file_path=None,
            txt_file_path_lq=None,
            mean=0.5,
            std=0.5,
            hflip=False,
            rotation=False,
            resize_back=False,
            length=None,
            need_path=False,
            im_exts=['png', 'jpg', 'jpeg', 'JPEG', 'bmp'],
            recursive=False,
            use_sharp=False,
            rescale_gt=True,
            gt_size=256,
            ):

        if txt_file_path is None:
            assert dir_path is not None
            self.file_paths_all = util_common.scan_files_from_folder(dir_path, im_exts, recursive)
        else:
            self.file_paths_all = util_common.readline_txt(txt_file_path)

        # Load low-quality (lq) file paths
        if txt_file_path_lq is None:
            assert dir_path_lq is not None
            self.file_paths_all_lq = util_common.scan_files_from_folder(dir_path_lq, im_exts, recursive)
        else:
            self.file_paths_all_lq = util_common.readline_txt(txt_file_path_lq)

        if length is None:
            self.file_paths = self.file_paths_all
            self.file_paths_lq = self.file_paths_all_lq
        else:
            assert len(self.file_paths_all) >= length
            assert len(self.file_paths_all_lq) >= length
            self.file_paths = random.sample(self.file_paths_all, length)
            self.file_paths_lq = random.sample(self.file_paths_all_lq, length)

        self.sf = sf
        self.mean = mean
        self.std = std
        self.hflip = hflip
        self.rotation = rotation
        self.length = length
        self.need_path = need_path
        self.resize_back = resize_back
        self.use_sharp = use_sharp
        self.rescale_gt = rescale_gt
        self.gt_size = gt_size

        self.transform = get_transforms('default', {'mean': mean, 'std': std})
        if rescale_gt:
            self.smallest_rescaler = SmallestMaxSize(max_size=gt_size)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, index):
         # Load ground truth image
        im_path = self.file_paths[index]
        im_gt = util_image.imread(im_path, chn='rgb', dtype='float32')

        # Load low quality image
        im_path_lq = self.file_paths_lq[index]
        im_lq = util_image.imread(im_path_lq, chn='rgb', dtype='float32')

        # Augmentation for training
        im_gt = augment(im_gt, hflip=self.hflip, rotation=self.rotation, return_status=False)
        im_lq = augment(im_lq, hflip=self.hflip, rotation=self.rotation, return_status=False)

        # im_lq = np.clip(im_lq, 0.0, 1.0)

        out = {'lq': self.transform(im_lq), 'gt': self.transform(im_gt)}
        if self.need_path:
            out['path'] = im_path  # or you can include both im_path and im_path_lq

        return out

And here is the test config I am running it with:

model:
  target: models.unet.UNetModelSwin
  ckpt_path: ~
  params:
    image_size: 64
    in_channels: 6
    model_channels: 160
    out_channels: 3
    cond_lq: True
    attention_resolutions: [64,32,16,8]
    dropout: 0
    channel_mult: [1, 2, 2, 4]
    num_res_blocks: [2, 2, 2, 2]
    conv_resample: True
    dims: 2
    use_fp16: False
    num_head_channels: 32
    use_scale_shift_norm: True
    resblock_updown: False
    swin_depth: 2
    swin_embed_dim: 192
    window_size: 8
    mlp_ratio: 4

diffusion:
  target: models.script_util.create_gaussian_diffusion
  params:
    sf: 1
    schedule_name: exponential
    schedule_kwargs:
      power: 0.3
    etas_end: 0.99
    steps: 15
    min_noise_level: 0.04
    kappa: 1.0
    weighted_mse: False
    predict_type: xstart
    timestep_respacing: ~
    scale_factor: 1.0
    normalize_input: True
    latent_flag: True

autoencoder:
  target: ldm.models.autoencoder.VQModelTorch
  ckpt_path: weights/autoencoder_vq_f4.pth
  use_fp16: True
  params:
    embed_dim: 3
    n_embed: 8192
    ddconfig:
      double_z: False
      z_channels: 3
      resolution: 64
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult:
      - 1
      - 2
      - 4
      num_res_blocks: 2
      attn_resolutions: []
      dropout: 0.0
      padding_mode: zeros

data:
  train:
    type: paired
    params:
      sf: 1
      dir_path: datasets/clean/64/train/clean
      dir_path_lq: datasets/clean/64/train/dirty
      txt_file_path: ~
      txt_file_path_lq: ~
      mean: 0.5
      std: 0.5
      hflip: False
      rotation: False
      resize_back: False
      length: ~
      need_path: False
      im_exts: ['png', 'jpg', 'jpeg', 'JPEG', 'bmp']
      recursive: False
      use_sharp: False
      rescale_gt: False
      gt_size: 64
  val:
    type: paired
    params:
      sf: 1
      dir_path: datasets/clean/64/val/clean
      dir_path_lq: datasets/clean/64/val/dirty
      txt_file_path: ~
      txt_file_path_lq: ~
      mean: 0.5
      std: 0.5
      hflip: False
      rotation: False
      resize_back: False
      length: 5 # number of images to be evaluated
      need_path: False
      im_exts: ['png', 'jpg', 'jpeg', 'JPEG', 'bmp']
      recursive: False
      use_sharp: False
      rescale_gt: False
      gt_size: 64

train:
  lr: 5e-5
  batch: [64, 8]   # batchsize for training and validation
  use_fp16: False
  microbatch: 16
  seed: 123456
  global_seeding: False
  prefetch_factor: 4
  num_workers: 8
  ema_rate: 0.999
  iterations: 500000
  milestones: [5000, 500000]
  weight_decay: 0
  save_freq: 100
  val_freq: 100
  log_freq: [10, 100, 1] #[training loss, training images, val images]
  save_images: True  # save the images of tensorboard logging
  use_ema_val: True
cyprian commented 11 months ago

Hi @zsyOAOA would you be able to provide any guidance on where to look in the code to enable training of same size inputs and outputs?

zsyOAOA commented 11 months ago

@cyprian For the same size inputs and outputs, try to set the diffusion.params.sf = 1 and model.params.in_channels=51 (51=3+48) in the config file.

Note that the vqgan downsamples the input 4 times, we thus unshuffles the input to the dimension of H/4 x W/4 x 48.

cyprian commented 10 months ago

@zsyOAOA the model.params.in_channels=51 still gave me tensor size mismatch on the vqgan downsampling. I went with a different approach of encoding both LQ and GT via VQGAN and that seams to work, but I am loosing some of the fidelity with this 4x downsampling. Is there a VGGAN model that does 2x downsampling? I could not find where you got the autoencoder_vq_f4.pth from.

zsyOAOA commented 10 months ago

I haven't a VQGAN model with 2x downsampling. For the 4x model, please find it in this link.

cyprian commented 10 months ago

@zsyOAOA I would like to train my own VQGAN model to create better image representation for my image class. Can you tell me how you trained your VQGAN?

zsyOAOA commented 10 months ago

I haven't trained VQGAN. If you want to train VQGAN, please refer to this repo. @cyprian

cyprian commented 10 months ago

Thank you for a quick reply. In the repo I don't see the weights for the VQGAN f4, that you are using in your code. I am trying to find config for that training so that I could just use my own Dataset. Did you download your weights from that repo? @zsyOAOA

zsyOAOA commented 10 months ago

The checkpoint I used is extracted from the latent diffusion model for image super-resolution.

cyprian commented 10 months ago

Ok. So if I understand correctly you extracted just extracted the first_stage_model from that super resolution checkpoint? @zsyOAOA (BTW, I really appreciate your help)

zsyOAOA commented 10 months ago

Yes. @cyprian