kuprel / min-dalle

min(DALL·E) is a fast, minimal port of DALL·E Mini to PyTorch
MIT License
3.48k stars 256 forks source link

--mega / is_mega raises TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16. #15

Closed neilobremski closed 2 years ago

neilobremski commented 2 years ago

I tried running the following in the Google Colab:

image = generate_image_from_text("court sketch of godzilla on trial", is_mega=True, seed=100)

This caused an exception:

parsing metadata from ./pretrained/dalle_bart_mega
tokenizing text
['Ġcourt']
['Ġsketch']
['Ġof']
['Ġgodzilla']
['Ġon']
['Ġtrial']
text tokens [0, 2634, 4189, 111, 14450, 133, 5167, 2]
loading flax encoder
encoding text tokens
loading flax decoder
sampling image tokens
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-5-53d46ed9885c>](https://localhost:8080/#) in <module>()
      2 
----> 3 image = generate_image_from_text("court sketch of godzilla on trial", is_mega=True, seed=100)
      4 display(image)

67 frames
UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
[/content/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py](https://localhost:8080/#) in __call__(self, decoder_state, keys_state, values_state, attention_mask, state_index)
     38             keys_state,
     39             self.k_proj(decoder_state).reshape(shape_split),
---> 40             state_index
     41         )
     42         values_state = lax.dynamic_update_slice(

TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

The same thing happened when I tried running the command-line locally:

python3 image_from_text.py --text='court sketch of godzilla on trial' --mega --seed=100

NOTE: I had to add the following line to the Setup block of the Jupyter code:

! wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14
diapason-consulting commented 2 years ago

I had also two runs that ended up with a different message but the same ending: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.

kuprel commented 2 years ago

It should work if you pip install flax==0.4.2. I need to address what is causing the dtype mismatch in the latest flax version

warmlogic commented 2 years ago

I posted a fix that worked for me in this closed issue

ummjackson commented 2 years ago

@kuprel I suspect this might have something to do with the default dtype change that was implemented in v0.5.0 of flax - tracking down exactly how to fix that is beyond me. In the meantime, rolling back to 0.4.2 works as you suggested. 👍

kuprel commented 2 years ago

Ok, it should work with the latest flax version now