Open vikigenius opened 11 months ago
I thought my issue might be related to https://github.com/Lightning-AI/lightning/issues/17936# so updated lightning to master.
But I get the same issue in master and 2.0.5
Also there are 8 TPU cores but it seems like only 5 processes are created and it gets stuck. Does TPU training use multiprocessing?
I don't get this issue when I set devices=1
in the trainer. But obviously that means I am not using all the TPU cores (and this leads to OOM issues)
I am able to reproduce this even with the following mnist example.
import pytorch_lightning as L
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
BATCH_SIZE = 1024
class MNISTDataModule(L.LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
self.dims = (1, 28, 28)
self.num_classes = 10
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(
self.data_dir, train=False, transform=self.transform
)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)
class LitModel(L.LightningModule):
def __init__(
self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4
):
super().__init__()
self.save_hyperparameters()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(channels * width * height, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, num_classes),
)
def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
# acc = accuracy(preds, y, task='multilabel', num_labels=10)
self.log("val_loss", loss, prog_bar=True)
# self.log("val_acc", acc, prog_bar=True)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer
# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.dims, dm.num_classes)
trainer = L.Trainer(
max_epochs=1,
accelerator="tpu",
devices=8,
)
# Train
trainer.fit(model, dm)
Which vm instance are you using?
What's the output of py-spy dump -n -p PID_OF_THE_HANGING_PROGRAM
? (requires pip install py-spy
)
@carmocca I am using Google Cloud's v3-8 with TPU VM architecture and tpu-vm-pt-2.0
software version
Here is the py-spy dump
Process 5408: python mnist.py
Python v3.8.17 (/home/void/miniconda3/envs/siamenc/bin/python3.8)
Thread 5408 (idle): "MainThread"
do_futex_wait.constprop.0 (libpthread-2.31.so)
__new_sem_wait_slow.constprop.0 (libpthread-2.31.so)
PyThread_acquire_lock_timed (python3.8)
lock_PyThread_acquire_lock (python3.8)
_wait_for_tstate_lock (threading.py:1027)
join (threading.py:1011)
shutdown (concurrent/futures/process.py:686)
__exit__ (concurrent/futures/_base.py:644)
_run_multiprocess (torch_xla/experimental/pjrt.py:322)
wrapper (torch_xla/experimental/pjrt.py:92)
spawn (torch_xla/experimental/pjrt.py:365)
spawn (torch_xla/distributed/xla_multiprocessing.py:386)
launch (pytorch_lightning/strategies/launchers/xla.py:75)
_call_and_handle_interrupt (pytorch_lightning/trainer/call.py:41)
fit (pytorch_lightning/trainer/trainer.py:529)
<module> (mnist.py:108)
Thread 6753 (idle): "QueueManagerThread"
poll (libc-2.31.so)
select_poll_poll_impl (selectmodule.c:634)
select_poll_poll (selectmodule.c.h:219)
select (selectors.py:415)
wait (multiprocessing/connection.py:931)
_queue_management_worker (concurrent/futures/process.py:362)
run (threading.py:870)
_bootstrap_inner (threading.py:932)
_bootstrap (threading.py:890)
clone (libc-2.31.so)
Thread 6754 (idle): "QueueFeederThread"
do_futex_wait.constprop.0 (libpthread-2.31.so)
__new_sem_wait_slow.constprop.0 (libpthread-2.31.so)
PyThread_acquire_lock_timed (python3.8)
lock_PyThread_acquire_lock (python3.8)
wait (threading.py:302)
_feed (multiprocessing/queues.py:227)
run (threading.py:870)
_bootstrap_inner (threading.py:932)
_bootstrap (threading.py:890)
clone (libc-2.31.so)
Thread 6802 (idle): "Thread-1"
accept4 (libc-2.31.so)
sock_accept_impl (socketmodule.c:2640)
sock_call_ex (socketmodule.c:935)
sock_accept (socketmodule.c:2682)
accept (socket.py:292)
accept (multiprocessing/connection.py:609)
accept (multiprocessing/connection.py:463)
_serve (multiprocessing/resource_sharer.py:142)
run (threading.py:870)
_bootstrap_inner (threading.py:932)
_bootstrap (threading.py:890)
clone (libc-2.31.so)
I spent the whole day on this issue. The reason appears to be spawned workers in the DataLoader. When I set num_workers=1
(or remove it), the training works fine. When I tried using parallel workers, the process hang. Maybe it happens because the main process spawns subprocesses that, in turn, try to spawn data loader workers?
Below is the output from py-spy
when using data loader workers.
@vikigenius I also noticed that you are using PJRT runtime, which is not fully supported by Lightning, as far as I understand. I was able to run the training using export XRT_TPU_CONFIG="localservice;0;localhost:51011"
.
Also, if you enable the PJRT runtime, the only way to disable it is to restart the node.
We've been working a lot on improving XLA support with the next release. Master contains complete support for PJRT
@carmocca Thank you for your reply! Unfortunately, the issues I described happened with the nightly version of Lightning installed from master
. I'm still running the experiments on the pod (v2-8), so if you need some logs/dumps, let me know.
@visheratin the mnist example I posted does not use dataloader workers at all, so it does not spawn any additional processes.
But you are right. After removing the PJRT_RUNTIME=TPU
, I could get the mnist example working.
@carmocca I can also confirm that even with the master version, setting PJRT_RUNTIME=TPU
causes the mnist example I posted to hang.
If you are experiencing issues on v2, unfortunately, I don't have access to try it (are you using colab?). Currently we test on v4 and I could also access v3 if there happens to be some capacity available.
PjRT is very new, I don't know how well-supported are v2 pods with it. Perhaps @Liyang90 or @gkroiz have more info here.
Either way, I think using XRT is an acceptable workaround with the older chips
PJRT should be supported for TPU v2, as suggested here: https://pytorch.org/xla/master/#pjrt-runtime-beta
I tried your mnist script on v4-8 and was able to get it working.
One thing I noticed is that you are importing import pytorch_lightning as L
. Afaik this is deprecated and the import should instead be import lightning.pytorch as L
. I had to make this change to get the mnist script working.
@gkroiz @carmocca It seems like in theory v2 and v3 should be supported in the PJRT runtime.
But in v2 and v3 PJRT causes each process to run multi-threaded. On v4 each process only runs single threaded. https://pytorch.org/xla/master/#multithreading-on-tpu-v2v3
I am guessing this is the issue probably because of GIL and some kind of non thread safe code in lightning causing issues because of mixing multiprocessing and multithreading .
@vikigenius You are probably right. I'll see if I can get access to v3 after my vacation to fix it. Ideally we'd also add one job at least to our testing matrix. In the meantime, contributions are welcome of course
any updates on this. i still cant use all the core in tpu.
This seems related to https://github.com/Lightning-AI/lightning/issues/17936. I was able to get some training working by using the nightly as suggested in the linked issue, but training is slow (~1 iterations/second, which is comparable or slower than an RTX A6000 GPU).
Bug description
The training code simply gets stuck on the TPU.
What version are you seeing the problem on?
master
How to reproduce the bug
Just used the following calls to trainer and fit.
I also set
export PJRT_DEVICE=TPU
before calling the trainer code from CLI.Error messages and logs
Environment
Current environment
* CUDA: - GPU: None - available: False - version: None * Lightning: - lightning: 2.1.0.dev0 - lightning-cloud: 0.5.37 - lightning-utilities: 0.9.0 - torch: 2.0.1 - torch-xla: 2.0 - torchmetrics: 0.11.4 * Packages: - absl-py: 1.4.0 - aim: 3.17.5 - aim-ui: 3.17.5 - aimrecords: 0.0.7 - aimrocks: 0.4.0 - aiofiles: 23.1.0 - aiohttp: 3.8.4 - aiosignal: 1.3.1 - alabaster: 0.7.13 - alembic: 1.11.1 - annotated-types: 0.5.0 - anyio: 3.7.1 - arger: 1.4.8 - arrow: 1.2.3 - async-timeout: 4.0.2 - attrs: 23.1.0 - babel: 2.12.1 - backoff: 2.2.1 - base58: 2.0.1 - beautifulsoup4: 4.12.2 - blessed: 1.20.0 - boto3: 1.28.4 - botocore: 1.31.4 - build: 0.10.0 - cachecontrol: 0.12.14 - cachetools: 5.3.1 - cattrs: 23.1.2 - certifi: 2023.5.7 - cffi: 1.15.1 - charset-normalizer: 3.2.0 - cleo: 2.0.1 - click: 8.1.5 - cloud-tpu-client: 0.10 - coverage: 7.2.7 - crashtest: 0.4.1 - croniter: 1.4.1 - cryptography: 41.0.2 - datasets: 2.13.1 - dateutils: 0.6.12 - deepdiff: 6.3.1 - dill: 0.3.6 - distlib: 0.3.7 - docstring-to-markdown: 0.12 - docutils: 0.20.1 - dparse: 0.6.3 - dulwich: 0.21.5 - exceptiongroup: 1.1.2 - execnet: 2.0.2 - fastapi: 0.100.0 - filelock: 3.12.2 - frozenlist: 1.4.0 - fsspec: 2023.6.0 - gmpy2: 2.1.2 - google-api-core: 1.16.0 - google-api-python-client: 1.8.0 - google-auth: 1.6.3 - google-auth-httplib2: 0.1.0 - googleapis-common-protos: 1.59.1 - greenlet: 2.0.2 - grpcio: 1.56.0 - h11: 0.14.0 - html5lib: 1.1 - httplib2: 0.22.0 - huggingface-hub: 0.16.4 - idna: 3.4 - imagesize: 1.4.1 - importlib-metadata: 6.8.0 - importlib-resources: 6.0.0 - iniconfig: 2.0.0 - inquirer: 3.1.3 - installer: 0.7.0 - itsdangerous: 2.1.2 - jaraco.classes: 3.3.0 - jedi: 0.18.2 - jeepney: 0.8.0 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.3.1 - jsonschema: 4.18.4 - jsonschema-specifications: 2023.7.1 - keyring: 23.13.1 - lightning: 2.1.0.dev0 - lightning-cloud: 0.5.37 - lightning-utilities: 0.9.0 - lockfile: 0.12.2 - lsprotocol: 2023.0.0a2 - m2r2: 0.3.3.post2 - mako: 1.2.4 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - mdurl: 0.1.2 - mistune: 0.8.4 - monotonic: 1.6 - mpmath: 1.3.0 - msgpack: 1.0.5 - multidict: 6.0.4 - multiprocess: 0.70.14 - mypy: 1.4.1 - mypy-extensions: 1.0.0 - networkx: 3.1 - nltk: 3.8.1 - numpy: 1.24.4 - oauth2client: 4.1.3 - ordered-set: 4.1.0 - packaging: 23.1 - pandas: 2.0.3 - parso: 0.8.3 - pexpect: 4.8.0 - pillow: 10.0.0 - pip: 23.2 - pkginfo: 1.9.6 - pkgutil-resolve-name: 1.3.10 - platformdirs: 3.9.1 - pluggy: 1.2.0 - poetry-core: 1.6.1 - poetry-plugin-export: 1.4.0 - pprintpp: 0.4.0 - protobuf: 4.23.4 - psutil: 5.9.5 - ptyprocess: 0.7.0 - py3nvml: 0.2.7 - pyarrow: 12.0.1 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pycparser: 2.21 - pydantic: 2.0.3 - pydantic-core: 2.3.0 - pygls: 1.0.2 - pygments: 2.15.1 - pyjwt: 2.8.0 - pyparsing: 3.1.0 - pyproject-hooks: 1.0.0 - pytest: 7.4.0 - pytest-clarity: 1.0.1 - pytest-cov: 4.1.0 - pytest-randomly: 3.13.0 - pytest-sugar: 0.9.7 - pytest-xdist: 3.3.1 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-lsp-jsonrpc: 1.0.0 - python-lsp-server: 1.7.4 - python-multipart: 0.0.6 - pytz: 2023.3 - pyyaml: 6.0.1 - rapidfuzz: 2.15.1 - readchar: 4.0.5 - referencing: 0.30.0 - regex: 2023.6.3 - requests: 2.31.0 - requests-toolbelt: 1.0.0 - restrictedpython: 6.1 - rich: 13.4.2 - rpds-py: 0.9.2 - rsa: 4.9 - ruamel.yaml: 0.17.32 - ruamel.yaml.clib: 0.2.7 - ruff: 0.0.278 - ruff-lsp: 0.0.35 - s3transfer: 0.6.1 - safetensors: 0.3.1 - safety: 2.3.5 - secretstorage: 3.3.3 - segment-analytics-python: 2.2.3 - setuptools: 68.0.0 - shellingham: 1.5.0.post1export PJRT_DEVICE=TPU - siamenc: 2.0.0 - six: 1.16.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - soupsieve: 2.4.1 - sphinx: 7.0.1 - sphinx-autodoc-typehints: 1.23.3 - sphinxcontrib-applehelp: 1.0.4 - sphinxcontrib-devhelp: 1.0.2 - sphinxcontrib-htmlhelp: 2.0.1 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 1.0.3 - sphinxcontrib-serializinghtml: 1.1.5 - sqlalchemy: 1.4.49 - starlette: 0.27.0 - starsessions: 1.3.0 - sympy: 1.12 - termcolor: 2.3.0 - tokenizers: 0.13.3 - tomli: 2.0.1 - tomlkit: 0.11.8 - torch: 2.0.1 - torch-xla: 2.0 - torchmetrics: 0.11.4 - tqdm: 4.65.0 - traitlets: 5.9.0 - transformers: 4.30.2 - trove-classifiers: 2023.7.6 - typeguard: 3.0.2 - typing-extensions: 4.7.1 - tzdata: 2023.3 - ujson: 5.8.0 - uritemplate: 3.0.1 - urllib3: 1.26.16 - uvicorn: 0.23.1 - virtualenv: 20.24.0 - wcwidth: 0.2.6 - webencodings: 0.5.1 - websocket-client: 1.6.1 - websockets: 11.0.3 - wheel: 0.38.4 - xmltodict: 0.13.0 - xxhash: 3.2.0 - yarl: 1.9.2 - zipp: 3.16.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.8.17 - release: 5.13.0-1027-gcp - version: #32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022More info
The trainer simply gets stuck after the message
WARNING:root:Unsupported nprocs (8), ignoring...
Pressing Ctrl-C leads to
Messages like Process ForkProcess-4 Process ForkProcess-5 Process ForkProcess-2 Process ForkProcess-3
and it seems to be getting stuck somehwere
etc.
cc @carmocca @JackCaoG @steventk-g @Liyang90