madebyollin / taesd

Tiny AutoEncoder for Stable Diffusion
MIT License
494 stars 27 forks source link

Image quality (SD2.1 Encoder, taesd Decoder ) #24

Closed walsharry closed 1 week ago

walsharry commented 1 week ago

Hey madebyollin,

Back again. Another quick question, I want to use stable diffusion 2.1 to encode an image and then taesd to decode it. However, the image quality is much lower compared to using the SD2.1 decoder. Is this expected or am I doing something wrong?

    from diffusers import AutoencoderTiny, AutoencoderKL
    import torch
    import torchvision.transforms.functional as TF
    from PIL import Image

    def summarize_tensor(x):
        return f"\033[34m{str(tuple(x.shape)).ljust(24)}\033[0m (\033[31mmin {x.min().item():+.4f}\033[0m / \033[32mmean {x.mean().item():+.4f}\033[0m / \033[33mmax {x.max().item():+.4f}\033[0m)"

    taesd = AutoencoderTiny.from_pretrained(
        "madebyollin/taesd", torch_dtype=torch.float16
    sd_vae = AutoencoderKL.from_pretrained(

    def display(img):

    def demo_sd3_encode_and_taesd_decode(image: Image, taesd, sd_vae):
        image_raw = TF.to_tensor(image).unsqueeze(0).to(torch.float16).to("cuda")
        image_enc = (
        # decode with TAESD3 and scale back to [0, 1] range
        image_dec = taesd.decode(image_enc).sample.add_(1).div_(2).clamp_(0, 1)

        print("Raw input image", summarize_tensor(image_raw[0]))

        print("SD2.1-encoded latents", summarize_tensor(image_enc[0]))
        display(TF.to_pil_image(image_enc[0, :3])).mul(0.3).add(0.5).clamp(0, 1)

        print("TAESD-decoded image", summarize_tensor(image_dec[0]))

    test_image = TF.center_crop(
                  512), 512
    demo_sd3_encode_and_taesd_decode(test_image, taesd, sd_vae)

Input image bird_input

SD2.1 decoded bird_sd2 1

Teasa decoded bird_teasa

Thanks again for the help, very much appreciated,


madebyollin commented 1 week ago

@walsharry Sadly, I think you're using it correctly! TAESD currently isn't as high-quality as the full-size VAE. I think TAESDXL and TAESD3 generally look better than the base TAESD (the SDXL and SD3 latent spaces are easier to decode), but if you want max quality you should always use the full-size VAEs.

image image image
walsharry commented 1 week ago

@madebyollin thanks for the quick response. I am currently trying to use the TEASD decoder during training so that I can add an Image loss to a model. The newer TEASD3 has twice the embedding dimensions and is too computationally expensive. Do you believe it is possible to train an image generation network with the smaller TEASD decoder and then at inference swap to the original?

Also is there any documentation/papers about how you create the teasd models?



madebyollin commented 1 week ago

Do you believe it is possible to train an image generation network with the smaller TEASD decoder and then at inference swap to the original?

Yeah that should definitely work ( does this iirc). You can also always fine-tune your model with the full-size VAE at the end.

is there any documentation/papers about how you create the teasd models?

There's no paper or full report. I did post some example code here that shows how to do basic TAESDXL training with pure adversarial loss. The released TAESD checkpoints used slightly more complicated training recipes (with some augmentations, auxiliary regression losses, etc.) but the core structure is the same. I've answered questions about training here and here among other places.

walsharry commented 1 week ago

Thanks again for the help! Really appreciated