madebyollin / taesd

Tiny AutoEncoder for Stable Diffusion
MIT License
495 stars 27 forks source link

Gray pixelated image output (SD3 encoder, taesd3 decoder) #20

Closed walsharry closed 1 month ago

walsharry commented 1 month ago

Hey madebyollin,

Great work here. Just a quick question, I want to use stable diffusion 3 to encode an image and then taesd3 to decode it. However, the image comes out slightly pixelated and grey. Any idea what I'm doing wrong?

Code:

test_image = Image.open(p).convert("RGB").resize((512, 512))
taesd = AutoencoderTiny.from_pretrained("madebyollin/taesd3", torch_dtype=torch.float16).to('cuda')
sd3 = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers",
                                               torch_dtype=torch.float16).vae.to('cuda')
image_raw = TF.to_tensor(test_image).unsqueeze(0).to(torch.float16).to('cuda')
image_enc = sd3.encoder(image_raw).latent_dist.mean
image_dec = taesd.decoder(image_enc).sample.clamp(0, 1)

print("input image", summarize_tensor(image_raw[0]))
display(TF.to_pil_image(image_raw[0]), 'sd3 input image')

print("latents", summarize_tensor(image_enc[0]))

print("decoded image", summarize_tensor(image_dec[0]))
display(TF.to_pil_image(image_dec[0]), 'taesd decoded image')

Input: input

Output:

output

Thanks,

Harry

madebyollin commented 1 month ago

Unlike TAESD, the SD3 VAE requires you to keep track of two annoying external parameters (scaling_factor and shift_factor) which must be applied manually in order to get correct results from the SD3 VAE encoder / decoder. Furthermore, both the SD VAEs and the diffusers AutoencoderTiny implementation expect image values to be in [-1, 1] range (not [0, 1]).

Here is corrected code for encoding an image with SD3 VAE and decoding with TAESD3:

from diffusers import AutoencoderTiny, AutoencoderKL
import torch
import torchvision.transforms.functional as TF
from PIL import Image

def summarize_tensor(x):
    return f"\033[34m{str(tuple(x.shape)).ljust(24)}\033[0m (\033[31mmin {x.min().item():+.4f}\033[0m / \033[32mmean {x.mean().item():+.4f}\033[0m / \033[33mmax {x.max().item():+.4f}\033[0m)"

taesd3 = AutoencoderTiny.from_pretrained(
    "madebyollin/taesd3", torch_dtype=torch.float16
).to("cuda")
sd3_vae = AutoencoderKL.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    torch_dtype=torch.float16,
    subfolder="vae",
).to("cuda")

def demo_sd3_encode_and_taesd_decode(image: Image, taesd3, sd3_vae):
    # load raw image onto the device with values in in [0, 1] range
    image_raw = TF.to_tensor(test_image).unsqueeze(0).to(torch.float16).to("cuda")
    # scale to the [-1, 1] value range expected by diffusers VAEs, encode with the SD3 VAE, and manually apply the SD3 VAE scaling factors
    image_enc = (
        sd3_vae.encode(image_raw.mul(2).sub_(1))
        .latent_dist.sample()
        .sub_(sd3_vae.config.shift_factor)
        .mul_(sd3_vae.config.scaling_factor)
    )
    # decode with TAESD3 and scale back to [0, 1] range
    image_dec = taesd3.decode(image_enc).sample.add_(1).div_(2).clamp_(0, 1)

    print("Raw input image", summarize_tensor(image_raw[0]))
    display(TF.to_pil_image(image_raw[0]))

    print("SD3-encoded latents", summarize_tensor(image_enc[0]))
    display(TF.to_pil_image(image_enc[0, :3].mul(0.3).add(0.5).clamp(0, 1)))

    print("TAESD3-decoded image", summarize_tensor(image_dec[0]))
    display(TF.to_pil_image(image_dec[0]))

# !wget -nc -q https://lafeber.com/pet-birds/wp-content/uploads/2018/06/Scarlet-Macaw-2.jpg -O test_image.jpg
test_image = TF.center_crop(
    TF.resize(Image.open("test_image.jpg").convert("RGB"), 512), 512
)
demo_sd3_encode_and_taesd_decode(test_image, taesd3, sd3_vae)

Here is the result of the corrected code:

image
walsharry commented 1 month ago

Legend, thanks for the quick response