Using the example Colab, running on a TPU instance, user get the following error running the second code chunk:
Attempted Fixes
Updating the jax version to solve the issue
!pip install -U jax jaxlib
This now leads to a new error when running the second code chunk:
Suggested Fix
Delete the following lines from the second code chunk of the Google Colab example:
if "COLAB_TPU_ADDR" in os.environ:
from jax.tools import colab_tpu
colab_tpu.setup_tpu()
This will avoid the issue with jax and TPU conflicts. This reduces functionality of the notebook by removing TPUs from consideration, however it will reduce user friction when using the notebook by removing the jax/TPU issues.
Another notebook demonstrating full TPU functionality can then be developed once a proper fix to the jax issue is found, however from the error message I encountered it looks like that might be difficult in the Colab ecosystem.
Context
Using the example Colab, running on a TPU instance, user get the following error running the second code chunk:
Attempted Fixes
Updating the jax version to solve the issue
This now leads to a new error when running the second code chunk:
Suggested Fix
Delete the following lines from the second code chunk of the Google Colab example:
This will avoid the issue with jax and TPU conflicts. This reduces functionality of the notebook by removing TPUs from consideration, however it will reduce user friction when using the notebook by removing the jax/TPU issues.
Another notebook demonstrating full TPU functionality can then be developed once a proper fix to the jax issue is found, however from the error message I encountered it looks like that might be difficult in the Colab ecosystem.