jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.54k stars 2.8k forks source link

JAX 0.1.58 doesn't see TPU in public Colab #2156

Closed romanngg closed 4 years ago

romanngg commented 4 years ago

Example, when selecting Python 3 TPU runtime https://colab.research.google.com/gist/romanngg/e02e7a08514d13d57be8fc48bd76dd83/no_tpu.ipynb

(Works for a GPU runtime though)

skye commented 4 years ago

You need the special preamble that switches to the new Cloud TPU stack that jax uses (we'd eventually like Cloud TPUs to work out of the box, but for now this boilerplate is necessary). See the first cell of any our example cloud TPU notebooks, e.g. https://github.com/google/jax/blob/master/cloud_tpu_colabs/NeurIPS_2019_JAX_demo.ipynb. You can probably bump the jax and jaxlib versions on that first pip install line too, please report if that somehow breaks things!

romanngg commented 4 years ago

Thanks for the explanation Skye!