Open umm-maybe opened 2 years ago
From the top of my head; pip install jax==0.2.12 jaxlib==0.1.67
Can not try right now, but that version combination should work on TPU-VM.
Edit: I think it also has to do with what Python (3.7 on TPU v2 and Colab, 3.8 on v3) version you're running and what TPU-version / accelerator-type. I think I've seen jaxlib==0.1.68
in v2 setups, so also worth a shot.
I'm also using a TPU v2 setup and ran into this problem. I used the JAX TPU install instructions from their README and it worked for me.
Now I'm also getting "AttributeError: module 'jax' has no attribute 'version'"... Or, also:
AttributeError: module 'jaxlib.pocketfft' has no attribute 'pocketfft'
. Tried couple of different colab notebooks... Doesn't work...
I fixed it by doing this right after the install dependencies section:
!pip install jaxlib==0.1.67
And restart the runtime if it asks
Though it feels so fragile. Don't know why
Hello, I have followed the (very much appreciated) howto_finetune.md guide and, upon attempting to run the magic python device_train.py command, received the error noted above. The only Google search result that seems to mention something similar is this: https://bytemeta.vip/repo/deepmind/alphafold/issues/515
The answer to that question seems to imply it has to do with a version incompatibility between jax and jaxlib, but the solution they link to doesn't work here. Any tips or advice for working around this would be greatly appreciated!