Open RobertLiJN opened 1 year ago
I'm not sure, but I think the t5x requirements may be messed up and not installing libtpu, which is the low-level library required for jax and other frameworks to access the libtpu. Can you try manually installing the jax[tpu] setup, which includes the proper libtpu version, and see if that fixes the issue?
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
I also recommend using --version=tpu-vm-base
when creating a TPU VM for use with jax (or tpu-vm-v4-base
if creating a TPU v4 VM). The TF images come with a preinstalled libtpu version for the specified TF version, whereas the base images do not. I think jax ended up using the incorrect presintalled libtpu version in this case, which can lead to confusing errors like this one. (The Unable to initialize backend 'tpu_driver'
error is actually talking about the old TPU Node architecture, and doesn't mean jax isn't using the TPU on a TPU VM.)
I'm not sure, but I think the t5x requirements may be messed up and not installing libtpu, which is the low-level library required for jax and other frameworks to access the libtpu. Can you try manually installing the jax[tpu] setup, which includes the proper libtpu version, and see if that fixes the issue?
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
I also recommend using
--version=tpu-vm-base
when creating a TPU VM for use with jax (ortpu-vm-v4-base
if creating a TPU v4 VM). The TF images come with a preinstalled libtpu version for the specified TF version, whereas the base images do not. I think jax ended up using the incorrect presintalled libtpu version in this case, which can lead to confusing errors like this one. (TheUnable to initialize backend 'tpu_driver'
error is actually talking about the old TPU Node architecture, and doesn't mean jax isn't using the TPU on a TPU VM.)
Hi Skye,
This command of manually installing jax[tpu] fixes the issue, even on a TensorFlow VM! Looks like the command in this T5X repo for installing jax[tpu] is problematic here for some reason. Thank you so much for your help!
Awesome!
even on a TensorFlow VM
Yup this makes sense, since jax will always use the pip-installed libtpu if available. I just recommend using the base image because in cases like this where the pip-installed libtpu isn't present for some reason, it can make it a bit easier to debug since it'll fall back to CPU, instead of crashing in a weird way.
I'm gonna leave this issue open until we fix the underlying install issue, since other people could easily hit this. It looks like t5x[tpu]
pulls in jax[tpu]
here:
https://github.com/google-research/t5x/blob/2a62e14fd2806a28c8b24c7674fdd5423aa95e3d/setup.py#L72
I don't understand why this is only pulling in jaxlib and not libtpu. Here's the jax[tpu]
definition:
https://github.com/google/jax/blob/fc04c71d9342186b1ec51fcdb0a13fe1c6fcd5e2/setup.py#L84-L87
I can't dig into this right now, but I wonder if we're hitting some strange pip edge case around custom indices (which is how it locates the libtpu package).
Oh I can't reopen it. @RobertLiJN if you're able to reopen please do so
Hi Skye, I have reopened it. Thanks again!
Hi,
I have been trying to run the wmt demo on TPUv2 or TPUv3 VMs, but I keep encountering a
bac_alloc
error before training even starts. It seems that the output also says that no TPU backend is found, although I have verified that JAX is able to see the 8 TPUs.Specifically, I first acquire a TPU VM with the command
gcloud compute tpus tpu-vm create t5_test_3 --zone=europe-west4-a --accelerator-type=v3-8 --version=tpu-vm-tf-2.11.0
.Then, I login to the VM using
gcloud alpha compute tpus tpu-vm ssh t5_test_3 --zone=europe-west4-a
.Now, I get T5X and dependencies with
The directories are set with
Now I run the pretrain command
And the following is the output
Note the line at
20:23:28.534243
that saysI0106 20:23:28.534243 140414174514240 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
I wonder if this is the cause of the
bad_alloc
error and if there is a way to fix it. Thanks in advance!