madebyollin / taesd

Tiny AutoEncoder for Stable Diffusion
MIT License
495 stars 27 forks source link

Using taesd encode with original SD decode #12

Closed jon-chuang closed 7 months ago

jon-chuang commented 7 months ago

How can I achieve this?

madebyollin commented 7 months ago

Any pipeline that can decode SD-generated latents should also decode TAESD-encoded latents.

Here's an example showing how to decode TAESD-encoded latents with diffusers' AutoencoderKL implementation.

from diffusers import AutoencoderKL, AutoencoderTiny
from PIL import Image

import torch
import torchvision.transforms.functional as TF

taesd = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=torch.float16).to("cuda")
sdvae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema", torch_dtype=torch.float16).to("cuda")

!wget -q https://upload.wikimedia.org/wikipedia/commons/9/9c/Crepe_with_LaFrance_and_strawberries_and_fresh_cream_in_it.jpg -O test_image.jpg
test_image = TF.center_crop(TF.resize(Image.open("test_image.jpg").convert("RGB"), 512), 512)
display(test_image)

def summarize_tensor(x):
    return f"\033[34m{str(tuple(x.shape)).ljust(24)}\033[0m (\033[31mmin {x.min().item():+.4f}\033[0m / \033[32mmean {x.mean().item():+.4f}\033[0m / \033[33mmax {x.max().item():+.4f}\033[0m)"

def latent_to_visualization(taesd, latent):
    latent = taesd.scale_latents(latent)
    return torch.cat([latent[:3], latent[3:].expand(3, *latent.shape[-2:])], -2)

@torch.no_grad()
def demo_taesd_encoding(taesd, sdvae, image):
    image_raw = TF.to_tensor(image).unsqueeze(0).cuda().half().mul(2).sub(1)
    image_enc = taesd.encoder(image_raw)
    image_dec = taesd.decoder(image_enc)
    image_dec_sd = sdvae.decode(image_enc / sdvae.config.scaling_factor)[0]

    print("input image", summarize_tensor(image_raw[0]))
    display(TF.to_pil_image(image_raw[0].mul(0.5).add(0.5)))

    print("latents", summarize_tensor(image_enc[0]))
    print("(these latents are the same size / scale as SD UNet-generated latents - no extra scale_factor is needed)")
    display(TF.to_pil_image(latent_to_visualization(taesd, image_enc[0])))

    print("decoded images (TAESD / SD-VAE)")
    display(TF.to_pil_image(torch.cat([image_dec[0], image_dec_sd[0]], -1).mul(0.5).add(0.5).clamp(0, 1)))

demo_taesd_encoding(taesd, sdvae, test_image)
image
jon-chuang commented 7 months ago

Thank you I managed to get it to work