kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.28k stars 892 forks source link

AttributeError: module 'jaxlib.pocketfft' has no attribute 'pocketfft' #233

Open umm-maybe opened 2 years ago

umm-maybe commented 2 years ago

Hello, I have followed the (very much appreciated) howto_finetune.md guide and, upon attempting to run the magic python device_train.py command, received the error noted above. The only Google search result that seems to mention something similar is this: https://bytemeta.vip/repo/deepmind/alphafold/issues/515

The answer to that question seems to imply it has to do with a version incompatibility between jax and jaxlib, but the solution they link to doesn't work here. Any tips or advice for working around this would be greatly appreciated!

Ontopic commented 2 years ago

From the top of my head; pip install jax==0.2.12 jaxlib==0.1.67 Can not try right now, but that version combination should work on TPU-VM.

Edit: I think it also has to do with what Python (3.7 on TPU v2 and Colab, 3.8 on v3) version you're running and what TPU-version / accelerator-type. I think I've seen jaxlib==0.1.68 in v2 setups, so also worth a shot.

dunstantom commented 2 years ago

I'm also using a TPU v2 setup and ran into this problem. I used the JAX TPU install instructions from their README and it worked for me.

sxiii commented 2 years ago

Now I'm also getting "AttributeError: module 'jax' has no attribute 'version'"... Or, also: AttributeError: module 'jaxlib.pocketfft' has no attribute 'pocketfft'. Tried couple of different colab notebooks... Doesn't work...

musabgultekin commented 1 year ago

I fixed it by doing this right after the install dependencies section:

!pip install jaxlib==0.1.67

And restart the runtime if it asks

Though it feels so fragile. Don't know why