Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.93k stars 3.34k forks source link

Support non-conventional optimizers #16143

Open simonpokorny opened 1 year ago

simonpokorny commented 1 year ago

Bug description

I turned off the automatic optimisation, because I am using SAM optimizer (https://github.com/davda54/sam). After that, the global_step of the trainer is not updating each train step, therefore checkpointcallback are not call even though it is pass to trainer ..

used callback :

pl.callbacks.ModelCheckpoint save_weights_only=True, save_top_k=3, monitor="val_acc", mode="max", save_on_train_epoch_end=False)

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): - PyTorch Lightning Version 1.8.4: - PyTorch Version 1.13: - Python version 3.9: ```

More info

No response

cc @tchaton @justusschock @awaelchli @borda @carmocca

carmocca commented 1 year ago

Can you provide more details? This example shows it working

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.automatic_optimization = False

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        print(self.trainer.global_step)
        opt = self.optimizers()
        opt.zero_grad()
        loss = self(batch).sum()
        loss.backward()
        opt.step()
        return loss.detach()

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        max_epochs=2,
        limit_train_batches=3,
        enable_model_summary=False,
        enable_progress_bar=False,
        logger=False,
        enable_checkpointing=False,
    )
    trainer.fit(model, train_dataloaders=train_data)

if __name__ == "__main__":
    run()
simonpokorny commented 1 year ago

Thanks, for sure.

I used your example with the custom optimizer (see below) and the global step is not increasing ..


import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from classifiers.sam import SAM

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)
        self.labels = torch.randint(low=0, high=2, size=(size,))

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Linear(32, 2)
        self.automatic_optimization = False
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):

        data, labels = batch

        opt = self.optimizers()

        # first forward-backward pass
        pred = self.model(data)
        loss_1 = self.loss_fn(pred, labels)
        self.manual_backward(loss_1)
        opt.first_step(zero_grad=True)

        # second forward-backward pass
        pred = self.model(data)
        loss_2 = self.loss_fn(pred, labels)
        self.manual_backward(loss_2)
        opt.second_step(zero_grad=True)

        print(self.trainer.global_step)
        return loss_2

    def configure_optimizers(self):
        base_optimizer = torch.optim.Adam
        optimizer = SAM(self.parameters(), base_optimizer, rho=1, adaptive=True, lr=0.001)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

def run():
    train_data = DataLoader(RandomDataset(size=32, length=64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        max_epochs=2,
        limit_train_batches=3,
        enable_model_summary=False,
        enable_progress_bar=False,
        logger=False,
        enable_checkpointing=False,
    )
    trainer.fit(model, train_dataloaders=train_data)

if __name__ == "__main__":
    run()

Where the SAM optimizer is from https://github.com/davda54/sam.


class SAM(torch.optim.Optimizer):
    """
    SAM Optimizer
    https://github.com/davda54/sam
    """

    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][
            0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
            torch.stack([
                ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                for group in self.param_groups for p in group["params"]
                if p.grad is not None
            ]),
            p=2
        )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups
carmocca commented 1 year ago

Okay. This happens because we assume there will be an optimizer.step() call, which is what we wrap to inject the strategy-specific logic (e.g. DDP): https://github.com/Lightning-AI/lightning/blob/50331e08e111d6b9ebb25a21a86b7170b46c5f1f/src/pytorch_lightning/core/optimizer.py#L101-L173

The call chain is LightningModule.training_step() -> _LightningOptimizer.step() -> Strategy.optimizer_step() -> PrecisionPlugin.optimizer_step() -> Optimizer.step()

Your use of the SAM optimizer violates this assumption, as you are calling two different step methods ({first,second}_step) which are not wrapped like .step(). It's not clear to me if you would expect to increase the global_step count after each or if only after the second_step().

To resolve this, we would need some mechanism to indicate what method we should wrap. cc @awaelchli @justusschock in case they have suggestions in this regard.

Another example of this issue is in https://github.com/ludwigwinkler/JaxLightning/blob/8585863be636152b6adba77a0436ff7509fb92f3/BNN/JaxLightning_BNN.py#L215-L217 (cc @ludwigwinkler) which also suffers from this issue because the Jax optimizer uses .update() instead of .step()

simonpokorny commented 1 year ago

The SAM optimizer training step can be rewrite to classical form with a single closure-based step function

    def training_step(self, batch, batch_idx):

        data, labels = batch
        opt = self.optimizers()

        def closure():
            loss = self.loss_fn(self.model(data), labels)
            loss.backward()
            return loss

        loss = self.loss_fn(self.model(data), labels)
        loss.backward()
        opt.step(closure)
        opt.zero_grad()

        print(self.trainer.global_step)
        return loss

After that , pl is able to wrap call .step() and self.trainer.global_step is increasing.

awaelchli commented 11 months ago

If I understand this here correctly, my proposal is to have a check in our LightningOptimizer wrapper that the step method is available. If not, raise an error suggesting the user should do optimizer.step = optimizer.real_step_method in e.g. the configure_optimizers hook to have it supported in Lightning. IMO this is the easiest and doesn't require new APIs.

carmocca commented 11 months ago

The suggestion

have a check in our LightningOptimizer wrapper that the step method is available

is not foolproof: the SAM optimizer shown above offers first_step, second_step, and step. If the user didn't know about this limitation and called first_step and second_step, they would face this issue but such check wouldn't trigger because the Optimizer also defines a step.

But I don't have a better suggestion that doesn't involve a complex solution such as wrapping all optimizer methods and checking if parameters changed