Open entrpn opened 3 months ago
Hi @ahmadki I'm trying to reproduce the Stable diffusion training results.
I noticed when I decode the moments back to images using the VAE's decoder, I'm getting whited out images. See:
I noticed in the images2latents.py, the images are not normalized:
https://github.com/mlcommons/training/blob/master/stable_diffusion/webdataset_images2latents.py#L86
If I add normalization as follows:
transforms = transforms.Compose( [ transforms.ToTensor(), transforms.Resize(size=512, interpolation=transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(size=512), transforms.Normalize([0.5], [0.5]) ] )
The whiteout goes away:
Is there a reason why images were not normalized and how does this affect training of the unet?
Code to reproduce using HuggingFace diffusers:
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler import torch from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution import numpy as np from PIL import Image model_id = "stabilityai/stable-diffusion-2-base" # Use the Euler scheduler here instead scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16) pipe = pipe.to("cuda") moments = torch.from_numpy(np.load("000009999.npy")).type(torch.float16) moments = moments.to("cuda") latents = DiagonalGaussianDistribution(moments).sample() latents = latents * pipe.vae.config.scaling_factor latents = 1 / pipe.vae.config.scaling_factor * latents image = pipe.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3 ,1).detach().float().numpy() image = (image * 255).round().astype("uint8") image = Image.fromarray(image[0]) image.save("test.png")
@ahmadki can you take a look this?
Hi @ahmadki I'm trying to reproduce the Stable diffusion training results.
I noticed when I decode the moments back to images using the VAE's decoder, I'm getting whited out images. See:
I noticed in the images2latents.py, the images are not normalized:
https://github.com/mlcommons/training/blob/master/stable_diffusion/webdataset_images2latents.py#L86
If I add normalization as follows:
The whiteout goes away:
Is there a reason why images were not normalized and how does this affect training of the unet?
Code to reproduce using HuggingFace diffusers: