Closed walsharry closed 1 month ago
Unlike TAESD, the SD3 VAE requires you to keep track of two annoying external parameters (scaling_factor
and shift_factor
) which must be applied manually in order to get correct results from the SD3 VAE encoder / decoder. Furthermore, both the SD VAEs and the diffusers AutoencoderTiny implementation expect image values to be in [-1, 1] range (not [0, 1]).
Here is corrected code for encoding an image with SD3 VAE and decoding with TAESD3:
from diffusers import AutoencoderTiny, AutoencoderKL
import torch
import torchvision.transforms.functional as TF
from PIL import Image
def summarize_tensor(x):
return f"\033[34m{str(tuple(x.shape)).ljust(24)}\033[0m (\033[31mmin {x.min().item():+.4f}\033[0m / \033[32mmean {x.mean().item():+.4f}\033[0m / \033[33mmax {x.max().item():+.4f}\033[0m)"
taesd3 = AutoencoderTiny.from_pretrained(
"madebyollin/taesd3", torch_dtype=torch.float16
).to("cuda")
sd3_vae = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
torch_dtype=torch.float16,
subfolder="vae",
).to("cuda")
def demo_sd3_encode_and_taesd_decode(image: Image, taesd3, sd3_vae):
# load raw image onto the device with values in in [0, 1] range
image_raw = TF.to_tensor(test_image).unsqueeze(0).to(torch.float16).to("cuda")
# scale to the [-1, 1] value range expected by diffusers VAEs, encode with the SD3 VAE, and manually apply the SD3 VAE scaling factors
image_enc = (
sd3_vae.encode(image_raw.mul(2).sub_(1))
.latent_dist.sample()
.sub_(sd3_vae.config.shift_factor)
.mul_(sd3_vae.config.scaling_factor)
)
# decode with TAESD3 and scale back to [0, 1] range
image_dec = taesd3.decode(image_enc).sample.add_(1).div_(2).clamp_(0, 1)
print("Raw input image", summarize_tensor(image_raw[0]))
display(TF.to_pil_image(image_raw[0]))
print("SD3-encoded latents", summarize_tensor(image_enc[0]))
display(TF.to_pil_image(image_enc[0, :3].mul(0.3).add(0.5).clamp(0, 1)))
print("TAESD3-decoded image", summarize_tensor(image_dec[0]))
display(TF.to_pil_image(image_dec[0]))
# !wget -nc -q https://lafeber.com/pet-birds/wp-content/uploads/2018/06/Scarlet-Macaw-2.jpg -O test_image.jpg
test_image = TF.center_crop(
TF.resize(Image.open("test_image.jpg").convert("RGB"), 512), 512
)
demo_sd3_encode_and_taesd_decode(test_image, taesd3, sd3_vae)
Here is the result of the corrected code:
Legend, thanks for the quick response
Hey madebyollin,
Great work here. Just a quick question, I want to use stable diffusion 3 to encode an image and then taesd3 to decode it. However, the image comes out slightly pixelated and grey. Any idea what I'm doing wrong?
Code:
Input:![input](https://github.com/madebyollin/taesd/assets/50142465/6a8cf458-4eac-46a8-a92e-d9517db7d1d0)
Output:
Thanks,
Harry