CompVis / stable-diffusion

A latent text-to-image diffusion model
https://ommer-lab.com/research/latent-diffusion-models/
Other
68.54k stars 10.18k forks source link

Interpolation of latent representation is not semantic interpolation ? #263

Open ProkopHapala opened 2 years ago

ProkopHapala commented 2 years ago

In this paper (Fig.6) https://arxiv.org/pdf/2010.02502.pdf they show that it is possible to interpolate sematically the images in latent space.

I tried it with Colab verision of stable diffusion here https://colab.research.google.com/drive/11xRHNFskeBse0J4m5U3-FhUyw4c1mNch?usp=sharing

simpole code looks like this like this:

  prompt="Ent woman is a mythological female character with branches of oak tree growing from her head and trunk of tree in place of torso. Yet she has narrow waist and wide hips, Fantasy, Illustration, Craig Mullins, Octane Render"

  text_embeds = get_text_embeds(prompt)
  latents1   = denoise_lantents( text_embeds ) # generate latent encoding for image 1
  latents2   = denoise_lantents( text_embeds ) # generate latent encoding for image 2
  latents_comb = (latents1+latents2)*0.5           # interpolate the two images in latentn space
  imgs1      = decode_img_latents(latents1)      # decode image1 from latent representation
  imgs2      = decode_img_latents(latents2)      # decode image2 from latent representation
  imgs_comb  = decode_img_latents(latents_comb)  # decode interpolated image from latent representation
  display(imgs1[0],imgs2[0],imgs_comb[0]) 

Rather than sematic interpolation in seems to do just simple interpolation in the image space: like this https://ibb.co/n7RrBrn

Why? What I do wrong? Is is somehow possible to achieve sematic-interpolation like described in the paper (Fig.6) https://arxiv.org/pdf/2010.02502.pdf

================= For completenes there are the functions ===============

def get_text_embeds(prompt):
  text_input = tokenizer(prompt, padding='max_length', max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt')
  with torch.no_grad():
    text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
  uncond_input = tokenizer([''] * len(prompt), padding='max_length',max_length=tokenizer.model_max_length, return_tensors='pt')
  with torch.no_grad():
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
  return text_embeddings

def denoise_lantents(text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, start_step=0 ):
  if latents is None:
    latents = torch.randn((text_embeddings.shape[0] // 2, unet.in_channels, height // 8, width // 8))
  latents = latents.to(device)
  scheduler.set_timesteps(num_inference_steps)
  latents = latents * scheduler.sigmas[0]
  with autocast('cuda'):
    for i, t in tqdm(enumerate(scheduler.timesteps[start_step:])):
      latent_model_input = torch.cat([latents] * 2)
      sigma = scheduler.sigmas[i]
      latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
      with torch.no_grad():
        noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
      noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
      noise_pred = noise_pred_uncond*(1-guidance_scale) + noise_pred_text*guidance_scale
      latents = scheduler.step(noise_pred, i, latents)['prev_sample']
  return latents

def decode_img_latents(latents):
  latents = 1 / 0.18215 * latents
  with torch.no_grad():
    imgs = vae.decode(latents)
  imgs = (imgs / 2 + 0.5).clamp(0, 1)
  imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
  imgs = (imgs * 255).round().astype('uint8')
  pil_images = [Image.fromarray(image) for image in imgs]
  return pil_images
gregturk commented 2 years ago

You should be interpolating the text embeddings, not the latent space versions of the images. The latent space of an image is a small image (size 64 x 64 x 4). The text embedding space is 77 x 768, and encodes text semantics instead of pixels.

Most people recommend using spherical interpolation (slerp), but just regular linear interpolation seems to give fairly reasonable images.