Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.92k stars 3.34k forks source link

Bug of SingleDeviceStrategy: incoherent device between accelerator and strategy when accelerator="auto" #18902

Open ZekunZh opened 10 months ago

ZekunZh commented 10 months ago

Bug description

Hello lighting team,

I observed a strange behaviour of SingleDeviceStrategy: when using Trainer(accelerator="auto", strategy="single_device"), the strategy's accelerator is not correctly detected.

For example, the following code snippet:

    trainer = Trainer(accelerator="auto", strategy="single_device")
    print("accelerator:", trainer.accelerator)
    print("strategy's root device:", trainer.strategy.root_device)

When running on a machine with GPU, I got:

accelerator: <lightning.pytorch.accelerators.cuda.CUDAAccelerator object at 0x7faefe6c4890>
strategy's root device: cpu

There's a mismatch between the accelerator and strategy's device, and will raise this error when running inference:

lightning.fabric.utilities.exceptions.MisconfigurationException: Device should be GPU, got cpu instead

What version are you seeing the problem on?

v2.0

How to reproduce the bug

from lightning import LightningModule
from lightning.pytorch import Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

class SimpleDataset(Dataset):
    def __len__(self):
        return 1000

    def __getitem__(self, idx):
        return torch.randn((1, 28, 28))

class SimpleModel(LightningModule):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer1 = nn.Linear(28 * 28, 512)
        self.layer2 = nn.Linear(512, 512)
        self.layer3 = nn.Linear(512, 512)
        self.layer4 = nn.Linear(512, 512)
        self.layer5 = nn.Linear(512, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        x = F.relu(self.layer4(x))
        x = self.layer5(x)
        return x

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        return self(batch)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

if __name__ == '__main__':
    trainer = Trainer(accelerator="auto", strategy="single_device")
    print("accelerator", trainer.accelerator)
    print("strategy's root device", trainer.strategy.root_device)

    dataset = SimpleDataset()
    dataloader = DataLoader(dataset, batch_size=32)
    model = SimpleModel()

    trainer = Trainer(strategy="single_device")
    trainer.predict(model, dataloader)

Error messages and logs

# Error messages and logs here please
zekun_zhang_gleamer_ai@gpu-2:~/ai-gleamer$ poetry run python minimal_reproducible_examples/device_conflict_for_single_device_strategy.py
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
accelerator: <lightning.pytorch.accelerators.cuda.CUDAAccelerator object at 0x7faefe6c4890>
strategy's root device: cpu
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Traceback (most recent call last):
  File "/home/zekun_zhang_gleamer_ai/ai-gleamer/minimal_reproducible_examples/device_conflict_for_single_device_strategy.py", line 52, in <module>
    trainer.predict(model, dataloader)
  File "/home/zekun_zhang_gleamer_ai/.cache/pypoetry/virtualenvs/glmlib-BFgLEs1Y-py3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 805, in predict
    return call._call_and_handle_interrupt(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zekun_zhang_gleamer_ai/.cache/pypoetry/virtualenvs/glmlib-BFgLEs1Y-py3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zekun_zhang_gleamer_ai/.cache/pypoetry/virtualenvs/glmlib-BFgLEs1Y-py3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 847, in _predict_impl
    results = self._run(model, ckpt_path=ckpt_path)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zekun_zhang_gleamer_ai/.cache/pypoetry/virtualenvs/glmlib-BFgLEs1Y-py3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 893, in _run
    self.strategy.setup_environment()
  File "/home/zekun_zhang_gleamer_ai/.cache/pypoetry/virtualenvs/glmlib-BFgLEs1Y-py3.11/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 127, in setup_environment
    self.accelerator.setup_device(self.root_device)
  File "/home/zekun_zhang_gleamer_ai/.cache/pypoetry/virtualenvs/glmlib-BFgLEs1Y-py3.11/lib/python3.11/site-packages/lightning/pytorch/accelerators/cuda.py", line 42, in setup_device
    raise MisconfigurationException(f"Device should be GPU, got {device} instead")
lightning.fabric.utilities.exceptions.MisconfigurationException: Device should be GPU, got cpu instead

Environment

Current environment * CUDA: - GPU: - Tesla T4 - available: True - version: 11.7 * Lightning: - lightning: 2.0.0 - lightning-cloud: 0.5.36 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.3 - torch: 2.0.0 - torchinfo: 1.5.3 - torchmetrics: 0.11.4 - torchvision: 0.15.1 * Packages: - absl-py: 1.4.0 - adal: 1.2.7 - addict: 2.4.0 - aiofiles: 23.1.0 - aiohttp: 3.8.4 - aiohttp-retry: 2.8.3 - aiosignal: 1.3.1 - albumentations: 1.1.0 - amqp: 5.1.1 - antlr4-python3-runtime: 4.9.3 - anyio: 3.7.0 - appdirs: 1.4.4 - argcomplete: 2.1.2 - arrow: 1.2.3 - async-timeout: 4.0.2 - asyncssh: 2.13.1 - atpublic: 4.0 - attrs: 23.1.0 - autoflake: 2.2.0 - azure-common: 1.1.28 - azure-core: 1.27.0 - azure-graphrbac: 0.61.1 - azure-mgmt-authorization: 3.0.0 - azure-mgmt-containerregistry: 10.1.0 - azure-mgmt-core: 1.4.0 - azure-mgmt-keyvault: 10.2.2 - azure-mgmt-resource: 22.0.0 - azure-mgmt-storage: 21.0.0 - azure-nspkg: 3.0.2 - azure-storage: 0.36.0 - azure-storage-blob: 1.1.0 - azure-storage-common: 1.1.0 - azure-storage-nspkg: 3.1.0 - azureml-core: 1.50.0 - backports.tempfile: 1.0 - backports.weakref: 1.0.post1 - bcrypt: 4.0.1 - beautifulsoup4: 4.12.2 - bigjson: 1.0.9 - billiard: 3.6.4.0 - black: 23.3.0 - blessed: 1.20.0 - blindspin: 2.0.1 - boto3: 1.26.149 - botocore: 1.29.149 - cachetools: 5.3.1 - celery: 5.2.2 - certifi: 2023.5.7 - cffi: 1.15.1 - cfgv: 3.3.1 - charset-normalizer: 3.1.0 - chumpy: 0.71 - clearml: 1.3.2 - click: 8.0.2 - click-didyoumean: 0.3.0 - click-plugins: 1.1.1 - click-repl: 0.2.0 - clickclick: 20.10.2 - cloudpickle: 2.2.1 - cmake: 3.26.4 - colorama: 0.4.6 - configobj: 5.0.8 - connexion: 2.14.2 - contextlib2: 0.5.5 - contourpy: 1.0.7 - coverage: 7.2.5 - crayons: 0.4.0 - croniter: 1.3.15 - cryptography: 3.4.8 - cycler: 0.11.0 - cython: 0.29.33 - dacite: 1.7.0 - dateutils: 0.6.12 - decorator: 5.1.1 - deepdiff: 6.3.0 - deprecated: 1.2.14 - detectron2: 0.7+cu118 - detrex: 0.3.0 - dictdiffer: 0.9.0 - dill: 0.3.6 - diskcache: 5.6.1 - distlib: 0.3.6 - distro: 1.8.0 - dnspython: 2.3.0 - docker: 6.1.3 - docker-pycreds: 0.4.0 - dpath: 2.1.6 - dulwich: 0.21.5 - dvc: 2.46.0 - dvc-data: 0.42.3 - dvc-gs: 2.22.0 - dvc-http: 2.30.2 - dvc-objects: 0.22.0 - dvc-render: 0.5.3 - dvc-studio-client: 0.10.0 - dvc-task: 0.2.1 - einops: 0.6.1 - et-xmlfile: 1.1.0 - eventlet: 0.33.3 - fairscale: 0.4.13 - fastapi: 0.86.0 - fiftyone: 0.20.0 - fiftyone-brain: 0.11.0 - fiftyone-db: 0.4.0 - filelock: 3.12.0 - flake8: 6.0.0 - flask: 2.2.5 - flask-testing: 0.8.1 - flatten-dict: 0.4.2 - flufl.lock: 7.1.1 - fonttools: 4.39.4 - frozenlist: 1.3.3 - fsspec: 2023.5.0 - ftfy: 6.1.1 - funcy: 2.0 - furl: 2.1.3 - future: 0.18.3 - fvcore: 0.1.5.post20220506 - gcsfs: 2023.5.0 - gitdb: 4.0.10 - gitdb2: 2.0.6 - gitpython: 3.1.31 - glmlib: 1.0.0 - glob2: 0.7 - google-api-core: 1.34.0 - google-auth: 2.19.1 - google-auth-oauthlib: 1.0.0 - google-cloud-core: 2.3.2 - google-cloud-pubsub: 1.0.2 - google-cloud-storage: 1.43.0 - google-crc32c: 1.5.0 - google-resumable-media: 1.3.0 - googleapis-common-protos: 1.59.0 - gputil: 1.4.0 - grandalf: 0.8 - graphql-core: 3.2.3 - greenlet: 2.0.2 - grpc-google-iam-v1: 0.12.6 - grpcio: 1.54.2 - grpcio-status: 1.48.2 - h11: 0.14.0 - h2: 4.1.0 - hpack: 4.0.0 - httpcore: 0.17.2 - httpx: 0.24.1 - huggingface-hub: 0.15.1 - humanfriendly: 10.0 - hydra-core: 1.3.2 - hydra-zen: 0.11.0 - hypercorn: 0.14.3 - hyperframe: 6.0.1 - identify: 2.5.24 - idna: 3.4 - ijson: 3.2.3 - imageio: 2.31.0 - imgaug: 0.4.0 - inflection: 0.5.1 - iniconfig: 2.0.0 - inquirer: 3.1.3 - iopath: 0.1.9 - isodate: 0.6.1 - isort: 5.12.0 - iterative-telemetry: 0.0.8 - itsdangerous: 2.1.2 - jaraco.classes: 3.3.0 - jeepney: 0.8.0 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.2.0 - json-tricks: 3.17.0 - jsonpickle: 3.0.1 - jsonschema: 4.10.0 - kaleido: 0.2.1 - keyring: 24.2.0 - keyrings.google-artifactregistry-auth: 1.1.2 - kili: 2.120.0 - kiwisolver: 1.4.4 - knack: 0.10.1 - kombu: 5.3.0 - lazy-loader: 0.2 - lightning: 2.0.0 - lightning-cloud: 0.5.36 - lightning-utilities: 0.8.0 - lit: 16.0.5.post0 - markdown: 3.4.3 - markdown-it-py: 2.2.0 - markupsafe: 2.1.3 - matplotlib: 3.7.1 - mccabe: 0.7.0 - mdurl: 0.1.2 - mmcv: 1.4.2 - mmpose: 0.21.0 - monai: 0.9.1 - mongoengine: 0.24.2 - more-itertools: 8.8.0 - motor: 3.1.2 - mpmath: 1.3.0 - msal: 1.22.0 - msal-extensions: 1.0.0 - msrest: 0.7.1 - msrestazure: 0.6.4 - multidict: 6.0.4 - munkres: 1.1.4 - mypy-extensions: 1.0.0 - nanotime: 0.5.2 - ndg-httpsclient: 0.5.1 - ndjson: 0.3.1 - networkx: 3.1 - nibabel: 3.2.1 - nodeenv: 1.8.0 - numpy: 1.24.2 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.2.10.91 - nvidia-cusolver-cu11: 11.4.0.1 - nvidia-cusparse-cu11: 11.7.4.91 - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - oauthlib: 3.2.2 - omegaconf: 2.2.1 - opencv-python: 4.7.0.72 - opencv-python-headless: 4.7.0.72 - openpyxl: 3.0.7 - ordered-set: 4.1.0 - orderedmultidict: 1.0.1 - orjson: 3.9.0 - packaging: 23.0 - pandas: 2.0.2 - paramiko: 3.2.0 - pathlib2: 2.3.7.post1 - pathspec: 0.11.1 - pathtools: 0.1.2 - patool: 1.12 - pika: 1.1.0 - pillow: 9.5.0 - pip: 23.2.1 - pkginfo: 1.9.6 - platformdirs: 3.5.1 - plotly: 5.14.1 - pluggy: 1.0.0 - portalocker: 2.7.0 - pprintpp: 0.4.0 - pre-commit: 3.2.2 - priority: 2.0.0 - prompt-toolkit: 3.0.38 - protobuf: 3.20.3 - psutil: 5.9.5 - pyaescrypt: 0.4.3 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pybind11: 2.11.1 - pycocotools: 2.0.6 - pycodestyle: 2.10.0 - pycparser: 2.21 - pydantic: 1.10.9 - pydicom: 2.0.0 - pydot: 1.4.2 - pyelftools: 0.27 - pyflakes: 3.0.1 - pygit2: 1.12.1 - pygments: 2.15.1 - pygtrie: 2.5.0 - pyjwt: 2.1.0 - pymongo: 4.3.3 - pympler: 1.0.1 - pynacl: 1.5.0 - pyopenssl: 21.0.0 - pyparsing: 3.0.9 - pyrsistent: 0.19.3 - pysocks: 1.7.1 - pytest: 7.2.2 - pytest-mock: 3.10.0 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-gdcm: 3.0.21 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.3 - pytz: 2023.3 - pywavelets: 1.4.1 - pyyaml: 6.0 - qudida: 0.0.4 - readchar: 4.0.5 - regex: 2023.6.3 - requests: 2.30.0 - requests-oauthlib: 1.3.1 - retrying: 1.3.4 - rich: 13.4.1 - rsa: 4.9 - ruamel.yaml: 0.17.21 - ruff: 0.0.270 - s3transfer: 0.6.1 - schema: 0.7.0 - scikit-image: 0.20.0 - scikit-learn: 1.2.2 - scipy: 1.10.1 - scmrepo: 0.2.1 - secretstorage: 3.3.3 - sentry-sdk: 1.25.1 - setproctitle: 1.3.2 - setuptools: 67.2.0 - shapely: 2.0.1 - shortuuid: 1.0.11 - shtab: 1.6.1 - six: 1.16.0 - smmap: 5.0.0 - smmap2: 3.0.1 - sniffio: 1.3.0 - sortedcontainers: 2.4.0 - soupsieve: 2.4.1 - sqltrie: 0.4.0 - sse-starlette: 0.10.3 - sseclient-py: 1.7.2 - starlette: 0.20.4 - starsessions: 1.3.0 - strawberry-graphql: 0.138.1 - submitit: 1.4.5 - sympy: 1.12 - tabulate: 0.9.0 - tenacity: 8.2.2 - tensorboard: 2.13.0 - tensorboard-data-server: 0.7.0 - termcolor: 2.3.0 - testcontainers: 3.0.0 - threadpoolctl: 3.1.0 - tifffile: 2023.4.12 - timm: 0.6.13 - toml: 0.10.2 - tomli: 2.0.1 - tomlkit: 0.11.8 - torch: 2.0.0 - torchinfo: 1.5.3 - torchmetrics: 0.11.4 - torchvision: 0.15.1 - tqdm: 4.64.0 - traitlets: 5.9.0 - triton: 2.0.0 - typeguard: 4.0.0 - typing-extensions: 4.6.3 - tzdata: 2023.3 - tzlocal: 5.0.1 - universal-analytics-python3: 1.1.1 - urllib3: 1.26.16 - uvicorn: 0.22.0 - vine: 5.0.0 - virtualenv: 20.23.0 - voluptuous: 0.13.1 - voxel51-eta: 0.8.4 - wandb: 0.15.0 - wcwidth: 0.2.6 - websocket-client: 1.5.2 - websockets: 11.0.3 - werkzeug: 2.3.5 - wheel: 0.40.0 - wrapt: 1.15.0 - wsproto: 1.2.0 - xmltodict: 0.13.0 - xtcocotools: 1.13 - yacs: 0.1.8 - yapf: 0.33.0 - yarl: 1.9.2 - zc.lockfile: 3.0.post1 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.11.5 - release: 5.15.0-1045-gcp - version: #53~20.04.2-Ubuntu SMP Wed Oct 18 12:59:20 UTC 2023

More info

Sorry to bother you again @awaelchli, but you've given great help to me in another issue. I'am grateful for any ideas and suggestions ! 😉

cc @tchaton

awaelchli commented 10 months ago

Hey @ZekunZh The usage strategy="single_device" is not documented. That's because the default strategy="auto" actually means single device if devices=1. It's just registered for completeness I guess, but we don't expect users to explicitly set it.

I'd suggest that we remove it from the registry completely.

ZekunZh commented 10 months ago

Hey @ZekunZh The usage strategy="single_device" is not documented. That's because the default strategy="auto" actually means single device if devices=1. It's just registered for completeness I guess, but we don't expect users to explicitly set it.

I'd suggest that we remove it from the registry completely.

Thanks a lot for the explication !

ZekunZh commented 10 months ago

So what strategy would you recommend for production purposes @awaelchli ? In our production environment, we either have 1 CUDA GPU or CPU-only on the machine.

The reason of choosing a specific strategy is because we would like to "fix" the OOM problem mentioned in this issue, by modifying the teardown method.