Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.4k stars 3.38k forks source link

Cannot load checkpoint when using multithreaded distributed training #17665

Closed Quasar-Kim closed 1 year ago

Quasar-Kim commented 1 year ago

Bug description

I'm experimenting with XLA PJRT runtime using nightly version (currently 2.1.0.dev0). I tried to load a checkpoint by calling LightningModule's load_from_checkpoint() method, but it hangs indefinitely.

I furtuer investigated this behavior and found that following line from pl_legacy_patch.__exit__() is a culprit:

del sys.modules["lightning.pytorch.utilities.argparse_utils"]

The problem with this line is that when using XLA PJRT runtime on TPU v2/v3, xmp.spawn() actually spawns 4 process and 2 threads on each of them. Because two threads share same sys.modules object, one thread executes the line ahead, causing the other thread fail silently. This bug seems to affect all multithreading-based strategies, so this needs to be addressed to support such ones.

What version are you seeing the problem on?

master

How to reproduce the bug

import threading

from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel, BoringDataModule

def create_checkpoint():
    trainer = Trainer(max_steps=1)
    trainer.fit(BoringModel(), BoringDataModule())
    trainer.save_checkpoint('checkpoint.ckpt')

def load_model():
    model = BoringModel.load_from_checkpoint('checkpoint.ckpt')

if __name__ == '__main__':
    create_checkpoint()

    t1 = threading.Thread(target=load_model)
    t2 = threading.Thread(target=load_model)

    t1.start()
    t2.start()

    t1.join()
    t2.join()

output:

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | layer | Linear | 66    
---------------------------------
66        Trainable params
0         Non-trainable params
66        Total params
0.000     Total estimated model params size (MB)
Sanity Checking: 0it [00:00, ?it/s]/home/quasarkim/lightning-contrib/src/lightning/pytorch/trainer/connectors/data_connector.py:435: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/quasarkim/lightning-contrib/src/lightning/pytorch/trainer/connectors/data_connector.py:435: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Epoch 0:   2%|█▋                                                                                                        | 1/64 [00:00<00:00, 149.41it/s, v_num=34]`Trainer.fit` stopped: `max_steps=1` reached.
Epoch 0:   2%|█▋                                                                                                         | 1/64 [00:00<00:00, 72.28it/s, v_num=34]
Exception in thread Thread-4 (load_model):
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/home/quasarkim/lightning-contrib/reproduce.py", line 12, in load_model
    model = BoringModel.load_from_checkpoint('checkpoint.ckpt')
  File "/home/quasarkim/lightning-contrib/src/lightning/pytorch/core/module.py", line 1515, in load_from_checkpoint
    loaded = _load_from_checkpoint(
  File "/home/quasarkim/lightning-contrib/src/lightning/pytorch/core/saving.py", line 60, in _load_from_checkpoint
    with pl_legacy_patch():
  File "/home/quasarkim/lightning-contrib/src/lightning/pytorch/utilities/migration/utils.py", line 110, in __exit__
    del sys.modules["lightning.pytorch.utilities.argparse_utils"]
KeyError: 'lightning.pytorch.utilities.argparse_utils'

Error messages and logs

No response

Environment

Current environment * CUDA: - GPU: None - available: False - version: 11.7 * Lightning: - lightning: 2.1.0.dev0 - lightning-api-access: 0.0.5 - lightning-cloud: 0.5.36 - lightning-fabric: 2.0.2 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.2 - torch: 2.0.0 - torch-xla: 2.0 - torchmetrics: 0.11.4 * Packages: - absl-py: 1.4.0 - aiobotocore: 2.4.2 - aiohttp: 3.8.4 - aioitertools: 0.11.0 - aiosignal: 1.3.1 - altair: 4.2.2 - antlr4-python3-runtime: 4.9.3 - anyio: 3.6.2 - arrow: 1.2.3 - asttokens: 2.2.1 - async-timeout: 4.0.2 - attrs: 23.1.0 - backcall: 0.2.0 - backports.zoneinfo: 0.2.1 - beautifulsoup4: 4.12.2 - bleach: 6.0.0 - blessed: 1.20.0 - blinker: 1.6.2 - bokeh: 2.4.3 - botocore: 1.27.59 - cachetools: 5.3.0 - certifi: 2023.5.7 - charset-normalizer: 3.1.0 - click: 8.1.3 - cloud-tpu-client: 0.10 - cmake: 3.26.3 - comm: 0.1.3 - contourpy: 1.0.7 - croniter: 1.3.14 - cycler: 0.11.0 - dateutils: 0.6.12 - debugpy: 1.6.7 - decorator: 5.1.1 - deepdiff: 6.3.0 - docker: 6.1.2 - docstring-parser: 0.15 - entrypoints: 0.4 - executing: 1.2.0 - fastapi: 0.88.0 - filelock: 3.12.0 - fonttools: 4.39.4 - frozenlist: 1.3.3 - fsspec: 2022.11.0 - gitdb: 4.0.10 - gitpython: 3.1.31 - google-api-core: 1.34.0 - google-api-python-client: 1.8.0 - google-auth: 2.17.3 - google-auth-httplib2: 0.1.0 - googleapis-common-protos: 1.59.0 - h11: 0.14.0 - httplib2: 0.22.0 - hydra-core: 1.3.2 - idna: 3.4 - importlib-metadata: 6.6.0 - importlib-resources: 5.12.0 - inquirer: 3.1.3 - intel-openmp: 2023.1.0 - ipykernel: 6.23.0 - ipython: 8.12.2 - itsdangerous: 2.1.2 - jax: 0.4.10 - jaxlib: 0.4.10 - jedi: 0.18.2 - jinja2: 3.1.2 - jmespath: 1.0.1 - jsonargparse: 4.21.1 - jsonschema: 4.17.3 - jupyter-client: 8.2.0 - jupyter-core: 5.3.0 - kiwisolver: 1.4.4 - lightning: 2.1.0.dev0 - lightning-api-access: 0.0.5 - lightning-cloud: 0.5.36 - lightning-fabric: 2.0.2 - lightning-utilities: 0.8.0 - lit: 16.0.3 - markdown: 3.4.3 - markdown-it-py: 2.2.0 - markupsafe: 2.1.2 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mkl: 2023.1.0 - ml-dtypes: 0.1.0 - mpmath: 1.3.0 - multidict: 6.0.4 - nest-asyncio: 1.5.6 - networkx: 3.1 - numpy: 1.24.3 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.2.10.91 - nvidia-cusolver-cu11: 11.4.0.1 - nvidia-cusparse-cu11: 11.7.4.91 - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - oauth2client: 4.1.3 - omegaconf: 2.3.0 - opt-einsum: 3.3.0 - ordered-set: 4.1.0 - packaging: 23.1 - pandas: 2.0.1 - panel: 0.14.4 - param: 1.13.0 - parso: 0.8.3 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.5.0 - pip: 23.1 - pkgutil-resolve-name: 1.3.10 - platformdirs: 3.5.1 - prompt-toolkit: 3.0.38 - protobuf: 3.20.3 - psutil: 5.9.5 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pyarrow: 12.0.0 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pyct: 0.5.0 - pydantic: 1.10.7 - pydeck: 0.8.0 - pygments: 2.15.1 - pyjwt: 2.7.0 - pympler: 1.0.1 - pyparsing: 3.0.9 - pyrsistent: 0.19.3 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.2 - pytz: 2023.3 - pyviz-comms: 2.2.1 - pyyaml: 6.0 - pyzmq: 25.0.2 - readchar: 4.0.5 - redis: 4.5.5 - requests: 2.30.0 - rich: 13.3.5 - rsa: 4.9 - s3fs: 2022.11.0 - scipy: 1.10.1 - setuptools: 67.7.2 - six: 1.16.0 - smmap: 5.0.0 - sniffio: 1.3.0 - soupsieve: 2.4.1 - stack-data: 0.6.2 - starlette: 0.22.0 - starsessions: 1.3.0 - streamlit: 1.22.0 - sympy: 1.12 - tbb: 2021.9.0 - tenacity: 8.2.2 - tensorboardx: 2.6 - toml: 0.10.2 - toolz: 0.12.0 - torch: 2.0.0 - torch-xla: 2.0 - torchmetrics: 0.11.4 - tornado: 6.3.1 - tqdm: 4.65.0 - traitlets: 5.9.0 - triton: 2.0.0 - typeshed-client: 2.3.0 - typing-extensions: 4.5.0 - tzdata: 2023.3 - tzlocal: 5.0.1 - uritemplate: 3.0.1 - urllib3: 1.26.15 - uvicorn: 0.22.0 - validators: 0.20.0 - watchdog: 3.0.0 - wcwidth: 0.2.6 - webencodings: 0.5.1 - websocket-client: 1.5.1 - websockets: 11.0.3 - wheel: 0.40.0 - wrapt: 1.15.0 - yarl: 1.9.2 - zipp: 3.15.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.8.10 - release: 5.13.0-1027-gcp - version: #32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022

More info

No response

cc @awaelchli

awaelchli commented 1 year ago

Hi @Quasar-Kim This was reported before and an attempt to fix it was made here #12814, but it was reverted at some point, I can't find the right commit in the history. But I'm pretty sure the reason was related to pickling issues, because the threading lock is not pickle-friendly. I'm not sure what the best fix is here if we want to keep full checkpoint migration.

Note that unrelated to this issue, you should use this syntax:

model = BoringModel.load_from_checkpoint('checkpoint.ckpt')

to re-instantiate a model from a checkpoint.

Quasar-Kim commented 1 year ago

@awaelchli Thank you for your response! I'll open a PR once I find a fix for this issue.

Quasar-Kim commented 1 year ago

There seems no clean and transparent way to fix this issue. One potential solution can be providing a function responsible for loading legacy checkpoint file. The function will be a thin wrapper over torch.load, providing custom Unpickler that resolves missing legacy modules. But there are a few downsides to consider:

Quasar-Kim commented 1 year ago

Because there seems to be no clean solution without using thread synchronization, I'm pretty sure re-introducing a lock is a way to go. I found the commit 277b0b811fb1419d6c06e7953941d6f6076eaf6d removed the lock but it does not explain why. @awaelchli Could you explain why the lock was removed? I can't reproduce pickling related issue when I added a locking; all tests are passing in CI (see draft PR).

awaelchli commented 1 year ago

@Quasar-Kim I honestly don't remember. It could have been an accident with rebasing. But note that the test that was introduced in #12814 still exists: https://github.com/Lightning-AI/lightning/blob/6f4524a25c4d06bf69993ca8543c58340a1e8088/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py#L70-L88

I'm fine with trying to add it back. Would it be possible to modify your reproducible script with regular threading so we can test the failure/fix without needing the TPU runtime?

Quasar-Kim commented 1 year ago

@awaelchli Sure! I also updated it to be self-contained.