keras-team / keras-cv

Industry-strength Computer Vision workflows with Keras
Other
1.01k stars 330 forks source link

InvalidArgumentError occurred when I implemeted image generation using mixed float16 #2102

Closed y-vectorfield closed 11 months ago

y-vectorfield commented 1 year ago

Current Behavior:

InvalidArgumentError occurred when I implemented image generation using mixed float16. Environment: https://www.tensorflow.org/tutorials/generative/generate_images_with_stable_diffusion

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
[<ipython-input-13-80d21f9296ec>](https://localhost:8080/#) in <cell line: 2>()
      1 # Warm up model to run graph tracing before benchmarking.
----> 2 model.text_to_image("warming up the model", batch_size=3)
      3 
      4 start = time.time()
      5 images = model.text_to_image(

3 frames
[/usr/local/lib/python3.10/dist-packages/keras_cv/src/models/stable_diffusion/stable_diffusion.py](https://localhost:8080/#) in text_to_image(self, prompt, negative_prompt, batch_size, num_steps, unconditional_guidance_scale, seed)
     81         encoded_text = self.encode_text(prompt)
     82 
---> 83         return self.generate_image(
     84             encoded_text,
     85             negative_prompt=negative_prompt,

[/usr/local/lib/python3.10/dist-packages/keras_cv/src/models/stable_diffusion/stable_diffusion.py](https://localhost:8080/#) in generate_image(self, encoded_text, negative_prompt, batch_size, num_steps, unconditional_guidance_scale, diffusion_noise, seed)
    236             )
    237             a_t, a_prev = alphas[index], alphas_prev[index]
--> 238             pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent) / math.sqrt(
    239                 a_t
    240             )

[/usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
    151     except Exception as e:
    152       filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153       raise e.with_traceback(filtered_tb) from None
    154     finally:
    155       del filtered_tb

[/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/ops.py](https://localhost:8080/#) in raise_from_not_ok_status(e, name)
   5886 def raise_from_not_ok_status(e, name) -> NoReturn:
   5887   e.message += (" name: " + str(name if name is not None else ""))
-> 5888   raise core._status_to_exception(e) from None  # pylint: disable=protected-access
   5889 
   5890 

InvalidArgumentError: cannot compute Sub as input #1(zero-based) was expected to be a float tensor but is a half tensor [Op:Sub] name:

Expected Behavior:

I want the image to be generated similarly when using floa32.

Steps To Reproduce:

The following code cell ended with an error.

# Warm up model to run graph tracing before benchmarking.
model.text_to_image("warming up the model", batch_size=3)

start = time.time()
images = model.text_to_image(
    "a cute magical flying dog, fantasy art, "
    "golden color, high quality, highly detailed, elegant, sharp focus, "
    "concept art, character concepts, digital painting, mystery, adventure",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["Mixed Precision", end - start])
plot_images(images)

print(f"Mixed precision model: {(end - start):.2f} seconds")
keras.backend.clear_session()
image

Version:

0.6.4

Details:

Python 3.10.12 keras 2.14.0 keras-core 0.1.7 keras-cv 0.6.4 pytensor 2.14.2 tensorboard 2.14.1 tensorboard-data-server 0.7.1 tensorflow 2.14.0 tensorflow-datasets 4.9.3 tensorflow-estimator 2.14.0 tensorflow-gcs-config 2.13.0 tensorflow-hub 0.15.0 tensorflow-io-gcs-filesystem 0.34.0 tensorflow-metadata 1.14.0 tensorflow-probability 0.20.1 tensorstore 0.1.45

ianstenbit commented 1 year ago

Is this bug also appearing when using Keras Core with the TF backend? It may be a simple fix but I'd like to verify that it's affecting multi-backend Keras and not just Keras 2

y-vectorfield commented 1 year ago

@ianstenbit Yes, I used this with the TF backend.

y-vectorfield commented 12 months ago

@ianstenbit , @jbischof This error was occured when I applied the following commit.

https://github.com/keras-team/keras-cv/commit/9c18f5649b330b109d1a03f26e3b28546438036f