CompVis / latent-diffusion

High-Resolution Image Synthesis with Latent Diffusion Models
MIT License
11.09k stars 1.45k forks source link

Semantic Image Synthesis #110

Open KyriaAnnwyn opened 1 year ago

KyriaAnnwyn commented 1 year ago

Trying to implement inference for these semantic image synthesis nets. I'm getting this error: expected input[1, 3, 301, 240] to have 182 channels, but got 3 channels instead

Seems like we have to make a binary channel for each class in our segmentation. This seems to be cocostuff segmentation as it has 182 classes. Is this correct?

Are you planning to share inference script for semantic image synthesis?

KyriaAnnwyn commented 1 year ago

This doesn't crush, but I get really strange picture 284787225_1173851433457064_4870445758140443386_n s

che808 commented 1 year ago

@KyriaAnnwyn hi, I'm also cracking the semantic image synthesis with LDM. Could you share your script? I want to check it.

KyriaAnnwyn commented 1 year ago

@che808 This is my semantic.py `import argparse, os, sys, glob from ctypes import resize from omegaconf import OmegaConf from PIL import Image from tqdm import tqdm import numpy as np import torch from main import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler import time import torch, torchvision from einops import rearrange, repeat

def ismap(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] > 3)

@torch.no_grad() def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None, mask=None, x0=None, quantize_x0=False, img_callback=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, x_T=None, log_every_t=None ):

ddim = DDIMSampler(model)
bs = shape[0]  # dont know where this comes from but wayne
shape = shape[1:]  # cut batch dim
print(f"Sampling with eta = {eta}; steps: {steps}")
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
                                     normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
                                     mask=mask, x0=x0, temperature=temperature, verbose=False,
                                     score_corrector=score_corrector,
                                     corrector_kwargs=corrector_kwargs, x_T=x_T)

return samples, intermediates

@torch.no_grad() def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, eta=1.0, swap_mode=False, masked=False, invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000, resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False): log = dict()

z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
                                    return_first_stage_outputs=True,
                                    force_c_encode=not (hasattr(model, 'split_input_params')
                                                        and model.cond_stage_key == 'coordinates_bbox'),
                                    return_original_cond=True)

log_every_t = 1 if save_intermediate_vid else None

if custom_shape is not None:
    z = torch.randn(custom_shape)
    print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")

z0 = None

log["input"] = x
log["reconstruction"] = xrec

if ismap(xc):
    log["original_conditioning"] = model.to_rgb(xc)
    if hasattr(model, 'cond_stage_key'):
        log[model.cond_stage_key] = model.to_rgb(xc)

else:
    log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
    if model.cond_stage_model:
        log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
        if model.cond_stage_key =='class_label':
            log[model.cond_stage_key] = xc[model.cond_stage_key]

with model.ema_scope("Plotting"):
    t0 = time.time()
    img_cb = None

    sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
                                            eta=eta,
                                            quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0,
                                            temperature=temperature, noise_dropout=noise_dropout,
                                            score_corrector=corrector, corrector_kwargs=corrector_kwargs,
                                            x_T=x_T, log_every_t=log_every_t)
    t1 = time.time()

    if ddim_use_x0_pred:
        sample = intermediates['pred_x0'][-1]

x_sample = model.decode_first_stage(sample)

try:
    x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
    log["sample_noquant"] = x_sample_noquant
    log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except:
    pass

log["sample"] = x_sample
log["time"] = t1 - t0

return log

def get_cond(mode, selected_path): example = dict() if mode == "superresolution": up_f = 4

visualize_cond_img(selected_path)

    c = Image.open(selected_path)
    c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
    c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)
    c_up = rearrange(c_up, '1 c h w -> 1 h w c')
    c = rearrange(c, '1 c h w -> 1 h w c')
    c = 2. * c - 1.

    c = c.to(torch.device("cuda"))
    example["LR_image"] = c
    example["image"] = c_up
if mode == "semantic":
    up_f = 1
    #visualize_cond_img(selected_path)

    c = Image.open(selected_path)
    c1 = c.convert('RGB')
    c1 = torch.unsqueeze(torchvision.transforms.ToTensor()(c1), 0)
    c_up = torchvision.transforms.functional.resize(c1, size=[up_f * c1.shape[2], up_f * c1.shape[3]], antialias=True)
    c_up = rearrange(c_up, '1 c h w -> 1 h w c')
    imarr = np.array(c)

    zeros_tensor = torch.zeros(182,imarr.shape[0], imarr.shape[1])
    print(zeros_tensor.shape)
    for (x,y), value in np.ndenumerate(imarr):
        v = imarr[x][y]
        zeros_tensor[v+1][x][y] = 1.0 # +1 because my segmentation alg doesnt include unlabeled class 0
    c = torch.unsqueeze(zeros_tensor, 0)
    c = rearrange(c, '1 c h w -> 1 h w c')

    c = c.to(torch.device("cuda"))
    example["segmentation"] = c
    example["image"] = c_up

return example

def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None):

example = get_cond(task, selected_path)

save_intermediate_vid = False
n_runs = 1
masked = False
guider = None
ckwargs = None
mode = 'ddim'
ddim_use_x0_pred = False
temperature = 1.
eta = 1.
make_progrow = True
custom_shape = None

height, width = example["image"].shape[1:3]
split_input = height >= 128 and width >= 128

if split_input:
    ks = 128
    stride = 64
    vqf = 4  #
    model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
                                "vqf": vqf,
                                "patch_distributed_vq": True,
                                "tie_braker": False,
                                "clip_max_weight": 0.5,
                                "clip_min_weight": 0.01,
                                "clip_max_tie_weight": 0.5,
                                "clip_min_tie_weight": 0.01}
else:
    if hasattr(model, "split_input_params"):
        delattr(model, "split_input_params")

invert_mask = False

x_T = None
for n in range(n_runs):
    if custom_shape is not None:
        x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
        x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0])

    logs = make_convolutional_sample(example, model,
                                     mode=mode, custom_steps=custom_steps,
                                     eta=eta, swap_mode=False , masked=masked,
                                     invert_mask=invert_mask, quantize_x0=False,
                                     custom_schedule=None, decode_interval=10,
                                     resize_enabled=resize_enabled, custom_shape=custom_shape,
                                     temperature=temperature, noise_dropout=0.,
                                     corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,
                                     make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred
                                     )
return logs

def load_model_from_config(config, ckpt): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") global_step = pl_sd["global_step"] sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) model.cuda() model.eval() return {"model": model}, global_step

if name == "main": parser = argparse.ArgumentParser() parser.add_argument( "--indir", type=str, nargs="?", help="dir containing image-mask pairs (example.png and example_mask.png)", ) parser.add_argument( "--outdir", type=str, nargs="?", help="dir to write results to", ) parser.add_argument( "--steps", type=int, default=50, help="number of ddim sampling steps", ) opt = parser.parse_args()

os.makedirs(opt.outdir, exist_ok=True)

path_conf = "models/ldm/semantic_synthesis512/config.yaml"
path_ckpt = "models/ldm/semantic_synthesis512/model.ckpt"
config = OmegaConf.load(path_conf)
model, step = load_model_from_config(config, path_ckpt)

#sampler = DDIMSampler(model)

custom_steps = 200
fname = "./SemanticDataCoco/285178355_735479634247196_5946430906048693462_n.png"
logs = run(model["model"], fname, "semantic", custom_steps)

sample = logs["sample"]
sample = sample.detach().cpu()
sample = torch.clamp(sample, -1., 1.)
sample = (sample + 1.) / 2. * 255
sample = sample.numpy().astype(np.uint8)
sample = np.transpose(sample, (0, 2, 3, 1))
print(sample.shape)
a = Image.fromarray(sample[0])

outname = "./semantic_results/285178355_735479634247196_5946430906048693462_n.png"
a.save(outname)`
che808 commented 1 year ago

thanks, @KyriaAnnwyn for your sharing, I will reply to you if I have progress.

myyy777 commented 1 year ago

hi, I'm also cracking the semantic image synthesis with LDM. How is it going now? I want a github that can achieve semantic image sysnthesis. Any model is ok. Thank u very much!

Feanor007 commented 1 year ago

Hi, does anyone know how to train this semantic synthesis model?

Wang-Wenqing commented 1 year ago

@KyriaAnnwyn Thanks for your sharing. I'm not very familiar with Semantic Image Synthesis tasks, so can I ask what "split_input_params" means and what it works for? Thanks

Wang-Wenqing commented 1 year ago

@KyriaAnnwyn Thanks for your sharing. I'm not very familiar with Semantic Image Synthesis tasks, so can I ask what "split_input_params" means and what it works for? Thanks

Sorry, it's not just in Semantic Image Synthesis tasks, and it exists a lot in ddpm.py, and I am still very confused about it.