Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.36k stars 3.39k forks source link

Schedule model testing every N training epochs #5245

Closed LucaBonfiglioli closed 3 years ago

LucaBonfiglioli commented 3 years ago

🚀 Feature

A check_test_every_n_epoch trainer option to schedule model testing every n epochs, just like check_val_every_n_epoch for validation.

Motivation

Sometimes validation and test tasks are very different. For instance, in unsupervised anomaly detection or segmentation, the training and validation set cannot contain anomalous samples, and therefore the set of metrics that can be computed on such set is limited.

Test metrics cannot also be computed on a second validation set that contains a portion of test data, because all parameters and hyperparameter optimization should be performed on clean (anomaly-free) samples.

The only way the user has to check test metrics is to run a test epoch, but currently pytorch-lightning allows to run a test epoch only after the trainer.fit() call has ended, by explicitly calling trainer.test().

Pitch

When initializing a Trainer with check_test_every_n_epoch: x every x epochs the fitting automatically stops, a test epoch is executed, logging and computing test metrics, and then the fitting resumes.

Alternatives

I tried to:

  1. Calling self.trainer.test() from on_training_epoch_start module callback when the value of trainer.current_epochs becomes multiple of a desired value. Apparently, doing this works fine, but after calling the test method, the number of epochs continues to increase from the last value, but the trainer global_step is reset to the value it had when test was last called, creating the beautiful effect shown in figure and making logs unreadable.
  2. Raising KeyboardInterrupt from on_training_epoch_start module callback to stop the fitting, testing, manually saving a checkpoint, and resuming training from the saved checkpoint in a while loop. Cumbersome, poor disk usage, cannot gracefully stop training with Ctrl+C anymore.
  3. Computing test metrics in validation_epoch_end by manually looping on test set using pytorch only. Cumbersome, code duplication, makes test hooks useless, poor code readability.

Additional context

Tensorboard logging current epoch each step when trying solution 1: image

github-actions[bot] commented 3 years ago

Hi! thanks for your contribution!, great first issue!

stale[bot] commented 3 years ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

jinwon-samsung commented 3 years ago

How is this work going along? I would strong like this feature to be incorporated into pl

devfoo-one commented 3 years ago

:+1: I am also interested in this. In my use case, I use DDP, and performing a validation after every epoch on a per-rank basis is fine. However, after n epochs, I would like to test the current model only on one worker because my test involves the initialization of a kNN index which takes a lot of memory.

danielgafni commented 3 years ago

I'm also interested in this feature

daviddavini commented 3 years ago

Me too, this would be really useful!

ananthsub commented 3 years ago

Thanks for raising this!

Some questions that would need to be addressed are:

fijackokresimir commented 3 years ago

does anyone have solution for this?

ashkan-leo commented 3 years ago

I also would want to know how to achieve this. Any help is appreciated!

tomeramit commented 3 years ago

I also want this feature as well I want to know how the progress on the test set as well during training

lamhoangtung commented 3 years ago

Why this issue was tagged won't fix. Can we raise some awareness on this ?

domef commented 3 years ago

I'm also interested in this feature.

tomeramit commented 3 years ago

me too, had to do ugly hack to overcome this In my case my validation set is simulation images, and the test set is real images, so in this case its very important for me to run on test during training

ghost commented 3 years ago

me too, had to do ugly hack to overcome this In my case my validation set is simulation images, and the test set is real images, so in this case its very important for me to run on test during training

Would you mind you sharing your solution? Thank you very much.

ChristianHinge commented 3 years ago

What I am really looking for is to schedule model prediction every n epochs, but either way works for me.

I'm currently working on a synthetic volume-to-volume model which is trained to predict 3D patches of volumes, hence my train- and validation-loaders return patches. At inference, multiple overlapping patches of a single volume are passed through the model and stitched together forming a full predicted volume. Due to the overlapping nature of patches, there is a massive memory advantage in having the inference data_loader return full volumes of images and not patches. I wish to optimize the similarity metrics of the stitched images, so I also need a way of scheduling prediction/testing during training.

I can share my hack, as I also have very different train and test datasetes. My lightning datamodule creates two dataloaders; one which contains my ordinary validation set and one which is actually my inference set of full volumes. The argument dataloader_idx is passed by lightning and tells you which dataloader generated the batch.

By checking the self.current_epoch, I can achieve the desired result. Downside: the dataloader for inference loads the whole dataset at each epoch even though the data is only used every n epochs.

def validation_step(self, batch, batch_idx,dataloader_idx=None):
        # Ordinary validation
        if dataloader_idx == None or dataloader_idx == 0:
            condition, real, mask  = batch
            for idx in range(len(self.generators)+len(self.discriminators)):

                loss = None

                ## Discriminator
                if idx % 2 == 1:
                    dis_idx = self._opt_to_dis_idx(idx)
                    loss = self._disc_step(real, condition,mask,dis_idx)
                    self.log(f"Loss/val.Discriminator_{dis_idx}",loss,on_step=False,on_epoch=True,add_dataloader_idx=False)

                ## Generator
                elif idx %2 == 0:
                    gen_idx = self._opt_to_gen_idx(idx)
                    loss = self._gen_step(real, condition,mask,gen_idx)

                    self.log(f"Loss/val.Generator_{gen_idx}",loss,on_step=False,on_epoch=True,add_dataloader_idx=False)

        elif dataloader_idx == 1 and self.current_epoch % self.hparams.train.test_every_n_epoch == 0:
            mrs, pets, masks, ids = batch
            for mr,pet,mask,id in zip(mrs,pets,masks,ids):
                out = self._predict_and_stitch(mr,idx=-1) # This is my testing function
                mse = torch.sum(torch.square(out[mask.squeeze()>0]-pet.squeeze()[mask.squeeze()>0]))
                self.log(f"Validation/MSE/{id}",mse,on_step=False,on_epoch=True)
LucaBonfiglioli commented 3 years ago

What I am really looking for is to schedule model prediction every n epochs, but either way works for me.

I'm currently working on a synthetic volume-to-volume model which is trained to predict 3D patches of volumes, hence my train- and validation-loaders return patches. At inference, multiple overlapping patches of a single volume are passed through the model and stitched together forming a full predicted volume. Due to the overlapping nature of patches, there is a massive memory advantage in having the inference data_loader return full volumes of images and not patches. I wish to optimize the similarity metrics of the stitched images, so I also need a way of scheduling prediction/testing during training.

I can share my hack, as I also have very different train and test datasetes. My lightning datamodule creates two dataloaders; one which contains my ordinary validation set and one which is actually my inference set of full volumes. The argument dataloader_idx is passed by lightning and tells you which dataloader generated the batch.

By checking the self.current_epoch, I can achieve the desired result. Downside: the dataloader for inference loads the whole dataset at each epoch even though the data is only used every n epochs.

def validation_step(self, batch, batch_idx,dataloader_idx=None):
        # Ordinary validation
        if dataloader_idx == None or dataloader_idx == 0:
            condition, real, mask  = batch
            for idx in range(len(self.generators)+len(self.discriminators)):

                loss = None

                ## Discriminator
                if idx % 2 == 1:
                    dis_idx = self._opt_to_dis_idx(idx)
                    loss = self._disc_step(real, condition,mask,dis_idx)
                    self.log(f"Loss/val.Discriminator_{dis_idx}",loss,on_step=False,on_epoch=True,add_dataloader_idx=False)

                ## Generator
                elif idx %2 == 0:
                    gen_idx = self._opt_to_gen_idx(idx)
                    loss = self._gen_step(real, condition,mask,gen_idx)

                    self.log(f"Loss/val.Generator_{gen_idx}",loss,on_step=False,on_epoch=True,add_dataloader_idx=False)

        elif dataloader_idx == 1 and self.current_epoch % self.hparams.train.test_every_n_epoch == 0:
            mrs, pets, masks, ids = batch
            for mr,pet,mask,id in zip(mrs,pets,masks,ids):
                out = self._predict_and_stitch(mr,idx=-1) # This is my testing function
                mse = torch.sum(torch.square(out[mask.squeeze()>0]-pet.squeeze()[mask.squeeze()>0]))
                self.log(f"Validation/MSE/{id}",mse,on_step=False,on_epoch=True)

I did not know of this feature of pytorch-lightning, it is still a hack but far better than any other "solutions" I have tried.

Maybe with a bit of patience one could actually extend base LightningModule, LightningDataModule and Trainer to automatically perform this in a cleaner way.

jiwidi commented 3 years ago

Would love to have this feature too!

GimmickNG commented 3 years ago

I modified @ChristianHinge 's hack a bit to transparently train and test multiple times per epoch for any vanilla LightningModel, as I found that having that validation step function be used for each different model class results in a lot of boilerplate being copied and pasted which would obscure the actual model logic. This function requires a LightningDataModule that returns either just the validation dataloader or both the validation and the test dataloader, depending on whether the "test_every_n" flag is set or not.

In my LightningDataModule:

def val_dataloader():
    val_dataloaders = []
    if self.validation_set:
        val_dataloaders.append(DataLoader(...))
    if self.test_every_n > 0:
        val_dataloaders.append(self.test_dataloader())
    return val_dataloaders

And the wrapper functions call the relevant functions:

def patched_on_validation_epoch_start(self):
    self._validation_started = self._test_started = False

def patched_validation_step(self, batch, batch_idx, dataloader_idx):
    if dataloader_idx == 0:
        if not self._validation_started:
            self.do_on_validation_epoch_start()
            self._validation_started = True
        return self.do_validation(batch, batch_idx)
    elif dataloader_idx == 1 and self.current_epoch % self.test_every_n == 0:
        if not self._test_started:
            self.on_test_epoch_start()
            self._test_started = True
        return self.test_step(batch, batch_idx)

def patched_validation_epoch_end(self, results):
    val_loop, test_loop = results
    on_vend = self.validation_end(val_loop)
    if len(test_loop):
        self.test_epoch_end(test_loop)
    return on_vend

def wrap_model(model, test_every=None):
    """Wraps a model so that it can be tested every N epochs"""

    if test_every:
        model.test_every_n = test_every
        model.do_validation = model.validation_step
        model.validation_end = model.validation_epoch_end
        model.do_on_validation_epoch_start = model.on_validation_epoch_start

        model.validation_step = patched_validation_step.__get__(model, model.__class__)
        model.validation_epoch_end = patched_validation_epoch_end.__get__(model, model.__class__)
        model.on_validation_epoch_start = patched_on_validation_epoch_start.__get__(model, model.__class__)
    return model

def unwrap_model(model):
    """Undoes the wrapping caused by `wrap_model`"""

    try:
        model.validation_step = model.do_validation
        model.validation_epoch_end = model.validation_end
        model.on_validation_epoch_start = model.do_on_validation_epoch_start
        del model.test_every_n, model._validation_started, model._test_started
    except AttributeError:
        pass
    return model

Finally, before calling trainer.fit(), set the flags and call the functions, and unwrap after training:

# tests every 5 epochs
datamodule.test_every_n = 5
wrap_model(model, datamodule.test_every_n)
trainer.fit(model, datamodule=datamodule)
# unwrap after training
unwrap_model(model)

The same downsides still apply, that is, the test dataloader iterates through the entire test dataset every validation epoch, even if it is not used at all. Still, I find it makes the code a bit cleaner.

geogeo28 commented 3 years ago

One solution I am using is:

Of course this is a dirty solution as well since any change of Pytorch Lightning can break this patch or produce issues.

carmocca commented 3 years ago

Hi all! Just found this issue now.

9254 contains a discussion on why this hasn't been implemented.

Please take a look. You can directly answer there to provide any arguments you think we missed.

danielgafni commented 2 years ago

Hey, I have something to say about this!

Testing every n epochs can definitely be useful.

For example, I want to know if my early stopping condition is valid. It can be very useful to compute test metrics alongside validation metrics for this case. This way we get to know if our test metrics behave in the same way the validation metrics do. If they don't correlate / val and test curves don't have similar shape, early stopping can be useless!

Modifying the val_dataloaders can be very hard to do for external libraries. For example, Pytorch-Forecasting doesn't support that.

This why I think it's still a good thing to leave this option open to the user. Lightning can print a warning if the user is testing before training, but why remove this option completely?

Tell me what do you think @carmocca

carmocca commented 2 years ago

Modifying the val_dataloaders can be very hard to do for external libraries.

You wouldn't need to modify the dataloaders as you can instead pass a list of them.

This way we get to know if our test metrics behave in the same way the validation metrics do. If they don't correlate / val and test curves don't have similar shape, early stopping can be useless!

This is considered "cheating" as your training is being informed from its performance on the test set. The test set is supposed to be held out and only used after the best model (generally on validation) is selected.

There are many techniques to avoid inconsistencies between the different data partitions, from better shuffling to data drifting analysis.

Lightning can print a warning if the user is testing before training, but why remove this option completely?

If the user explicitly chose to do this, we wouldn't want to print a warning because it's his/her choice. However, we are not comfortable adding a flag to the Trainer that might signal our users that this practice is generally okay.

danielgafni commented 2 years ago

Yeah, I thought about it again and I agree with you. Probably I only needed this because pytorch-forecasting doesn't support multiple dataloaders (their pl_modules break), otherwise there is no need for testing every epoch. It really should be just another val dataloader.

ZENGYIMING-EAMON commented 2 years ago

any cleaner solutions for PL-1.6.1? looking forward to it.

Jaliborc commented 2 years ago

Modifying the val_dataloaders can be very hard to do for external libraries.

You wouldn't need to modify the dataloaders as you can instead pass a list of them.

This way we get to know if our test metrics behave in the same way the validation metrics do. If they don't correlate / val and test curves don't have similar shape, early stopping can be useless!

This is considered "cheating" as your training is being informed from its performance on the test set. The test set is supposed to be held out and only used after the best model (generally on validation) is selected.

There are many techniques to avoid inconsistencies between the different data partitions, from better shuffling to data drifting analysis.

Lightning can print a warning if the user is testing before training, but why remove this option completely?

If the user explicitly chose to do this, we wouldn't want to print a warning because it's his/her choice. However, we are not comfortable adding a flag to the Trainer that might signal our users that this practice is generally okay.

This is a ridiculous argument. You're just making it harder for when people need to run tests during trainings for earlier results. People are still gonna do it when they need it, you're just making it unnecessarily messier.

Furthermore, there are many problems which there are no metrics for validation. The only way to measure how a model is doing is to test it. In fact, I think it is quite arrogant in your part to assume you know the entirety of science on which deep learning might be used upon, and thus set yourselves as purveyors of what is acceptable or not to do.

Shameful.

tommyfuu commented 11 months ago

seconding this thread to see if there's any updates

theonlynick0430 commented 7 months ago

this is especially useful for robotics algorithms where we want to test rollout every n epochs. this will definitely be a useful feature