Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.37k stars 3.38k forks source link

Bug of SingleDeviceStrategy: incoherent device between accelerator and strategy when accelerator="auto" #18902

Open ZekunZh opened 1 year ago

ZekunZh commented 1 year ago

Bug description

Hello lighting team,

I observed a strange behaviour of SingleDeviceStrategy: when using Trainer(accelerator="auto", strategy="single_device"), the strategy's accelerator is not correctly detected.

For example, the following code snippet:

    trainer = Trainer(accelerator="auto", strategy="single_device")
    print("accelerator:", trainer.accelerator)
    print("strategy's root device:", trainer.strategy.root_device)

When running on a machine with GPU, I got:

accelerator: <lightning.pytorch.accelerators.cuda.CUDAAccelerator object at 0x7faefe6c4890>
strategy's root device: cpu

There's a mismatch between the accelerator and strategy's device, and will raise this error when running inference:

lightning.fabric.utilities.exceptions.MisconfigurationException: Device should be GPU, got cpu instead

What version are you seeing the problem on?

v2.0

How to reproduce the bug

from lightning import LightningModule
from lightning.pytorch import Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

class SimpleDataset(Dataset):
    def __len__(self):
        return 1000

    def __getitem__(self, idx):
        return torch.randn((1, 28, 28))

class SimpleModel(LightningModule):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer1 = nn.Linear(28 * 28, 512)
        self.layer2 = nn.Linear(512, 512)
        self.layer3 = nn.Linear(512, 512)
        self.layer4 = nn.Linear(512, 512)
        self.layer5 = nn.Linear(512, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        x = F.relu(self.layer4(x))
        x = self.layer5(x)
        return x

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        return self(batch)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

if __name__ == '__main__':
    trainer = Trainer(accelerator="auto", strategy="single_device")
    print("accelerator", trainer.accelerator)
    print("strategy's root device", trainer.strategy.root_device)

    dataset = SimpleDataset()
    dataloader = DataLoader(dataset, batch_size=32)
    model = SimpleModel()

    trainer = Trainer(strategy="single_device")
    trainer.predict(model, dataloader)

Error messages and logs

# Error messages and logs here please
zekun_zhang_gleamer_ai@gpu-2:~/ai-gleamer$ poetry run python minimal_reproducible_examples/device_conflict_for_single_device_strategy.py
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
accelerator: <lightning.pytorch.accelerators.cuda.CUDAAccelerator object at 0x7faefe6c4890>
strategy's root device: cpu
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Traceback (most recent call last):
  File "/home/zekun_zhang_gleamer_ai/ai-gleamer/minimal_reproducible_examples/device_conflict_for_single_device_strategy.py", line 52, in <module>
    trainer.predict(model, dataloader)
  File "/home/zekun_zhang_gleamer_ai/.cache/pypoetry/virtualenvs/glmlib-BFgLEs1Y-py3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 805, in predict
    return call._call_and_handle_interrupt(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zekun_zhang_gleamer_ai/.cache/pypoetry/virtualenvs/glmlib-BFgLEs1Y-py3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zekun_zhang_gleamer_ai/.cache/pypoetry/virtualenvs/glmlib-BFgLEs1Y-py3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 847, in _predict_impl
    results = self._run(model, ckpt_path=ckpt_path)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zekun_zhang_gleamer_ai/.cache/pypoetry/virtualenvs/glmlib-BFgLEs1Y-py3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 893, in _run
    self.strategy.setup_environment()
  File "/home/zekun_zhang_gleamer_ai/.cache/pypoetry/virtualenvs/glmlib-BFgLEs1Y-py3.11/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 127, in setup_environment
    self.accelerator.setup_device(self.root_device)
  File "/home/zekun_zhang_gleamer_ai/.cache/pypoetry/virtualenvs/glmlib-BFgLEs1Y-py3.11/lib/python3.11/site-packages/lightning/pytorch/accelerators/cuda.py", line 42, in setup_device
    raise MisconfigurationException(f"Device should be GPU, got {device} instead")
lightning.fabric.utilities.exceptions.MisconfigurationException: Device should be GPU, got cpu instead

Environment

Current environment * CUDA: - GPU: - Tesla T4 - available: True - version: 11.7 * Lightning: - lightning: 2.0.0 - lightning-cloud: 0.5.36 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.3 - torch: 2.0.0 - torchinfo: 1.5.3 - torchmetrics: 0.11.4 - torchvision: 0.15.1 * Packages: - absl-py: 1.4.0 - adal: 1.2.7 - addict: 2.4.0 - aiofiles: 23.1.0 - aiohttp: 3.8.4 - aiohttp-retry: 2.8.3 - aiosignal: 1.3.1 - albumentations: 1.1.0 - amqp: 5.1.1 - antlr4-python3-runtime: 4.9.3 - anyio: 3.7.0 - appdirs: 1.4.4 - argcomplete: 2.1.2 - arrow: 1.2.3 - async-timeout: 4.0.2 - asyncssh: 2.13.1 - atpublic: 4.0 - attrs: 23.1.0 - autoflake: 2.2.0 - azure-common: 1.1.28 - azure-core: 1.27.0 - azure-graphrbac: 0.61.1 - azure-mgmt-authorization: 3.0.0 - azure-mgmt-containerregistry: 10.1.0 - azure-mgmt-core: 1.4.0 - azure-mgmt-keyvault: 10.2.2 - azure-mgmt-resource: 22.0.0 - azure-mgmt-storage: 21.0.0 - azure-nspkg: 3.0.2 - azure-storage: 0.36.0 - azure-storage-blob: 1.1.0 - azure-storage-common: 1.1.0 - azure-storage-nspkg: 3.1.0 - azureml-core: 1.50.0 - backports.tempfile: 1.0 - backports.weakref: 1.0.post1 - bcrypt: 4.0.1 - beautifulsoup4: 4.12.2 - bigjson: 1.0.9 - billiard: 3.6.4.0 - black: 23.3.0 - blessed: 1.20.0 - blindspin: 2.0.1 - boto3: 1.26.149 - botocore: 1.29.149 - cachetools: 5.3.1 - celery: 5.2.2 - certifi: 2023.5.7 - cffi: 1.15.1 - cfgv: 3.3.1 - charset-normalizer: 3.1.0 - chumpy: 0.71 - clearml: 1.3.2 - click: 8.0.2 - click-didyoumean: 0.3.0 - click-plugins: 1.1.1 - click-repl: 0.2.0 - clickclick: 20.10.2 - cloudpickle: 2.2.1 - cmake: 3.26.4 - colorama: 0.4.6 - configobj: 5.0.8 - connexion: 2.14.2 - contextlib2: 0.5.5 - contourpy: 1.0.7 - coverage: 7.2.5 - crayons: 0.4.0 - croniter: 1.3.15 - cryptography: 3.4.8 - cycler: 0.11.0 - cython: 0.29.33 - dacite: 1.7.0 - dateutils: 0.6.12 - decorator: 5.1.1 - deepdiff: 6.3.0 - deprecated: 1.2.14 - detectron2: 0.7+cu118 - detrex: 0.3.0 - dictdiffer: 0.9.0 - dill: 0.3.6 - diskcache: 5.6.1 - distlib: 0.3.6 - distro: 1.8.0 - dnspython: 2.3.0 - docker: 6.1.3 - docker-pycreds: 0.4.0 - dpath: 2.1.6 - dulwich: 0.21.5 - dvc: 2.46.0 - dvc-data: 0.42.3 - dvc-gs: 2.22.0 - dvc-http: 2.30.2 - dvc-objects: 0.22.0 - dvc-render: 0.5.3 - dvc-studio-client: 0.10.0 - dvc-task: 0.2.1 - einops: 0.6.1 - et-xmlfile: 1.1.0 - eventlet: 0.33.3 - fairscale: 0.4.13 - fastapi: 0.86.0 - fiftyone: 0.20.0 - fiftyone-brain: 0.11.0 - fiftyone-db: 0.4.0 - filelock: 3.12.0 - flake8: 6.0.0 - flask: 2.2.5 - flask-testing: 0.8.1 - flatten-dict: 0.4.2 - flufl.lock: 7.1.1 - fonttools: 4.39.4 - frozenlist: 1.3.3 - fsspec: 2023.5.0 - ftfy: 6.1.1 - funcy: 2.0 - furl: 2.1.3 - future: 0.18.3 - fvcore: 0.1.5.post20220506 - gcsfs: 2023.5.0 - gitdb: 4.0.10 - gitdb2: 2.0.6 - gitpython: 3.1.31 - glmlib: 1.0.0 - glob2: 0.7 - google-api-core: 1.34.0 - google-auth: 2.19.1 - google-auth-oauthlib: 1.0.0 - google-cloud-core: 2.3.2 - google-cloud-pubsub: 1.0.2 - google-cloud-storage: 1.43.0 - google-crc32c: 1.5.0 - google-resumable-media: 1.3.0 - googleapis-common-protos: 1.59.0 - gputil: 1.4.0 - grandalf: 0.8 - graphql-core: 3.2.3 - greenlet: 2.0.2 - grpc-google-iam-v1: 0.12.6 - grpcio: 1.54.2 - grpcio-status: 1.48.2 - h11: 0.14.0 - h2: 4.1.0 - hpack: 4.0.0 - httpcore: 0.17.2 - httpx: 0.24.1 - huggingface-hub: 0.15.1 - humanfriendly: 10.0 - hydra-core: 1.3.2 - hydra-zen: 0.11.0 - hypercorn: 0.14.3 - hyperframe: 6.0.1 - identify: 2.5.24 - idna: 3.4 - ijson: 3.2.3 - imageio: 2.31.0 - imgaug: 0.4.0 - inflection: 0.5.1 - iniconfig: 2.0.0 - inquirer: 3.1.3 - iopath: 0.1.9 - isodate: 0.6.1 - isort: 5.12.0 - iterative-telemetry: 0.0.8 - itsdangerous: 2.1.2 - jaraco.classes: 3.3.0 - jeepney: 0.8.0 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.2.0 - json-tricks: 3.17.0 - jsonpickle: 3.0.1 - jsonschema: 4.10.0 - kaleido: 0.2.1 - keyring: 24.2.0 - keyrings.google-artifactregistry-auth: 1.1.2 - kili: 2.120.0 - kiwisolver: 1.4.4 - knack: 0.10.1 - kombu: 5.3.0 - lazy-loader: 0.2 - lightning: 2.0.0 - lightning-cloud: 0.5.36 - lightning-utilities: 0.8.0 - lit: 16.0.5.post0 - markdown: 3.4.3 - markdown-it-py: 2.2.0 - markupsafe: 2.1.3 - matplotlib: 3.7.1 - mccabe: 0.7.0 - mdurl: 0.1.2 - mmcv: 1.4.2 - mmpose: 0.21.0 - monai: 0.9.1 - mongoengine: 0.24.2 - more-itertools: 8.8.0 - motor: 3.1.2 - mpmath: 1.3.0 - msal: 1.22.0 - msal-extensions: 1.0.0 - msrest: 0.7.1 - msrestazure: 0.6.4 - multidict: 6.0.4 - munkres: 1.1.4 - mypy-extensions: 1.0.0 - nanotime: 0.5.2 - ndg-httpsclient: 0.5.1 - ndjson: 0.3.1 - networkx: 3.1 - nibabel: 3.2.1 - nodeenv: 1.8.0 - numpy: 1.24.2 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.2.10.91 - nvidia-cusolver-cu11: 11.4.0.1 - nvidia-cusparse-cu11: 11.7.4.91 - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - oauthlib: 3.2.2 - omegaconf: 2.2.1 - opencv-python: 4.7.0.72 - opencv-python-headless: 4.7.0.72 - openpyxl: 3.0.7 - ordered-set: 4.1.0 - orderedmultidict: 1.0.1 - orjson: 3.9.0 - packaging: 23.0 - pandas: 2.0.2 - paramiko: 3.2.0 - pathlib2: 2.3.7.post1 - pathspec: 0.11.1 - pathtools: 0.1.2 - patool: 1.12 - pika: 1.1.0 - pillow: 9.5.0 - pip: 23.2.1 - pkginfo: 1.9.6 - platformdirs: 3.5.1 - plotly: 5.14.1 - pluggy: 1.0.0 - portalocker: 2.7.0 - pprintpp: 0.4.0 - pre-commit: 3.2.2 - priority: 2.0.0 - prompt-toolkit: 3.0.38 - protobuf: 3.20.3 - psutil: 5.9.5 - pyaescrypt: 0.4.3 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pybind11: 2.11.1 - pycocotools: 2.0.6 - pycodestyle: 2.10.0 - pycparser: 2.21 - pydantic: 1.10.9 - pydicom: 2.0.0 - pydot: 1.4.2 - pyelftools: 0.27 - pyflakes: 3.0.1 - pygit2: 1.12.1 - pygments: 2.15.1 - pygtrie: 2.5.0 - pyjwt: 2.1.0 - pymongo: 4.3.3 - pympler: 1.0.1 - pynacl: 1.5.0 - pyopenssl: 21.0.0 - pyparsing: 3.0.9 - pyrsistent: 0.19.3 - pysocks: 1.7.1 - pytest: 7.2.2 - pytest-mock: 3.10.0 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-gdcm: 3.0.21 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.3 - pytz: 2023.3 - pywavelets: 1.4.1 - pyyaml: 6.0 - qudida: 0.0.4 - readchar: 4.0.5 - regex: 2023.6.3 - requests: 2.30.0 - requests-oauthlib: 1.3.1 - retrying: 1.3.4 - rich: 13.4.1 - rsa: 4.9 - ruamel.yaml: 0.17.21 - ruff: 0.0.270 - s3transfer: 0.6.1 - schema: 0.7.0 - scikit-image: 0.20.0 - scikit-learn: 1.2.2 - scipy: 1.10.1 - scmrepo: 0.2.1 - secretstorage: 3.3.3 - sentry-sdk: 1.25.1 - setproctitle: 1.3.2 - setuptools: 67.2.0 - shapely: 2.0.1 - shortuuid: 1.0.11 - shtab: 1.6.1 - six: 1.16.0 - smmap: 5.0.0 - smmap2: 3.0.1 - sniffio: 1.3.0 - sortedcontainers: 2.4.0 - soupsieve: 2.4.1 - sqltrie: 0.4.0 - sse-starlette: 0.10.3 - sseclient-py: 1.7.2 - starlette: 0.20.4 - starsessions: 1.3.0 - strawberry-graphql: 0.138.1 - submitit: 1.4.5 - sympy: 1.12 - tabulate: 0.9.0 - tenacity: 8.2.2 - tensorboard: 2.13.0 - tensorboard-data-server: 0.7.0 - termcolor: 2.3.0 - testcontainers: 3.0.0 - threadpoolctl: 3.1.0 - tifffile: 2023.4.12 - timm: 0.6.13 - toml: 0.10.2 - tomli: 2.0.1 - tomlkit: 0.11.8 - torch: 2.0.0 - torchinfo: 1.5.3 - torchmetrics: 0.11.4 - torchvision: 0.15.1 - tqdm: 4.64.0 - traitlets: 5.9.0 - triton: 2.0.0 - typeguard: 4.0.0 - typing-extensions: 4.6.3 - tzdata: 2023.3 - tzlocal: 5.0.1 - universal-analytics-python3: 1.1.1 - urllib3: 1.26.16 - uvicorn: 0.22.0 - vine: 5.0.0 - virtualenv: 20.23.0 - voluptuous: 0.13.1 - voxel51-eta: 0.8.4 - wandb: 0.15.0 - wcwidth: 0.2.6 - websocket-client: 1.5.2 - websockets: 11.0.3 - werkzeug: 2.3.5 - wheel: 0.40.0 - wrapt: 1.15.0 - wsproto: 1.2.0 - xmltodict: 0.13.0 - xtcocotools: 1.13 - yacs: 0.1.8 - yapf: 0.33.0 - yarl: 1.9.2 - zc.lockfile: 3.0.post1 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.11.5 - release: 5.15.0-1045-gcp - version: #53~20.04.2-Ubuntu SMP Wed Oct 18 12:59:20 UTC 2023

More info

Sorry to bother you again @awaelchli, but you've given great help to me in another issue. I'am grateful for any ideas and suggestions ! 😉

cc @tchaton

awaelchli commented 1 year ago

Hey @ZekunZh The usage strategy="single_device" is not documented. That's because the default strategy="auto" actually means single device if devices=1. It's just registered for completeness I guess, but we don't expect users to explicitly set it.

I'd suggest that we remove it from the registry completely.

ZekunZh commented 1 year ago

Hey @ZekunZh The usage strategy="single_device" is not documented. That's because the default strategy="auto" actually means single device if devices=1. It's just registered for completeness I guess, but we don't expect users to explicitly set it.

I'd suggest that we remove it from the registry completely.

Thanks a lot for the explication !

ZekunZh commented 1 year ago

So what strategy would you recommend for production purposes @awaelchli ? In our production environment, we either have 1 CUDA GPU or CPU-only on the machine.

The reason of choosing a specific strategy is because we would like to "fix" the OOM problem mentioned in this issue, by modifying the teardown method.

JStyborski commented 2 months ago

Commenting to add my experience:

Windows 11 with CUDA-enabled GPU via PyCharm. I am running pytorch-lightning version 1.9.5 (though I checked that the code is similar in the current version, 2.4.0) and python 3.8.5.

I received a similar error with call Trainer(strategy='single_device', accelerator='gpu', devices=[0]). The issue seems to stem from single_device.SingleDeviceStrategy initializing with device='cpu' without any other argument when strategy is 'single_device'. Swapping the Trainer input to strategy='auto' fixed the issue.

I might recommend updating documentation for strategy aliases. None of the aliases listed in the strategy documentation (https://lightning.ai/docs/pytorch/stable/extensions/strategy.html) fit my use case, so I found the 'single_device' alias by digging through the code (specifically accelerator_connector.py and single_device.py). A clearer explanation of the 'auto' alias might have helped.