keras-team / keras-cv

Industry-strength Computer Vision workflows with Keras
Other
1.01k stars 330 forks source link

Buggy stable diffusion VQVAE image encoder/decoder? #1655

Closed rex-yue-wu closed 1 year ago

rex-yue-wu commented 1 year ago

Hey folks,

I encountered something strange in keras-cv-0.4.2. Long story short, I believe there are something wrong with the provided stable diffusion model's VQVAE image encoder/decoder or related pre/post-processing -- the reconstructed image doesn't align with the original (where the reconstructed one has random noise in the first 8 rows and 8 cols).

image

Below is the code block that I used to generate the provided image, and a colab notebook can be found here

resized = ... # load a RGB image from disk and resize it to a size of 8x
RESOLUTION = 256 # a smaller size makes the issue more prominent 

# load stable diffusion's vae encoder and decoder
from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder
from keras_cv.models.stable_diffusion.decoder import Decoder

Img2Latents = ImageEncoder(RESOLUTION, RESOLUTION)
Latents2Img = Decoder(RESOLUTION, RESOLUTION)

import numpy as np
# encode test image to latents and reconstruct
# 1. normalize (0, 255) uint8 image to (-1, +1) float32 image
inp = tf.constant( np.float32(resized)[None], 'float32') / 127.5 - 1 
# 2. get latents
latents = Img2Latents( inp, training=False )
# 3. I think this step is needed as it was in many keras diffusion examples
# but it turns out that the reconstructed image looks wrong if not comment it
# latents = latents * 0.18215 
# 4. decode latents and reconstruct image
rec = Latents2Img( latents, training=False )
# 5. postprocess
rec = rec[0].numpy()
rec = np.clip(rec * 0.5 + 0.5, 0, 1)

# compute difference and show
diff = np.abs(rec - np.float32(resized)/255.).mean(axis=-1)

plt.figure(figsize=(15,5))
plt.subplot(131)
plt.imshow(resized)
plt.title('input')
plt.subplot(132)
plt.imshow(rec)
plt.title('reconstructed')
plt.subplot(133)
plt.imshow(diff)
plt.title('difference')

In addition, I would appreciate if anyone can explain to me why the step#3 is not needed, i.e. why we don't need to multiply the magic number 0.18215 before feeding the latents into the decoder, whose input is supposed to be a latent of range (-1, 1). I manually checked and confirmed that the provided decoder has a rescaling layer with a scale parameter of 5.4899, which is the reciprocal of 0.18215.

rex-yue-wu commented 1 year ago

Just noticed that such noisy pixel borders also exist in the released keras examples, e.g.,

https://keras.io/examples/generative/finetune_stable_diffusion/ image

LukeWood commented 1 year ago

I have been trying very hard to figure out this bug... it is super tricky. I'll do some work on it now.

rex-yue-wu commented 1 year ago

Thank you very much. Yes, I also notice that the same set of models sometimes does generate outputs without noisy borders.

rex-yue-wu commented 1 year ago

Regarding the question step#3, I can answer it myself. Yes, we need it, but we need to change some code in Step#2, as it is not the correct way of getting an image's latent embedding.

In short, I misunderstood the VQVAE model's architecture. The VQVAE's output should not be directly used as the input of a VQVAE decoder, because it is actually a noisy version. The right way of getting VQVAE's embedding is not intuitive for someone has little domain knowledge about VQVAE, and it is only available in some tutorial's source code (one may easily omit it).

The right way of getting an image's latent embedding is,

  1. we need to get the output of the 2nd to last layer of VQVAE's encoder
  2. we need to further break down this output into two halves, where the first half is the right embedding and the second half should be ignored as it is the logvar of the noise.
# get a new model that outputs the 2nd to last layer of original VQVAE image encoder
vae=tf.keras.Model(
        image_encoder.input,
        image_encoder.layers[-2].output,
    )
# get the output of the 2nd to the last layer
outputs=image_encoder(image)
# further break it down into two halves
mean, logvar = tf.split(outputs, 2, axis=-1)
# we should use mean and ignore logvar for VQVAE decoder
LukeWood commented 1 year ago

Hey @rex-yue-wu - I figured it out! Opened a PR to fix in #1693