keras-team / keras-cv

Industry-strength Computer Vision workflows with Keras
Other
1.01k stars 330 forks source link

BUG: StableDiffusion.inpaint() throws float tensor error #1019

Closed BFauber closed 1 year ago

BFauber commented 1 year ago

I am using today's version (15Nov2022) of https://github.com/keras-team/keras-cv/ and I receive the following error when using StableDiffusion.inpaint():

InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a half tensor but is a float tensor [Op:AddV2]

Is anyone also receiving this error when running StableDiffusion.inpaint()?

Here is my code:

from tensorflow import keras
from keras_cv.models import StableDiffusion

keras.mixed_precision.set_global_policy("mixed_float16")

model = StableDiffusion(
    img_height=512, 
    img_width=512, 
    jit_compile=True
)

prompt = "photo of coffee beans"
seed = 52983

# NOTE: img1 was successfully generated using StableDiffusion.text_to_image() with the same prompt and seed values as shown above.
type(img.1) # numpy.ndarray (dtype=int8)
img1.shape # (1, 512, 512, 3)

type(mask) # numpy.ndarray (dtype=int8)
mask.shape # (1, 512, 512)

img2 = model.inpaint(
        prompt=prompt,
        image=img1,
        mask=mask,
        num_resamples=1,
        batch_size=1,
        num_steps=25,
        unconditional_guidance_scale=7.5,
        diffusion_noise=None,
        seed=seed,
        verbose=True,
        )
bhack commented 1 year ago

The problem is that internally we convert img1 to float32:

https://github.com/keras-team/keras-cv/blob/fdb40d0d5c59773bd6607af24145ce1f37f3a7c9/keras_cv/models/stable_diffusion/stable_diffusion.py#L325

So when we are going to call op between image and mask the types are going to mismatch like at: https://github.com/keras-team/keras-cv/blob/fdb40d0d5c59773bd6607af24145ce1f37f3a7c9/keras_cv/models/stable_diffusion/stable_diffusion.py#L385

It seems also that we don't have inpaint test coverage: https://github.com/keras-team/keras-cv/blob/master/keras_cv/models/stable_diffusion/stable_diffusion_test.py

Currently StableDiffusion, probably for its large scale nature, is still a little bit "an outlier" in the library.

bhack commented 1 year ago

/cc @bobqywei

bobqywei commented 1 year ago

Thanks for finding this and letting me know @BFauber @bhack! I think i'll allow the user to specify the dtype for these internal ops so that they have control over the tradeoffs?

BFauber commented 1 year ago

Thanks for your suggestion @bobqywei.

I've found the source of my error is the mixed_float16 keras global policy:

keras.mixed_precision.set_global_policy("mixed_float16")

If I omit this global policy line, the inpainting code runs without an error, but it only outputs a black box.

Thus, the StableDiffusion.inpaint() function is not compatible with keras.mixed_precision.set_global_policy("mixed_float16") https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/ and it is not properly inpainting.

ianstenbit commented 1 year ago

It seems like at the minimum, AIs here are to:

@BFauber can you please share a full repro script of your issue (either a Colab or a gist is good). I will take a look if you @ me on this thread.

BFauber commented 1 year ago

Thanks @ianstenbit! Here is a link to the gist: https://gist.github.com/BFauber/92a275cf29d2806353c02f9666f43887

Thanks for looking into the usage of inpaint, and please let me know if you have any additional questions.

ianstenbit commented 1 year ago

Doing some digging. Here's what I've found so far: The output image is all zeroes because the latent at the end of the inpainting flow is all NaNs.

While debugging, I've found that the latents consistently become NaN at the 15th step, regardless of the seed. If I set num_resamples=2, then the NaNs show up on the second half of the 7th step, so it seems like we're consistently getting NaNs on the ~30th call to the diffusion model. Mysterious

Still not quite sure what the root cause is, but I am working on it.

ianstenbit commented 1 year ago

I've also noticed that on a step-by-step basis, latents is becoming more and more uniform prior to becoming NaNs. The first latent looks normally distributed, but after a while end up looking like this:

[[ 0.01294876, -0.01969388, -0.02878554, -0.02118983],
         [-0.01029961, -0.01716631, -0.03993208, -0.01427874],
         [-0.01029961, -0.01716631, -0.03993208, -0.01427874],
         ...,
         [-0.01029961, -0.01716631, -0.03993208, -0.01427874],
         [-0.01029961, -0.01716631, -0.03993208, -0.01427874],
         [-0.01993897, -0.00069592, -0.03201018, -0.01465789]],
ianstenbit commented 1 year ago

Aha -- turns out the issue is actually just that your mask uses 0 and 255 instead of 0 and 1 😄 @BFauber if you divide your mask by 255 before calling inpaint this will work correctly.

I will add a note to the docstring stating that mixed precision is not currently supported for inpainting.