XPixelGroup / DiffBIR

Official codes of DiffBIR: Towards Blind Image Restoration with Generative Diffusion Prior
Apache License 2.0
3.32k stars 282 forks source link

How to train DiffBIR for face retouching? #113

Open luser350 opened 5 months ago

luser350 commented 5 months ago

Hi, thanks for sharing your awesome work. I want to train DiffBIR on 1024 x1024 ffhqr dataset. I want to use ffhq dataset as input (1024x024). I have modified codeformer getitem().

def __getitem__(self, index: int) -> Dict[str, Union[np.ndarray, str]]:
        # load gt image
        img_gt = None
        while img_gt is None:
            # load meta file
            image_file = self.image_files[index]
            gt_path = image_file["image_path"]
            prompt = image_file["prompt"]
            img_gt = self.load_gt_image(gt_path)
            if img_gt is None:
                print(f"filed to load {gt_path}, try another image")
                index = random.randint(0, len(self) - 1)

        # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
        img_gt = (img_gt[..., ::-1] / 255.0).astype(np.float32)
        h, w, _ = img_gt.shape
        if np.random.uniform() < 0.5:
            prompt = ""

        # ------------------------ generate lq image ------------------------ #
        lq_path = gt_path.replace("ffhqr", "ffhq")
        img_lq = cv2.imread(lq_path)
        img_lq = img_lq.astype(np.float32)/255.0

        # BGR to RGB, [-1, 1]
        gt = (img_gt[..., ::-1] * 2 - 1).astype(np.float32)
        # BGR to RGB, [0, 1]
        lq = img_lq[..., ::-1].astype(np.float32)

        return gt, lq, prompt

I have generated the training.list and validation.list for ffhqr. Please guide:

  1. Which pre-trained model I should use, the face_swinir_v1.ckpt or v1_face.pth.
  2. Does I need to train stage_1 or I can directly train stage_2
  3. I want input 1024x1024 and output 1024x1024
luser350 commented 5 months ago

@chxy95 @hejingwenhejingwen @xinntao @WenlongZhang0517 can somebody answer my question, please?

0x3f3f3f3fun commented 5 months ago

Hello!

  1. v1_face.pth. This checkpoint contains the weight of IRControlNet, which receives a smooth face image as condition and output a high-quality restoration result.
  2. I am not familiar with face retouching. I think you can directly train stage_2 with the original face image as conditions.
  3. If you want to train on a resolution of 1024, the first thing you need to do is to ensure that the pretrained VAE can encode and decode your image successfully, because it's originally trained on a resolution of 256. After that, you can directly use 1024 images as inputs, there are no other things to do.

Feel free to contact with me if you have any further questions.

luser350 commented 5 months ago

Hi, thanks for replying. Please tell me which VAE you are referring to. Please guide me on how to verify the VAE.

0x3f3f3f3fun commented 5 months ago

Hi, thanks for replying. Please tell me which VAE you are referring to. Please guide me on how to verify the VAE.

Follow the instructions below:

  1. load ControlLDM
  2. load an 1024x1024 image and convert it to a tensor $x$ (nchw, rgb, range in [-1,1])
  3. call ControlLDM.vae_encode() to encode $x$ to a latent code $z$
  4. call ControlLDM.vae_decode() to decode $z$ and get the result $x\prime$ (nchw, rgb, range in [-1,1])
  5. $x\prime$ should be very close to $x$

Here is an example (not tested):

from model.cldm import ControlLDM
from utils.common import instantiate_from_config
from omegaconf import OmegaConf
import torch

cldm: ControlLDM = instantiate_from_config(OmegaConf.load("configs/inference/cldm.yaml"))
# VAE is contained in pretrained SD
sd = torch.load("path/to/pretrained_sd_v2.1", map_location="cpu")
unused = cldm.load_pretrained_sd(sd)
print(f"strictly load pretrained sd_v2.1, unused weights: {unused}")
cldm.eval().to("cuda")

# load image and convert to tensor
from PIL import Image
img = Image.open("xxx").convert("RGB")
x = ...

with torch.no_grad():
    z = cldm.vae_encode(x)
    x_decoded = cldm.vae_decode(z)

# convert x_decoded back to image 
img_decoded = ...

# save image and take a look at them...
luser350 commented 5 months ago

Hi @0x3f3f3f3fun I have completed the above step. Here is my script

from model.cldm import ControlLDM
from utils.common import instantiate_from_config
from omegaconf import OmegaConf
import torch
import torchvision.transforms as transforms
from PIL import Image

# Define the transformation
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 
])

cldm: ControlLDM = instantiate_from_config(OmegaConf.load("configs/inference/cldm.yaml"))
# VAE is contained in pretrained SD
sd = torch.load("pretrained/v2-1_512-ema-pruned.ckpt", map_location="cpu")
#unused = cldm.load_pretrained_sd(sd)
#print(f"strictly load pretrained sd_v2.1, unused weights: {unused}")
cldm.eval().to("cpu")

# load image and convert to tensor
img = Image.open("00006.png").convert("RGB")
x = transform(img)
x = x.unsqueeze(0)
#x = x.permute(0, 2, 3, 1)

with torch.no_grad():
    z = cldm.vae_encode(x)
    x_decoded = cldm.vae_decode(z)

# reverse normalization
x_decoded = x_decoded.squeeze()#.permute(1, 2, 0) 
x_decoded = (x_decoded * 0.5) + 0.5 
img_decoded = (transforms.ToPILImage()(x_decoded.cpu().clamp(0, 1)))

img_decoded.save("decoded_image.png")

I have commented out these two lines

unused = cldm.load_pretrained_sd(sd)
print(f"strictly load pretrained sd_v2.1, unused weights: {unused}")

It was giving a KeyError

Traceback (most recent call last):
  File "/home/luser350/Desktop/diffbir/DiffBIR/vae.py", line 16, in <module>
    unused = cldm.load_pretrained_sd(sd)
  File "/home/luser350/anaconda3/envs/diffbir/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/luser350/Desktop/diffbir/DiffBIR/model/cldm.py", line 51, in load_pretrained_sd
    init_sd[key] = sd[target_key].clone()
KeyError: 'model.diffusion_model.time_embed.0.weight'

My input image 00006 Decoded Image decoded_image

The obvious result, since cldm was unable to load the sd model. So how to solve this KeyError