Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.25k stars 3.38k forks source link

Registered buffers not moved to correct device when using DeepSpeed Stage 3 #20258

Open amorehead opened 1 month ago

amorehead commented 1 month ago

Bug description

Using the DeepSpeed Strategy configuration

_target_: lightning.pytorch.strategies.DeepSpeedStrategy
zero_optimization: true
stage: 3
allgather_bucket_size: 2e8
reduce_bucket_size: 2e8
offload_optimizer: false
offload_parameters: false
partition_activations: false
cpu_checkpointing: false
contiguous_gradients: false
overlap_comm: false

I am experiencing an issue (specifically with DeepSpeed stage 3, not stages 1-2) where the tensors registered within sub-nn.Modules of my LightningModule's main lit_model.network nn.Module are not moved by register_buffer() to the correct device upon training the lit_module.network. In particular, I am trying to register buffers as

distance_bins_tensor = tensor([0.0, 1.0, 2.0, 3.0])
self.register_buffer("distance_bins", distance_bins_tensor)

within the various submodules of my lit_module.network. When my optimizer tries to perform a step, I get the error

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:6 and cpu!

when trying to use these registered buffers e.g., by multiplying them by feature tensors loaded onto (in this case) cuda:6.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment * CUDA: - GPU: - NVIDIA A100 80GB PCIe - NVIDIA A100 80GB PCIe - available: True - version: 11.8 * Lightning: - adam-atan2-pytorch: 0.0.10 - alphafold3-pytorch: 0.0.41 - alphafold3-pytorch-lightning-hydra: 0.1.111 - frame-averaging-pytorch: 0.0.19 - lightning: 2.4.0 - lightning-utilities: 0.11.6 - pytorch-lightning: 2.4.0 - rotary-embedding-torch: 0.6.1 - torch: 2.3.0+cu118 - torch-geometric: 2.5.3 - torchaudio: 2.3.0+cu118 - torchmetrics: 1.4.1 - torchtyping: 0.1.4 - torchvision: 0.18.0+cu118 * Packages: - adam-atan2-pytorch: 0.0.10 - aiofiles: 23.2.1 - aiohttp: 3.9.5 - aiosignal: 1.3.1 - alembic: 1.13.1 - alphafold3-pytorch: 0.0.41 - alphafold3-pytorch-lightning-hydra: 0.1.111 - annotated-types: 0.7.0 - antlr4-python3-runtime: 4.9.3 - anyio: 4.4.0 - appdirs: 1.4.4 - argcomplete: 3.3.0 - asttokens: 2.4.1 - async-timeout: 4.0.3 - attrs: 23.2.0 - autopage: 0.5.2 - beartype: 0.18.5 - beautifulsoup4: 4.12.3 - biopandas: 0.5.1.dev0 - biopython: 1.83 - bioservices: 1.11.2 - cattrs: 23.2.3 - certifi: 2024.8.30 - cfgv: 3.4.0 - chardet: 5.2.0 - charset-normalizer: 3.3.2 - click: 8.1.7 - cliff: 4.7.0 - cmaes: 0.10.0 - cmd2: 2.4.3 - colorama: 0.4.6 - colorlog: 6.8.2 - colt5-attention: 0.11.0 - comm: 0.2.2 - contourpy: 1.2.1 - cycler: 0.12.1 - debugpy: 1.8.1 - decorator: 5.1.1 - deepdiff: 7.0.1 - deepspeed: 0.15.0 - distlib: 0.3.8 - docker-pycreds: 0.4.0 - easydev: 0.13.2 - einops: 0.8.0 - einx: 0.2.2 - environs: 11.0.0 - exceptiongroup: 1.2.1 - executing: 2.0.1 - fastapi: 0.112.2 - ffmpy: 0.4.0 - filelock: 3.13.1 - fonttools: 4.52.4 - frame-averaging-pytorch: 0.0.19 - freetype-py: 2.3.0 - frozendict: 2.4.4 - frozenlist: 1.4.1 - fsspec: 2024.2.0 - gemmi: 0.6.6 - gevent: 24.2.1 - gitdb: 4.0.11 - gitpython: 3.1.43 - gradio: 4.43.0 - gradio-client: 1.3.0 - gradio-molecule3d: 0.0.5 - graphein: 1.7.6 - greenlet: 3.0.3 - grequests: 0.7.0 - h11: 0.14.0 - hjson: 3.1.0 - httpcore: 1.0.5 - httpx: 0.27.2 - huggingface-hub: 0.23.4 - hydra-colorlog: 1.2.0 - hydra-core: 1.3.2 - hydra-optuna-sweeper: 1.2.0 - identify: 2.5.36 - idna: 3.7 - importlib-resources: 6.4.4 - iniconfig: 2.0.0 - ipykernel: 6.29.4 - ipython: 8.24.0 - jaxtyping: 0.2.28 - jedi: 0.19.1 - jinja2: 3.1.3 - joblib: 1.4.2 - jupyter-client: 8.6.2 - jupyter-core: 5.7.2 - kiwisolver: 1.4.5 - lightning: 2.4.0 - lightning-utilities: 0.11.6 - line-profiler: 4.1.3 - local-attention: 1.9.1 - loguru: 0.7.2 - looseversion: 1.1.2 - lxml: 5.2.2 - mako: 1.3.5 - markdown-it-py: 3.0.0 - markupsafe: 2.1.5 - marshmallow: 3.21.3 - matplotlib: 3.8.4 - matplotlib-inline: 0.1.7 - mdurl: 0.1.2 - mmtf-python: 1.1.3 - mpmath: 1.3.0 - msgpack: 1.0.8 - multidict: 6.0.5 - multipledispatch: 1.0.0 - munkres: 1.1.4 - nest-asyncio: 1.6.0 - networkx: 3.2.1 - ninja: 1.11.1.1 - nodeenv: 1.8.0 - numpy: 1.23.5 - nvidia-cublas-cu11: 11.11.3.6 - nvidia-cuda-cupti-cu11: 11.8.87 - nvidia-cuda-nvrtc-cu11: 11.8.89 - nvidia-cuda-runtime-cu11: 11.8.89 - nvidia-cudnn-cu11: 8.7.0.84 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.3.0.86 - nvidia-cusolver-cu11: 11.4.1.48 - nvidia-cusparse-cu11: 11.7.5.86 - nvidia-ml-py: 12.560.30 - nvidia-nccl-cu11: 2.20.5 - nvidia-nvtx-cu11: 11.8.86 - omegaconf: 2.3.0 - optree: 0.11.0 - optuna: 2.10.1 - ordered-set: 4.1.0 - orjson: 3.10.7 - packaging: 24.0 - pandas: 1.5.3 - parso: 0.8.4 - pbr: 6.0.0 - pdbeccdutils: 0.8.5 - pexpect: 4.9.0 - pillow: 10.2.0 - pip: 24.0 - pipx: 1.5.0 - platformdirs: 4.2.2 - plotly: 5.22.0 - pluggy: 1.5.0 - polars: 1.3.0 - pre-commit: 3.7.1 - prettytable: 3.10.0 - prompt-toolkit: 3.0.45 - protobuf: 4.25.4 - psutil: 5.9.8 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py-cpuinfo: 9.0.0 - pycairo: 1.26.0 - pydantic: 2.8.2 - pydantic-core: 2.20.1 - pydub: 0.25.1 - pygments: 2.18.0 - pyparsing: 3.1.2 - pyperclip: 1.8.2 - pytest: 8.2.1 - python-dateutil: 2.9.0 - python-dotenv: 1.0.1 - python-multipart: 0.0.9 - pytorch-lightning: 2.4.0 - pytz: 2024.1 - pyyaml: 6.0.1 - pyzmq: 26.0.3 - rdkit: 2024.3.2 - reportlab: 4.1.0 - requests: 2.32.2 - requests-cache: 1.2.0 - retrying: 1.3.4 - rich: 13.7.1 - rich-click: 1.8.2 - rlpycairo: 0.2.0 - rootutils: 1.0.7 - rotary-embedding-torch: 0.6.1 - ruff: 0.6.4 - scikit-learn: 1.5.0 - scipy: 1.13.1 - seaborn: 0.13.2 - semantic-version: 2.10.0 - sentry-sdk: 2.12.0 - setproctitle: 1.3.3 - setuptools: 70.0.0 - sh: 2.0.7 - shellingham: 1.5.4 - shortuuid: 1.0.13 - six: 1.16.0 - smmap: 5.0.1 - sniffio: 1.3.1 - soupsieve: 2.5 - sqlalchemy: 2.0.30 - stack-data: 0.6.3 - starlette: 0.38.4 - stevedore: 5.2.0 - suds-community: 1.1.2 - sympy: 1.12 - taylor-series-linear-attention: 0.1.12 - tenacity: 8.3.0 - threadpoolctl: 3.5.0 - timeout-decorator: 0.5.0 - tomli: 2.0.1 - tomlkit: 0.12.0 - torch: 2.3.0+cu118 - torch-geometric: 2.5.3 - torchaudio: 2.3.0+cu118 - torchmetrics: 1.4.1 - torchtyping: 0.1.4 - torchvision: 0.18.0+cu118 - tornado: 6.4 - tqdm: 4.66.4 - traitlets: 5.14.3 - triton: 2.3.0 - typeguard: 2.13.3 - typer: 0.12.5 - typing-extensions: 4.11.0 - tzdata: 2024.1 - unicodedata2: 15.1.0 - url-normalize: 1.4.3 - urllib3: 2.2.1 - userpath: 1.9.2 - uvicorn: 0.30.6 - virtualenv: 20.26.2 - wandb: 0.16.6 - wcwidth: 0.2.13 - websockets: 12.0 - wget: 3.2 - wheel: 0.43.0 - wrapt: 1.16.0 - xarray: 2024.3.0 - xmltodict: 0.13.0 - yarl: 1.9.4 - zope.event: 5.0 - zope.interface: 6.4.post2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.14 - release: 4.18.0-553.16.1.el8_10.x86_64 - version: #1 SMP Thu Aug 8 07:11:46 EDT 2024

More info

No response

cstsunfu commented 3 weeks ago

I have a similar issue. Any updates on this?

amorehead commented 3 weeks ago

Not yet.