Closed monney closed 4 years ago
I think you need to set period=0
, then it could work, just looking at the code:
if (
self.epoch_last_check is not None
and (epoch - self.epoch_last_check) < self.period
):
# skipping in this term
return
Try it :) period can only be an integer. Setting it to the val_check_interval does not make sense.
@awaelchli Thanks! This works, but I find it a bit unintuitive. I didn’t get any warnings or anything trying to set it to a fractional value (the same as Val interval) so I assumed that’s what I had to do. Perhaps we could auto set this if Val interval is below 1? Or trigger a warning stating the correctly value to set in this condition?
Yes I agree, period=0 only works because of an implementation detail and it is not meant to be used like that. It's a hack. The sub-epoch checkpointing is not supported currenlty. We're looking into that. If you're feeling lucky, give it a try and send a draft PR? :rocket: It is a tricky one though.
Let's classify this as a feature requrest instead of bug?
Ill give it a look over and see if Im able to do it. Feature request sounds good, since it's working as intended
@awaelchli what do you think about these options for checkpointing? I think this suite could be really helpful:
I think yes, these are all fine use cases. Given that the current ModelCheckpoint callback is quite complex, it may be hard or become impossible to maintain all these options in a single class. We could consider splitting the functionality into several callbacks. A combination of these features would mean passing several callbacks to the Trainer. But then there are new challenges, like clashing filenames etc.
Support for checkpointing on training epoch end if validation steps aren't supported
is that not already supported?
@awaelchli i believe this is fixed on master (At least for the case of checkpointing with sub epoch validations), since it now checks to make sure we haven’t saved on the same global step, instead of the same epoch. Can you confirm?
import torch
import pytorch_lightning as pl
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
class LitClassifier(pl.LightningModule):
def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('valid_loss', loss)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('test_loss', loss)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parser
def cli_main():
pl.seed_everything(1234)
# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
args = parser.parse_args()
# ------------
# data
# ------------
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
# ------------
# model
# ------------
model = LitClassifier(args.hidden_dim, args.learning_rate)
# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(
args,
max_epochs=3,
val_check_interval=0.25,
gpus=1,
checkpoint_callback=ModelCheckpoint(
filepath="lightning_logs/test/{epoch:d}-{valid_loss:.2f}",
save_top_k=-1
)
)
trainer.fit(model, train_loader, val_loader)
# ------------
# testing
# ------------
trainer.test(test_dataloaders=test_loader)
if __name__ == '__main__':
cli_main()
Yes! I just checked it. Above is the code that I tested with val_check_interval = .25 It saves 4 checkpoints per epoch
Awesome, closing for now, the other features can likely be added separately
Question
When setting period to a fractional value, checkpointing doesn’t trigger correctly. Additionally I think period should default to val_check_interval, if it doesn’t already.
To Reproduce
Steps to reproduce the behavior:
Run any model and set checkpoint to run at a fractional value. Only the first checkpoint will be saved.
Expected behavior
A checkpoint should be saved every specified period
Environment
conda
,pip
, source): pip