Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.33k stars 3.38k forks source link

Lightning 2.1.0 no longer supports saving/loading FSDP checkpoints with PyTorch < 2.0 #18230

Closed speediedan closed 1 year ago

speediedan commented 1 year ago

Bug description

With the latest dev commit as of this writing (0aeeb60566cc0375df3cf1a4458592651f143717), Lightning imports do not allow saving/loading of FSDP checkpoints with PyTorch < 2.0:

./tests/tests_pytorch/strategies/test_fsdp.py::test_fsdp_strategy_save_optimizer_states[2] Failed: [undefined]ModuleNotFoundError: No module named 'torch.distributed.fsdp.api'
tmpdir = local('/tmp/pytest-of-speediedan/pytest-807/test_fsdp_strategy_save_optimi0')
wrap_min_params = 2

    @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=False, min_torch="1.12")
    @pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
    def test_fsdp_strategy_save_optimizer_states(tmpdir, wrap_min_params):
        """Test to ensure that the full state dict and optimizer states is saved when using FSDP strategy.

        Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the model can
        be restored to DDP, it means that the optimizer states were saved correctly.
        """
        model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params)

        strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params))
        trainer = Trainer(
            default_root_dir=tmpdir,
            accelerator="gpu",
            devices=2,
            strategy=strategy,
            precision="16-mixed",
            max_epochs=1,
            barebones=True,
        )

        trainer.fit(model)
        model_path = os.path.join(tmpdir, "last.ckpt")
        model_path = trainer.strategy.broadcast(model_path)
>       trainer.save_checkpoint(model_path)

tests/tests_pytorch/strategies/test_fsdp.py:577: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
src/lightning/pytorch/trainer/trainer.py:1360: in save_checkpoint
    checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
src/lightning/pytorch/trainer/connectors/checkpoint_connector.py:433: in dump_checkpoint
    "state_dict": self._get_lightning_module_state_dict(),
src/lightning/pytorch/trainer/connectors/checkpoint_connector.py:491: in _get_lightning_module_state_dict
    return self.trainer.strategy.lightning_module_state_dict()

self = <lightning.pytorch.strategies.fsdp.FSDPStrategy object at 0x7fd315bcb850>

    def lightning_module_state_dict(self) -> Dict[str, Any]:
        """Gathers the full state dict by unsharding all the parameters.

        To avoid OOM, the returned parameters will only be returned on rank 0 and on CPU. All other ranks get an empty
        dict.
        """
        from torch.distributed.fsdp import FullyShardedDataParallel
>       from torch.distributed.fsdp.api import FullStateDictConfig, StateDictType
E       ModuleNotFoundError: No module named 'torch.distributed.fsdp.api'

src/lightning/pytorch/strategies/fsdp.py:177: ModuleNotFoundError

Also note that both save and load code-paths use the state_dict_type context manager and attempt to import from FSDP PyTorch 2.0 locations even with PyTorch < 2.0. https://github.com/Lightning-AI/lightning/blob/0aeeb60566cc0375df3cf1a4458592651f143717/src/lightning/fabric/strategies/fsdp.py#L792-L819

Finally, I don't believe FullOptimStateDictConfig is defined in the FSDP 1.x API so that may need to be worked around if support for 1.x FSDP continues.

I imagine the above challenges could be surmounted to continue providing support for saving/loading FSDP checkpoints with PyTorch < 2.0 but I wanted to ensure that was the intention. If deprecation of this FSDP functionality for PyTorch 1.x is expected I'll go ahead and begin deprecating this functionality in finetuning-scheduler.

Thanks again for all your invaluable contributions to the open-source ML ecosystem!

What version are you seeing the problem on?

master

How to reproduce the bug

To reproduce, install `torch==1.13.1` and run the following existing test:
./tests/tests_pytorch/strategies/test_fsdp.py::test_fsdp_strategy_save_optimizer_states[2]

Error messages and logs

./tests/tests_pytorch/strategies/test_fsdp.py::test_fsdp_strategy_save_optimizer_states[2] Failed: [undefined]ModuleNotFoundError: No module named 'torch.distributed.fsdp.api'
tmpdir = local('/tmp/pytest-of-speediedan/pytest-807/test_fsdp_strategy_save_optimi0')
wrap_min_params = 2

    @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=False, min_torch="1.12")
    @pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
    def test_fsdp_strategy_save_optimizer_states(tmpdir, wrap_min_params):
        """Test to ensure that the full state dict and optimizer states is saved when using FSDP strategy.

        Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the model can
        be restored to DDP, it means that the optimizer states were saved correctly.
        """
        model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params)

        strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params))
        trainer = Trainer(
            default_root_dir=tmpdir,
            accelerator="gpu",
            devices=2,
            strategy=strategy,
            precision="16-mixed",
            max_epochs=1,
            barebones=True,
        )

        trainer.fit(model)
        model_path = os.path.join(tmpdir, "last.ckpt")
        model_path = trainer.strategy.broadcast(model_path)
>       trainer.save_checkpoint(model_path)

tests/tests_pytorch/strategies/test_fsdp.py:577: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
src/lightning/pytorch/trainer/trainer.py:1360: in save_checkpoint
    checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
src/lightning/pytorch/trainer/connectors/checkpoint_connector.py:433: in dump_checkpoint
    "state_dict": self._get_lightning_module_state_dict(),
src/lightning/pytorch/trainer/connectors/checkpoint_connector.py:491: in _get_lightning_module_state_dict
    return self.trainer.strategy.lightning_module_state_dict()

self = <lightning.pytorch.strategies.fsdp.FSDPStrategy object at 0x7fd315bcb850>

    def lightning_module_state_dict(self) -> Dict[str, Any]:
        """Gathers the full state dict by unsharding all the parameters.

        To avoid OOM, the returned parameters will only be returned on rank 0 and on CPU. All other ranks get an empty
        dict.
        """
        from torch.distributed.fsdp import FullyShardedDataParallel
>       from torch.distributed.fsdp.api import FullStateDictConfig, StateDictType
E       ModuleNotFoundError: No module named 'torch.distributed.fsdp.api'

src/lightning/pytorch/strategies/fsdp.py:177: ModuleNotFoundError

Environment

Current environment * CUDA: - GPU: - NVIDIA GeForce RTX 4090 - NVIDIA GeForce RTX 2070 SUPER - available: True - version: 11.7 * Lightning: - lightning: 2.1.0.dev0 - lightning-api-access: 0.0.5 - lightning-cloud: 0.5.37 - lightning-fabric: 2.0.6 - lightning-utilities: 0.9.0 - pt-lightning-sphinx-theme: 0.0.31 - pytorch-lightning: 2.0.6 - torch: 1.13.1 - torchmetrics: 1.0.2 - torchvision: 0.14.1 * Packages: - absl-py: 1.4.0 - aiobotocore: 2.5.2 - aiohttp: 3.8.5 - aioitertools: 0.11.0 - aiosignal: 1.3.1 - alabaster: 0.7.13 - altair: 5.0.1 - annotated-types: 0.5.0 - antlr4-python3-runtime: 4.9.3 - anyio: 3.7.1 - apeye: 1.4.0 - apeye-core: 1.1.4 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - asttokens: 2.2.1 - async-generator: 1.10 - async-lru: 2.0.4 - async-timeout: 4.0.2 - attrs: 23.1.0 - autodocsumm: 0.2.11 - babel: 2.12.1 - backcall: 0.2.0 - backoff: 2.2.1 - beautifulsoup4: 4.12.2 - bleach: 6.0.0 - blessed: 1.20.0 - blinker: 1.6.2 - bokeh: 3.2.1 - botocore: 1.29.161 - bracex: 2.3.post1 - brotlipy: 0.7.0 - cachecontrol: 0.13.1 - cachetools: 5.3.1 - certifi: 2023.7.22 - cffi: 1.15.1 - charset-normalizer: 2.0.4 - click: 8.1.6 - cloudpickle: 2.2.1 - colorama: 0.4.6 - coloredlogs: 15.0.1 - comm: 0.1.4 - contourpy: 1.1.0 - coverage: 7.2.7 - croniter: 1.4.1 - cryptography: 41.0.2 - cssutils: 2.7.1 - cycler: 0.11.0 - dateutils: 0.6.12 - debugpy: 1.6.7 - decorator: 5.1.1 - deepdiff: 6.3.1 - defusedxml: 0.7.1 - dict2css: 0.3.0 - docker: 6.1.3 - docstring-parser: 0.15 - docutils: 0.17.1 - domdf-python-tools: 3.6.1 - exceptiongroup: 1.1.2 - executing: 1.2.0 - fastapi: 0.100.1 - fastjsonschema: 2.18.0 - filelock: 3.12.2 - fire: 0.5.0 - flatbuffers: 23.5.26 - fonttools: 4.42.0 - fqdn: 1.5.1 - frozenlist: 1.4.0 - fsspec: 2023.6.0 - gitdb: 4.0.10 - gitpython: 3.1.32 - google-auth: 2.22.0 - google-auth-oauthlib: 1.0.0 - greenlet: 2.0.2 - grpcio: 1.56.2 - h11: 0.14.0 - html5lib: 1.1 - httpcore: 0.17.3 - httpx: 0.24.1 - humanfriendly: 10.0 - hydra-core: 1.3.2 - idna: 3.4 - imagesize: 1.4.1 - importlib-metadata: 6.8.0 - importlib-resources: 6.0.0 - iniconfig: 2.0.0 - inquirer: 3.1.3 - ipykernel: 6.25.0 - ipython: 8.6.0 - ipywidgets: 8.1.0 - isoduration: 20.11.0 - itsdangerous: 2.1.2 - jedi: 0.19.0 - jinja2: 3.0.3 - jmespath: 1.0.1 - joblib: 1.3.1 - json5: 0.9.14 - jsonargparse: 4.22.1 - jsonpointer: 2.4 - jsonschema: 4.18.6 - jsonschema-specifications: 2023.7.1 - jupyter-client: 8.3.0 - jupyter-core: 5.3.1 - jupyter-events: 0.7.0 - jupyter-lsp: 2.2.0 - jupyter-server: 2.7.0 - jupyter-server-terminals: 0.4.4 - jupyterlab: 4.0.4 - jupyterlab-pygments: 0.2.2 - jupyterlab-server: 2.24.0 - jupyterlab-widgets: 3.0.8 - kiwisolver: 1.4.4 - lightning: 2.1.0.dev0 - lightning-api-access: 0.0.5 - lightning-cloud: 0.5.37 - lightning-fabric: 2.0.6 - lightning-utilities: 0.9.0 - linkify-it-py: 2.0.2 - livereload: 2.6.3 - lockfile: 0.12.2 - markdown: 3.4.4 - markdown-it-py: 2.2.0 - markupsafe: 2.1.3 - matplotlib: 3.7.2 - matplotlib-inline: 0.1.6 - mdit-py-plugins: 0.3.5 - mdurl: 0.1.2 - mistune: 3.0.1 - mkl-fft: 1.3.6 - mkl-random: 1.2.2 - mkl-service: 2.4.0 - mpmath: 1.3.0 - msgpack: 1.0.5 - multidict: 6.0.4 - myst-parser: 0.18.1 - natsort: 8.4.0 - nbclient: 0.8.0 - nbconvert: 7.7.3 - nbformat: 5.9.2 - nbsphinx: 0.8.9 - nest-asyncio: 1.5.7 - notebook: 7.0.1 - notebook-shim: 0.2.3 - numpy: 1.25.0 - oauthlib: 3.2.2 - omegaconf: 2.3.0 - onnx: 1.12.0 - onnxruntime: 1.15.1 - ordered-set: 4.1.0 - outcome: 1.2.0 - overrides: 7.3.1 - packaging: 23.1 - pandas: 2.0.3 - pandoc: 2.3 - pandocfilters: 1.5.0 - panel: 1.2.1 - param: 1.13.0 - parso: 0.8.3 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 23.2.1 - platformdirs: 3.10.0 - playwright: 1.35.0 - pluggy: 1.2.0 - plumbum: 1.8.2 - ply: 3.11 - prometheus-client: 0.17.1 - prompt-toolkit: 3.0.39 - protobuf: 3.20.1 - psutil: 5.9.5 - pt-lightning-sphinx-theme: 0.0.31 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py: 1.11.0 - pyarrow: 12.0.1 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pycparser: 2.21 - pydantic: 2.0.3 - pydantic-core: 2.3.0 - pydeck: 0.8.0 - pyee: 9.0.4 - pygments: 2.15.1 - pyjwt: 2.8.0 - pympler: 1.0.1 - pyopenssl: 23.2.0 - pyparsing: 3.0.9 - pysocks: 1.7.1 - pytest: 7.4.0 - pytest-asyncio: 0.21.1 - pytest-cov: 4.1.0 - pytest-doctestplus: 0.13.0 - pytest-forked: 1.4.0 - pytest-random-order: 1.1.0 - pytest-rerunfailures: 10.3 - pytest-timeout: 2.1.0 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-json-logger: 2.0.7 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.6 - pytz: 2023.3 - pytz-deprecation-shim: 0.1.0.post0 - pyviz-comms: 2.3.2 - pyyaml: 6.0.1 - pyzmq: 25.1.0 - readchar: 4.0.5 - redis: 4.6.0 - referencing: 0.30.0 - requests: 2.31.0 - requests-mock: 1.11.0 - requests-oauthlib: 1.3.1 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.5.2 - rpds-py: 0.9.2 - rsa: 4.9 - ruamel.yaml: 0.17.32 - ruamel.yaml.clib: 0.2.7 - s3fs: 2023.6.0 - scikit-learn: 1.3.0 - scipy: 1.11.1 - send2trash: 1.8.2 - setuptools: 57.5.0 - six: 1.16.0 - smmap: 5.0.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - sortedcontainers: 2.4.0 - soupsieve: 2.4.1 - sphinx: 4.5.0 - sphinx-autobuild: 2021.3.14 - sphinx-autodoc-typehints: 1.19.1 - sphinx-copybutton: 0.5.2 - sphinx-jinja2-compat: 0.2.0 - sphinx-multiproject: 1.0.0rc1 - sphinx-paramlinks: 0.5.4 - sphinx-prompt: 1.5.0 - sphinx-rtd-dark-mode: 1.2.4 - sphinx-rtd-theme: 1.2.2 - sphinx-tabs: 3.4.0 - sphinx-togglebutton: 0.3.2 - sphinx-toolbox: 3.4.0 - sphinxcontrib-applehelp: 1.0.4 - sphinxcontrib-devhelp: 1.0.2 - sphinxcontrib-fulltoc: 1.2.0 - sphinxcontrib-htmlhelp: 2.0.1 - sphinxcontrib-jquery: 4.1 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-mockautodoc: 0.0.1.dev20130518 - sphinxcontrib-qthelp: 1.0.3 - sphinxcontrib-serializinghtml: 1.1.5 - sphinxcontrib-video: 0.2.0 - stack-data: 0.6.2 - starlette: 0.27.0 - starsessions: 1.3.0 - streamlit: 1.25.0 - sympy: 1.12 - tabulate: 0.9.0 - tenacity: 8.2.2 - tensorboard: 2.13.0 - tensorboard-data-server: 0.7.1 - tensorboardx: 2.6.2 - termcolor: 2.3.0 - terminado: 0.17.1 - threadpoolctl: 3.2.0 - tinycss2: 1.2.1 - toml: 0.10.2 - tomli: 2.0.1 - toolz: 0.12.0 - torch: 1.13.1 - torchmetrics: 1.0.2 - torchvision: 0.14.1 - tornado: 6.3.2 - tqdm: 4.65.0 - traitlets: 5.9.0 - trio: 0.21.0 - typeshed-client: 2.3.0 - typing-extensions: 4.7.1 - tzdata: 2023.3 - tzlocal: 4.3.1 - uc-micro-py: 1.0.2 - uri-template: 1.3.0 - urllib3: 1.26.16 - uvicorn: 0.23.2 - validators: 0.20.0 - watchdog: 3.0.0 - wcmatch: 8.4.1 - wcwidth: 0.2.6 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.6.1 - websockets: 11.0.3 - werkzeug: 2.3.6 - wheel: 0.38.4 - widgetsnbextension: 4.0.8 - wrapt: 1.15.0 - xyzservices: 2023.7.0 - yarl: 1.9.2 - zipp: 3.16.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.12 - release: 5.4.0-155-generic - version: #172-Ubuntu SMP Fri Jul 7 16:10:02 UTC 2023

More info

No response

cc @awaelchli @carmocca

awaelchli commented 1 year ago

Hi @speediedan Thanks for opening the issue. The PR that changed this was here: https://github.com/Lightning-AI/lightning/pull/17819 If I recall correctly, we used the latest APIs and conditioned it on >= 2.0 there because we knew that the previous loading logic wasn't really working / was incorrect (see the linked issues in that PR).

We could however try to import from here https://github.com/pytorch/pytorch/blob/v1.13.1/torch/distributed/fsdp/__init__.py and maybe that's enough to make it torch 1.13 compatible.

speediedan commented 1 year ago

We could however try to import from here https://github.com/pytorch/pytorch/blob/v1.13.1/torch/distributed/fsdp/__init__.py and maybe that's enough to make it torch 1.13 compatible.

Yeah, I tinkered with conditionally using the old FullStateDictConfig and StateDictType import locations for a few minutes before I opened this PR but when I noticed that nonextant (in 1.13.1) config classes (e.g. FullOptimStateDictConfig) were being used for the _get_*_state_dict_context context managers, I figured it was better to check with the team to see what you guys were planning before investigating further.

While I think we likely could backport this functionality to 1.x if we:

  1. Used the old import locations for FullStateDictConfig/StateDictType in a 1.x context
  2. Enhanced the _get_*_state_dict_context context managers to use a customized, backported version of the 2.x state_dict_type context manager that extends the 1.x state_dict_type context manager to set StateDictType appropriately for the optim state dicts of all descendant modules.

That would be a fair amount of custom code in the backport that could prove a fairly ugly/brittle solution though. As such, it may be worth considering the alternative of:

  1. Using the old import locations for FullStateDictConfig/StateDictType in a 1.x context only for lightning_module_state_dict imports
  2. In a 1.x context, for optimizer state dict save/load calls, provide a warning with a noop indicating optimizer state saving or loading is not supported for 1.x FSDP (after updating the Lightning documentation accordingly of course) (throwing an exception may be preferred rather than a warning and noop)

Open to other thoughts and suggestions of course (of which yours are so often awesome!). What do you think?

awaelchli commented 1 year ago

Thanks for the suggestion. Keeping the loading for the model state compatible with 1.13 seems feasible, and warning/error for optimizer state is probably the easiest for now. Would that work for you and the finetuning-scheduler as well?

speediedan commented 1 year ago

Absolutely, sounds great. finetuning-scheduler is overriding load_optimizer_state_dict and optimizer_state for other reasons already so I could just mirror the relevant Lightning warnings/errors in that context while continuing to rely upon Lightning's lightning_module_state_dict for the model state dict collection.

carmocca commented 1 year ago

@speediedan Do you plan to work on this? We'd want to fix this before the next release to avoid breaking these checkpoints.

speediedan commented 1 year ago

Do you plan to work on this? We'd want to fix this before the next release to avoid breaking these checkpoints.

Not sure if I'll have the bandwidth in the next few days and wouldn't want to hold this up since I know it'll be important to ensure it's in 2.1. Certainly go ahead and implement. Thanks for checking!

speediedan commented 1 year ago

@awaelchli I ran into #18277 today which is very close to this issue (#18230) in terms of modified code-path intersection so I figured it made sense to implement the discussed resolution to this issue in a PR that addresses both #18277 and #18230. Hope that's okay!