jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.06k stars 2.75k forks source link

Can't initialize cloud TPU on GKE TPU Nodes #13969

Open zurcnilva213 opened 1 year ago

zurcnilva213 commented 1 year ago

Description

We are trying to run JAX on GKE Nodes. https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_nodes_4

The docker image works fine on the TPU VM instances but the same one never works when we try to run it on the Google Kubernetes Engine TPU Nodes.

Following is a part of the error log from GKE.

_

Traceback (most recent call last):\n File \"/usr/local/lib/python3.9/dist-packages/jax/_src/lib/xla_bridge.py\", line 335, in backends\n backend = _init_backend(platform)\n File \"/usr/local/lib/python3.9/dist-packages/jax/_src/lib/xla_bridge.py\", line 387, in _init_backend\n backend = factory()\n File \"/usr/local/lib/python3.9/dist-packages/jax/_src/lib/xla_bridge.py\", line 191, in tpu_client_timer_callback\n client = xla_client.make_tpu_client()\n File \"/usr/local/lib/python3.9/dist-packages/jaxlib/xla_client.py\", line 126, in make_tpu_client\n return _xla.get_tpu_client(\njaxlib.xla_extension.XlaRuntimeError: NOT_FOUND: No ba16c7433 device found."

_

As we were able to run the TensorFlow base image and that was able to find the cloud TPU devices following this link https://cloud.google.com/tpu/docs/kubernetes-engine-setup#job-spec

So the issue is that JAX is not able to find the cloud TPU devices properly for some reason.

What jax/jaxlib version are you using?

0.2.16

Which accelerator(s) are you using?

TPU Node supported by GKE

Additional system info

Python 3.7

NVIDIA GPU info

No response

OrenLeung commented 1 year ago

https://github.com/google/jax/issues/12917

skye commented 1 year ago

GKE is still on the legacy TPU Node architecture, which are different from TPU VMs (see this blog post for a short summary of the difference). I think you're getting that particular error because you're trying to use a TPU VM setup on a TPU Node. The jax[tpu] install doesn't work on TPU Nodes, only TPU VMs. I think a regular pip install jax jaxlib would get around that error.

However, as discussed in https://github.com/google/jax/issues/12917, JAX doesn't work very well on TPU Nodes, and at this point has only best-effort support. (Thanks @OrenLeung for digging that thread up!) See that thread for options on how to proceed. Feel free to ask about any of them here.