Open brijow opened 2 years ago
Check the code of the notebook at:
Not sure if this is too late to be helpful, but the following is about twice as fast as looping over a list of captions and seems to be what you want. I've cleaned it up form my own file, so I haven't had a chance to run it and there may be an error lurking somewhere. The basic idea is to replace tiling the tokens coming from a single prompt—see the multiplications by batch_size
in the original code—with additional tokens.
from PIL import Image
from glide_text2im.download import load_checkpoint
from glide_text2im.model_creation import (
create_model_and_diffusion,
model_and_diffusion_defaults,
model_and_diffusion_defaults_upsampler
)
import torch
import matplotlib.pyplot as plt
has_cuda = torch.cuda.is_available()
device = torch.device('cpu' if not has_cuda else 'cuda')
# Create base glide.
options = model_and_diffusion_defaults()
options['use_fp16'] = has_cuda
options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling
glide, diffusion = create_model_and_diffusion(**options)
glide.eval()
if has_cuda:
glide.convert_to_fp16()
glide.to(device)
glide.load_state_dict(load_checkpoint('base', device))
print('total base parameters', sum(x.numel() for x in glide.parameters()))
guidance_scale = 3.0
upsample_temp = 0.997
# Create a classifier-free guidance sampling function
def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = glide(combined, ts, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
def glide_generate(prompts):
"""Returns a tensor of images where the ith image is generated from the ith prompt in [prompts].
Args:
prompts -- list of string prompts
"""
batch_size = len(prompts)
tokens = [glide.tokenizer.encode(p) for p in prompts]
tokens_and_masks = [glide.tokenizer.padded_tokens_and_mask(t, options['text_ctx']) for t in tokens]
tokens = [t for t,_ in tokens_and_masks]
masks = [m for _,m in tokens_and_masks]
# Create the classifier-free guidance tokens (empty)
full_batch_size = batch_size * 2
uncond_tokens, uncond_mask = glide.tokenizer.padded_tokens_and_mask([], options['text_ctx'])
# Pack the tokens together into glide kwargs.
model_kwargs = dict(
tokens=torch.tensor(tokens + [uncond_tokens] * batch_size, device=device),
mask=torch.tensor(masks + [uncond_mask] * batch_size, dtype=torch.bool, device=device))
glide.del_cache()
samples = diffusion.p_sample_loop(
model_fn,
(full_batch_size, 3, options["image_size"], options["image_size"]),
device=device,
clip_denoised=True,
progress=True,
model_kwargs=model_kwargs,
cond_fn=None,
)[:batch_size]
glide.del_cache()
# Uncomment what's below to validate the function
# scaled = ((samples + 1)*127.5).round().clamp(0,255).to(torch.uint8).cpu()
# for s in scaled:
# plt.imshow(s.permute(1, 2, 0) )
# plt.show()
return (samples + 1) / 2
# Uncomment what's below to validate the function
# glide_generate(["a painting of a blue bird", "a painting of a red cat", "a painting of a purple apple"])
Hi, in the example notebook text2im.ipynb, I'm not clear on how to use a larger batch size that 1, or the recommended way to generate many images?
I'd like to play around with the model and generate several thousand images for some captions I have collected and evaluate the overall quality of results... however, I'm not clear on the best way to do this, rather than something along the lines of the psuedo-code below:
Would there be a faster way to do this than (more/less) following the recipe above?