explainingai-code / StableDiffusion-PyTorch

This repo implements a Stable Diffusion model in PyTorch with all the essential components.
121 stars 24 forks source link

Can this code use for image super-resolution or restoration? #27

Open wendeyy opened 1 week ago

wendeyy commented 1 week ago

Hi, I would like to apply this model for image super-resolution or restoration. Specifically, I want to try enhancing images that are blurred due to adverse weather conditions. Maybe this should be feasible i think and what adjustments should I make in the code? Thank you so much!

explainingai-code commented 6 days ago

Hello @wendeyy , I think you can use the code which does mask conditioned generation to perform super-resolution without requiring too many changes. So say you want to train a model which given a 32x32 image generates a 8x resolution(256x256) image.

Here are steps that I believe should enable you to have something work out of the box(with celebhq dataset as example).

  1. Start with config created for celebhq https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/config/celebhq_text_image_cond.yaml
  2. Remove the text conditioning(and keep mask conditioning) parameters from config - https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/config/celebhq_text_image_cond.yaml#L24
  3. Modify the input channels to be 3 in the mask config - https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/config/celebhq_text_image_cond.yaml#L32
  4. In the celeb_dataset file ensure the actual images (256x256) are added here - https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/dataset/celeb_dataset.py#L84
  5. In get_mask method simply return a resized 32x32 version of this image here - https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/dataset/celeb_dataset.py#L102 .
  6. Train ldm with this config using train_ddpm_cond script

Could you please try these and let me know if you run into any issues with them.

In case of any confusion, I also talk about mask conditioning and super resolution and what exact inputs the repo uses for mask conditioning in the conditional ldm video mentioned in README. Maybe just look at those parts and see if it helps to get a better understanding of the repos code Mask Conditioning @ 18:47 Super resolution @ 25:50

wendeyy commented 6 days ago

@explainingai-code Thank you for your thorough explanation. I realize I may not described my task clearly. My goal is to deblur the image or enhance it to make it more clear than the original. Maybe it's a little bit different from mask-conditioned generation and super-resolution, but I will try the mask conditioned generation code first to see the results. If you have any suggestions or further ideas, I would greatly appreciate it!

wendeyy commented 3 days ago

Hi @explainingai-code, I've made changes to the code following your instructions and tried to run it. Here’s what I did.

  1. Remove the text conditioning, keep mask conditioning parameters and change input channels to 3 condition_config: condition_types: ['image'] image_condition_config: image_condition_input_channels: 3 image_condition_output_channels: 3 image_condition_h : 512 image_condition_w : 512 cond_drop_prob: 0.1

  2. My data file structure is like the following, StableDiffusion-PyTorch -> data -> CelebAMask-HQ -> CelebA-HQ-img -> CelebA-HQ-img copy -> CelebAMask-HQ-mask-anno -> 0/1/2/3.../14 -> .png -> CelebAMask-HQ-mask -> .png i load some celebahq images and resize thenm to size 32x32, saving in CelebA-HQ-img. and resize them to 256x256, saving in CelebAMask-HQ-mask to make sure here: if 'image' in self.condition_types: im_name = int(os.path.split(fname)[1].split('.')[0]) masks.append(os.path.join(im_path, 'CelebAMask-HQ-mask', '{}.png'.format(im_name))) read the 256x256 actual images.

  3. In get_mask method simply return a resized 32x32 image def get_mask(self, index): mask_im = Image.open(self.masks[index]) mask_im = mask_im.resize((32, 32)) return mask_im

not sure if these changes are correct, and error came out, seems I have to train the vae first? I would very appreciate any suggestions you have. image

explainingai-code commented 3 days ago

Yes, since this is latent diffusion model, we would need to train a VAE(but vae on celebhq should not require more than 4-5 epochs to get a decent result). I have a trained VAE checkpoint but that is for 128x128 images and with a downscale factor of 4(latent size will be 32x32). If it helps I can share that checkpoint as well as changes you would need in config to work with that.

Regarding the conditioning changes that you have mentioned, they seem fine. But if you are looking specifically for restoration, then wouldn't it be better that rather than resizing 256x256 to 32x32, you first resize it to 16x16 and then to 32x32. That should blur the condition image. Then you would be training the model to denoise a latent image conditioned on a blurry image(32x32) in pixel space. And that way during inference, when you pass in your blurry image with a random noise sample, it should be able to generate a latent which when passed to decoder of vae gives us a deblurred version of the blurry condition image.

As a disclaimer I havent ever trained a deblurring model or read papers on that topic, so this is just something that I think should intuitively work(but not sure).