jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.65k stars 2.82k forks source link

cloud_tpu_init failed #13409

Open krahnikblis opened 2 years ago

krahnikblis commented 2 years ago

Description

here's what i ran (but why should anyone using google colab service which offers google tpu hardware need to separately install google jax special software versions to get the paid service to work?):

#@title JAX TPU setup for Colab - run this then restart the runtime if output says to!
!pip install --upgrade jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install --upgrade jaxlib

#@title AFTER restart, check TPU status using this - should get 8 devices!
# https://github.com/googlecolab/colabtools/issues/3009 
# "It seems it is necessary to both install the latest versions of JAX manually and supply the TPU driver explicitly:"
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu('tpu_driver_20221122') # "explicitly" eh... ok just random pick some recent date... 

jax.local_devices()

on the "jax.local_devices()" command this is the error:
AttributeError: module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'

/usr/local/lib/python3.7/dist-packages/jax/__init__.py:27: UserWarning: cloud_tpu_init failed: KeyError('')
 This a JAX bug; please report an issue at https://github.com/google/jax/issues
  _warn(f"cloud_tpu_init failed: {repr(exc)}\n This a JAX bug; please report "
---------------------------------------------------------------------------
.. derp derp buncha nonsense ..
RuntimeError: Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client' (set JAX_PLATFORMS='' to automatically choose an available backend)

What jax/jaxlib version are you using?

the one in colab pro, and also whatever happens with "pip install --upgrade"

Which accelerator(s) are you using?

TPU in colab pro

Additional system info

google colab pro with google tpu and google jax

NVIDIA GPU info

No response

jakevdp commented 2 years ago

Hi, thanks for the question! Note that the instructions in https://github.com/googlecolab/colabtools/issues/3009 are for working around a previous bug that has been fixed in current jax/jaxlib/libtpu versions. You shouldn't need any special instructions now beyond connecting to a Colab TPU runtime and running jax.tools.colab_tpu.setup_tpu() (if you've already changed package installations in your runtime, choose Runtime->Disconnect and Delete Runtime and then reconnect to get a new runtime with default settings & packages).

Let me know if that doesn't solve your issue, or if there are other considerations at play.

joytianya commented 2 years ago

i try it and happen this on colab jax version: 0.3.25, error,

import jax
jax.tools.colab_tpu.setup_tpu()
/usr/local/lib/python3.8/dist-packages/jax/__init__.py:27: UserWarning: cloud_tpu_init failed: KeyError('')
 This a JAX bug; please report an issue at https://github.com/google/jax/issues
  _warn(f"cloud_tpu_init failed: {repr(exc)}\n This a JAX bug; please report "
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-8-c0a8b7c25ce1>](https://localhost:8080/#) in <module>
      1 import jax
----> 2 jax.tools.colab_tpu.setup_tpu()

AttributeError: module 'jax' has no attribute 'tools'
jakevdp commented 2 years ago

The jax.tools submodule is not imported by default into the JAX namespace. Try this instead:

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

See https://github.com/google/jax#pip-installation-colab-tpu, where this import is recommended.

krahnikblis commented 1 year ago

something is amiss and i have trouble chasing it, because it's sporadic in behavior and seems to affect different things...

  1. i'm able to get notebooks running in the TPU type, using the imports and setups pasted below. they are copypasta from various searches on this issue; this particular combination gives me the least grief (but ideally users shouldn't need any of it, other than the import right? shouldn't google software and google hardware work out-of-box with the google service that has a literal drop-down selector for TPU?).
  2. HOWEVER, sometimes after running the script below, there's the error reported in the linked issue, which says something like "didn't connect to device in time". only solution to that is to restart the notebook runtime. it was when i was trying to work with this behavior that i found the threads describing how to get nightly drivers or run various wizardry to get the TPUs recognized, and from that madness i ultimately got the issue i posted here about. reverting from that madness, i still have issues described here. i.e., the "this is a bug" error happened when trying to work around other errors and bad behaviors - so i can avoid it, but i can't avoid these...
  3. very often, during compile and run, compile takes a really long time and will often hang completely, requiring another runtime restart. this can happen with something as stupidly simple as splitting a PRNG, or as large as a complex CV model and its params replicated over all TPUs and nearing their max RAM. it can hang immediately on calling the cell, or after an hour of happily-fast training loops. or training is successful and things are fine. or i stop training, adjust something, and restart, and more often than not, it will hang during compile that time. and for you-know-whats-and-giggles, a compiled function will randomly decide it's time to recompile for a few minutes (and maybe hang), seemingly only to charge more "compute units"...
  4. when that hang behavior happens, the readout viewable in Runtime > View Runtime Logs will print this every 10 seconds for as long as you let it hang: "W external/org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc:617] TPU Execute is taking a long time. This might be due to a deadlock between multiple TPU cores or a very slow program."
  5. i find the above message odd because i'm using JAX and Flax, not Tensorflow, but perhaps amidst the millions of wrappers one of the libraries is importing from TF? and if the 2 reasons it gives are the only ones, then it seems it's "deadlock between multiple TPU cores" because when things are running, everything is fine and it's not "a very slow program"

here's what i'm using to get things started:

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
import jax
import os
os.environ["USE_FLAX"] = "1" 
os.environ["XLA_USE_BF16"] = "1"
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
zurcnilva213 commented 1 year ago

I am facing same error on TPU node. There is any way to setup tpu on TPU node?

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

This is working on TPU node also?

skye commented 1 year ago

import jax.tools.colab_tpu; jax.tools.colab_tpu.setup_tpu() should be all you need to get running on a TPU node on Colab. However, it's very possible there are other bugs once you start running things, as it looks like @krahnikblis is running into.

@krahnikblis, if you can provide an example notebook(s) demonstrating these problems I can try to take a look.

If possible, I also recommend trying Kaggle Notebooks (https://www.kaggle.com/code, click on "New Notebook" near the top). You have to create an account and log in to get accelerator support. Once you do that, there's a new "TPU VM v3-8" accelerator option. This gives you a TPU notebook environment similar to Colab, but using the newer TPU VM architecture. This should be a less buggy, more performant, and overall better experience than the older TPU Node architecture (see this blog post for a brief overview of the difference). I'd be interested to hear your feedback if you give it a shot.

krahnikblis commented 1 year ago

so yesterday i got a bunch more errors and troubles trying to use a previously-working notebook with TPU. i suspect some change in backend stuff in Colab... the version of Jax that is pre-installed in Colab seems a bit old - version 0.3.25 was what i got when i started a new session, but the latest pypl version is 0.4.1, so i gave that a go. i also noticed that the default value in the setup_tpu function was a driver dated in November. the default with the 0.4.1 version was something in December, but it also didn't work - both versions of Jax and driver defaults gave the same "timed out while waiting for dependency" error and "no TPU devices found". only with 0.4.1 and setting the driver version to "tpu_driver_nightly" worked.

here's what finally worked, yesterday anyway:

!pip install -U jax
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu(tpu_driver_version='tpu_driver_nightly')
skye commented 1 year ago

In general the driver version should be from the same day that the jaxlib version was released (https://jax.readthedocs.io/en/latest/changelog.html). The tpu_driver_nightly version is automatically updated every night, so I recommend pinning to today's day (which I believe should be the same bits as the nightly, until tomorrow at least) instead of relying on the nightly version.

I would expect all the default versions shipped with Colab to "just work" without having to update anything though. It sounds like maybe we should update the default versions if you find newer versions working better.