Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.29k stars 3.38k forks source link

`No inf checks were recorded for this optimizer` when using SWA together with batch norm layers #17245

Open JanSellner opened 1 year ago

JanSellner commented 1 year ago

Bug description

When SWA is used together with a model which has batch norm layers, the assertion No inf checks were recorded for this optimizer. is raised in the last epoch (=SWA epoch).

This worked fine with torch<2.0 but I am not sure whether it is a torch or lightning issue.

How to reproduce the bug

import os

import torch
from torch.utils.data import DataLoader, Dataset

from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import StochasticWeightAveraging
import torch.nn as nn

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(100, 20)

        # This line is the problem even though the batch norm layer is not even used
        self.norm = nn.BatchNorm1d(10)

        self.ce_loss = nn.CrossEntropyLoss(weight=torch.ones(20))

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self.ce_loss(self(batch), self(batch))
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.layer.parameters(), lr=0.1)

def run():
    train_data = DataLoader(RandomDataset(100, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(100, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=2,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=10,
        enable_model_summary=False,
        callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)],
        precision=16,
        accelerator='gpu',
        devices=1,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)

if __name__ == "__main__":
    run()

Error messages and logs

Traceback (most recent call last):                                                                                                                                                                                                                                                                                                                                  
  File "/mnt/ssd_8tb/htc/src/tt.py", line 75, in <module>
    run()
  File "/mnt/ssd_8tb/htc/src/tt.py", line 70, in run
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 935, in _run
    results = self._run_stage()
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 978, in _run_stage
    self.fit_loop.run()
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 201, in run
    self.advance()
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 354, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 133, in run
    self.advance(data_fetcher)
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 218, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 185, in run
    self._optimizer_step(kwargs.get("batch_idx", 0), closure)
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 261, in _optimizer_step
    call._call_lightning_module_hook(
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 142, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1265, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py", line 158, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 224, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/amp.py", line 83, in optimizer_step
    step_output = self.scaler.step(optimizer, **kwargs)
  File "/home/j562r/miniconda3/envs/htc2/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 370, in step
    assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.
Epoch 10:  67%|██████▋   | 2/3 [00:00<00:00, 14.17it/s, v_num=0]    

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): 2.0.1 #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): 2.0 #- Python version (e.g., 3.9): 3.10 #- OS (e.g., Linux): Ubuntu #- CUDA/cuDNN version: 10.8 #- GPU models and configuration: 3090 #- How you installed Lightning(`conda`, `pip`, source): pip #- Running environment of LightningApp (e.g. local, cloud): ```

More info

No response

JanSellner commented 1 year ago

Digging a bit further into this: according to this line: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/callbacks/stochastic_weight_avg.py#L254 in the SWA implementation, the backward pass should be skipped in the last SWA epoch. The variable _skip_backward defined in https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/loops/fit_loop.py#L133 is responsible for skipping the backward pass, however, this does not seem to work because the backward pass of the optimization object is still called: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/loops/optimization/automatic.py#L185. self._optimizer_step is called even though closure._backward_fn is None.

Further, the variable skipped_backward in the MixedPrecisionPlugin class considers only the closure results but not closure._backward_fn. Maybe this is the error? I.e. changing skipped_backward = closure_result is None to skipped_backward = closure._backward_fn is None would solve the problem.

So I think this might indeed be an issue with lightning and not with PyTorch but for some reason it only happens with the latest PyTorch version.

As a workaround, we can switch to manual optimization in the SWA epoch:

def on_train_epoch_start(self) -> None:
    if self.current_epoch == self.trainer.max_epochs - 1:
        # Workaround to always save the last epoch until the bug is fixed in lightning (https://github.com/Lightning-AI/lightning/issues/4539)
        self.trainer.check_val_every_n_epoch = 1

        # Disable backward pass for SWA until the bug is fixed in lightning (https://github.com/Lightning-AI/lightning/issues/17245)
        self.automatic_optimization = False
KamiCalcium commented 1 year ago
on_train_epoch_start

Hi,

Did you test your workaround solution? Sorry I'm very new to lightning. I'm wondering where should I add the solution? Under https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/callbacks/stochastic_weight_avg.py#L177? Thanks!

JanSellner commented 1 year ago

I added the on_train_epoch_start method directly to my lightning module, e.g. as part of the BoringModel in the initial example from above:

Updated BoringModel ```python class BoringModel(LightningModule): def __init__(self): super().__init__() self.layer = torch.nn.Linear(100, 20) # This line is the problem even though the batch norm layer is not even used self.norm = nn.BatchNorm1d(10) self.ce_loss = nn.CrossEntropyLoss(weight=torch.ones(20)) def forward(self, x): return self.layer(x) def on_train_epoch_start(self) -> None: if self.current_epoch == self.trainer.max_epochs - 1: # Workaround to always save the last epoch until the bug is fixed in lightning (https://github.com/Lightning-AI/lightning/issues/4539) self.trainer.check_val_every_n_epoch = 1 # Disable backward pass for SWA until the bug is fixed in lightning (https://github.com/Lightning-AI/lightning/issues/17245) self.automatic_optimization = False def training_step(self, batch, batch_idx): loss = self(batch).sum() self.log("train_loss", loss) return {"loss": loss} def validation_step(self, batch, batch_idx): loss = self.ce_loss(self(batch), self(batch)) self.log("valid_loss", loss) def test_step(self, batch, batch_idx): loss = self(batch).sum() self.log("test_loss", loss) def configure_optimizers(self): return torch.optim.Adam(self.layer.parameters(), lr=0.1) ```

This is working so far. I guess it would also be possible to write your custom SWA callback which inherits from StochasticWeightAveraging and overwrite the method there, but I have not tested it.

Borda commented 1 year ago

@JanSellner what PL version are you using?

JanSellner commented 1 year ago

2.0.1 at the time of creating this issue but also just reproduced with 2.0.2.

snipdome commented 1 year ago

I can confirm the undesired behaviour, pl 2.0.2

brendanartley commented 1 year ago

Seems related to this issue regarding AMP on the torch forums. Maybe this helps?

I can confirm that this solution by @JanSellner stops the error being thrown, but not sure if this is the expected behaviour for SWA.

Digging a bit further into this: according to this line: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/callbacks/stochastic_weight_avg.py#L254 in the SWA implementation, the backward pass should be skipped in the last SWA epoch. The variable _skip_backward defined in https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/loops/fit_loop.py#L133 is responsible for skipping the backward pass, however, this does not seem to work because the backward pass of the optimization object is still called: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/loops/optimization/automatic.py#L185. self._optimizer_step is called even though closure._backward_fn is None.

Further, the variable skipped_backward in the MixedPrecisionPlugin class considers only the closure results but not closure._backward_fn. Maybe this is the error? I.e. changing skipped_backward = closure_result is None to skipped_backward = closure._backward_fn is None would solve the problem.

So I think this might indeed be an issue with lightning and not with PyTorch but for some reason it only happens with the latest PyTorch version.

As a workaround, we can switch to manual optimization in the SWA epoch:

def on_train_epoch_start(self) -> None:
    if self.current_epoch == self.trainer.max_epochs - 1:
        # Workaround to always save the last epoch until the bug is fixed in lightning (https://github.com/Lightning-AI/lightning/issues/4539)
        self.trainer.check_val_every_n_epoch = 1

        # Disable backward pass for SWA until the bug is fixed in lightning (https://github.com/Lightning-AI/lightning/issues/17245)
        self.automatic_optimization = False
felixdivo commented 9 months ago

This has been tagged as 2.0.x, but there seems to be no fix for it there. When is this targeted to be solved?

b5y commented 3 months ago

I am facing the same error if I try to use accumulate_grad_batches with value larger than 1. This argument works only when automatic_optimization = True

So I downgraded lightning to 1.9.4 and pytorch version to 1.13.1. It didn't help. I also tried with downgrading lightning first, it didn't help either.

Any ideas how to use accumulate_grad_batches when automatic_optimization = False?

I know that we should call optimizer and manual_backward in training_step when automatic_optimization = False.

UPDATE: setting precision='bf16' instead of precision='16' fixed the problem (latest lightning and pytorch).