Open luser350 opened 6 months ago
@chxy95 @hejingwenhejingwen @xinntao @WenlongZhang0517 can somebody answer my question, please?
Hello!
Feel free to contact with me if you have any further questions.
Hi, thanks for replying. Please tell me which VAE you are referring to. Please guide me on how to verify the VAE.
Hi, thanks for replying. Please tell me which VAE you are referring to. Please guide me on how to verify the VAE.
Follow the instructions below:
Here is an example (not tested):
from model.cldm import ControlLDM
from utils.common import instantiate_from_config
from omegaconf import OmegaConf
import torch
cldm: ControlLDM = instantiate_from_config(OmegaConf.load("configs/inference/cldm.yaml"))
# VAE is contained in pretrained SD
sd = torch.load("path/to/pretrained_sd_v2.1", map_location="cpu")
unused = cldm.load_pretrained_sd(sd)
print(f"strictly load pretrained sd_v2.1, unused weights: {unused}")
cldm.eval().to("cuda")
# load image and convert to tensor
from PIL import Image
img = Image.open("xxx").convert("RGB")
x = ...
with torch.no_grad():
z = cldm.vae_encode(x)
x_decoded = cldm.vae_decode(z)
# convert x_decoded back to image
img_decoded = ...
# save image and take a look at them...
Hi @0x3f3f3f3fun I have completed the above step. Here is my script
from model.cldm import ControlLDM
from utils.common import instantiate_from_config
from omegaconf import OmegaConf
import torch
import torchvision.transforms as transforms
from PIL import Image
# Define the transformation
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
cldm: ControlLDM = instantiate_from_config(OmegaConf.load("configs/inference/cldm.yaml"))
# VAE is contained in pretrained SD
sd = torch.load("pretrained/v2-1_512-ema-pruned.ckpt", map_location="cpu")
#unused = cldm.load_pretrained_sd(sd)
#print(f"strictly load pretrained sd_v2.1, unused weights: {unused}")
cldm.eval().to("cpu")
# load image and convert to tensor
img = Image.open("00006.png").convert("RGB")
x = transform(img)
x = x.unsqueeze(0)
#x = x.permute(0, 2, 3, 1)
with torch.no_grad():
z = cldm.vae_encode(x)
x_decoded = cldm.vae_decode(z)
# reverse normalization
x_decoded = x_decoded.squeeze()#.permute(1, 2, 0)
x_decoded = (x_decoded * 0.5) + 0.5
img_decoded = (transforms.ToPILImage()(x_decoded.cpu().clamp(0, 1)))
img_decoded.save("decoded_image.png")
I have commented out these two lines
unused = cldm.load_pretrained_sd(sd)
print(f"strictly load pretrained sd_v2.1, unused weights: {unused}")
It was giving a KeyError
Traceback (most recent call last):
File "/home/luser350/Desktop/diffbir/DiffBIR/vae.py", line 16, in <module>
unused = cldm.load_pretrained_sd(sd)
File "/home/luser350/anaconda3/envs/diffbir/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/luser350/Desktop/diffbir/DiffBIR/model/cldm.py", line 51, in load_pretrained_sd
init_sd[key] = sd[target_key].clone()
KeyError: 'model.diffusion_model.time_embed.0.weight'
My input image Decoded Image
The obvious result, since cldm was unable to load the sd model. So how to solve this KeyError
Hi, thanks for sharing your awesome work. I want to train DiffBIR on 1024 x1024 ffhqr dataset. I want to use ffhq dataset as input (1024x024). I have modified codeformer getitem().
I have generated the training.list and validation.list for ffhqr. Please guide: