instadeepai / nucleotide-transformer

🧬 Nucleotide Transformer: Building and Evaluating Robust Foundation Models for Human Genomics
https://www.biorxiv.org/content/10.1101/2023.01.11.523679v2
Other
445 stars 52 forks source link

Bug: jax version issues on TPU Colab #36

Open hlydecker opened 8 months ago

hlydecker commented 8 months ago

Context

Using the example Colab, running on a TPU instance, user get the following error running the second code chunk:

image

Attempted Fixes

Updating the jax version to solve the issue

!pip install -U jax jaxlib

This now leads to a new error when running the second code chunk:

image

Suggested Fix

Delete the following lines from the second code chunk of the Google Colab example:

if "COLAB_TPU_ADDR" in os.environ:
    from jax.tools import colab_tpu

    colab_tpu.setup_tpu()

This will avoid the issue with jax and TPU conflicts. This reduces functionality of the notebook by removing TPUs from consideration, however it will reduce user friction when using the notebook by removing the jax/TPU issues.

Another notebook demonstrating full TPU functionality can then be developed once a proper fix to the jax issue is found, however from the error message I encountered it looks like that might be difficult in the Colab ecosystem.