Open KyriaAnnwyn opened 1 year ago
This doesn't crush, but I get really strange picture
s
@KyriaAnnwyn hi, I'm also cracking the semantic image synthesis with LDM. Could you share your script? I want to check it.
@che808 This is my semantic.py `import argparse, os, sys, glob from ctypes import resize from omegaconf import OmegaConf from PIL import Image from tqdm import tqdm import numpy as np import torch from main import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler import time import torch, torchvision from einops import rearrange, repeat
def ismap(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] > 3)
@torch.no_grad() def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None, mask=None, x0=None, quantize_x0=False, img_callback=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, x_T=None, log_every_t=None ):
ddim = DDIMSampler(model)
bs = shape[0] # dont know where this comes from but wayne
shape = shape[1:] # cut batch dim
print(f"Sampling with eta = {eta}; steps: {steps}")
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
mask=mask, x0=x0, temperature=temperature, verbose=False,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, x_T=x_T)
return samples, intermediates
@torch.no_grad() def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, eta=1.0, swap_mode=False, masked=False, invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000, resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False): log = dict()
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=not (hasattr(model, 'split_input_params')
and model.cond_stage_key == 'coordinates_bbox'),
return_original_cond=True)
log_every_t = 1 if save_intermediate_vid else None
if custom_shape is not None:
z = torch.randn(custom_shape)
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
z0 = None
log["input"] = x
log["reconstruction"] = xrec
if ismap(xc):
log["original_conditioning"] = model.to_rgb(xc)
if hasattr(model, 'cond_stage_key'):
log[model.cond_stage_key] = model.to_rgb(xc)
else:
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_model:
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_key =='class_label':
log[model.cond_stage_key] = xc[model.cond_stage_key]
with model.ema_scope("Plotting"):
t0 = time.time()
img_cb = None
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
eta=eta,
quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0,
temperature=temperature, noise_dropout=noise_dropout,
score_corrector=corrector, corrector_kwargs=corrector_kwargs,
x_T=x_T, log_every_t=log_every_t)
t1 = time.time()
if ddim_use_x0_pred:
sample = intermediates['pred_x0'][-1]
x_sample = model.decode_first_stage(sample)
try:
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except:
pass
log["sample"] = x_sample
log["time"] = t1 - t0
return log
def get_cond(mode, selected_path): example = dict() if mode == "superresolution": up_f = 4
c = Image.open(selected_path)
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1.
c = c.to(torch.device("cuda"))
example["LR_image"] = c
example["image"] = c_up
if mode == "semantic":
up_f = 1
#visualize_cond_img(selected_path)
c = Image.open(selected_path)
c1 = c.convert('RGB')
c1 = torch.unsqueeze(torchvision.transforms.ToTensor()(c1), 0)
c_up = torchvision.transforms.functional.resize(c1, size=[up_f * c1.shape[2], up_f * c1.shape[3]], antialias=True)
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
imarr = np.array(c)
zeros_tensor = torch.zeros(182,imarr.shape[0], imarr.shape[1])
print(zeros_tensor.shape)
for (x,y), value in np.ndenumerate(imarr):
v = imarr[x][y]
zeros_tensor[v+1][x][y] = 1.0 # +1 because my segmentation alg doesnt include unlabeled class 0
c = torch.unsqueeze(zeros_tensor, 0)
c = rearrange(c, '1 c h w -> 1 h w c')
c = c.to(torch.device("cuda"))
example["segmentation"] = c
example["image"] = c_up
return example
def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None):
example = get_cond(task, selected_path)
save_intermediate_vid = False
n_runs = 1
masked = False
guider = None
ckwargs = None
mode = 'ddim'
ddim_use_x0_pred = False
temperature = 1.
eta = 1.
make_progrow = True
custom_shape = None
height, width = example["image"].shape[1:3]
split_input = height >= 128 and width >= 128
if split_input:
ks = 128
stride = 64
vqf = 4 #
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
"vqf": vqf,
"patch_distributed_vq": True,
"tie_braker": False,
"clip_max_weight": 0.5,
"clip_min_weight": 0.01,
"clip_max_tie_weight": 0.5,
"clip_min_tie_weight": 0.01}
else:
if hasattr(model, "split_input_params"):
delattr(model, "split_input_params")
invert_mask = False
x_T = None
for n in range(n_runs):
if custom_shape is not None:
x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0])
logs = make_convolutional_sample(example, model,
mode=mode, custom_steps=custom_steps,
eta=eta, swap_mode=False , masked=masked,
invert_mask=invert_mask, quantize_x0=False,
custom_schedule=None, decode_interval=10,
resize_enabled=resize_enabled, custom_shape=custom_shape,
temperature=temperature, noise_dropout=0.,
corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,
make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred
)
return logs
def load_model_from_config(config, ckpt): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") global_step = pl_sd["global_step"] sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) model.cuda() model.eval() return {"model": model}, global_step
if name == "main":
parser = argparse.ArgumentParser()
parser.add_argument(
"--indir",
type=str,
nargs="?",
help="dir containing image-mask pairs (example.png
and example_mask.png
)",
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
)
parser.add_argument(
"--steps",
type=int,
default=50,
help="number of ddim sampling steps",
)
opt = parser.parse_args()
os.makedirs(opt.outdir, exist_ok=True)
path_conf = "models/ldm/semantic_synthesis512/config.yaml"
path_ckpt = "models/ldm/semantic_synthesis512/model.ckpt"
config = OmegaConf.load(path_conf)
model, step = load_model_from_config(config, path_ckpt)
#sampler = DDIMSampler(model)
custom_steps = 200
fname = "./SemanticDataCoco/285178355_735479634247196_5946430906048693462_n.png"
logs = run(model["model"], fname, "semantic", custom_steps)
sample = logs["sample"]
sample = sample.detach().cpu()
sample = torch.clamp(sample, -1., 1.)
sample = (sample + 1.) / 2. * 255
sample = sample.numpy().astype(np.uint8)
sample = np.transpose(sample, (0, 2, 3, 1))
print(sample.shape)
a = Image.fromarray(sample[0])
outname = "./semantic_results/285178355_735479634247196_5946430906048693462_n.png"
a.save(outname)`
thanks, @KyriaAnnwyn for your sharing, I will reply to you if I have progress.
hi, I'm also cracking the semantic image synthesis with LDM. How is it going now? I want a github that can achieve semantic image sysnthesis. Any model is ok. Thank u very much!
Hi, does anyone know how to train this semantic synthesis model?
@KyriaAnnwyn Thanks for your sharing. I'm not very familiar with Semantic Image Synthesis tasks, so can I ask what "split_input_params" means and what it works for? Thanks
@KyriaAnnwyn Thanks for your sharing. I'm not very familiar with Semantic Image Synthesis tasks, so can I ask what "split_input_params" means and what it works for? Thanks
Sorry, it's not just in Semantic Image Synthesis tasks, and it exists a lot in ddpm.py, and I am still very confused about it.
Trying to implement inference for these semantic image synthesis nets. I'm getting this error: expected input[1, 3, 301, 240] to have 182 channels, but got 3 channels instead
Seems like we have to make a binary channel for each class in our segmentation. This seems to be cocostuff segmentation as it has 182 classes. Is this correct?
Are you planning to share inference script for semantic image synthesis?