I'm running a TPU v3-8 VM on Google. On the VM I installed jax with pip install "jax[tpu]==0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html.
Unfortunately, I'm getting the message No GPU/TPU found, falling back to CPU. when issuing jax.device_count(). The same holds for pip install jax==0.2.12. Only when I'm using pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html (newest jax version), it works. As far as I can see, for fine-tuning we need jax version 0.2.12 or 0.2.16.
Hello
I'm running a TPU v3-8 VM on Google. On the VM I installed jax with
pip install "jax[tpu]==0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
.Unfortunately, I'm getting the message
No GPU/TPU found, falling back to CPU.
when issuingjax.device_count()
. The same holds forpip install jax==0.2.12
. Only when I'm usingpip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
(newest jax version), it works. As far as I can see, for fine-tuning we need jax version 0.2.12 or 0.2.16.How can I get it running with these versions?