mlcommons / training

Reference implementations of MLPerf™ training benchmarks
https://mlcommons.org/en/groups/training
Apache License 2.0
1.57k stars 548 forks source link

[Stable Diffusion] VAE Moments to image outputs whited out image. #721

Open entrpn opened 3 months ago

entrpn commented 3 months ago

Hi @ahmadki I'm trying to reproduce the Stable diffusion training results.

I noticed when I decode the moments back to images using the VAE's decoder, I'm getting whited out images. See:

from_latents

I noticed in the images2latents.py, the images are not normalized:

https://github.com/mlcommons/training/blob/master/stable_diffusion/webdataset_images2latents.py#L86

If I add normalization as follows:

transforms = transforms.Compose(
  [
    transforms.ToTensor(),
    transforms.Resize(size=512, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(size=512),
    transforms.Normalize([0.5], [0.5])
  ]
)

The whiteout goes away:

encoded_decoded

Is there a reason why images were not normalized and how does this affect training of the unet?

Code to reproduce using HuggingFace diffusers:

from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import torch
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
import numpy as np
from PIL import Image
model_id = "stabilityai/stable-diffusion-2-base"

# Use the Euler scheduler here instead
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

moments = torch.from_numpy(np.load("000009999.npy")).type(torch.float16)
moments = moments.to("cuda")
latents = DiagonalGaussianDistribution(moments).sample()
latents = latents * pipe.vae.config.scaling_factor

latents = 1 / pipe.vae.config.scaling_factor * latents
image = pipe.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3 ,1).detach().float().numpy()
image = (image * 255).round().astype("uint8")
image = Image.fromarray(image[0])
image.save("test.png")
hiwotadese commented 2 weeks ago

@ahmadki can you take a look this?