Open krahnikblis opened 2 years ago
Hi, thanks for the question! Note that the instructions in https://github.com/googlecolab/colabtools/issues/3009 are for working around a previous bug that has been fixed in current jax/jaxlib/libtpu versions. You shouldn't need any special instructions now beyond connecting to a Colab TPU runtime and running jax.tools.colab_tpu.setup_tpu()
(if you've already changed package installations in your runtime, choose Runtime->Disconnect and Delete Runtime and then reconnect to get a new runtime with default settings & packages).
Let me know if that doesn't solve your issue, or if there are other considerations at play.
i try it and happen this on colab jax version: 0.3.25, error,
import jax
jax.tools.colab_tpu.setup_tpu()
/usr/local/lib/python3.8/dist-packages/jax/__init__.py:27: UserWarning: cloud_tpu_init failed: KeyError('')
This a JAX bug; please report an issue at https://github.com/google/jax/issues
_warn(f"cloud_tpu_init failed: {repr(exc)}\n This a JAX bug; please report "
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
[<ipython-input-8-c0a8b7c25ce1>](https://localhost:8080/#) in <module>
1 import jax
----> 2 jax.tools.colab_tpu.setup_tpu()
AttributeError: module 'jax' has no attribute 'tools'
The jax.tools
submodule is not imported by default into the JAX namespace. Try this instead:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
See https://github.com/google/jax#pip-installation-colab-tpu, where this import is recommended.
something is amiss and i have trouble chasing it, because it's sporadic in behavior and seems to affect different things...
here's what i'm using to get things started:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
import jax
import os
os.environ["USE_FLAX"] = "1"
os.environ["XLA_USE_BF16"] = "1"
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
I am facing same error on TPU node. There is any way to setup tpu on TPU node?
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
This is working on TPU node also?
import jax.tools.colab_tpu; jax.tools.colab_tpu.setup_tpu()
should be all you need to get running on a TPU node on Colab. However, it's very possible there are other bugs once you start running things, as it looks like @krahnikblis is running into.
@krahnikblis, if you can provide an example notebook(s) demonstrating these problems I can try to take a look.
If possible, I also recommend trying Kaggle Notebooks (https://www.kaggle.com/code, click on "New Notebook" near the top). You have to create an account and log in to get accelerator support. Once you do that, there's a new "TPU VM v3-8" accelerator option. This gives you a TPU notebook environment similar to Colab, but using the newer TPU VM architecture. This should be a less buggy, more performant, and overall better experience than the older TPU Node architecture (see this blog post for a brief overview of the difference). I'd be interested to hear your feedback if you give it a shot.
so yesterday i got a bunch more errors and troubles trying to use a previously-working notebook with TPU. i suspect some change in backend stuff in Colab... the version of Jax that is pre-installed in Colab seems a bit old - version 0.3.25 was what i got when i started a new session, but the latest pypl version is 0.4.1, so i gave that a go. i also noticed that the default value in the setup_tpu function was a driver dated in November. the default with the 0.4.1 version was something in December, but it also didn't work - both versions of Jax and driver defaults gave the same "timed out while waiting for dependency" error and "no TPU devices found". only with 0.4.1 and setting the driver version to "tpu_driver_nightly" worked.
here's what finally worked, yesterday anyway:
!pip install -U jax
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu(tpu_driver_version='tpu_driver_nightly')
In general the driver version should be from the same day that the jaxlib version was released (https://jax.readthedocs.io/en/latest/changelog.html). The tpu_driver_nightly
version is automatically updated every night, so I recommend pinning to today's day (which I believe should be the same bits as the nightly, until tomorrow at least) instead of relying on the nightly version.
I would expect all the default versions shipped with Colab to "just work" without having to update anything though. It sounds like maybe we should update the default versions if you find newer versions working better.
Description
here's what i ran (but why should anyone using google colab service which offers google tpu hardware need to separately install google jax special software versions to get the paid service to work?):
What jax/jaxlib version are you using?
the one in colab pro, and also whatever happens with "pip install --upgrade"
Which accelerator(s) are you using?
TPU in colab pro
Additional system info
google colab pro with google tpu and google jax
NVIDIA GPU info
No response