Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.32k stars 3.38k forks source link

`StepLR` doesn't work as expected after loading from checkpoint using `Trainer.fit(ckpt_path=...)` #17296

Closed rafathasan closed 1 year ago

rafathasan commented 1 year ago

Bug description

Well! the title speaks for itself. When the train.fit(ckpt_path=...) is called with checkpoint, it breaks StepLR. And the lr no longer get changed by lr scheduler. I have provided with highly reproduceable code. And no detailed explanation is required.

What version are you seeing the problem on?

2.0+

How to reproduce the bug ```python #!/opt/conda/bin/python import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint import torch import torch.nn as nn import torch.nn.functional as F import pytorch_lightning as pl from torch.utils.data import DataLoader, random_split from torchvision.datasets import MNIST from torchvision import transforms from torch.optim.lr_scheduler import StepLR class DemoModel(pl.LightningModule): def __init__(self, hidden_dim=64, learning_rate=2e-4): super().__init__() self.hidden_dim = hidden_dim self.learning_rate = learning_rate self.fc1 = nn.Linear(28 * 28, hidden_dim) self.fc2 = nn.Linear(hidden_dim, 10) def forward(self, x): x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = F.softmax(self.fc2(x), dim=1) return x def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.cross_entropy(logits, y) self.log('train_loss', loss) return loss def on_train_epoch_end(self): self.log("lr", self.optimizers().param_groups[0]['lr'], prog_bar=True, sync_dist=True) print(f"lr => {self.optimizers().param_groups[0]['lr']}") def validation_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.cross_entropy(logits, y) self.log('val_loss', loss) def test_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.cross_entropy(logits, y) self.log('test_loss', loss) def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.9) return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} def prepare_data(self): # Download dataset MNIST(root='data/', train=True, download=True) MNIST(root='data/', train=False, download=True) def setup(self, stage=None): # Assign train/val datasets for use in dataloaders if stage == 'fit' or stage is None: mnist_full = MNIST(root='data/', train=True, transform=transforms.ToTensor()) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) # Assign test dataset for use in dataloader(s) if stage == 'test' or stage is None: self.mnist_test = MNIST(root='data/', train=False, transform=transforms.ToTensor()) def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=512, num_workers=4) def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=512, num_workers=4) def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=64, num_workers=4) # Initialize the model model = DemoModel() # Initialize the ModelCheckpoint callback checkpoint_callback = ModelCheckpoint( dirpath='./checkpoints', save_last=True, ) # Initialize the trainer trainer = pl.Trainer(devices="0,1,2,3", accelerator="cuda", strategy="ddp", max_epochs=5, callbacks=[checkpoint_callback], log_every_n_steps=1, ) # Train the model with ModelCheckpoint callback trainer.fit(model) print("################################## loading checkpoint #############################################") # Initialize the model model = DemoModel() # Initialize the trainer trainer = pl.Trainer(devices="0,1,2,3", accelerator="cuda", strategy="ddp", max_epochs=10, log_every_n_steps=1) # Train the model with ModelCheckpoint callback trainer.fit(model, ckpt_path="./checkpoints/last.ckpt") ```
Error messages and logs ``` GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default warning_cache.warn( Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4 Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4 Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4 Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4 ---------------------------------------------------------------------------------------------------- distributed_backend=nccl All distributed processes registered. Starting with 4 processes ---------------------------------------------------------------------------------------------------- LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7] LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7] LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7] | Name | Type | Params -------------------------------- 0 | fc1 | Linear | 50.2 K 1 | fc2 | Linear | 650 -------------------------------- 50.9 K Trainable params 0 Non-trainable params 50.9 K Total params 0.204 Total estimated model params size (MB) Sanity Checking DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 4.86it/s]/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices. warning_cache.warn( Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 21.20it/s, v_num=1]lr => 0.00018 lr => 0.00018 lr => 0.00018 lr => 0.00018 Epoch 1: 100%|████████████████lr => 0.000162███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 27.05it/s, v_num=1, lr=0.00018] lr => 0.000162t [00:00, ?it/s] Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 20.63it/s, v_num=1, lr=0.00018]lr => 0.000162 lr => 0.000162 Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 26.38it/s, v_num=1, lr=0.000162lr => 0.000145800000000000020%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 328.04it/s] Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 20.76it/s, v_num=1, lr=0.000162]lr => 0.00014580000000000002 lr => 0.00014580000000000002 lr => 0.00014580000000000002 Epoch 3: 100%|████████████████lr => 0.00013122000000000003████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 26.05it/s, v_num=1, lr=0.000146] Validation: 0it [00:00, ?it/s] Epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 19.85it/s, v_num=1, lr=0.000146]lr => 0.00013122000000000003 lr => 0.00013122000000000003 Epoch 4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 26.99it/s, v_num=1, lr=0.000131lr => 0.000118098000000000033%|██████████████████████████████████████████████████▎ | 1/3 [00:00<00:00, 279.14it/s] lr => 0.00011809800000000003 Epoch 4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 20.67it/s, v_num=1, lr=0.000131]lr => 0.00011809800000000003 `Trainer.fit` stopped: `max_epochs=5` reached. Epoch 4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 20.42it/s, v_num=1, lr=0.000131] ################################## loading checkpoint ############################################# ################################## loading checkpoint ############################################# ################################## loading checkpoint ############################################# ################################## loading checkpoint ############################################# GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs Restoring states from the checkpoint path at ./checkpoints/last.ckpt /opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:337: UserWarning: The dirpath has changed from '/src/notebooks/checkpoints' to '/src/notebooks/lightning_logs/version_2/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded. warnings.warn( LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7] /opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:337: UserWarning: The dirpath has changed from '/src/notebooks/checkpoints' to '/src/notebooks/lightning_logs/version_2/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded. warnings.warn( LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7] /opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:337: UserWarning: The dirpath has changed from '/src/notebooks/checkpoints' to '/src/notebooks/lightning_logs/version_2/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded. warnings.warn( LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7] /opt/conda/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:337: UserWarning: The dirpath has changed from '/src/notebooks/checkpoints' to '/src/notebooks/lightning_logs/version_2/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded. warnings.warn( LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7] | Name | Type | Params -------------------------------- 0 | fc1 | Linear | 50.2 K 1 | fc2 | Linear | 650 -------------------------------- 50.9 K Trainable params 0 Non-trainable params 50.9 K Total params 0.204 Total estimated model params size (MB) Restored all states from the checkpoint at ./checkpoints/last.ckpt Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 21.26it/s, v_num=2]lr => 0.0002 lr => 0.0002 lr => 0.0002 lr => 0.0002 Epoch 6: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 21.34it/s, v_num=2, lr=0.0002]lr => 0.0002 lr => 0.0002 lr => 0.0002 lr => 0.0002 Epoch 7: 100%|████████████████lr => 0.0002██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 26.33it/s, v_num=2, lr=0.0002] Epoch 7: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 20.50it/s, v_num=2, lr=0.0002]lr => 0.0002 lr => 0.0002 lr => 0.0002 Epoch 8: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 21.41it/s, v_num=2, lr=0.0002]lr => 0.0002 lr => 0.0002 lr => 0.0002 lr => 0.0002 Epoch 9: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 26.95it/s, v_num=2, lr=0.0002] lr => 0.0002 0%| | 0/3 [00:00 0.0002DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 124.44it/s] Epoch 9: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 20.53it/s, v_num=2, lr=0.0002]lr => 0.0002 lr => 0.0002 `Trainer.fit` stopped: `max_epochs=10` reached. Epoch 9: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 20.23it/s, v_num=2, lr=0.0002] ```

Environment

Current environment * CUDA: - GPU: - Tesla K80 - Tesla K80 - Tesla K80 - Tesla K80 - Tesla K80 - Tesla K80 - Tesla K80 - Tesla K80 - available: True - version: 11.7 * Lightning: - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.0 - torch: 2.0.0 - torchdata: 0.6.0 - torchelastic: 0.2.2 - torchmetrics: 0.11.4 - torchsummary: 1.5.1 - torchtext: 0.15.1 - torchvision: 0.15.1 * Packages: - aiohttp: 3.8.4 - aiosignal: 1.3.1 - antlr4-python3-runtime: 4.9.3 - anyio: 3.6.2 - appdirs: 1.4.4 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - asttokens: 2.0.5 - astunparse: 1.6.3 - async-timeout: 4.0.2 - attrs: 22.1.0 - autopep8: 2.0.2 - backcall: 0.2.0 - beautifulsoup4: 4.11.1 - bleach: 6.0.0 - brotlipy: 0.7.0 - certifi: 2022.9.24 - cffi: 1.15.1 - chardet: 4.0.0 - charset-normalizer: 2.0.4 - click: 8.1.3 - cmake: 3.26.1 - comm: 0.1.3 - conda: 22.11.1 - conda-build: 3.23.3 - conda-package-handling: 1.9.0 - contourpy: 1.0.7 - cryptography: 38.0.1 - cycler: 0.11.0 - debugpy: 1.6.6 - decorator: 5.1.1 - defusedxml: 0.7.1 - dnspython: 2.2.1 - docker-pycreds: 0.4.0 - docopt: 0.6.2 - exceptiongroup: 1.0.4 - executing: 0.8.3 - expecttest: 0.1.4 - fastjsonschema: 2.16.3 - filelock: 3.6.0 - flit-core: 3.6.0 - fonttools: 4.39.2 - fqdn: 1.5.1 - frozenlist: 1.3.3 - fsspec: 2023.3.0 - future: 0.18.2 - gdown: 4.7.1 - gitdb: 4.0.10 - gitpython: 3.1.31 - glob2: 0.7 - hypothesis: 6.61.0 - idna: 3.4 - ipykernel: 6.22.0 - ipython: 8.11.0 - ipython-genutils: 0.2.0 - ipywidgets: 8.0.5 - isoduration: 20.11.0 - jedi: 0.18.1 - jinja2: 3.1.2 - joblib: 1.2.0 - jsonpointer: 2.3 - jsonschema: 4.17.3 - jupyter: 1.0.0 - jupyter-client: 8.1.0 - jupyter-console: 6.6.3 - jupyter-core: 5.3.0 - jupyter-events: 0.6.3 - jupyter-server: 2.5.0 - jupyter-server-terminals: 0.4.4 - jupyterlab-pygments: 0.2.2 - jupyterlab-widgets: 3.0.6 - kiwisolver: 1.4.4 - libarchive-c: 2.9 - lightning-utilities: 0.8.0 - lit: 16.0.0 - markupsafe: 2.0.1 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - mistune: 2.0.5 - mkl-fft: 1.3.1 - mkl-random: 1.2.2 - mkl-service: 2.4.0 - mpmath: 1.2.1 - multidict: 6.0.4 - nbclassic: 0.5.3 - nbclient: 0.7.2 - nbconvert: 7.2.10 - nbformat: 5.8.0 - nest-asyncio: 1.5.6 - networkx: 3.0 - notebook: 6.5.3 - notebook-shim: 0.2.2 - numpy: 1.24.2 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.2.10.91 - nvidia-cusolver-cu11: 11.4.0.1 - nvidia-cusparse-cu11: 11.7.4.91 - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - omegaconf: 2.3.0 - onedrivedownloader: 1.1.3 - packaging: 23.0 - pandocfilters: 1.5.0 - parso: 0.8.3 - pathtools: 0.1.2 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 22.3.1 - pipreqs: 0.4.11 - pkginfo: 1.8.3 - platformdirs: 3.2.0 - pluggy: 1.0.0 - prometheus-client: 0.16.0 - prompt-toolkit: 3.0.38 - protobuf: 4.22.1 - psutil: 5.9.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pycodestyle: 2.10.0 - pycosat: 0.6.4 - pycparser: 2.21 - pygments: 2.11.2 - pyopenssl: 22.0.0 - pyparsing: 3.0.9 - pyrsistent: 0.19.3 - pysocks: 1.7.1 - python-dateutil: 2.8.2 - python-etcd: 0.4.5 - python-json-logger: 2.0.7 - pytorch-lightning: 2.0.0 - pytz: 2022.1 - pyyaml: 6.0 - pyzmq: 25.0.2 - qtconsole: 5.4.1 - qtpy: 2.3.0 - requests: 2.28.1 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - ruamel.yaml: 0.17.21 - ruamel.yaml.clib: 0.2.6 - scikit-learn: 1.2.2 - scipy: 1.10.1 - send2trash: 1.8.0 - sentry-sdk: 1.17.0 - setproctitle: 1.3.2 - setuptools: 65.5.0 - six: 1.16.0 - smmap: 5.0.0 - sniffio: 1.3.0 - sortedcontainers: 2.4.0 - soupsieve: 2.3.2.post1 - stack-data: 0.2.0 - sympy: 1.11.1 - terminado: 0.17.1 - thop: 0.1.1.post2209072238 - threadpoolctl: 3.1.0 - tinycss2: 1.2.1 - toml: 0.10.2 - tomli: 2.0.1 - toolz: 0.12.0 - torch: 2.0.0 - torchdata: 0.6.0 - torchelastic: 0.2.2 - torchmetrics: 0.11.4 - torchsummary: 1.5.1 - torchtext: 0.15.1 - torchvision: 0.15.1 - tornado: 6.2 - tqdm: 4.65.0 - traitlets: 5.7.1 - triton: 2.0.0 - types-dataclasses: 0.6.6 - typing-extensions: 4.4.0 - uri-template: 1.2.0 - urllib3: 1.26.13 - wandb: 0.14.0 - wcwidth: 0.2.5 - webcolors: 1.12 - webencodings: 0.5.1 - websocket-client: 1.5.1 - wheel: 0.37.1 - widgetsnbextension: 4.0.6 - yarg: 0.1.9 - yarl: 1.8.2 * System: - OS: Linux - architecture: - 64bit - - processor: x86_64 - python: 3.10.8 - version: #76~20.04.1-Ubuntu SMP Mon Mar 20 15:54:19 UTC 2023

More info

No response

cc @awaelchli

ryan597 commented 1 year ago

I can get the lr updates with both of these calls instead.

    def on_train_epoch_end(self):
        self.log("lr", self.lr_schedulers().get_last_lr()[0],  prog_bar=True, sync_dist=True)
        print(f"lr => {self.lr_schedulers().state_dict()['_last_lr']}")

I have checked that the optimizer hook uses the correct learning rate (ie it continues to step the lr) by looking at this step in the lightning/pytorch/loops/optimization/automatic.py file

https://github.com/Lightning-AI/lightning/blob/fd4697c62c059fc7b9946e84d91625ecb6efdbe5/src/lightning/pytorch/loops/optimization/automatic.py#L237-L271

and checking the lr with print("OPT PARMS", optimizer.optimizer.param_groups[0]['lr'])

The trainer is keeping track of its own optimizer configs through trainer.lr_scheduler_configs as seen in lightning/pytorch/trainer/connectors/checkpoint_connector.py

https://github.com/Lightning-AI/lightning/blob/fd4697c62c059fc7b9946e84d91625ecb6efdbe5/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py#L383-L391

So it seems to be in the method that you call self.optimizers().param_groups[0]['lr'] that is not updated.

Edit: I looked further to this and see that by setting use_pl_optimizer=False the optimizer methods are correct again

def on_train_epoch_end(self):
        print(f"lr => {self.optimizers(use_pl_optimizer=False).param_groups[0]['lr']}")

But I still get different losses if I run the first training for 10 epochs, versus 5 and then fit with the checkpoint and go another 5 epochs so there is maybe something affecting the rng differently.

Further Testing Script

```python import os import lightning.pytorch as pl from lightning.pytorch.callbacks import ModelCheckpoint # import pytorch_lightning as pl # from pytorch_lightning.callbacks import ModelCheckpoint import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, random_split from torchvision.datasets import MNIST from torchvision import transforms from torch.optim.lr_scheduler import StepLR class DemoModel(pl.LightningModule): def __init__(self, hidden_dim=64, learning_rate=2e-4): super().__init__() self.hidden_dim = hidden_dim self.learning_rate = learning_rate self.fc1 = nn.Linear(28 * 28, hidden_dim) self.fc2 = nn.Linear(hidden_dim, 10) def forward(self, x): x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = F.softmax(self.fc2(x), dim=1) return x def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.cross_entropy(logits, y) self.log('train_loss', loss, prog_bar=True) return loss def on_train_epoch_end(self): print(f"lr => {self.optimizers(use_pl_optimizer=False).param_groups[0]['lr']}") def validation_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.cross_entropy(logits, y) self.log('val_loss', loss) def test_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.cross_entropy(logits, y) self.log('test_loss', loss) def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.9) return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} def prepare_data(self): # Download dataset MNIST(root='data/', train=True, download=True) MNIST(root='data/', train=False, download=True) def setup(self, stage=None): # Assign train/val datasets for use in dataloaders if stage == 'fit' or stage is None: mnist_full = MNIST(root='data/', train=True, transform=transforms.ToTensor()) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) # Assign test dataset for use in dataloader(s) if stage == 'test' or stage is None: self.mnist_test = MNIST(root='data/', train=False, transform=transforms.ToTensor()) def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=512, num_workers=4, shuffle=False, pin_memory=True, persistent_workers=True) def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=512, num_workers=4) def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=64, num_workers=4) def test_load(fit, refit): # Initialize the ModelCheckpoint callback pl.seed_everything(100) torch.manual_seed(100) checkpoint_callback = ModelCheckpoint( dirpath='./checkpoints', save_last=True ) if fit: model = DemoModel() trainer = pl.Trainer(devices=1, accelerator="cuda", strategy="auto", max_epochs=fit, callbacks=[checkpoint_callback], log_every_n_steps=1, enable_model_summary=False ) trainer.fit(model) if refit: print("## loading checkpoint ##") model = DemoModel() trainer = pl.Trainer(devices=1, accelerator="cuda", strategy="auto", max_epochs=10 if refit == "ckpt_path" else 5, callbacks=[checkpoint_callback], log_every_n_steps=1, enable_model_summary=False ) if refit == "from_checkpoint": model = DemoModel.load_from_checkpoint("./checkpoints/last.ckpt", map_location=torch.device('cuda')) trainer.fit(model) if refit == "ckpt_path": trainer.fit(model, ckpt_path="./checkpoints/last.ckpt") os.remove("./checkpoints/last.ckpt") if __name__ == "__main__": torch.set_float32_matmul_precision('high') print("FROM CHECKPOINT") test_load(5, "from_checkpoint") print("\n\nCKPT PATH") test_load(5, "ckpt_path") print("\n\nCONTINUOUS") test_load(10, False) ```

Further Testing Logs

``` FROM CHECKPOINT Epoch 0: 100%|████████████████████████| 108/108 [00:00<00:00, 243.36it/s, v_num=165, loss=2.050]lr => 0.00018 Epoch 1: 100%|████████████████████████| 108/108 [00:00<00:00, 260.28it/s, v_num=165, loss=1.850]lr => 0.000162 Epoch 2: 100%|████████████████████████| 108/108 [00:00<00:00, 250.01it/s, v_num=165, loss=1.740]lr => 0.00014580000000000002 Epoch 3: 100%|████████████████████████| 108/108 [00:00<00:00, 220.72it/s, v_num=165, loss=1.670]lr => 0.00013122000000000003 Epoch 4: 100%|████████████████████████| 108/108 [00:00<00:00, 249.70it/s, v_num=165, loss=1.650]lr => 0.00011809800000000003 Epoch 4: 100%|████████████████████████| 108/108 [00:00<00:00, 247.65it/s, v_num=165, loss=1.650] ## loading checkpoint ## Epoch 0: 100%|████████████████████████| 108/108 [00:00<00:00, 231.37it/s, v_num=166, loss=1.610]lr => 0.00018 Epoch 1: 100%|████████████████████████| 108/108 [00:00<00:00, 237.92it/s, v_num=166, loss=1.600]lr => 0.000162 Epoch 2: 100%|████████████████████████| 108/108 [00:00<00:00, 239.84it/s, v_num=166, loss=1.590]lr => 0.00014580000000000002 Epoch 3: 100%|████████████████████████| 108/108 [00:00<00:00, 249.72it/s, v_num=166, loss=1.580]lr => 0.00013122000000000003 Epoch 4: 100%|████████████████████████| 108/108 [00:00<00:00, 234.87it/s, v_num=166, loss=1.570]lr => 0.00011809800000000003 Epoch 4: 100%|████████████████████████| 108/108 [00:00<00:00, 232.78it/s, v_num=166, loss=1.570] CKPT PATH Epoch 0: 100%|████████████████████████| 108/108 [00:00<00:00, 235.67it/s, v_num=167, loss=2.050]lr => 0.00018 Epoch 1: 100%|████████████████████████| 108/108 [00:00<00:00, 222.51it/s, v_num=167, loss=1.850]lr => 0.000162 Epoch 2: 100%|████████████████████████| 108/108 [00:00<00:00, 236.36it/s, v_num=167, loss=1.740]lr => 0.00014580000000000002 Epoch 3: 100%|████████████████████████| 108/108 [00:00<00:00, 242.70it/s, v_num=167, loss=1.670]lr => 0.00013122000000000003 Epoch 4: 100%|████████████████████████| 108/108 [00:00<00:00, 230.25it/s, v_num=167, loss=1.650]lr => 0.00011809800000000003 Epoch 4: 100%|████████████████████████| 108/108 [00:00<00:00, 228.53it/s, v_num=167, loss=1.650] ## loading checkpoint ## Epoch 5: 100%|████████████████████████| 108/108 [00:00<00:00, 213.97it/s, v_num=168, loss=1.590]lr => 0.00010628820000000004 Epoch 6: 100%|████████████████████████| 108/108 [00:00<00:00, 234.14it/s, v_num=168, loss=1.580]lr => 9.565938000000004e-05 Epoch 7: 100%|████████████████████████| 108/108 [00:00<00:00, 216.25it/s, v_num=168, loss=1.580]lr => 8.609344200000004e-05 Epoch 8: 100%|████████████████████████| 108/108 [00:00<00:00, 245.35it/s, v_num=168, loss=1.570]lr => 7.748409780000004e-05 Epoch 9: 100%|████████████████████████| 108/108 [00:00<00:00, 229.21it/s, v_num=168, loss=1.570]lr => 6.973568802000003e-05 Epoch 9: 100%|████████████████████████| 108/108 [00:00<00:00, 227.43it/s, v_num=168, loss=1.570] CONTINUOUS Epoch 0: 100%|████████████████████████| 108/108 [00:00<00:00, 203.86it/s, v_num=169, loss=2.050]lr => 0.00018 Epoch 1: 100%|████████████████████████| 108/108 [00:00<00:00, 230.33it/s, v_num=169, loss=1.850]lr => 0.000162 Epoch 2: 100%|████████████████████████| 108/108 [00:00<00:00, 243.14it/s, v_num=169, loss=1.740]lr => 0.00014580000000000002 Epoch 3: 100%|████████████████████████| 108/108 [00:00<00:00, 219.94it/s, v_num=169, loss=1.670]lr => 0.00013122000000000003 Epoch 4: 100%|████████████████████████| 108/108 [00:00<00:00, 245.71it/s, v_num=169, loss=1.650]lr => 0.00011809800000000003 Epoch 5: 100%|████████████████████████| 108/108 [00:00<00:00, 229.97it/s, v_num=169, loss=1.630]lr => 0.00010628820000000004 Epoch 6: 100%|████████████████████████| 108/108 [00:00<00:00, 235.31it/s, v_num=169, loss=1.620]lr => 9.565938000000004e-05 Epoch 7: 100%|████████████████████████| 108/108 [00:00<00:00, 249.66it/s, v_num=169, loss=1.610]lr => 8.609344200000004e-05 Epoch 8: 100%|████████████████████████| 108/108 [00:00<00:00, 240.83it/s, v_num=169, loss=1.600]lr => 7.748409780000004e-05 Epoch 9: 100%|████████████████████████| 108/108 [00:00<00:00, 245.06it/s, v_num=169, loss=1.600]lr => 6.973568802000003e-05 Epoch 9: 100%|████████████████████████| 108/108 [00:00<00:00, 243.11it/s, v_num=169, loss=1.600] ```

rafathasan commented 1 year ago

@ryan597 I have a question. If self.automatic_optimization=False is set and I have to manually do loss.backward(), self.optimizers().step() and self.optimizers().zero_grad(). By calling self.optimizers() points the correct optimizer without explicitly passing use_pl_optimizer=False. So, the question is, isn't this making it more ambiguous?

    def __init__(self, hidden_dim=64, learning_rate=2e-4):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate

        self.fc1 = nn.Linear(28 * 28, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 10)
        self.automatic_optimization=False
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('train_loss', loss)

        loss.backward()

        self.optimizers().step()
        self.optimizers().zero_grad()

        return loss
    def on_train_epoch_end(self):
        self.lr_schedulers().step()
        self.log("lr", self.optimizers(use_pl_optimizer=0).param_groups[0]['lr'],  prog_bar=True, sync_dist=True)
        print(f"lr => {self.optimizers(use_pl_optimizer=0).param_groups[0]['lr']}")

logs

outputs ``` INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs INFO:lightning_fabric.utilities.distributed:Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1 INFO:pytorch_lightning.utilities.rank_zero:---------------------------------------------------------------------------------------------------- distributed_backend=nccl All distributed processes registered. Starting with 1 processes ---------------------------------------------------------------------------------------------------- INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] INFO:pytorch_lightning.callbacks.model_summary: | Name | Type | Params -------------------------------- 0 | fc1 | Linear | 50.2 K 1 | fc2 | Linear | 650 -------------------------------- 50.9 K Trainable params 0 Non-trainable params 50.9 K Total params 0.204 Total estimated model params size (MB) /usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py:561: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(_create_warning_msg( /usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices. warning_cache.warn( Epoch 4: 100% 108/108 [00:07<00:00, 15.23it/s, v_num=9, lr=0.000131] lr => 0.00018 lr => 0.000162 lr => 0.00014580000000000002 lr => 0.00013122000000000003 lr => 0.00011809800000000003 INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached. INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs ################################## loading checkpoint ############################################# INFO:lightning_fabric.utilities.distributed:Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1 INFO:pytorch_lightning.utilities.rank_zero:---------------------------------------------------------------------------------------------------- distributed_backend=nccl All distributed processes registered. Starting with 1 processes ---------------------------------------------------------------------------------------------------- INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at ./checkpoints/last.ckpt /usr/local/lib/python3.9/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:337: UserWarning: The dirpath has changed from '/content/checkpoints' to '/content/lightning_logs/version_10/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded. warnings.warn( INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] INFO:pytorch_lightning.callbacks.model_summary: | Name | Type | Params -------------------------------- 0 | fc1 | Linear | 50.2 K 1 | fc2 | Linear | 650 -------------------------------- 50.9 K Trainable params 0 Non-trainable params 50.9 K Total params 0.204 Total estimated model params size (MB) INFO:pytorch_lightning.utilities.rank_zero:Restored all states from the checkpoint at ./checkpoints/last.ckpt /usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py:561: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(_create_warning_msg( /usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices. warning_cache.warn( Epoch 9: 100% 108/108 [00:07<00:00, 14.95it/s, v_num=10, lr=7.75e-5] lr => 0.00010628820000000004 lr => 9.565938000000004e-05 lr => 8.609344200000004e-05 lr => 7.748409780000004e-05 lr => 6.973568802000003e-05 INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached. ```
ryan597 commented 1 year ago

@rafathasan I haven't looked specifically at that, but if you have told it you want to do manual optimization then it won't connect/wrap the pl optimizers to it, so its going to return the plain optimizer without needing to set use_pl_optimizers=False.

I do agree though, you should be getting the same LR regardless of passing use_pl_optimizers=False or not.

rafathasan commented 1 year ago

@ryan597 I think I should clarify my question a bit further. When I try to get lr with self.optimizers().param_groups[0]['lr'] the problem still persist while self.automatic_optimization=False is set. It only work with self.optimizers(use_pl_optimizers=False).param_groups[0]['lr'] while self.automatic_optimization=False is set. So my question was how come I can use optimizers correctly by manually calling self.optimizers().step() and self.optimizers().zero_grad() without passing use_pl_optimizers=False but it does not work for self.optimizers().param_groups[0]['lr'] ??

awaelchli commented 1 year ago

Hey everyone. This was fixed in #18280 and released in 2.0.7. But no worries, the scheduler and optimizer was always correctly reloaded. The only bug was that the optimizer wrapper returned by self.optimizers() had an outdated state, but the internal optimizer was always using the correct state and that's the one used for training. The PR I linked will make sure the wrapper correctly represents the state of the user's optimizer.

I'm closing the issue because I was able to use the provided repro (thanks a ton, way to go!) to verify the fix. Cheers!