Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.15k stars 3.37k forks source link

Investigate FSDP + CPU Offload performance in Trainer #18336

Open awaelchli opened 1 year ago

awaelchli commented 1 year ago

Bug description

When writing the new FSDP guide for Trainer in #18326, I got suspiciously slow iteration speed when enabling CPU offload (see https://github.com/Lightning-AI/lightning/pull/18326#discussion_r1296185766).

Iterations per second Fabric FSDP + Offload: 0.3 Trainer FSDP + Offload: 0.02

What version are you seeing the problem on?

master

How to reproduce the bug

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
from lightning.pytorch.demos import Transformer, WikiText2

class LanguageModel(L.LightningModule):
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.model = None

    def configure_model(self):
        self.model = self.model or Transformer(  # 1B parameters
            vocab_size=self.vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64
        )

    def training_step(self, batch):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.1)

L.seed_everything(42)

# Data
dataset = WikiText2()
train_dataloader = DataLoader(dataset)

# Model
model = LanguageModel(vocab_size=dataset.vocab_size)

policy = {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
strategy = FSDPStrategy(
    auto_wrap_policy=policy,
    cpu_offload=True,
)

# Trainer
trainer = L.Trainer(accelerator="cuda", devices=2, strategy=strategy)
trainer.fit(model, train_dataloader)
trainer.print(torch.cuda.memory_summary())

Error messages and logs

No errors.

Environment

Current environment * CUDA: - GPU: - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - available: True - version: 11.8 * Lightning: - lightning: 2.0.6 - lightning-api-access: 0.0.5 - lightning-cloud: 0.5.37 - lightning-fabric: 1.9.3 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.6 - pytorch-triton: 2.1.0+e6216047b8 - torch: 2.1.0.dev20230817+cu118 - torchmetrics: 1.0.0 - torchvision: 0.16.0.dev20230817+cu118 * Packages: - absl-py: 1.4.0 - aiobotocore: 2.4.2 - aiohttp: 3.8.4 - aioitertools: 0.11.0 - aiosignal: 1.3.1 - altair: 4.2.2 - antlr4-python3-runtime: 4.9.3 - anyio: 3.6.2 - appdirs: 1.4.4 - arrow: 1.2.3 - asttokens: 2.2.1 - async-generator: 1.10 - async-timeout: 4.0.2 - attrs: 22.2.0 - backcall: 0.2.0 - backoff: 2.2.1 - backports.functools-lru-cache: 1.6.5 - beautifulsoup4: 4.11.2 - black: 23.3.0 - bleach: 6.0.0 - blessed: 1.20.0 - blinker: 1.5 - bokeh: 2.4.3 - botocore: 1.27.59 - cachetools: 5.3.0 - certifi: 2023.5.7 - cffi: 1.15.1 - cfgv: 3.3.1 - charset-normalizer: 3.0.1 - click: 8.1.3 - cloudpickle: 2.2.1 - cmake: 3.26.0 - codecov: 2.1.12 - coloredlogs: 15.0.1 - contourpy: 1.0.7 - coverage: 6.5.0 - croniter: 1.3.8 - cryptography: 40.0.2 - cycler: 0.11.0 - datasets: 2.12.0 - dateutils: 0.6.12 - debugpy: 1.5.1 - decorator: 5.1.1 - deepdiff: 6.2.3 - deepspeed: 0.9.3 - dill: 0.3.6 - distlib: 0.3.6 - dnspython: 2.3.0 - docker: 6.0.1 - docker-pycreds: 0.4.0 - docstring-parser: 0.15 - email-validator: 1.3.1 - entrypoints: 0.4 - exceptiongroup: 1.1.0 - executing: 1.2.0 - fairscale: 0.4.13 - fastapi: 0.100.0 - filelock: 3.9.0 - flatbuffers: 23.1.21 - fonttools: 4.38.0 - frozenlist: 1.3.3 - fsspec: 2022.11.0 - gitdb: 4.0.10 - gitpython: 3.1.31 - google-auth: 2.16.1 - google-auth-oauthlib: 0.4.6 - greenlet: 2.0.1 - grpcio: 1.51.3 - h11: 0.14.0 - hjson: 3.1.0 - httpcore: 0.16.3 - httptools: 0.5.0 - httpx: 0.23.3 - huggingface-hub: 0.14.1 - humanfriendly: 10.0 - hydra-core: 1.3.2 - identify: 2.5.18 - idna: 3.4 - importlib-metadata: 6.7.0 - importlib-resources: 5.12.0 - iniconfig: 2.0.0 - inquirer: 3.1.2 - ipykernel: 6.14.0 - ipython: 8.14.0 - isort: 5.12.0 - itsdangerous: 2.1.2 - jedi: 0.18.2 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.2.0 - jsonargparse: 4.20.0 - jsonschema: 4.17.3 - jupyter-client: 7.3.4 - jupyter-core: 5.3.1 - kiwisolver: 1.4.4 - lightning: 2.0.6 - lightning-api-access: 0.0.5 - lightning-cloud: 0.5.37 - lightning-fabric: 1.9.3 - lightning-utilities: 0.8.0 - lit: 15.0.7 - markdown: 3.4.1 - markdown-it-py: 2.2.0 - markupsafe: 2.1.2 - matplotlib: 3.7.0 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mpmath: 1.2.1 - multidict: 6.0.4 - multiprocess: 0.70.14 - mypy: 1.4.1 - mypy-extensions: 1.0.0 - nest-asyncio: 1.5.6 - networkx: 3.0 - ninja: 1.11.1 - nodeenv: 1.7.0 - numpy: 1.24.2 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.2.10.91 - nvidia-cusolver-cu11: 11.4.0.1 - nvidia-cusparse-cu11: 11.7.4.91 - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - oauthlib: 3.2.2 - omegaconf: 2.3.0 - onnx: 1.12.0 - onnxruntime: 1.14.1 - ordered-set: 4.1.0 - orjson: 3.8.6 - outcome: 1.2.0 - packaging: 23.1 - pandas: 1.5.3 - panel: 0.14.3 - param: 1.12.3 - parso: 0.8.3 - pathspec: 0.11.1 - pathtools: 0.1.2 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 23.1.2 - platformdirs: 3.9.1 - playwright: 1.30.0 - pluggy: 1.0.0 - pre-commit: 2.20.0 - prompt-toolkit: 3.0.39 - protobuf: 3.20.1 - psutil: 5.9.4 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py: 1.11.0 - py-cpuinfo: 9.0.0 - py3nvml: 0.2.7 - pyarrow: 11.0.0 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycparser: 2.21 - pyct: 0.5.0 - pydantic: 1.10.5 - pydeck: 0.8.0 - pyee: 9.0.4 - pygments: 2.15.1 - pyjwt: 2.6.0 - pympler: 1.0.1 - pynvml: 11.5.0 - pyparsing: 3.0.9 - pyrsistent: 0.19.3 - pytest: 7.2.0 - pytest-asyncio: 0.20.3 - pytest-cov: 4.0.0 - pytest-doctestplus: 0.12.1 - pytest-forked: 1.4.0 - pytest-rerunfailures: 10.3 - pytest-timeout: 2.1.0 - python-dateutil: 2.8.2 - python-dotenv: 1.0.0 - python-editor: 1.0.4 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.6 - pytorch-triton: 2.1.0+e6216047b8 - pytz: 2022.7.1 - pytz-deprecation-shim: 0.1.0.post0 - pyviz-comms: 2.2.1 - pyyaml: 6.0 - pyzmq: 25.1.0 - readchar: 4.0.3 - redis: 4.5.1 - regex: 2023.3.23 - requests: 2.28.2 - requests-mock: 1.10.0 - requests-oauthlib: 1.3.1 - responses: 0.18.0 - rfc3986: 1.5.0 - rich: 13.3.1 - rsa: 4.9 - s3fs: 2022.11.0 - scikit-learn: 1.2.1 - scipy: 1.10.1 - semver: 2.13.0 - sentencepiece: 0.1.99 - sentry-sdk: 1.28.1 - setproctitle: 1.3.2 - setuptools: 60.9.3 - six: 1.16.0 - smmap: 5.0.0 - sniffio: 1.3.0 - sortedcontainers: 2.4.0 - soupsieve: 2.4 - sqlalchemy: 1.4.41 - sqlalchemy2-stubs: 0.0.2a32 - sqlmodel: 0.0.8 - stack-data: 0.6.2 - starlette: 0.27.0 - starsessions: 1.3.0 - streamlit: 1.19.0 - sympy: 1.11.1 - tensorboard: 2.12.0 - tensorboard-data-server: 0.7.0 - tensorboard-plugin-wit: 1.8.1 - tensorboardx: 2.6 - threadpoolctl: 3.1.0 - tokenizers: 0.13.3 - toml: 0.10.2 - tomli: 2.0.1 - toolz: 0.12.0 - torch: 2.1.0.dev20230817+cu118 - torchmetrics: 1.0.0 - torchvision: 0.16.0.dev20230817+cu118 - tornado: 6.2 - tqdm: 4.64.1 - traitlets: 5.9.0 - transformers: 4.28.1 - trio: 0.22.0 - triton: 2.0.0 - types-cachetools: 5.3.0.5 - types-croniter: 1.3.2.9 - types-decorator: 5.1.8.3 - types-protobuf: 4.22.0.2 - types-pyopenssl: 23.1.0.2 - types-python-dateutil: 2.8.19.12 - types-pytz: 2023.3.0.0 - types-pyyaml: 6.0.12.9 - types-redis: 4.5.4.1 - types-requests: 2.28.11.17 - types-setuptools: 57.4.9 - types-six: 1.16.21.8 - types-tabulate: 0.9.0.2 - types-toml: 0.10.8.6 - types-tzlocal: 4.3.0.0 - types-ujson: 5.7.0.3 - types-urllib3: 1.26.25.10 - typeshed-client: 2.2.0 - typing-extensions: 4.7.0 - tzdata: 2022.7 - tzlocal: 4.2 - ujson: 5.7.0 - urllib3: 1.26.14 - uvicorn: 0.20.0 - uvloop: 0.17.0 - validators: 0.20.0 - virtualenv: 20.19.0 - wandb: 0.15.5 - watchdog: 2.3.0 - watchfiles: 0.18.1 - watermark: 2.4.2 - wcwidth: 0.2.6 - webencodings: 0.5.1 - websocket-client: 1.5.1 - websockets: 10.4 - werkzeug: 2.2.3 - wheel: 0.37.1 - wrapt: 1.15.0 - xmltodict: 0.13.0 - xxhash: 3.2.0 - yapf: 0.40.1 - yarl: 1.8.2 - zipp: 3.15.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.9 - release: 5.15.0-75-generic - version: #82-Ubuntu SMP Tue Jun 6 23:10:23 UTC 2023

More info

No response

cc @borda @awaelchli @carmocca

carmocca commented 1 year ago

Do you observe the same results with Fabric?

awaelchli commented 1 year ago

The reason why I opened the issue is precisely because the difference to Fabric is so noticeable. The numbers are in the description above and in the docs pages.