Closed speediedan closed 1 year ago
Hi @speediedan Thanks for opening the issue. The PR that changed this was here: https://github.com/Lightning-AI/lightning/pull/17819 If I recall correctly, we used the latest APIs and conditioned it on >= 2.0 there because we knew that the previous loading logic wasn't really working / was incorrect (see the linked issues in that PR).
We could however try to import from here https://github.com/pytorch/pytorch/blob/v1.13.1/torch/distributed/fsdp/__init__.py and maybe that's enough to make it torch 1.13 compatible.
We could however try to import from here https://github.com/pytorch/pytorch/blob/v1.13.1/torch/distributed/fsdp/__init__.py and maybe that's enough to make it torch 1.13 compatible.
Yeah, I tinkered with conditionally using the old FullStateDictConfig
and StateDictType
import locations for a few minutes before I opened this PR but when I noticed that nonextant (in 1.13.1
) config classes (e.g. FullOptimStateDictConfig
) were being used for the _get_*_state_dict_context
context managers, I figured it was better to check with the team to see what you guys were planning before investigating further.
While I think we likely could backport this functionality to 1.x if we:
FullStateDictConfig
/StateDictType
in a 1.x context_get_*_state_dict_context
context managers to use a customized, backported version of the 2.x state_dict_type
context manager that extends the 1.x state_dict_type
context manager to set StateDictType
appropriately for the optim state dicts of all descendant modules.That would be a fair amount of custom code in the backport that could prove a fairly ugly/brittle solution though. As such, it may be worth considering the alternative of:
FullStateDictConfig
/StateDictType
in a 1.x context only for lightning_module_state_dict
importsOpen to other thoughts and suggestions of course (of which yours are so often awesome!). What do you think?
Thanks for the suggestion. Keeping the loading for the model state compatible with 1.13 seems feasible, and warning/error for optimizer state is probably the easiest for now. Would that work for you and the finetuning-scheduler as well?
Absolutely, sounds great. finetuning-scheduler
is overriding load_optimizer_state_dict
and optimizer_state
for other reasons already so I could just mirror the relevant Lightning warnings/errors in that context while continuing to rely upon Lightning's lightning_module_state_dict
for the model state dict collection.
@speediedan Do you plan to work on this? We'd want to fix this before the next release to avoid breaking these checkpoints.
Do you plan to work on this? We'd want to fix this before the next release to avoid breaking these checkpoints.
Not sure if I'll have the bandwidth in the next few days and wouldn't want to hold this up since I know it'll be important to ensure it's in 2.1. Certainly go ahead and implement. Thanks for checking!
@awaelchli I ran into #18277 today which is very close to this issue (#18230) in terms of modified code-path intersection so I figured it made sense to implement the discussed resolution to this issue in a PR that addresses both #18277 and #18230. Hope that's okay!
Bug description
With the latest dev commit as of this writing (0aeeb60566cc0375df3cf1a4458592651f143717), Lightning imports do not allow saving/loading of FSDP checkpoints with PyTorch < 2.0:
Also note that both save and load code-paths use the
state_dict_type
context manager and attempt to import from FSDP PyTorch 2.0 locations even with PyTorch < 2.0. https://github.com/Lightning-AI/lightning/blob/0aeeb60566cc0375df3cf1a4458592651f143717/src/lightning/fabric/strategies/fsdp.py#L792-L819Finally, I don't believe
FullOptimStateDictConfig
is defined in the FSDP 1.x API so that may need to be worked around if support for 1.x FSDP continues.I imagine the above challenges could be surmounted to continue providing support for saving/loading FSDP checkpoints with PyTorch < 2.0 but I wanted to ensure that was the intention. If deprecation of this FSDP functionality for PyTorch 1.x is expected I'll go ahead and begin deprecating this functionality in finetuning-scheduler.
Thanks again for all your invaluable contributions to the open-source ML ecosystem!
What version are you seeing the problem on?
master
How to reproduce the bug
Error messages and logs
Environment
Current environment
* CUDA: - GPU: - NVIDIA GeForce RTX 4090 - NVIDIA GeForce RTX 2070 SUPER - available: True - version: 11.7 * Lightning: - lightning: 2.1.0.dev0 - lightning-api-access: 0.0.5 - lightning-cloud: 0.5.37 - lightning-fabric: 2.0.6 - lightning-utilities: 0.9.0 - pt-lightning-sphinx-theme: 0.0.31 - pytorch-lightning: 2.0.6 - torch: 1.13.1 - torchmetrics: 1.0.2 - torchvision: 0.14.1 * Packages: - absl-py: 1.4.0 - aiobotocore: 2.5.2 - aiohttp: 3.8.5 - aioitertools: 0.11.0 - aiosignal: 1.3.1 - alabaster: 0.7.13 - altair: 5.0.1 - annotated-types: 0.5.0 - antlr4-python3-runtime: 4.9.3 - anyio: 3.7.1 - apeye: 1.4.0 - apeye-core: 1.1.4 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - asttokens: 2.2.1 - async-generator: 1.10 - async-lru: 2.0.4 - async-timeout: 4.0.2 - attrs: 23.1.0 - autodocsumm: 0.2.11 - babel: 2.12.1 - backcall: 0.2.0 - backoff: 2.2.1 - beautifulsoup4: 4.12.2 - bleach: 6.0.0 - blessed: 1.20.0 - blinker: 1.6.2 - bokeh: 3.2.1 - botocore: 1.29.161 - bracex: 2.3.post1 - brotlipy: 0.7.0 - cachecontrol: 0.13.1 - cachetools: 5.3.1 - certifi: 2023.7.22 - cffi: 1.15.1 - charset-normalizer: 2.0.4 - click: 8.1.6 - cloudpickle: 2.2.1 - colorama: 0.4.6 - coloredlogs: 15.0.1 - comm: 0.1.4 - contourpy: 1.1.0 - coverage: 7.2.7 - croniter: 1.4.1 - cryptography: 41.0.2 - cssutils: 2.7.1 - cycler: 0.11.0 - dateutils: 0.6.12 - debugpy: 1.6.7 - decorator: 5.1.1 - deepdiff: 6.3.1 - defusedxml: 0.7.1 - dict2css: 0.3.0 - docker: 6.1.3 - docstring-parser: 0.15 - docutils: 0.17.1 - domdf-python-tools: 3.6.1 - exceptiongroup: 1.1.2 - executing: 1.2.0 - fastapi: 0.100.1 - fastjsonschema: 2.18.0 - filelock: 3.12.2 - fire: 0.5.0 - flatbuffers: 23.5.26 - fonttools: 4.42.0 - fqdn: 1.5.1 - frozenlist: 1.4.0 - fsspec: 2023.6.0 - gitdb: 4.0.10 - gitpython: 3.1.32 - google-auth: 2.22.0 - google-auth-oauthlib: 1.0.0 - greenlet: 2.0.2 - grpcio: 1.56.2 - h11: 0.14.0 - html5lib: 1.1 - httpcore: 0.17.3 - httpx: 0.24.1 - humanfriendly: 10.0 - hydra-core: 1.3.2 - idna: 3.4 - imagesize: 1.4.1 - importlib-metadata: 6.8.0 - importlib-resources: 6.0.0 - iniconfig: 2.0.0 - inquirer: 3.1.3 - ipykernel: 6.25.0 - ipython: 8.6.0 - ipywidgets: 8.1.0 - isoduration: 20.11.0 - itsdangerous: 2.1.2 - jedi: 0.19.0 - jinja2: 3.0.3 - jmespath: 1.0.1 - joblib: 1.3.1 - json5: 0.9.14 - jsonargparse: 4.22.1 - jsonpointer: 2.4 - jsonschema: 4.18.6 - jsonschema-specifications: 2023.7.1 - jupyter-client: 8.3.0 - jupyter-core: 5.3.1 - jupyter-events: 0.7.0 - jupyter-lsp: 2.2.0 - jupyter-server: 2.7.0 - jupyter-server-terminals: 0.4.4 - jupyterlab: 4.0.4 - jupyterlab-pygments: 0.2.2 - jupyterlab-server: 2.24.0 - jupyterlab-widgets: 3.0.8 - kiwisolver: 1.4.4 - lightning: 2.1.0.dev0 - lightning-api-access: 0.0.5 - lightning-cloud: 0.5.37 - lightning-fabric: 2.0.6 - lightning-utilities: 0.9.0 - linkify-it-py: 2.0.2 - livereload: 2.6.3 - lockfile: 0.12.2 - markdown: 3.4.4 - markdown-it-py: 2.2.0 - markupsafe: 2.1.3 - matplotlib: 3.7.2 - matplotlib-inline: 0.1.6 - mdit-py-plugins: 0.3.5 - mdurl: 0.1.2 - mistune: 3.0.1 - mkl-fft: 1.3.6 - mkl-random: 1.2.2 - mkl-service: 2.4.0 - mpmath: 1.3.0 - msgpack: 1.0.5 - multidict: 6.0.4 - myst-parser: 0.18.1 - natsort: 8.4.0 - nbclient: 0.8.0 - nbconvert: 7.7.3 - nbformat: 5.9.2 - nbsphinx: 0.8.9 - nest-asyncio: 1.5.7 - notebook: 7.0.1 - notebook-shim: 0.2.3 - numpy: 1.25.0 - oauthlib: 3.2.2 - omegaconf: 2.3.0 - onnx: 1.12.0 - onnxruntime: 1.15.1 - ordered-set: 4.1.0 - outcome: 1.2.0 - overrides: 7.3.1 - packaging: 23.1 - pandas: 2.0.3 - pandoc: 2.3 - pandocfilters: 1.5.0 - panel: 1.2.1 - param: 1.13.0 - parso: 0.8.3 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 23.2.1 - platformdirs: 3.10.0 - playwright: 1.35.0 - pluggy: 1.2.0 - plumbum: 1.8.2 - ply: 3.11 - prometheus-client: 0.17.1 - prompt-toolkit: 3.0.39 - protobuf: 3.20.1 - psutil: 5.9.5 - pt-lightning-sphinx-theme: 0.0.31 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py: 1.11.0 - pyarrow: 12.0.1 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pycparser: 2.21 - pydantic: 2.0.3 - pydantic-core: 2.3.0 - pydeck: 0.8.0 - pyee: 9.0.4 - pygments: 2.15.1 - pyjwt: 2.8.0 - pympler: 1.0.1 - pyopenssl: 23.2.0 - pyparsing: 3.0.9 - pysocks: 1.7.1 - pytest: 7.4.0 - pytest-asyncio: 0.21.1 - pytest-cov: 4.1.0 - pytest-doctestplus: 0.13.0 - pytest-forked: 1.4.0 - pytest-random-order: 1.1.0 - pytest-rerunfailures: 10.3 - pytest-timeout: 2.1.0 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-json-logger: 2.0.7 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.6 - pytz: 2023.3 - pytz-deprecation-shim: 0.1.0.post0 - pyviz-comms: 2.3.2 - pyyaml: 6.0.1 - pyzmq: 25.1.0 - readchar: 4.0.5 - redis: 4.6.0 - referencing: 0.30.0 - requests: 2.31.0 - requests-mock: 1.11.0 - requests-oauthlib: 1.3.1 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.5.2 - rpds-py: 0.9.2 - rsa: 4.9 - ruamel.yaml: 0.17.32 - ruamel.yaml.clib: 0.2.7 - s3fs: 2023.6.0 - scikit-learn: 1.3.0 - scipy: 1.11.1 - send2trash: 1.8.2 - setuptools: 57.5.0 - six: 1.16.0 - smmap: 5.0.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - sortedcontainers: 2.4.0 - soupsieve: 2.4.1 - sphinx: 4.5.0 - sphinx-autobuild: 2021.3.14 - sphinx-autodoc-typehints: 1.19.1 - sphinx-copybutton: 0.5.2 - sphinx-jinja2-compat: 0.2.0 - sphinx-multiproject: 1.0.0rc1 - sphinx-paramlinks: 0.5.4 - sphinx-prompt: 1.5.0 - sphinx-rtd-dark-mode: 1.2.4 - sphinx-rtd-theme: 1.2.2 - sphinx-tabs: 3.4.0 - sphinx-togglebutton: 0.3.2 - sphinx-toolbox: 3.4.0 - sphinxcontrib-applehelp: 1.0.4 - sphinxcontrib-devhelp: 1.0.2 - sphinxcontrib-fulltoc: 1.2.0 - sphinxcontrib-htmlhelp: 2.0.1 - sphinxcontrib-jquery: 4.1 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-mockautodoc: 0.0.1.dev20130518 - sphinxcontrib-qthelp: 1.0.3 - sphinxcontrib-serializinghtml: 1.1.5 - sphinxcontrib-video: 0.2.0 - stack-data: 0.6.2 - starlette: 0.27.0 - starsessions: 1.3.0 - streamlit: 1.25.0 - sympy: 1.12 - tabulate: 0.9.0 - tenacity: 8.2.2 - tensorboard: 2.13.0 - tensorboard-data-server: 0.7.1 - tensorboardx: 2.6.2 - termcolor: 2.3.0 - terminado: 0.17.1 - threadpoolctl: 3.2.0 - tinycss2: 1.2.1 - toml: 0.10.2 - tomli: 2.0.1 - toolz: 0.12.0 - torch: 1.13.1 - torchmetrics: 1.0.2 - torchvision: 0.14.1 - tornado: 6.3.2 - tqdm: 4.65.0 - traitlets: 5.9.0 - trio: 0.21.0 - typeshed-client: 2.3.0 - typing-extensions: 4.7.1 - tzdata: 2023.3 - tzlocal: 4.3.1 - uc-micro-py: 1.0.2 - uri-template: 1.3.0 - urllib3: 1.26.16 - uvicorn: 0.23.2 - validators: 0.20.0 - watchdog: 3.0.0 - wcmatch: 8.4.1 - wcwidth: 0.2.6 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.6.1 - websockets: 11.0.3 - werkzeug: 2.3.6 - wheel: 0.38.4 - widgetsnbextension: 4.0.8 - wrapt: 1.15.0 - xyzservices: 2023.7.0 - yarl: 1.9.2 - zipp: 3.16.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.12 - release: 5.4.0-155-generic - version: #172-Ubuntu SMP Fri Jul 7 16:10:02 UTC 2023More info
No response
cc @awaelchli @carmocca