Closed yueyin85 closed 1 year ago
facing same issue. Does it have to do with this JAX commit? https://github.com/google/jax/commit/009af386972b7d71f724220be0f60d0f00e6d6be
Running the following lines after the setup block got me past the error: !pip install git+https://github.com/google-research/t5x.git@2b010160e7fe8a4505a6d1032a7b737a633636e5 !pip install git+https://github.com/google/jax.git@47df8628a0fa83900e38431c88a7a0e27660b7aa
Not sure if that first line is necessary and a later jax commit might work as well
Seems like t5x started using a feature introduced in jax 0.4.9 without updating setup.py to require jax 0.4.9. Until it's fixed there, you can use @MalcolmMashig's fix or do pip install jax==0.4.9
.
Should be fixed with https://github.com/google-research/t5x/pull/1278
When running Imports and Definitions in Google Collab, you will get No module named 'jax.experiential.array_ The serialization prompt indicates that I reinstalled jax=0.2.19, but the issue has not been improved.