Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.36k stars 3.38k forks source link

Fabric wrappers for optimizers do not load state dict #18482

Closed nikvaessen closed 1 year ago

nikvaessen commented 1 year ago

Bug description

When load_state_dict() is called on an optimizer object returned by fabric.setup(...), the resulting state_dict will be empty.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import lightning
import torch

USE_FABRIC_SETUP = True

class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = torch.nn.Linear(1000, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.fc1(x))

def first_session():
    fabric = lightning.Fabric(accelerator="gpu", devices=1)
    fabric.launch()

    network = Network()
    opt = torch.optim.Adam(network.parameters())

    if USE_FABRIC_SETUP:
        network, opt = fabric.setup(network, opt)
    else:
        network = network.cuda()

    for i in range(5):
        opt.zero_grad()

        inp = torch.rand((8, 1000)).cuda()
        pred = network(inp)

        target = torch.randint(low=0, high=9, size=(8,)).cuda()
        loss = torch.nn.functional.cross_entropy(pred, target)

        fabric.backward(loss)
        opt.step()

    fabric.save("first_session.ckpt", {"network": network, "opt": opt})

def second_session():
    fabric = lightning.Fabric(accelerator="gpu", devices=1)
    fabric.launch()

    network = Network()
    opt = torch.optim.Adam(network.parameters())

    if USE_FABRIC_SETUP:
        network, opt = fabric.setup(network, opt)

    ckpt = fabric.load("first_session.ckpt")

    network.load_state_dict(ckpt["network"])
    opt.load_state_dict(ckpt["opt"])

    assert len(opt.state_dict()["state"]) == len(ckpt["opt"]["state"])

def main():
    print(lightning.__version__)
    torch.manual_seed(123)

    first_session()
    second_session()

if __name__ == "__main__":
    main()

Error messages and logs

Traceback (most recent call last):
  File "/home/nvaessen/phd/repo/nanow2v2/playground/loading_bug/buggy_loading.py", line 71, in <module>
    main()
  File "/home/nvaessen/phd/repo/nanow2v2/playground/loading_bug/buggy_loading.py", line 67, in main
    second_session()
  File "/home/nvaessen/phd/repo/nanow2v2/playground/loading_bug/buggy_loading.py", line 59, in second_session
    assert len(opt.state_dict()["state"]) == len(ckpt["opt"]["state"])
AssertionError

The checkpoint (ckpt['opt']:

{'state': {0: {'step': tensor(5.), 'exp_avg': tensor([[ 0.0008,  0.0006,  0.0015,  ...,  0.0010,  0.0012,  0.0011],
        [ 0.0013,  0.0017,  0.0026,  ...,  0.0018,  0.0029,  0.0013],
        [-0.0038, -0.0014, -0.0019,  ..., -0.0006, -0.0010, -0.0031],
        ...,
        [ 0.0196,  0.0150,  0.0168,  ...,  0.0152,  0.0192,  0.0213],
        [ 0.0039,  0.0045,  0.0067,  ...,  0.0053,  0.0063,  0.0055],
        [ 0.0006,  0.0004,  0.0009,  ...,  0.0004,  0.0010,  0.0003]]), 'exp_avg_sq': tensor([[1.5035e-07, 8.4814e-08, 5.0681e-07,  ..., 2.2370e-07, 3.5894e-07,
         2.7821e-07],
        [4.1096e-07, 6.4151e-07, 1.5764e-06,  ..., 7.8267e-07, 1.9212e-06,
         3.8671e-07],
        [8.4621e-06, 6.0896e-06, 1.3199e-05,  ..., 6.2065e-06, 9.0778e-06,
         1.1578e-05],
        ...,
        [2.4836e-05, 2.4092e-05, 2.5735e-05,  ..., 2.5762e-05, 3.9960e-05,
         3.7594e-05],
        [1.9858e-06, 2.4212e-06, 6.8515e-06,  ..., 3.5609e-06, 6.0807e-06,
         4.0056e-06],
        [9.6174e-08, 3.3862e-08, 1.6834e-07,  ..., 3.2796e-08, 2.2325e-07,
         1.7551e-08]])}, 1: {'step': tensor(5.), 'exp_avg': tensor([0.0032, 0.0051, 0.0019, 0.0000, 0.0007, 0.0042, 0.0058, 0.0310, 0.0122,
        0.0016]), 'exp_avg_sq': tensor([2.3620e-06, 5.9441e-06, 2.4069e-05, 0.0000e+00, 1.1375e-07, 4.0922e-06,
        7.9036e-06, 1.0822e-04, 2.1972e-05, 5.7628e-07])}}, 'param_groups': [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1]}]}

The result of calling state_dict() after calling opt.load_state_dict():

{'state': {}, 'param_groups': [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1]}]}

Environment

Current environment * CUDA: - GPU: - NVIDIA GeForce GTX 1080 Ti - available: True - version: 11.7 * Lightning: - lightning: 2.0.8 - lightning-cloud: 0.5.37 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.3 - torch: 2.0.1 - torch-tb-profiler: 0.4.1 - torchaudio: 2.0.1 - torchdata: 0.6.1 - torchmetrics: 0.11.4 - torchvision: 0.15.1 * Packages: - absl-py: 1.4.0 - aiohttp: 3.8.4 - aiosignal: 1.3.1 - alembic: 1.11.1 - antlr4-python3-runtime: 4.9.3 - anyio: 3.7.0 - appdirs: 1.4.4 - arrow: 1.2.3 - async-timeout: 4.0.2 - attrs: 23.1.0 - autopage: 0.5.1 - backoff: 2.2.1 - beautifulsoup4: 4.12.2 - blessed: 1.20.0 - cachetools: 5.3.1 - certifi: 2023.5.7 - cffi: 1.15.1 - charset-normalizer: 3.1.0 - click: 8.1.3 - cliff: 4.3.0 - cloudpickle: 2.2.1 - cmaes: 0.9.1 - cmake: 3.26.4 - cmd2: 2.4.3 - colorlog: 6.7.0 - contourpy: 1.1.0 - croniter: 1.3.15 - cycler: 0.11.0 - dateutils: 0.6.12 - deepdiff: 6.3.0 - docker-pycreds: 0.4.0 - exceptiongroup: 1.1.1 - fastapi: 0.98.0 - filelock: 3.12.2 - fonttools: 4.40.0 - frozenlist: 1.3.3 - fsspec: 2023.6.0 - gitdb: 4.0.10 - gitpython: 3.1.31 - google-auth: 2.20.0 - google-auth-oauthlib: 0.4.6 - greenlet: 2.0.2 - grpcio: 1.54.2 - h11: 0.14.0 - huggingface-hub: 0.15.1 - hydra-core: 1.3.2 - hydra-optuna-sweeper: 1.2.0 - hydra-submitit-launcher: 1.2.0 - idna: 3.4 - importlib-metadata: 6.7.0 - iniconfig: 2.0.0 - inquirer: 3.1.3 - itsdangerous: 2.1.2 - jinja2: 3.1.2 - jiwer: 3.0.0 - joblib: 1.2.0 - kiwisolver: 1.4.4 - lightning: 2.0.8 - lightning-cloud: 0.5.37 - lightning-utilities: 0.8.0 - lit: 16.0.6 - mako: 1.2.4 - markdown: 3.4.3 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - matplotlib: 3.7.1 - mdurl: 0.1.2 - mpmath: 1.3.0 - multidict: 6.0.4 - nanow2v2: 1.0 - networkx: 3.1 - numpy: 1.25.0 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.2.10.91 - nvidia-cusolver-cu11: 11.4.0.1 - nvidia-cusparse-cu11: 11.7.4.91 - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - oauthlib: 3.2.2 - omegaconf: 2.3.0 - optuna: 2.10.1 - ordered-set: 4.1.0 - packaging: 23.1 - pandas: 2.0.2 - pathtools: 0.1.2 - pbr: 5.11.1 - pillow: 9.5.0 - pip: 23.1.2 - pluggy: 1.2.0 - polars: 0.18.3 - prettytable: 3.8.0 - protobuf: 4.23.3 - psutil: 5.9.5 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pycparser: 2.21 - pydantic: 1.10.9 - pygments: 2.15.1 - pyjwt: 2.7.0 - pyparsing: 3.1.0 - pyperclip: 1.8.2 - pytest: 7.4.0 - python-dateutil: 2.8.2 - python-dotenv: 1.0.0 - python-editor: 1.0.4 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.3 - pytz: 2023.3 - pyyaml: 6.0 - rapidfuzz: 2.13.7 - readchar: 4.0.5 - regex: 2023.6.3 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - rich: 13.4.2 - rsa: 4.9 - safetensors: 0.3.1 - scikit-learn: 1.2.2 - scipy: 1.10.1 - seaborn: 0.12.2 - sentry-sdk: 1.25.1 - setproctitle: 1.3.2 - setuptools: 65.5.0 - six: 1.16.0 - smmap: 5.0.0 - sniffio: 1.3.0 - soundfile: 0.12.1 - soupsieve: 2.4.1 - sqlalchemy: 2.0.17 - starlette: 0.27.0 - starsessions: 1.3.0 - stevedore: 5.1.0 - submitit: 1.4.5 - sympy: 1.12 - tensorboard: 2.12.0 - tensorboard-data-server: 0.7.1 - tensorboard-plugin-wit: 1.8.1 - threadpoolctl: 3.1.0 - tokenizers: 0.13.3 - tomli: 2.0.1 - torch: 2.0.1 - torch-tb-profiler: 0.4.1 - torchaudio: 2.0.1 - torchdata: 0.6.1 - torchmetrics: 0.11.4 - torchvision: 0.15.1 - tqdm: 4.65.0 - traitlets: 5.9.0 - transformers: 4.30.2 - triton: 2.0.0 - typing-extensions: 4.6.3 - tzdata: 2023.3 - urllib3: 1.26.16 - uvicorn: 0.22.0 - wandb: 0.15.5 - wcwidth: 0.2.6 - websocket-client: 1.6.0 - websockets: 11.0.3 - werkzeug: 2.3.6 - wheel: 0.40.0 - yarl: 1.9.2 - zipp: 3.15.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: - python: 3.10.12 - release: 6.4.12-arch1-1 - version: #1 SMP PREEMPT_DYNAMIC Thu, 24 Aug 2023 00:38:14 +0000

More info

No response

cc @carmocca @justusschock @awaelchli

awaelchli commented 1 year ago

Hey @nikvaessen This is very odd, I'll look into it.

Btw, in case you didn't know the recommended way to load in Fabric is

fabric.load("first_session.ckpt", {"network": network, "opt": opt})

because this generalizes across all strategies and accelerators + offers a convenient way to make scripts stateful in general. And this will pass your assertion. So you can use this way as a workaround until I make the bugfix.

Thanks for reporting!

nikvaessen commented 1 year ago

And this will pass your assertion. So you can use this way as a workaround until I make the bugfix.

I've tried the following modification to the reloading part of the reproduction code sample:

def second_session():
    fabric = lightning.Fabric(accelerator="gpu", devices=1)
    fabric.launch()

    network = Network()
    opt = torch.optim.Adam(network.parameters())

    if USE_FABRIC_SETUP:
        network, opt = fabric.setup(network, opt)
    else:
        network.cuda()

    state = {"network": network, "opt": opt}
    remainder = fabric.load("first_session.ckpt", state)

    print("remainder:", remainder)
    print("wrapper", opt.state_dict())
    print("optimizer", opt.optimizer.state_dict())
    print("checkpoint\n", torch.load("first_session.ckpt"))

    assert len(opt.state_dict()["state"]) >= 0

This still results in

remainder: {}
wrapper {'state': {}, 'param_groups': [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1]}]}
optimizer {'state': {}, 'param_groups': [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1]}]}

If I run my code with pip install git+https://github.com/Lightning-AI/lightning@bugfix/fabric-optimizer-load-state, the state dictionary is correctly loaded. So thanks for the bugfix :)

awaelchli commented 1 year ago

@nikvaessen Thanks for confirming!