Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.48k stars 3.3k forks source link

Unable to train on v3-8 TPUs with lightning. Training is stuck/deadlocked ? #18131

Open vikigenius opened 11 months ago

vikigenius commented 11 months ago

Bug description

The training code simply gets stuck on the TPU.

What version are you seeing the problem on?

master

How to reproduce the bug

Just used the following calls to trainer and fit.

    pl.seed_everything(7, workers=True)
    torch.set_float32_matmul_precision("high")
    model = SiameseEncoder(model_name_or_path)
    datamodule = RetrievalDataModule(
        model_name_or_path,
        {"train": str(train_path), "val": str(val_path)},
        {"train": train_batch_size, "val": val_batch_size},
        padding_style,
        workers=workers,
    )
    monitor = "rec@1"

    # TODO Add clipping, control validation intervals etc. Lots of work to be done
    trainer = pl.Trainer(
        logger=AimLogger(experiment="SiameseEncoder"),
        accelerator=accelerator,
        devices=devices,
        deterministic=True,
        max_epochs=max_epochs,
        val_check_interval=0.1,
        gradient_clip_val=1,
        precision="16-mixed",
        callbacks=[
            EarlyStopping(monitor=monitor, mode="max", patience=10),
            ModelCheckpoint(monitor=monitor, mode="max", save_top_k=1),
        ],
    )
    trainer.fit(model, datamodule=datamodule)

I also set export PJRT_DEVICE=TPU before calling the trainer code from CLI.

Error messages and logs

Global seed set to 7
Some weights of the model checkpoint at sentence-transformers/multi-qa-mpnet-base-cos-v1 were not used when initializing MPNetModel: ['pooler.dense.weight', 'pooler.dense.b
ias']
- This IS expected if you are initializing MPNetModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequen
ceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MPNetModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassificati
on model from a BertForSequenceClassification model).
INFO:torch_xla:Letting libtpu.so load fail during _XLAC import. libtpu.so will be loaded from `libtpu` Python package when the ComputationClient is created.
INFO:torch_xla:Using bundled libtpu.so (/home/void/miniconda3/envs/siamenc/lib/python3.8/site-packages/torch_xla/lib/libtpu.so)
/home/void/miniconda3/envs/siamenc/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:487: UserWarning: You passed `Trainer(accelerat
or='tpu', precision='16-mixed')` but AMP with fp16 is not supported on TPUs. Using `precision='bf16-mixed'` instead.
  rank_zero_warn(
GPU available: False, used: False
TPU available: True, using: 8 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
WARNING:root:Unsupported nprocs (8), ignoring...

Environment

Current environment * CUDA: - GPU: None - available: False - version: None * Lightning: - lightning: 2.1.0.dev0 - lightning-cloud: 0.5.37 - lightning-utilities: 0.9.0 - torch: 2.0.1 - torch-xla: 2.0 - torchmetrics: 0.11.4 * Packages: - absl-py: 1.4.0 - aim: 3.17.5 - aim-ui: 3.17.5 - aimrecords: 0.0.7 - aimrocks: 0.4.0 - aiofiles: 23.1.0 - aiohttp: 3.8.4 - aiosignal: 1.3.1 - alabaster: 0.7.13 - alembic: 1.11.1 - annotated-types: 0.5.0 - anyio: 3.7.1 - arger: 1.4.8 - arrow: 1.2.3 - async-timeout: 4.0.2 - attrs: 23.1.0 - babel: 2.12.1 - backoff: 2.2.1 - base58: 2.0.1 - beautifulsoup4: 4.12.2 - blessed: 1.20.0 - boto3: 1.28.4 - botocore: 1.31.4 - build: 0.10.0 - cachecontrol: 0.12.14 - cachetools: 5.3.1 - cattrs: 23.1.2 - certifi: 2023.5.7 - cffi: 1.15.1 - charset-normalizer: 3.2.0 - cleo: 2.0.1 - click: 8.1.5 - cloud-tpu-client: 0.10 - coverage: 7.2.7 - crashtest: 0.4.1 - croniter: 1.4.1 - cryptography: 41.0.2 - datasets: 2.13.1 - dateutils: 0.6.12 - deepdiff: 6.3.1 - dill: 0.3.6 - distlib: 0.3.7 - docstring-to-markdown: 0.12 - docutils: 0.20.1 - dparse: 0.6.3 - dulwich: 0.21.5 - exceptiongroup: 1.1.2 - execnet: 2.0.2 - fastapi: 0.100.0 - filelock: 3.12.2 - frozenlist: 1.4.0 - fsspec: 2023.6.0 - gmpy2: 2.1.2 - google-api-core: 1.16.0 - google-api-python-client: 1.8.0 - google-auth: 1.6.3 - google-auth-httplib2: 0.1.0 - googleapis-common-protos: 1.59.1 - greenlet: 2.0.2 - grpcio: 1.56.0 - h11: 0.14.0 - html5lib: 1.1 - httplib2: 0.22.0 - huggingface-hub: 0.16.4 - idna: 3.4 - imagesize: 1.4.1 - importlib-metadata: 6.8.0 - importlib-resources: 6.0.0 - iniconfig: 2.0.0 - inquirer: 3.1.3 - installer: 0.7.0 - itsdangerous: 2.1.2 - jaraco.classes: 3.3.0 - jedi: 0.18.2 - jeepney: 0.8.0 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.3.1 - jsonschema: 4.18.4 - jsonschema-specifications: 2023.7.1 - keyring: 23.13.1 - lightning: 2.1.0.dev0 - lightning-cloud: 0.5.37 - lightning-utilities: 0.9.0 - lockfile: 0.12.2 - lsprotocol: 2023.0.0a2 - m2r2: 0.3.3.post2 - mako: 1.2.4 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - mdurl: 0.1.2 - mistune: 0.8.4 - monotonic: 1.6 - mpmath: 1.3.0 - msgpack: 1.0.5 - multidict: 6.0.4 - multiprocess: 0.70.14 - mypy: 1.4.1 - mypy-extensions: 1.0.0 - networkx: 3.1 - nltk: 3.8.1 - numpy: 1.24.4 - oauth2client: 4.1.3 - ordered-set: 4.1.0 - packaging: 23.1 - pandas: 2.0.3 - parso: 0.8.3 - pexpect: 4.8.0 - pillow: 10.0.0 - pip: 23.2 - pkginfo: 1.9.6 - pkgutil-resolve-name: 1.3.10 - platformdirs: 3.9.1 - pluggy: 1.2.0 - poetry-core: 1.6.1 - poetry-plugin-export: 1.4.0 - pprintpp: 0.4.0 - protobuf: 4.23.4 - psutil: 5.9.5 - ptyprocess: 0.7.0 - py3nvml: 0.2.7 - pyarrow: 12.0.1 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pycparser: 2.21 - pydantic: 2.0.3 - pydantic-core: 2.3.0 - pygls: 1.0.2 - pygments: 2.15.1 - pyjwt: 2.8.0 - pyparsing: 3.1.0 - pyproject-hooks: 1.0.0 - pytest: 7.4.0 - pytest-clarity: 1.0.1 - pytest-cov: 4.1.0 - pytest-randomly: 3.13.0 - pytest-sugar: 0.9.7 - pytest-xdist: 3.3.1 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-lsp-jsonrpc: 1.0.0 - python-lsp-server: 1.7.4 - python-multipart: 0.0.6 - pytz: 2023.3 - pyyaml: 6.0.1 - rapidfuzz: 2.15.1 - readchar: 4.0.5 - referencing: 0.30.0 - regex: 2023.6.3 - requests: 2.31.0 - requests-toolbelt: 1.0.0 - restrictedpython: 6.1 - rich: 13.4.2 - rpds-py: 0.9.2 - rsa: 4.9 - ruamel.yaml: 0.17.32 - ruamel.yaml.clib: 0.2.7 - ruff: 0.0.278 - ruff-lsp: 0.0.35 - s3transfer: 0.6.1 - safetensors: 0.3.1 - safety: 2.3.5 - secretstorage: 3.3.3 - segment-analytics-python: 2.2.3 - setuptools: 68.0.0 - shellingham: 1.5.0.post1export PJRT_DEVICE=TPU - siamenc: 2.0.0 - six: 1.16.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - soupsieve: 2.4.1 - sphinx: 7.0.1 - sphinx-autodoc-typehints: 1.23.3 - sphinxcontrib-applehelp: 1.0.4 - sphinxcontrib-devhelp: 1.0.2 - sphinxcontrib-htmlhelp: 2.0.1 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 1.0.3 - sphinxcontrib-serializinghtml: 1.1.5 - sqlalchemy: 1.4.49 - starlette: 0.27.0 - starsessions: 1.3.0 - sympy: 1.12 - termcolor: 2.3.0 - tokenizers: 0.13.3 - tomli: 2.0.1 - tomlkit: 0.11.8 - torch: 2.0.1 - torch-xla: 2.0 - torchmetrics: 0.11.4 - tqdm: 4.65.0 - traitlets: 5.9.0 - transformers: 4.30.2 - trove-classifiers: 2023.7.6 - typeguard: 3.0.2 - typing-extensions: 4.7.1 - tzdata: 2023.3 - ujson: 5.8.0 - uritemplate: 3.0.1 - urllib3: 1.26.16 - uvicorn: 0.23.1 - virtualenv: 20.24.0 - wcwidth: 0.2.6 - webencodings: 0.5.1 - websocket-client: 1.6.1 - websockets: 11.0.3 - wheel: 0.38.4 - xmltodict: 0.13.0 - xxhash: 3.2.0 - yarl: 1.9.2 - zipp: 3.16.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.8.17 - release: 5.13.0-1027-gcp - version: #32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022

More info

The trainer simply gets stuck after the message

WARNING:root:Unsupported nprocs (8), ignoring...

Pressing Ctrl-C leads to

Messages like Process ForkProcess-4 Process ForkProcess-5 Process ForkProcess-2 Process ForkProcess-3

and it seems to be getting stuck somehwere

  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
    call_item = call_queue.get(block=True)
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
    call_item = call_queue.get(block=True)
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
    call_item = call_queue.get(block=True)
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/queues.py", line 96, in get
    with self._rlock:
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/queues.py", line 96, in get
    with self._rlock:
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/queues.py", line 96, in get
    with self._rlock:
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
Traceback (most recent call last):
KeyboardInterrupt
KeyboardInterrupt
KeyboardInterrupt
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
    call_item = call_queue.get(block=True)
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/queues.py", line 97, in get
    res = self._recv_bytes()
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/connection.py", line 414, in _recv_bytes
    buf = self._recv(4)
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
KeyboardInterrupt

etc.

cc @carmocca @JackCaoG @steventk-g @Liyang90

vikigenius commented 11 months ago

I thought my issue might be related to https://github.com/Lightning-AI/lightning/issues/17936# so updated lightning to master.

But I get the same issue in master and 2.0.5

Also there are 8 TPU cores but it seems like only 5 processes are created and it gets stuck. Does TPU training use multiprocessing?

vikigenius commented 11 months ago

I don't get this issue when I set devices=1 in the trainer. But obviously that means I am not using all the TPU cores (and this leads to OOM issues)

vikigenius commented 11 months ago

I am able to reproduce this even with the following mnist example.

import pytorch_lightning as L
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST

BATCH_SIZE = 1024

class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(
                self.data_dir, train=False, transform=self.transform
            )

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

class LitModel(L.LightningModule):
    def __init__(
        self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4
    ):
        super().__init__()

        self.save_hyperparameters()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        # acc = accuracy(preds, y, task='multilabel', num_labels=10)
        self.log("val_loss", loss, prog_bar=True)
        # self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.dims, dm.num_classes)

trainer = L.Trainer(
    max_epochs=1,
    accelerator="tpu",
    devices=8,
)
# Train
trainer.fit(model, dm)
carmocca commented 11 months ago

Which vm instance are you using?

What's the output of py-spy dump -n -p PID_OF_THE_HANGING_PROGRAM? (requires pip install py-spy)

vikigenius commented 11 months ago

@carmocca I am using Google Cloud's v3-8 with TPU VM architecture and tpu-vm-pt-2.0 software version

Here is the py-spy dump

Process 5408: python mnist.py
Python v3.8.17 (/home/void/miniconda3/envs/siamenc/bin/python3.8)

Thread 5408 (idle): "MainThread"
    do_futex_wait.constprop.0 (libpthread-2.31.so)
    __new_sem_wait_slow.constprop.0 (libpthread-2.31.so)
    PyThread_acquire_lock_timed (python3.8)
    lock_PyThread_acquire_lock (python3.8)
    _wait_for_tstate_lock (threading.py:1027)
    join (threading.py:1011)
    shutdown (concurrent/futures/process.py:686)
    __exit__ (concurrent/futures/_base.py:644)
    _run_multiprocess (torch_xla/experimental/pjrt.py:322)
    wrapper (torch_xla/experimental/pjrt.py:92)
    spawn (torch_xla/experimental/pjrt.py:365)
    spawn (torch_xla/distributed/xla_multiprocessing.py:386)
    launch (pytorch_lightning/strategies/launchers/xla.py:75)
    _call_and_handle_interrupt (pytorch_lightning/trainer/call.py:41)
    fit (pytorch_lightning/trainer/trainer.py:529)
    <module> (mnist.py:108)
Thread 6753 (idle): "QueueManagerThread"
    poll (libc-2.31.so)
    select_poll_poll_impl (selectmodule.c:634)
    select_poll_poll (selectmodule.c.h:219)
    select (selectors.py:415)
    wait (multiprocessing/connection.py:931)
    _queue_management_worker (concurrent/futures/process.py:362)
    run (threading.py:870)
    _bootstrap_inner (threading.py:932)
    _bootstrap (threading.py:890)
    clone (libc-2.31.so)
Thread 6754 (idle): "QueueFeederThread"
    do_futex_wait.constprop.0 (libpthread-2.31.so)
    __new_sem_wait_slow.constprop.0 (libpthread-2.31.so)
    PyThread_acquire_lock_timed (python3.8)
    lock_PyThread_acquire_lock (python3.8)
    wait (threading.py:302)
    _feed (multiprocessing/queues.py:227)
    run (threading.py:870)
    _bootstrap_inner (threading.py:932)
    _bootstrap (threading.py:890)
    clone (libc-2.31.so)
Thread 6802 (idle): "Thread-1"
    accept4 (libc-2.31.so)
    sock_accept_impl (socketmodule.c:2640)
    sock_call_ex (socketmodule.c:935)
    sock_accept (socketmodule.c:2682)
    accept (socket.py:292)
    accept (multiprocessing/connection.py:609)
    accept (multiprocessing/connection.py:463)
    _serve (multiprocessing/resource_sharer.py:142)
    run (threading.py:870)
    _bootstrap_inner (threading.py:932)
    _bootstrap (threading.py:890)
    clone (libc-2.31.so)
visheratin commented 11 months ago

I spent the whole day on this issue. The reason appears to be spawned workers in the DataLoader. When I set num_workers=1 (or remove it), the training works fine. When I tried using parallel workers, the process hang. Maybe it happens because the main process spawns subprocesses that, in turn, try to spawn data loader workers?

Below is the output from py-spy when using data loader workers.

py-spy output ``` Thread 42377 (idle): "MainThread" 0x7f8a339a8cab (libgomp-a34b3233.so.1) 0x7f8a339a77e9 (libgomp-a34b3233.so.1) at::native::randperm_out_cpu (libtorch_cpu.so) at::(anonymous namespace)::(anonymous namespace)::wrapper_CPU_generator_out_randperm_out (libtorch_cpu.so) c10::impl::wrap_kernel_functor_unboxed_, at::Tensor&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CPU_generator_out_randperm_out(long, c10::optional, at::Tensor&)>, at::Tensor&, c10::guts::typelist::typelist, at::Tensor&> >, at::Tensor& (long, c10::optional, at::Tensor&)>::call (libtorch_cpu.so) at::_ops::randperm_generator_out::call (libtorch_cpu.so) at::native::randperm (libtorch_cpu.so) at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd_generator_randperm (libtorch_cpu.so) c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd_generator_randperm(long, c10::optional, c10::optional, c10::optional, c10::optional, c10::optional)>, at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor(long, c10::optional, c10::optional, c10::optional, c10::optional, c10::optional)>::call (libtorch_cpu.so) at::_ops::randperm_generator::redispatch (libtorch_cpu.so) c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), &at::(anonymous namespace)::randperm_generator(long, c10::optional, c10::optional, c10::optional, c10::optional, c10::optional)>, at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor(long, c10::optional, c10::optional, c10::optional, c10::optional, c10::optional)>::call (libtorch_cpu.so) at::_ops::randperm_generator::call (libtorch_cpu.so) torch::autograd::THPVariable_randperm (libtorch_python.so) __iter__ (torch/utils/data/sampler.py:132) __iter__ (torch/utils/data/sampler.py:254) _next_index (torch/utils/data/dataloader.py:623) _try_put_index (torch/utils/data/dataloader.py:1351) _reset (torch/utils/data/dataloader.py:1117) __init__ (torch/utils/data/dataloader.py:1084) _get_iterator (torch/utils/data/dataloader.py:388) __iter__ (torch/utils/data/dataloader.py:441) _check_dataloader_iterable (lightning/pytorch/trainer/connectors/data_connector.py:391) setup_data (lightning/pytorch/loops/fit_loop.py:236) run (lightning/pytorch/loops/fit_loop.py:193) _run_stage (lightning/pytorch/trainer/trainer.py:1025) _run (lightning/pytorch/trainer/trainer.py:982) _fit_impl (lightning/pytorch/trainer/trainer.py:577) _wrapping_function (lightning/pytorch/strategies/launchers/xla.py:130) _start_fn (torch_xla/distributed/xla_multiprocessing.py:328) _mp_start_fn (torch_xla/distributed/xla_multiprocessing.py:334) _wrap (torch/multiprocessing/spawn.py:69) run (multiprocessing/process.py:108) _bootstrap (multiprocessing/process.py:315) _launch (multiprocessing/popen_fork.py:75) __init__ (multiprocessing/popen_fork.py:19) _Popen (multiprocessing/context.py:277) start (multiprocessing/process.py:121) start_processes (torch/multiprocessing/spawn.py:188) spawn (torch_xla/distributed/xla_multiprocessing.py:397) launch (lightning/pytorch/strategies/launchers/xla.py:88) _call_and_handle_interrupt (lightning/pytorch/trainer/call.py:41) fit (lightning/pytorch/trainer/trainer.py:538) (main.py:137) Thread 45753 (idle): "Thread-1" do_futex_wait.constprop.0 (libpthread-2.31.so) __new_sem_wait_slow.constprop.0 (libpthread-2.31.so) PyThread_acquire_lock_timed (python3.8) wait (threading.py:306) get (queue.py:179) _run (tensorboard/summary/writer/event_file_writer.py:269) run (tensorboard/summary/writer/event_file_writer.py:244) _bootstrap_inner (threading.py:932) _bootstrap (threading.py:890) clone (libc-2.31.so) Thread 45964 (idle): "Thread-2" do_futex_wait.constprop.0 (libpthread-2.31.so) __new_sem_wait_slow.constprop.0 (libpthread-2.31.so) PyThread_acquire_lock_timed (python3.8) wait (threading.py:306) wait (threading.py:558) run (tqdm/_monitor.py:60) _bootstrap_inner (threading.py:932) _bootstrap (threading.py:890) clone (libc-2.31.so) Thread 59772 (idle): "Thread-3" do_futex_wait.constprop.0 (libpthread-2.31.so) __new_sem_wait_slow.constprop.0 (libpthread-2.31.so) PyThread_acquire_lock_timed (python3.8) wait (threading.py:302) put (torch_xla/utils/keyd_queue.py:72) _loader_worker (torch_xla/distributed/parallel_loader.py:146) run (threading.py:870) _bootstrap_inner (threading.py:932) _bootstrap (threading.py:890) clone (libc-2.31.so) Thread 59778 (idle): "Thread-4" do_futex_wait.constprop.0 (libpthread-2.31.so) __new_sem_wait_slow.constprop.0 (libpthread-2.31.so) PyThread_acquire_lock_timed (python3.8) wait (threading.py:302) put (torch_xla/utils/keyd_queue.py:72) _worker (torch_xla/distributed/parallel_loader.py:168) run (threading.py:870) _bootstrap_inner (threading.py:932) _bootstrap (threading.py:890) clone (libc-2.31.so) Thread 65684 (idle): "QueueFeederThread" do_futex_wait.constprop.0 (libpthread-2.31.so) __new_sem_wait_slow.constprop.0 (libpthread-2.31.so) PyThread_acquire_lock_timed (python3.8) wait (threading.py:302) _feed (multiprocessing/queues.py:227) run (threading.py:870) _bootstrap_inner (threading.py:932) _bootstrap (threading.py:890) clone (libc-2.31.so) Thread 65685 (idle): "QueueFeederThread" do_futex_wait.constprop.0 (libpthread-2.31.so) __new_sem_wait_slow.constprop.0 (libpthread-2.31.so) PyThread_acquire_lock_timed (python3.8) wait (threading.py:302) _feed (multiprocessing/queues.py:227) run (threading.py:870) _bootstrap_inner (threading.py:932) _bootstrap (threading.py:890) clone (libc-2.31.so) Thread 65686 (idle): "QueueFeederThread" do_futex_wait.constprop.0 (libpthread-2.31.so) __new_sem_wait_slow.constprop.0 (libpthread-2.31.so) PyThread_acquire_lock_timed (python3.8) wait (threading.py:302) _feed (multiprocessing/queues.py:227) run (threading.py:870) _bootstrap_inner (threading.py:932) _bootstrap (threading.py:890) clone (libc-2.31.so) Thread 65689 (idle): "QueueFeederThread" do_futex_wait.constprop.0 (libpthread-2.31.so) __new_sem_wait_slow.constprop.0 (libpthread-2.31.so) PyThread_acquire_lock_timed (python3.8) wait (threading.py:302) _feed (multiprocessing/queues.py:227) run (threading.py:870) _bootstrap_inner (threading.py:932) _bootstrap (threading.py:890) clone (libc-2.31.so) Thread 65705 (idle): "QueueFeederThread" do_futex_wait.constprop.0 (libpthread-2.31.so) __new_sem_wait_slow.constprop.0 (libpthread-2.31.so) PyThread_acquire_lock_timed (python3.8) wait (threading.py:302) _feed (multiprocessing/queues.py:227) run (threading.py:870) _bootstrap_inner (threading.py:932) _bootstrap (threading.py:890) clone (libc-2.31.so) ```
visheratin commented 11 months ago

@vikigenius I also noticed that you are using PJRT runtime, which is not fully supported by Lightning, as far as I understand. I was able to run the training using export XRT_TPU_CONFIG="localservice;0;localhost:51011".

Also, if you enable the PJRT runtime, the only way to disable it is to restart the node.

carmocca commented 11 months ago

We've been working a lot on improving XLA support with the next release. Master contains complete support for PJRT

visheratin commented 11 months ago

@carmocca Thank you for your reply! Unfortunately, the issues I described happened with the nightly version of Lightning installed from master. I'm still running the experiments on the pod (v2-8), so if you need some logs/dumps, let me know.

vikigenius commented 11 months ago

@visheratin the mnist example I posted does not use dataloader workers at all, so it does not spawn any additional processes.

But you are right. After removing the PJRT_RUNTIME=TPU, I could get the mnist example working.

@carmocca I can also confirm that even with the master version, setting PJRT_RUNTIME=TPU causes the mnist example I posted to hang.

carmocca commented 11 months ago

If you are experiencing issues on v2, unfortunately, I don't have access to try it (are you using colab?). Currently we test on v4 and I could also access v3 if there happens to be some capacity available.

PjRT is very new, I don't know how well-supported are v2 pods with it. Perhaps @Liyang90 or @gkroiz have more info here.

Either way, I think using XRT is an acceptable workaround with the older chips

gkroiz commented 11 months ago

PJRT should be supported for TPU v2, as suggested here: https://pytorch.org/xla/master/#pjrt-runtime-beta

I tried your mnist script on v4-8 and was able to get it working.

One thing I noticed is that you are importing import pytorch_lightning as L. Afaik this is deprecated and the import should instead be import lightning.pytorch as L. I had to make this change to get the mnist script working.

vikigenius commented 11 months ago

@gkroiz @carmocca It seems like in theory v2 and v3 should be supported in the PJRT runtime.

But in v2 and v3 PJRT causes each process to run multi-threaded. On v4 each process only runs single threaded. https://pytorch.org/xla/master/#multithreading-on-tpu-v2v3

I am guessing this is the issue probably because of GIL and some kind of non thread safe code in lightning causing issues because of mixing multiprocessing and multithreading .

carmocca commented 11 months ago

@vikigenius You are probably right. I'll see if I can get access to v3 after my vacation to fix it. Ideally we'd also add one job at least to our testing matrix. In the meantime, contributions are welcome of course

carlesoctav commented 9 months ago

any updates on this. i still cant use all the core in tpu.

jaketae commented 9 months ago

This seems related to https://github.com/Lightning-AI/lightning/issues/17936. I was able to get some training working by using the nightly as suggested in the linked issue, but training is slow (~1 iterations/second, which is comparable or slower than an RTX A6000 GPU).