Open awaelchli opened 1 year ago
When writing the new FSDP guide for Trainer in #18326, I got suspiciously slow iteration speed when enabling CPU offload (see https://github.com/Lightning-AI/lightning/pull/18326#discussion_r1296185766).
Iterations per second Fabric FSDP + Offload: 0.3 Trainer FSDP + Offload: 0.02
master
import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader import lightning as L from lightning.pytorch.strategies import FSDPStrategy from lightning.pytorch.demos import Transformer, WikiText2 class LanguageModel(L.LightningModule): def __init__(self, vocab_size): super().__init__() self.vocab_size = vocab_size self.model = None def configure_model(self): self.model = self.model or Transformer( # 1B parameters vocab_size=self.vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64 ) def training_step(self, batch): input, target = batch output = self.model(input, target) loss = F.nll_loss(output, target.view(-1)) self.log("train_loss", loss, prog_bar=True) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.1) L.seed_everything(42) # Data dataset = WikiText2() train_dataloader = DataLoader(dataset) # Model model = LanguageModel(vocab_size=dataset.vocab_size) policy = {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer} strategy = FSDPStrategy( auto_wrap_policy=policy, cpu_offload=True, ) # Trainer trainer = L.Trainer(accelerator="cuda", devices=2, strategy=strategy) trainer.fit(model, train_dataloader) trainer.print(torch.cuda.memory_summary())
No errors.
No response
cc @borda @awaelchli @carmocca
Do you observe the same results with Fabric?
The reason why I opened the issue is precisely because the difference to Fabric is so noticeable. The numbers are in the description above and in the docs pages.
Bug description
When writing the new FSDP guide for Trainer in #18326, I got suspiciously slow iteration speed when enabling CPU offload (see https://github.com/Lightning-AI/lightning/pull/18326#discussion_r1296185766).
Iterations per second Fabric FSDP + Offload: 0.3 Trainer FSDP + Offload: 0.02
What version are you seeing the problem on?
master
How to reproduce the bug
Error messages and logs
No errors.
Environment
Current environment
* CUDA: - GPU: - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - NVIDIA A100-SXM4-40GB - available: True - version: 11.8 * Lightning: - lightning: 2.0.6 - lightning-api-access: 0.0.5 - lightning-cloud: 0.5.37 - lightning-fabric: 1.9.3 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.6 - pytorch-triton: 2.1.0+e6216047b8 - torch: 2.1.0.dev20230817+cu118 - torchmetrics: 1.0.0 - torchvision: 0.16.0.dev20230817+cu118 * Packages: - absl-py: 1.4.0 - aiobotocore: 2.4.2 - aiohttp: 3.8.4 - aioitertools: 0.11.0 - aiosignal: 1.3.1 - altair: 4.2.2 - antlr4-python3-runtime: 4.9.3 - anyio: 3.6.2 - appdirs: 1.4.4 - arrow: 1.2.3 - asttokens: 2.2.1 - async-generator: 1.10 - async-timeout: 4.0.2 - attrs: 22.2.0 - backcall: 0.2.0 - backoff: 2.2.1 - backports.functools-lru-cache: 1.6.5 - beautifulsoup4: 4.11.2 - black: 23.3.0 - bleach: 6.0.0 - blessed: 1.20.0 - blinker: 1.5 - bokeh: 2.4.3 - botocore: 1.27.59 - cachetools: 5.3.0 - certifi: 2023.5.7 - cffi: 1.15.1 - cfgv: 3.3.1 - charset-normalizer: 3.0.1 - click: 8.1.3 - cloudpickle: 2.2.1 - cmake: 3.26.0 - codecov: 2.1.12 - coloredlogs: 15.0.1 - contourpy: 1.0.7 - coverage: 6.5.0 - croniter: 1.3.8 - cryptography: 40.0.2 - cycler: 0.11.0 - datasets: 2.12.0 - dateutils: 0.6.12 - debugpy: 1.5.1 - decorator: 5.1.1 - deepdiff: 6.2.3 - deepspeed: 0.9.3 - dill: 0.3.6 - distlib: 0.3.6 - dnspython: 2.3.0 - docker: 6.0.1 - docker-pycreds: 0.4.0 - docstring-parser: 0.15 - email-validator: 1.3.1 - entrypoints: 0.4 - exceptiongroup: 1.1.0 - executing: 1.2.0 - fairscale: 0.4.13 - fastapi: 0.100.0 - filelock: 3.9.0 - flatbuffers: 23.1.21 - fonttools: 4.38.0 - frozenlist: 1.3.3 - fsspec: 2022.11.0 - gitdb: 4.0.10 - gitpython: 3.1.31 - google-auth: 2.16.1 - google-auth-oauthlib: 0.4.6 - greenlet: 2.0.1 - grpcio: 1.51.3 - h11: 0.14.0 - hjson: 3.1.0 - httpcore: 0.16.3 - httptools: 0.5.0 - httpx: 0.23.3 - huggingface-hub: 0.14.1 - humanfriendly: 10.0 - hydra-core: 1.3.2 - identify: 2.5.18 - idna: 3.4 - importlib-metadata: 6.7.0 - importlib-resources: 5.12.0 - iniconfig: 2.0.0 - inquirer: 3.1.2 - ipykernel: 6.14.0 - ipython: 8.14.0 - isort: 5.12.0 - itsdangerous: 2.1.2 - jedi: 0.18.2 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.2.0 - jsonargparse: 4.20.0 - jsonschema: 4.17.3 - jupyter-client: 7.3.4 - jupyter-core: 5.3.1 - kiwisolver: 1.4.4 - lightning: 2.0.6 - lightning-api-access: 0.0.5 - lightning-cloud: 0.5.37 - lightning-fabric: 1.9.3 - lightning-utilities: 0.8.0 - lit: 15.0.7 - markdown: 3.4.1 - markdown-it-py: 2.2.0 - markupsafe: 2.1.2 - matplotlib: 3.7.0 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mpmath: 1.2.1 - multidict: 6.0.4 - multiprocess: 0.70.14 - mypy: 1.4.1 - mypy-extensions: 1.0.0 - nest-asyncio: 1.5.6 - networkx: 3.0 - ninja: 1.11.1 - nodeenv: 1.7.0 - numpy: 1.24.2 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.2.10.91 - nvidia-cusolver-cu11: 11.4.0.1 - nvidia-cusparse-cu11: 11.7.4.91 - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - oauthlib: 3.2.2 - omegaconf: 2.3.0 - onnx: 1.12.0 - onnxruntime: 1.14.1 - ordered-set: 4.1.0 - orjson: 3.8.6 - outcome: 1.2.0 - packaging: 23.1 - pandas: 1.5.3 - panel: 0.14.3 - param: 1.12.3 - parso: 0.8.3 - pathspec: 0.11.1 - pathtools: 0.1.2 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 23.1.2 - platformdirs: 3.9.1 - playwright: 1.30.0 - pluggy: 1.0.0 - pre-commit: 2.20.0 - prompt-toolkit: 3.0.39 - protobuf: 3.20.1 - psutil: 5.9.4 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py: 1.11.0 - py-cpuinfo: 9.0.0 - py3nvml: 0.2.7 - pyarrow: 11.0.0 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycparser: 2.21 - pyct: 0.5.0 - pydantic: 1.10.5 - pydeck: 0.8.0 - pyee: 9.0.4 - pygments: 2.15.1 - pyjwt: 2.6.0 - pympler: 1.0.1 - pynvml: 11.5.0 - pyparsing: 3.0.9 - pyrsistent: 0.19.3 - pytest: 7.2.0 - pytest-asyncio: 0.20.3 - pytest-cov: 4.0.0 - pytest-doctestplus: 0.12.1 - pytest-forked: 1.4.0 - pytest-rerunfailures: 10.3 - pytest-timeout: 2.1.0 - python-dateutil: 2.8.2 - python-dotenv: 1.0.0 - python-editor: 1.0.4 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.6 - pytorch-triton: 2.1.0+e6216047b8 - pytz: 2022.7.1 - pytz-deprecation-shim: 0.1.0.post0 - pyviz-comms: 2.2.1 - pyyaml: 6.0 - pyzmq: 25.1.0 - readchar: 4.0.3 - redis: 4.5.1 - regex: 2023.3.23 - requests: 2.28.2 - requests-mock: 1.10.0 - requests-oauthlib: 1.3.1 - responses: 0.18.0 - rfc3986: 1.5.0 - rich: 13.3.1 - rsa: 4.9 - s3fs: 2022.11.0 - scikit-learn: 1.2.1 - scipy: 1.10.1 - semver: 2.13.0 - sentencepiece: 0.1.99 - sentry-sdk: 1.28.1 - setproctitle: 1.3.2 - setuptools: 60.9.3 - six: 1.16.0 - smmap: 5.0.0 - sniffio: 1.3.0 - sortedcontainers: 2.4.0 - soupsieve: 2.4 - sqlalchemy: 1.4.41 - sqlalchemy2-stubs: 0.0.2a32 - sqlmodel: 0.0.8 - stack-data: 0.6.2 - starlette: 0.27.0 - starsessions: 1.3.0 - streamlit: 1.19.0 - sympy: 1.11.1 - tensorboard: 2.12.0 - tensorboard-data-server: 0.7.0 - tensorboard-plugin-wit: 1.8.1 - tensorboardx: 2.6 - threadpoolctl: 3.1.0 - tokenizers: 0.13.3 - toml: 0.10.2 - tomli: 2.0.1 - toolz: 0.12.0 - torch: 2.1.0.dev20230817+cu118 - torchmetrics: 1.0.0 - torchvision: 0.16.0.dev20230817+cu118 - tornado: 6.2 - tqdm: 4.64.1 - traitlets: 5.9.0 - transformers: 4.28.1 - trio: 0.22.0 - triton: 2.0.0 - types-cachetools: 5.3.0.5 - types-croniter: 1.3.2.9 - types-decorator: 5.1.8.3 - types-protobuf: 4.22.0.2 - types-pyopenssl: 23.1.0.2 - types-python-dateutil: 2.8.19.12 - types-pytz: 2023.3.0.0 - types-pyyaml: 6.0.12.9 - types-redis: 4.5.4.1 - types-requests: 2.28.11.17 - types-setuptools: 57.4.9 - types-six: 1.16.21.8 - types-tabulate: 0.9.0.2 - types-toml: 0.10.8.6 - types-tzlocal: 4.3.0.0 - types-ujson: 5.7.0.3 - types-urllib3: 1.26.25.10 - typeshed-client: 2.2.0 - typing-extensions: 4.7.0 - tzdata: 2022.7 - tzlocal: 4.2 - ujson: 5.7.0 - urllib3: 1.26.14 - uvicorn: 0.20.0 - uvloop: 0.17.0 - validators: 0.20.0 - virtualenv: 20.19.0 - wandb: 0.15.5 - watchdog: 2.3.0 - watchfiles: 0.18.1 - watermark: 2.4.2 - wcwidth: 0.2.6 - webencodings: 0.5.1 - websocket-client: 1.5.1 - websockets: 10.4 - werkzeug: 2.2.3 - wheel: 0.37.1 - wrapt: 1.15.0 - xmltodict: 0.13.0 - xxhash: 3.2.0 - yapf: 0.40.1 - yarl: 1.8.2 - zipp: 3.15.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.9 - release: 5.15.0-75-generic - version: #82-Ubuntu SMP Tue Jun 6 23:10:23 UTC 2023More info
No response
cc @borda @awaelchli @carmocca