Closed romanngg closed 4 years ago
You need the special preamble that switches to the new Cloud TPU stack that jax uses (we'd eventually like Cloud TPUs to work out of the box, but for now this boilerplate is necessary). See the first cell of any our example cloud TPU notebooks, e.g. https://github.com/google/jax/blob/master/cloud_tpu_colabs/NeurIPS_2019_JAX_demo.ipynb. You can probably bump the jax and jaxlib versions on that first pip install
line too, please report if that somehow breaks things!
Thanks for the explanation Skye!
Example, when selecting Python 3 TPU runtime https://colab.research.google.com/gist/romanngg/e02e7a08514d13d57be8fc48bd76dd83/no_tpu.ipynb
(Works for a GPU runtime though)