magenta / mt3

MT3: Multi-Task Multitrack Music Transcription
Apache License 2.0
1.41k stars 185 forks source link

About No module named 'jax.experiential.array in Google Collab_ serialization #124

Closed yueyin85 closed 1 year ago

yueyin85 commented 1 year ago

When running Imports and Definitions in Google Collab, you will get No module named 'jax.experiential.array_ The serialization prompt indicates that I reinstalled jax=0.2.19, but the issue has not been improved.

MalcolmMashig commented 1 year ago

facing same issue. Does it have to do with this JAX commit? https://github.com/google/jax/commit/009af386972b7d71f724220be0f60d0f00e6d6be

Running the following lines after the setup block got me past the error: !pip install git+https://github.com/google-research/t5x.git@2b010160e7fe8a4505a6d1032a7b737a633636e5 !pip install git+https://github.com/google/jax.git@47df8628a0fa83900e38431c88a7a0e27660b7aa

Not sure if that first line is necessary and a later jax commit might work as well

iansimon commented 1 year ago

Seems like t5x started using a feature introduced in jax 0.4.9 without updating setup.py to require jax 0.4.9. Until it's fixed there, you can use @MalcolmMashig's fix or do pip install jax==0.4.9.

iansimon commented 1 year ago

Should be fixed with https://github.com/google-research/t5x/pull/1278