Unable to use TPU on Google Colab #19274

Open BrandonStudio opened 8 months ago

BrandonStudio commented 8 months ago

Bug description

Pytorch-Lightning Trainer does not find TPU

What version are you seeing the problem on?


How to reproduce the bug

  1. Install torch_xla as guidance and then install lightning:

    !pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 -f
    !pip install "jax[tpu]>=0.2.16" -f --upgrade
    !pip install lightning
  2. Restart session as prompted

  3. Run code

    import torch_xla # no error
    import pytorch_lightning as pl # no error
    trainer = pl.trainer(accelerator="tpu") # error occurred

Error messages and logs

WARNING:root:PJRT is now the default runtime. For more information, see
WARNING:root:Defaulting to PJRT_DEVICE=CPU
MisconfigurationException                 Traceback (most recent call last)
[<ipython-input-1-ce9cc1967aee>](https://localhost:8080/#) in <cell line: 3>()
      1 import torch_xla
      2 import pytorch_lightning as pl
----> 3 trainer = pl.Trainer(accelerator="tpu")

3 frames
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/](https://localhost:8080/#) in _set_parallel_devices_and_init_accelerator(self)
    379                 if AcceleratorRegistry[acc_str]["accelerator"].is_available()
    380             ]
--> 381             raise MisconfigurationException(
    382                 f"`{accelerator_cls.__qualname__}` can not run on your system"
    383                 " since the accelerator is not available. The following accelerator(s)"

MisconfigurationException: `XLAAccelerator` can not run on your system since the accelerator is not available. The following accelerator(s) is available and can be passed into `accelerator` argument of `Trainer`: ['cpu'].

Setting PJRT_DEVICE to TPU does not help.


More info


import torch_xla.core.xla_model as xm

could get a result

device(type='xla', index=0)

or get an error with stacktrace:

```python RuntimeError Traceback (most recent call last) [](https://localhost:8080/#) in () 1 import torch_xla.core.xla_model as xm ----> 2 xm.xla_device() 2 frames [/usr/local/lib/python3.10/dist-packages/torch_xla/core/](https://localhost:8080/#) in xla_device(n, devkind) 195 return torch.device(device) 196 --> 197 return runtime.xla_device(n, devkind) 198 199 [/usr/local/lib/python3.10/dist-packages/torch_xla/](https://localhost:8080/#) in wrapper(*args, **kwargs) 80 fn.__name__)) 81 ---> 82 return fn(*args, **kwargs) 83 84 return wrapper [/usr/local/lib/python3.10/dist-packages/torch_xla/](https://localhost:8080/#) in xla_device(n, devkind) 109 """ 110 if n is None: --> 111 return torch.device(torch_xla._XLAC._xla_get_default_device()) 112 113 devices = xm.get_xla_supported_devices(devkind=devkind) RuntimeError: torch_xla/csrc/runtime/ : Check failed: tpu_status.ok() *** Begin stack trace *** tsl::CurrentStackTrace() torch_xla::runtime::PjRtComputationClient::PjRtComputationClient() torch_xla::runtime::GetComputationClient() torch_xla::ParseDeviceString(std::string const&) torch_xla::GetDefaultDevice() torch_xla::GetCurrentDevice() torch_xla::bridge::GetCurrentAtenDevice() _PyObject_MakeTpCall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault PyEval_EvalCode _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault PyObject_Call _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault PyEval_EvalCode _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall Py_RunMain Py_BytesMain __libc_start_main _start *** End stack trace *** ```

There may be something different if swapping the installation order of torch_xla[tpu] and jax[tpu], but in both cases lightning does not recognize tpus.

carmocca commented 6 months ago

Hi @BrandonStudio. Can you print out the result of this?

This is the function that checks if xla is available. It would be useful to know what is happening there in your environment

BrandonStudio commented 6 months ago

I apologize that above code should be pl.Trainer rather than pl.trainer.


``` --------------------------------------------------------------------------- HTTPError Traceback (most recent call last) [/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/](https://localhost:8080/#) in version() 177 try: --> 178 env = get_tpu_env() 179 except requests.HTTPError as e: 6 frames [/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/](https://localhost:8080/#) in get_tpu_env() 171 return build_tpu_env_from_vars() --> 172 metadata = _get_metadata('tpu-env') 173 return yaml.load(metadata, yaml.Loader) [/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/](https://localhost:8080/#) in _get_metadata(key) 80 resp = requests.get(path, headers={'Metadata-Flavor': 'Google'}) ---> 81 resp.raise_for_status() 82 [/usr/local/lib/python3.10/dist-packages/requests/](https://localhost:8080/#) in raise_for_status(self) 1020 if http_error_msg: -> 1021 raise HTTPError(http_error_msg, response=self) 1022 HTTPError: 404 Client Error: Not Found for url: The above exception was the direct cause of the following exception: OSError Traceback (most recent call last) [](https://localhost:8080/#) in () 1 from torch_xla._internal import tpu ----> 2 tpu.num_available_devices() [/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/](https://localhost:8080/#) in num_available_devices() 117 before `xmp.spawn`. 118 """ --> 119 return num_available_chips() * num_logical_cores_per_chip() 120 121 [/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/](https://localhost:8080/#) in num_logical_cores_per_chip() 108 def num_logical_cores_per_chip() -> int: 109 """Returns number of XLA TPU devices per physical chip on the current host.""" --> 110 return 2 if version() <= 3 else 1 111 112 [/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/](https://localhost:8080/#) in version() 178 env = get_tpu_env() 179 except requests.HTTPError as e: --> 180 raise EnvironmentError('Failed to get TPU metadata') from e 181 182 match = re.match(r'^v(\d)([A-Za-z]?){7}-(\d+)$', env[xenv.ACCELERATOR_TYPE]) OSError: Failed to get TPU metadata ```

This is just Google Colab, and everyone can do this

carmocca commented 6 months ago

@BrandonStudio Can you report this in This seems to be a Colab issue or a PyTorch XLA issue.

ckwastra commented 6 months ago

Inspired by this reply, use the following setup:

!pip install torch==2.0.0
!pip install cloud-tpu-client
!pip install
!pip install lightning==2.0.0

Install specific versions (2.0.0) of torch, torch_xla, and lightning. The Trainer output for the given setup was as follows:

INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: True, using: 8 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: True, using: 8 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
BrandonStudio commented 6 months ago

@ckwastra Thank you for your solution. But lightning==2.0.0 does not work on lagacy TPU runtime.

Taking all comments together, I give following solution:

%pip install torch==2.0.0 torchaudio torchdata torchtext torchvision cloud-tpu-client "pytorch_lightning<2"

@carmocca Could you add this to docs?