Closed neilobremski closed 2 years ago
I had also two runs that ended up with a different message but the same ending:
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float32, float16.
It should work if you pip install flax==0.4.2
. I need to address what is causing the dtype mismatch in the latest flax version
I posted a fix that worked for me in this closed issue
@kuprel I suspect this might have something to do with the default dtype change that was implemented in v0.5.0 of flax - tracking down exactly how to fix that is beyond me. In the meantime, rolling back to 0.4.2 works as you suggested. 👍
Ok, it should work with the latest flax version now
I tried running the following in the Google Colab:
This caused an exception:
The same thing happened when I tried running the command-line locally:
NOTE: I had to add the following line to the Setup block of the Jupyter code: