Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.38k stars 3.38k forks source link

Support checkpointing for Sub-Epoch period #3646

Closed monney closed 4 years ago

monney commented 4 years ago

Question

When setting period to a fractional value, checkpointing doesn’t trigger correctly. Additionally I think period should default to val_check_interval, if it doesn’t already.

To Reproduce

Steps to reproduce the behavior:

Run any model and set checkpoint to run at a fractional value. Only the first checkpoint will be saved.

Expected behavior

A checkpoint should be saved every specified period

Environment

awaelchli commented 4 years ago

I think you need to set period=0, then it could work, just looking at the code:

        if (
            self.epoch_last_check is not None
            and (epoch - self.epoch_last_check) < self.period
        ):
            # skipping in this term
            return

Try it :) period can only be an integer. Setting it to the val_check_interval does not make sense.

monney commented 4 years ago

@awaelchli Thanks! This works, but I find it a bit unintuitive. I didn’t get any warnings or anything trying to set it to a fractional value (the same as Val interval) so I assumed that’s what I had to do. Perhaps we could auto set this if Val interval is below 1? Or trigger a warning stating the correctly value to set in this condition?

awaelchli commented 4 years ago

Yes I agree, period=0 only works because of an implementation detail and it is not meant to be used like that. It's a hack. The sub-epoch checkpointing is not supported currenlty. We're looking into that. If you're feeling lucky, give it a try and send a draft PR? :rocket: It is a tricky one though.

Let's classify this as a feature requrest instead of bug?

monney commented 4 years ago

Ill give it a look over and see if Im able to do it. Feature request sounds good, since it's working as intended

ananthsub commented 4 years ago

@awaelchli what do you think about these options for checkpointing? I think this suite could be really helpful:

awaelchli commented 4 years ago

I think yes, these are all fine use cases. Given that the current ModelCheckpoint callback is quite complex, it may be hard or become impossible to maintain all these options in a single class. We could consider splitting the functionality into several callbacks. A combination of these features would mean passing several callbacks to the Trainer. But then there are new challenges, like clashing filenames etc.

Support for checkpointing on training epoch end if validation steps aren't supported

is that not already supported?

monney commented 4 years ago

@awaelchli i believe this is fixed on master (At least for the case of checkpointing with sub epoch validations), since it now checks to make sure we haven’t saved on the same global step, instead of the same epoch. Can you confirm?

awaelchli commented 4 years ago
import torch
import pytorch_lightning as pl
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from torchvision.datasets.mnist import MNIST
from torchvision import transforms

class LitClassifier(pl.LightningModule):
    def __init__(self, hidden_dim=128, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()

        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('valid_loss', loss)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--hidden_dim', type=int, default=128)
        parser.add_argument('--learning_rate', type=float, default=0.0001)
        return parser

def cli_main():
    pl.seed_everything(1234)

    # ------------
    # args
    # ------------
    parser = ArgumentParser()
    parser.add_argument('--batch_size', default=32, type=int)
    parser = pl.Trainer.add_argparse_args(parser)
    parser = LitClassifier.add_model_specific_args(parser)
    args = parser.parse_args()

    # ------------
    # data
    # ------------
    dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
    mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
    mnist_train, mnist_val = random_split(dataset, [55000, 5000])

    train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
    val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
    test_loader = DataLoader(mnist_test, batch_size=args.batch_size)

    # ------------
    # model
    # ------------
    model = LitClassifier(args.hidden_dim, args.learning_rate)

    # ------------
    # training
    # ------------
    trainer = pl.Trainer.from_argparse_args(
        args,
        max_epochs=3,
        val_check_interval=0.25,
        gpus=1,
        checkpoint_callback=ModelCheckpoint(
            filepath="lightning_logs/test/{epoch:d}-{valid_loss:.2f}",
            save_top_k=-1
        )
    )
    trainer.fit(model, train_loader, val_loader)

    # ------------
    # testing
    # ------------
    trainer.test(test_dataloaders=test_loader)

if __name__ == '__main__':
    cli_main()
awaelchli commented 4 years ago

Yes! I just checked it. Above is the code that I tested with val_check_interval = .25 It saves 4 checkpoints per epoch

monney commented 4 years ago

Awesome, closing for now, the other features can likely be added separately