Closed nikvaessen closed 1 year ago
Hey @nikvaessen This is very odd, I'll look into it.
Btw, in case you didn't know the recommended way to load in Fabric is
fabric.load("first_session.ckpt", {"network": network, "opt": opt})
because this generalizes across all strategies and accelerators + offers a convenient way to make scripts stateful in general. And this will pass your assertion. So you can use this way as a workaround until I make the bugfix.
Thanks for reporting!
And this will pass your assertion. So you can use this way as a workaround until I make the bugfix.
I've tried the following modification to the reloading part of the reproduction code sample:
def second_session():
fabric = lightning.Fabric(accelerator="gpu", devices=1)
fabric.launch()
network = Network()
opt = torch.optim.Adam(network.parameters())
if USE_FABRIC_SETUP:
network, opt = fabric.setup(network, opt)
else:
network.cuda()
state = {"network": network, "opt": opt}
remainder = fabric.load("first_session.ckpt", state)
print("remainder:", remainder)
print("wrapper", opt.state_dict())
print("optimizer", opt.optimizer.state_dict())
print("checkpoint\n", torch.load("first_session.ckpt"))
assert len(opt.state_dict()["state"]) >= 0
This still results in
remainder: {}
wrapper {'state': {}, 'param_groups': [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1]}]}
optimizer {'state': {}, 'param_groups': [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1]}]}
If I run my code with pip install git+https://github.com/Lightning-AI/lightning@bugfix/fabric-optimizer-load-state
, the state dictionary is correctly loaded. So thanks for the bugfix :)
@nikvaessen Thanks for confirming!
Bug description
When
load_state_dict()
is called on an optimizer object returned byfabric.setup(...)
, the resultingstate_dict
will be empty.What version are you seeing the problem on?
v2.0
How to reproduce the bug
Error messages and logs
The checkpoint (
ckpt['opt']
:The result of calling
state_dict()
after callingopt.load_state_dict()
:Environment
Current environment
* CUDA: - GPU: - NVIDIA GeForce GTX 1080 Ti - available: True - version: 11.7 * Lightning: - lightning: 2.0.8 - lightning-cloud: 0.5.37 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.3 - torch: 2.0.1 - torch-tb-profiler: 0.4.1 - torchaudio: 2.0.1 - torchdata: 0.6.1 - torchmetrics: 0.11.4 - torchvision: 0.15.1 * Packages: - absl-py: 1.4.0 - aiohttp: 3.8.4 - aiosignal: 1.3.1 - alembic: 1.11.1 - antlr4-python3-runtime: 4.9.3 - anyio: 3.7.0 - appdirs: 1.4.4 - arrow: 1.2.3 - async-timeout: 4.0.2 - attrs: 23.1.0 - autopage: 0.5.1 - backoff: 2.2.1 - beautifulsoup4: 4.12.2 - blessed: 1.20.0 - cachetools: 5.3.1 - certifi: 2023.5.7 - cffi: 1.15.1 - charset-normalizer: 3.1.0 - click: 8.1.3 - cliff: 4.3.0 - cloudpickle: 2.2.1 - cmaes: 0.9.1 - cmake: 3.26.4 - cmd2: 2.4.3 - colorlog: 6.7.0 - contourpy: 1.1.0 - croniter: 1.3.15 - cycler: 0.11.0 - dateutils: 0.6.12 - deepdiff: 6.3.0 - docker-pycreds: 0.4.0 - exceptiongroup: 1.1.1 - fastapi: 0.98.0 - filelock: 3.12.2 - fonttools: 4.40.0 - frozenlist: 1.3.3 - fsspec: 2023.6.0 - gitdb: 4.0.10 - gitpython: 3.1.31 - google-auth: 2.20.0 - google-auth-oauthlib: 0.4.6 - greenlet: 2.0.2 - grpcio: 1.54.2 - h11: 0.14.0 - huggingface-hub: 0.15.1 - hydra-core: 1.3.2 - hydra-optuna-sweeper: 1.2.0 - hydra-submitit-launcher: 1.2.0 - idna: 3.4 - importlib-metadata: 6.7.0 - iniconfig: 2.0.0 - inquirer: 3.1.3 - itsdangerous: 2.1.2 - jinja2: 3.1.2 - jiwer: 3.0.0 - joblib: 1.2.0 - kiwisolver: 1.4.4 - lightning: 2.0.8 - lightning-cloud: 0.5.37 - lightning-utilities: 0.8.0 - lit: 16.0.6 - mako: 1.2.4 - markdown: 3.4.3 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - matplotlib: 3.7.1 - mdurl: 0.1.2 - mpmath: 1.3.0 - multidict: 6.0.4 - nanow2v2: 1.0 - networkx: 3.1 - numpy: 1.25.0 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.2.10.91 - nvidia-cusolver-cu11: 11.4.0.1 - nvidia-cusparse-cu11: 11.7.4.91 - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - oauthlib: 3.2.2 - omegaconf: 2.3.0 - optuna: 2.10.1 - ordered-set: 4.1.0 - packaging: 23.1 - pandas: 2.0.2 - pathtools: 0.1.2 - pbr: 5.11.1 - pillow: 9.5.0 - pip: 23.1.2 - pluggy: 1.2.0 - polars: 0.18.3 - prettytable: 3.8.0 - protobuf: 4.23.3 - psutil: 5.9.5 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pycparser: 2.21 - pydantic: 1.10.9 - pygments: 2.15.1 - pyjwt: 2.7.0 - pyparsing: 3.1.0 - pyperclip: 1.8.2 - pytest: 7.4.0 - python-dateutil: 2.8.2 - python-dotenv: 1.0.0 - python-editor: 1.0.4 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.3 - pytz: 2023.3 - pyyaml: 6.0 - rapidfuzz: 2.13.7 - readchar: 4.0.5 - regex: 2023.6.3 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - rich: 13.4.2 - rsa: 4.9 - safetensors: 0.3.1 - scikit-learn: 1.2.2 - scipy: 1.10.1 - seaborn: 0.12.2 - sentry-sdk: 1.25.1 - setproctitle: 1.3.2 - setuptools: 65.5.0 - six: 1.16.0 - smmap: 5.0.0 - sniffio: 1.3.0 - soundfile: 0.12.1 - soupsieve: 2.4.1 - sqlalchemy: 2.0.17 - starlette: 0.27.0 - starsessions: 1.3.0 - stevedore: 5.1.0 - submitit: 1.4.5 - sympy: 1.12 - tensorboard: 2.12.0 - tensorboard-data-server: 0.7.1 - tensorboard-plugin-wit: 1.8.1 - threadpoolctl: 3.1.0 - tokenizers: 0.13.3 - tomli: 2.0.1 - torch: 2.0.1 - torch-tb-profiler: 0.4.1 - torchaudio: 2.0.1 - torchdata: 0.6.1 - torchmetrics: 0.11.4 - torchvision: 0.15.1 - tqdm: 4.65.0 - traitlets: 5.9.0 - transformers: 4.30.2 - triton: 2.0.0 - typing-extensions: 4.6.3 - tzdata: 2023.3 - urllib3: 1.26.16 - uvicorn: 0.22.0 - wandb: 0.15.5 - wcwidth: 0.2.6 - websocket-client: 1.6.0 - websockets: 11.0.3 - werkzeug: 2.3.6 - wheel: 0.40.0 - yarl: 1.9.2 - zipp: 3.15.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: - python: 3.10.12 - release: 6.4.12-arch1-1 - version: #1 SMP PREEMPT_DYNAMIC Thu, 24 Aug 2023 00:38:14 +0000More info
No response
cc @carmocca @justusschock @awaelchli