Open zurcnilva213 opened 1 year ago
GKE is still on the legacy TPU Node architecture, which are different from TPU VMs (see this blog post for a short summary of the difference). I think you're getting that particular error because you're trying to use a TPU VM setup on a TPU Node. The jax[tpu]
install doesn't work on TPU Nodes, only TPU VMs. I think a regular pip install jax jaxlib
would get around that error.
However, as discussed in https://github.com/google/jax/issues/12917, JAX doesn't work very well on TPU Nodes, and at this point has only best-effort support. (Thanks @OrenLeung for digging that thread up!) See that thread for options on how to proceed. Feel free to ask about any of them here.
Description
We are trying to run JAX on GKE Nodes. https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_nodes_4
The docker image works fine on the TPU VM instances but the same one never works when we try to run it on the Google Kubernetes Engine TPU Nodes.
Following is a part of the error log from GKE.
_
_
As we were able to run the TensorFlow base image and that was able to find the cloud TPU devices following this link https://cloud.google.com/tpu/docs/kubernetes-engine-setup#job-spec
So the issue is that JAX is not able to find the cloud TPU devices properly for some reason.
What jax/jaxlib version are you using?
0.2.16
Which accelerator(s) are you using?
TPU Node supported by GKE
Additional system info
Python 3.7
NVIDIA GPU info
No response