Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.93k stars 3.34k forks source link

Unable to use TPU on Google Colab #19274

Open BrandonStudio opened 8 months ago

BrandonStudio commented 8 months ago

Bug description

Pytorch-Lightning Trainer does not find TPU

What version are you seeing the problem on?

v2.1

How to reproduce the bug

  1. Install torch_xla as guidance and then install lightning:

    !pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
    !pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --upgrade
    !pip install lightning
  2. Restart session as prompted

  3. Run code

    import torch_xla # no error
    import pytorch_lightning as pl # no error
    trainer = pl.trainer(accelerator="tpu") # error occurred

Error messages and logs

WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:Defaulting to PJRT_DEVICE=CPU
---------------------------------------------------------------------------
MisconfigurationException                 Traceback (most recent call last)
[<ipython-input-1-ce9cc1967aee>](https://localhost:8080/#) in <cell line: 3>()
      1 import torch_xla
      2 import pytorch_lightning as pl
----> 3 trainer = pl.Trainer(accelerator="tpu")

3 frames
[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py](https://localhost:8080/#) in _set_parallel_devices_and_init_accelerator(self)
    379                 if AcceleratorRegistry[acc_str]["accelerator"].is_available()
    380             ]
--> 381             raise MisconfigurationException(
    382                 f"`{accelerator_cls.__qualname__}` can not run on your system"
    383                 " since the accelerator is not available. The following accelerator(s)"

MisconfigurationException: `XLAAccelerator` can not run on your system since the accelerator is not available. The following accelerator(s) is available and can be passed into `accelerator` argument of `Trainer`: ['cpu'].

Setting PJRT_DEVICE to TPU does not help.

Environment

Current environment * CUDA: - GPU: None - available: False - version: 12.1 * Lightning: - lightning: 2.1.3 - lightning-utilities: 0.10.0 - pytorch-lightning: 2.1.3 - torch: 2.1.0+cu121 - torch-xla: 2.1.0 - torchaudio: 2.1.0+cu121 - torchdata: 0.7.0 - torchmetrics: 1.3.0 - torchsummary: 1.5.1 - torchtext: 0.16.0 - torchvision: 0.16.0+cu121 * Packages: - absl-py: 1.4.0 - aiohttp: 3.9.1 - aiosignal: 1.3.1 - alabaster: 0.7.13 - albumentations: 1.3.1 - altair: 4.2.2 - anyio: 3.7.1 - appdirs: 1.4.4 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - array-record: 0.5.0 - arviz: 0.15.1 - astropy: 5.3.4 - astunparse: 1.6.3 - async-timeout: 4.0.3 - atpublic: 4.0 - attrs: 23.2.0 - audioread: 3.0.1 - autograd: 1.6.2 - babel: 2.14.0 - backcall: 0.2.0 - beautifulsoup4: 4.11.2 - bidict: 0.22.1 - bigframes: 0.18.0 - bleach: 6.1.0 - blinker: 1.4 - blis: 0.7.11 - blosc2: 2.0.0 - bokeh: 3.3.2 - bqplot: 0.12.42 - branca: 0.7.0 - build: 1.0.3 - cachecontrol: 0.13.1 - cached-property: 1.5.2 - cachetools: 5.3.2 - catalogue: 2.0.10 - certifi: 2023.11.17 - cffi: 1.16.0 - chardet: 5.2.0 - charset-normalizer: 3.3.2 - chex: 0.1.6 - click: 8.1.7 - click-plugins: 1.1.1 - cligj: 0.7.2 - cloud-tpu-client: 0.10 - cloudpickle: 2.2.1 - cmake: 3.27.9 - cmdstanpy: 1.2.0 - colorcet: 3.0.1 - colorlover: 0.3.0 - colour: 0.1.5 - community: 1.0.0b1 - confection: 0.1.4 - cons: 0.4.6 - contextlib2: 21.6.0 - contourpy: 1.2.0 - cryptography: 41.0.7 - cufflinks: 0.17.3 - cupy-cuda12x: 12.2.0 - cvxopt: 1.3.2 - cvxpy: 1.3.2 - cycler: 0.12.1 - cymem: 2.0.8 - cython: 3.0.7 - dask: 2023.8.1 - datascience: 0.17.6 - db-dtypes: 1.2.0 - dbus-python: 1.2.18 - debugpy: 1.6.6 - decorator: 4.4.2 - defusedxml: 0.7.1 - diskcache: 5.6.3 - distributed: 2023.8.1 - distro: 1.7.0 - dlib: 19.24.2 - dm-tree: 0.1.8 - docutils: 0.18.1 - dopamine-rl: 4.0.6 - duckdb: 0.9.2 - earthengine-api: 0.1.384 - easydict: 1.11 - ecos: 2.0.12 - editdistance: 0.6.2 - eerepr: 0.0.4 - en-core-web-sm: 3.6.0 - entrypoints: 0.4 - et-xmlfile: 1.1.0 - etils: 1.6.0 - etuples: 0.3.9 - exceptiongroup: 1.2.0 - fastai: 2.7.13 - fastcore: 1.5.29 - fastdownload: 0.0.7 - fastjsonschema: 2.19.1 - fastprogress: 1.0.3 - fastrlock: 0.8.2 - filelock: 3.13.1 - fiona: 1.9.5 - firebase-admin: 5.3.0 - flask: 2.2.5 - flatbuffers: 23.5.26 - flax: 0.6.4 - folium: 0.14.0 - fonttools: 4.47.0 - frozendict: 2.4.0 - frozenlist: 1.4.1 - fsspec: 2023.6.0 - future: 0.18.3 - gast: 0.4.0 - gcsfs: 2023.6.0 - gdal: 3.4.3 - gdown: 4.6.6 - geemap: 0.30.0 - gensim: 4.3.2 - geocoder: 1.38.1 - geographiclib: 2.0 - geopandas: 0.13.2 - geopy: 2.3.0 - gin-config: 0.5.0 - glob2: 0.7 - google: 2.0.3 - google-ai-generativelanguage: 0.4.0 - google-api-core: 1.34.0 - google-api-python-client: 1.8.0 - google-auth: 2.17.3 - google-auth-httplib2: 0.1.1 - google-auth-oauthlib: 0.4.6 - google-cloud-aiplatform: 1.39.0 - google-cloud-bigquery: 3.12.0 - google-cloud-bigquery-connection: 1.12.1 - google-cloud-bigquery-storage: 2.24.0 - google-cloud-core: 2.3.3 - google-cloud-datastore: 2.15.2 - google-cloud-firestore: 2.11.1 - google-cloud-functions: 1.13.3 - google-cloud-iam: 2.13.0 - google-cloud-language: 2.9.1 - google-cloud-resource-manager: 1.11.0 - google-cloud-storage: 2.8.0 - google-cloud-translate: 3.11.3 - google-colab: 1.0.0 - google-crc32c: 1.5.0 - google-generativeai: 0.3.2 - google-pasta: 0.2.0 - google-resumable-media: 2.7.0 - googleapis-common-protos: 1.62.0 - googledrivedownloader: 0.4 - graphviz: 0.20.1 - greenlet: 3.0.3 - grpc-google-iam-v1: 0.13.0 - grpcio: 1.60.0 - grpcio-status: 1.48.2 - gspread: 3.4.2 - gspread-dataframe: 3.3.1 - gym: 0.25.2 - gym-notices: 0.0.8 - h5netcdf: 1.3.0 - h5py: 3.9.0 - holidays: 0.40 - holoviews: 1.17.1 - html5lib: 1.1 - httpimport: 1.3.1 - httplib2: 0.22.0 - huggingface-hub: 0.20.2 - humanize: 4.7.0 - hyperopt: 0.2.7 - ibis-framework: 7.1.0 - idna: 3.6 - imageio: 2.31.6 - imageio-ffmpeg: 0.4.9 - imagesize: 1.4.1 - imbalanced-learn: 0.10.1 - imgaug: 0.4.0 - importlib-metadata: 7.0.1 - importlib-resources: 6.1.1 - imutils: 0.5.4 - inflect: 7.0.0 - iniconfig: 2.0.0 - install: 1.3.5 - intel-openmp: 2023.2.3 - ipyevents: 2.0.2 - ipyfilechooser: 0.6.0 - ipykernel: 5.5.6 - ipyleaflet: 0.18.1 - ipython: 7.34.0 - ipython-genutils: 0.2.0 - ipython-sql: 0.5.0 - ipytree: 0.2.2 - ipywidgets: 7.7.1 - itsdangerous: 2.1.2 - jax: 0.4.23 - jaxlib: 0.4.23 - jedi: 0.19.1 - jeepney: 0.7.1 - jieba: 0.42.1 - jinja2: 3.1.2 - joblib: 1.3.2 - jsonpickle: 3.0.2 - jsonschema: 4.19.2 - jsonschema-specifications: 2023.12.1 - jupyter-client: 6.1.12 - jupyter-console: 6.1.0 - jupyter-core: 5.7.0 - jupyter-server: 1.24.0 - jupyterlab-pygments: 0.3.0 - jupyterlab-widgets: 3.0.9 - kaggle: 1.5.16 - kagglehub: 0.1.4 - keras: 2.12.0 - keyring: 23.5.0 - kiwisolver: 1.4.5 - langcodes: 3.3.0 - launchpadlib: 1.10.16 - lazr.restfulclient: 0.14.4 - lazr.uri: 1.0.6 - lazy-loader: 0.3 - libclang: 16.0.6 - librosa: 0.10.1 - libtpu-nightly: 0.1.dev20231213 - lida: 0.0.10 - lightgbm: 4.1.0 - lightning: 2.1.3 - lightning-utilities: 0.10.0 - linkify-it-py: 2.0.2 - llmx: 0.0.15a0 - llvmlite: 0.41.1 - locket: 1.0.0 - logical-unification: 0.4.6 - lxml: 4.9.4 - malloy: 2023.1067 - markdown: 3.5.1 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - matplotlib-venn: 0.11.9 - mdit-py-plugins: 0.4.0 - mdurl: 0.1.2 - minikanren: 1.0.3 - missingno: 0.5.2 - mistune: 0.8.4 - mizani: 0.9.3 - mkl: 2023.2.0 - ml-dtypes: 0.2.0 - mlxtend: 0.22.0 - more-itertools: 10.1.0 - moviepy: 1.0.3 - mpmath: 1.3.0 - msgpack: 1.0.7 - multidict: 6.0.4 - multipledispatch: 1.0.0 - multitasking: 0.0.11 - murmurhash: 1.0.10 - music21: 9.1.0 - natsort: 8.4.0 - nbclassic: 1.0.0 - nbclient: 0.9.0 - nbconvert: 6.5.4 - nbformat: 5.9.2 - nest-asyncio: 1.5.8 - networkx: 3.2.1 - nibabel: 4.0.2 - nltk: 3.8.1 - notebook: 6.5.5 - notebook-shim: 0.2.3 - numba: 0.58.1 - numexpr: 2.8.8 - numpy: 1.23.5 - oauth2client: 4.1.3 - oauthlib: 3.2.2 - opencv-contrib-python: 4.8.0.76 - opencv-python: 4.8.0.76 - opencv-python-headless: 4.9.0.80 - openpyxl: 3.1.2 - opt-einsum: 3.3.0 - optax: 0.1.7 - orbax: 0.1.2 - orbax-checkpoint: 0.4.4 - osqp: 0.6.2.post8 - packaging: 23.2 - pandas: 1.5.3 - pandas-datareader: 0.10.0 - pandas-gbq: 0.19.2 - pandas-stubs: 1.5.3.230304 - pandocfilters: 1.5.0 - panel: 1.3.6 - param: 2.0.1 - parso: 0.8.3 - parsy: 2.1 - partd: 1.4.1 - pathlib: 1.0.1 - pathy: 0.10.3 - patsy: 0.5.6 - peewee: 3.17.0 - pexpect: 4.9.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pins: 0.8.4 - pip: 23.1.2 - pip-tools: 6.13.0 - platformdirs: 4.1.0 - plotly: 5.15.0 - plotnine: 0.12.4 - pluggy: 1.3.0 - polars: 0.17.3 - pooch: 1.8.0 - portpicker: 1.5.2 - prefetch-generator: 1.0.3 - preshed: 3.0.9 - prettytable: 3.9.0 - proglog: 0.1.10 - progressbar2: 4.2.0 - prometheus-client: 0.19.0 - promise: 2.3 - prompt-toolkit: 3.0.43 - prophet: 1.1.5 - proto-plus: 1.23.0 - protobuf: 3.20.3 - psutil: 5.9.5 - psycopg2: 2.9.9 - ptyprocess: 0.7.0 - py-cpuinfo: 9.0.0 - py4j: 0.10.9.7 - pyarrow: 10.0.1 - pyarrow-hotfix: 0.6 - pyasn1: 0.5.1 - pyasn1-modules: 0.3.0 - pycocotools: 2.0.7 - pycparser: 2.21 - pyct: 0.5.0 - pydantic: 1.10.13 - pydata-google-auth: 1.8.2 - pydot: 1.4.2 - pydot-ng: 2.0.0 - pydotplus: 2.0.2 - pydrive: 1.3.1 - pydrive2: 1.6.3 - pyerfa: 2.0.1.1 - pygame: 2.5.2 - pygments: 2.16.1 - pygobject: 3.42.1 - pyjwt: 2.3.0 - pymc: 5.7.2 - pymystem3: 0.2.0 - pyopengl: 3.1.7 - pyopenssl: 23.3.0 - pyparsing: 3.1.1 - pyperclip: 1.8.2 - pyproj: 3.6.1 - pyproject-hooks: 1.0.0 - pyshp: 2.3.1 - pysocks: 1.7.1 - pytensor: 2.14.2 - pytest: 7.4.4 - python-apt: 0.0.0 - python-box: 7.1.1 - python-dateutil: 2.8.2 - python-louvain: 0.16 - python-slugify: 8.0.1 - python-utils: 3.8.1 - pytorch-lightning: 2.1.3 - pytz: 2023.3.post1 - pyviz-comms: 3.0.0 - pywavelets: 1.5.0 - pyyaml: 6.0.1 - pyzmq: 23.2.1 - qdldl: 0.1.7.post0 - qudida: 0.0.4 - ratelim: 0.1.6 - referencing: 0.32.1 - regex: 2023.6.3 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - requirements-parser: 0.5.0 - rich: 13.7.0 - rpds-py: 0.16.2 - rpy2: 3.4.2 - rsa: 4.9 - safetensors: 0.4.1 - scikit-image: 0.19.3 - scikit-learn: 1.2.2 - scipy: 1.11.4 - scooby: 0.9.2 - scs: 3.2.4.post1 - seaborn: 0.12.2 - secretstorage: 3.3.1 - send2trash: 1.8.2 - setuptools: 67.7.2 - shapely: 2.0.2 - six: 1.16.0 - sklearn-pandas: 2.2.0 - smart-open: 6.4.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - sortedcontainers: 2.4.0 - soundfile: 0.12.1 - soupsieve: 2.5 - soxr: 0.3.7 - spacy: 3.6.1 - spacy-legacy: 3.0.12 - spacy-loggers: 1.0.5 - sphinx: 5.0.2 - sphinxcontrib-applehelp: 1.0.7 - sphinxcontrib-devhelp: 1.0.5 - sphinxcontrib-htmlhelp: 2.0.4 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 1.0.6 - sphinxcontrib-serializinghtml: 1.1.9 - sqlalchemy: 2.0.24 - sqlglot: 19.9.0 - sqlparse: 0.4.4 - srsly: 2.4.8 - stanio: 0.3.0 - statsmodels: 0.14.1 - sympy: 1.12 - tables: 3.8.0 - tabulate: 0.9.0 - tbb: 2021.11.0 - tblib: 3.0.0 - tenacity: 8.2.3 - tensorboard: 2.12.0 - tensorboard-data-server: 0.7.2 - tensorboard-plugin-wit: 1.8.1 - tensorflow: 2.12.0 - tensorflow-datasets: 4.9.4 - tensorflow-estimator: 2.12.0 - tensorflow-gcs-config: 2.12.0 - tensorflow-hub: 0.15.0 - tensorflow-io-gcs-filesystem: 0.35.0 - tensorflow-metadata: 1.14.0 - tensorflow-probability: 0.22.0 - tensorstore: 0.1.45 - termcolor: 2.4.0 - terminado: 0.18.0 - text-unidecode: 1.3 - textblob: 0.17.1 - tf-slim: 1.1.0 - thinc: 8.1.12 - threadpoolctl: 3.2.0 - tifffile: 2023.12.9 - tinycss2: 1.2.1 - tokenizers: 0.15.0 - toml: 0.10.2 - tomli: 2.0.1 - toolz: 0.12.0 - torch: 2.1.0+cu121 - torch-xla: 2.1.0 - torchaudio: 2.1.0+cu121 - torchdata: 0.7.0 - torchmetrics: 1.3.0 - torchsummary: 1.5.1 - torchtext: 0.16.0 - torchvision: 0.16.0+cu121 - tornado: 6.3.2 - tqdm: 4.66.1 - traitlets: 5.7.1 - traittypes: 0.2.1 - transformers: 4.35.2 - triton: 2.1.0 - tweepy: 4.14.0 - typer: 0.9.0 - types-pytz: 2023.3.1.1 - types-setuptools: 69.0.0.20240106 - typing-extensions: 4.5.0 - tzlocal: 5.2 - uc-micro-py: 1.0.2 - uritemplate: 3.0.1 - urllib3: 2.0.7 - vega-datasets: 0.9.0 - wadllib: 1.3.6 - wasabi: 1.1.2 - wcwidth: 0.2.12 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.7.0 - werkzeug: 3.0.1 - wheel: 0.42.0 - widgetsnbextension: 3.6.6 - wordcloud: 1.9.3 - wrapt: 1.14.1 - xarray: 2023.7.0 - xarray-einstats: 0.6.0 - xgboost: 2.0.3 - xlrd: 2.0.1 - xxhash: 3.4.1 - xyzservices: 2023.10.1 - yarl: 1.9.4 - yellowbrick: 1.5 - yfinance: 0.2.33 - zict: 3.0.0 - zipp: 3.17.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.12 - release: 6.1.58+ - version: #1 SMP PREEMPT_DYNAMIC Sat Nov 18 15:31:17 UTC 2023

More info

Running

import torch_xla.core.xla_model as xm
xm.xla_device()

could get a result

device(type='xla', index=0)

or get an error with stacktrace:

```python RuntimeError Traceback (most recent call last) [](https://localhost:8080/#) in () 1 import torch_xla.core.xla_model as xm ----> 2 xm.xla_device() 2 frames [/usr/local/lib/python3.10/dist-packages/torch_xla/core/xla_model.py](https://localhost:8080/#) in xla_device(n, devkind) 195 return torch.device(device) 196 --> 197 return runtime.xla_device(n, devkind) 198 199 [/usr/local/lib/python3.10/dist-packages/torch_xla/runtime.py](https://localhost:8080/#) in wrapper(*args, **kwargs) 80 fn.__name__)) 81 ---> 82 return fn(*args, **kwargs) 83 84 return wrapper [/usr/local/lib/python3.10/dist-packages/torch_xla/runtime.py](https://localhost:8080/#) in xla_device(n, devkind) 109 """ 110 if n is None: --> 111 return torch.device(torch_xla._XLAC._xla_get_default_device()) 112 113 devices = xm.get_xla_supported_devices(devkind=devkind) RuntimeError: torch_xla/csrc/runtime/pjrt_computation_client.cc:108 : Check failed: tpu_status.ok() *** Begin stack trace *** tsl::CurrentStackTrace() torch_xla::runtime::PjRtComputationClient::PjRtComputationClient() torch_xla::runtime::GetComputationClient() torch_xla::ParseDeviceString(std::string const&) torch_xla::GetDefaultDevice() torch_xla::GetCurrentDevice() torch_xla::bridge::GetCurrentAtenDevice() _PyObject_MakeTpCall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault PyEval_EvalCode _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault PyObject_Call _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyEval_EvalFrameDefault PyEval_EvalCode _PyEval_EvalFrameDefault _PyFunction_Vectorcall _PyEval_EvalFrameDefault _PyFunction_Vectorcall Py_RunMain Py_BytesMain __libc_start_main _start *** End stack trace *** ```

There may be something different if swapping the installation order of torch_xla[tpu] and jax[tpu], but in both cases lightning does not recognize tpus.

cc @carmocca @JackCaoG @Liyang90 @gkroiz

carmocca commented 6 months ago

Hi @BrandonStudio. Can you print out the result of this? https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/accelerators/xla.py#L77-L79

This is the function that checks if xla is available. It would be useful to know what is happening there in your environment

BrandonStudio commented 6 months ago

I apologize that above code should be pl.Trainer rather than pl.trainer.

output:

``` --------------------------------------------------------------------------- HTTPError Traceback (most recent call last) [/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/tpu.py](https://localhost:8080/#) in version() 177 try: --> 178 env = get_tpu_env() 179 except requests.HTTPError as e: 6 frames [/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/tpu.py](https://localhost:8080/#) in get_tpu_env() 171 return build_tpu_env_from_vars() --> 172 metadata = _get_metadata('tpu-env') 173 return yaml.load(metadata, yaml.Loader) [/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/tpu.py](https://localhost:8080/#) in _get_metadata(key) 80 resp = requests.get(path, headers={'Metadata-Flavor': 'Google'}) ---> 81 resp.raise_for_status() 82 [/usr/local/lib/python3.10/dist-packages/requests/models.py](https://localhost:8080/#) in raise_for_status(self) 1020 if http_error_msg: -> 1021 raise HTTPError(http_error_msg, response=self) 1022 HTTPError: 404 Client Error: Not Found for url: http://metadata.google.internal/computeMetadata/v1/instance/attributes/tpu-env The above exception was the direct cause of the following exception: OSError Traceback (most recent call last) [](https://localhost:8080/#) in () 1 from torch_xla._internal import tpu ----> 2 tpu.num_available_devices() [/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/tpu.py](https://localhost:8080/#) in num_available_devices() 117 before `xmp.spawn`. 118 """ --> 119 return num_available_chips() * num_logical_cores_per_chip() 120 121 [/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/tpu.py](https://localhost:8080/#) in num_logical_cores_per_chip() 108 def num_logical_cores_per_chip() -> int: 109 """Returns number of XLA TPU devices per physical chip on the current host.""" --> 110 return 2 if version() <= 3 else 1 111 112 [/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/tpu.py](https://localhost:8080/#) in version() 178 env = get_tpu_env() 179 except requests.HTTPError as e: --> 180 raise EnvironmentError('Failed to get TPU metadata') from e 181 182 match = re.match(r'^v(\d)([A-Za-z]?){7}-(\d+)$', env[xenv.ACCELERATOR_TYPE]) OSError: Failed to get TPU metadata ```

This is just Google Colab, and everyone can do this

carmocca commented 6 months ago

@BrandonStudio Can you report this in https://github.com/pytorch/xla/issues? This seems to be a Colab issue or a PyTorch XLA issue.

ckwastra commented 6 months ago

Inspired by this reply, use the following setup:

!pip install torch==2.0.0
!pip install cloud-tpu-client
!pip install https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl
!pip install lightning==2.0.0

Install specific versions (2.0.0) of torch, torch_xla, and lightning. The Trainer output for the given setup was as follows:

INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: True, using: 8 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: True, using: 8 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
BrandonStudio commented 6 months ago

@ckwastra Thank you for your solution. But lightning==2.0.0 does not work on lagacy TPU runtime.

Taking all comments together, I give following solution:

%pip install torch==2.0.0 torchaudio torchdata torchtext torchvision cloud-tpu-client https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl "pytorch_lightning<2"

@carmocca Could you add this to docs?