Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.01k stars 3.36k forks source link

Cross validation feature #839

Closed BraveDistribution closed 2 years ago

BraveDistribution commented 4 years ago

🚀 Feature

Cross-Validation is a crucial model validation techniques for assessing how the model generalizes on new data.

Motivation

Research papers usually require cross-validation. From my point of view, this kind of feature would simplify the work of researches.

Pitch

I want to pass a parameter to the Trainer object to specify that I want to train the model on K-folds.

In the case that nobody wants to make a PR, I can start working on that.

Borda commented 4 years ago

I think that the cleaner way would some abstraction above the dataloader, because cross-validation is just systematic train/test on a particular dataset... Anyway, a PR is welcome! @BraveDistribution may you pls a bit more describe how do you plan to implement or make a draft PR and we can talk about it there :robot:

BraveDistribution commented 4 years ago

@Borda, I don't have any plan how to implement it because I wasn't working on that till now.

If I have any questions I will post it here, if not I will make a PR directly.

williamFalcon commented 4 years ago

what if we just integrate with sklearn cross validation? this can be the start of supporting sklearn interop

BraveDistribution commented 4 years ago

How would you propose that @williamFalcon?

In my "own" library I split the datasets into K folders by using my own script (you can use k-fold or stratified k-fold or any of the scikit methods).

dataset/k_0/train dataset/k_0/test

dataset/k_1/train dataset/k_1/test

Then I trained and evaluated K neural networks and finally I just grab all the results and saved out the mean of acc, f1 and other metrics.

That of course means you wasted space on HDD which equals to (K-1) * size of the dataset. We shouldn't be implementing that approach.


I think we should add new parameter into trainer which can be something like GridSearchCV in scikit-learn

cvint, cross-validation generator or an iterable, optional Determines the cross-validation splitting strategy. Possible inputs for cv are: None, to use the default 5-fold cross validation, integer, to specify the number of folds in a (Stratified)KFold, CV splitter, An iterable yielding (train, test) splits as arrays of indices. For integer/None inputs, if the estimator is a classifier and y is either binary or multiclass, StratifiedKFold is used. In all other cases, KFold is used.

Ir1d commented 4 years ago

what if we just integrate with sklearn cross validation? this can be the start of supporting sklearn interop

@williamFalcon skorch has a nice implementation. https://github.com/skorch-dev/skorch/blob/f94466e272f6f325898359fecb9a7c004354af7f/skorch/dataset.py#L212

Borda commented 4 years ago

check use case in #1393

Anjum48 commented 4 years ago

By passing data loaders directly to the Trainer my CV loop looks like this:

for fold, (train_idx, valid_idx) in enumerate(kfold.split(train_df):
    train_loader = create_dataloader(train_df.iloc[train_idx])
    valid_loader = create_dataloader(train_df.iloc[valid_idx])

    # Folder hack
    tb_logger = TensorBoardLogger(save_dir=OUTPUT_PATH, name=f'{args.model_name}', version=f'fold_{fold + 1}')
    os.makedirs(OUTPUT_PATH / f'{args.model_name}, exist_ok=True)
    checkpoint_callback = ModelCheckpoint(filepath=tb_logger.log_dir + "/{epoch:02d}-{val_metric:.4f}", 
                                          monitor='val_metric', mode='max')

    model = YourPLModule(args)
    trainer = pl.Trainer(logger=tb_logger, early_stop_callback=early_stop_callback, checkpoint_callback=checkpoint_callback)
    trainer.fit(model, train_dataloader=train_loader, val_dataloaders=valid_loader)

Note that the folder hack is from https://github.com/PyTorchLightning/pytorch-lightning/issues/1207

Borda commented 4 years ago

it could be a nice feature as we have now the LR finder... @PyTorchLightning/core-contributors any other suggestions? @Anjum48, I would say draft a PR would be nice...

justusschock commented 4 years ago

I wouldn't integrate this to fit or trainer init, but to a separate function internally calling fit

Borda commented 4 years ago

I wouldn't integrate this to fit or trainer init, but to a separate function internally calling fit

I agree, that's why I proposed to do it similar as LR finder... lol

BraveDistribution commented 4 years ago

We should also somehow include the CV results into tensorboard, to provide scientists easy way to check the quality of their models. I don't know much about tensorboard, so I don't know whether that's possible.

Or, we should at least save the final results into json / pickle file.

axkoenig commented 4 years ago

Are there any news on this?

Borda commented 4 years ago

@axkoenig how would you do it, Write a wrapper over a Trainer and perform the fold splitting followed by train-test?

justusschock commented 4 years ago

I think, we could have something like that in bolts, but it is very hard to generalize this, since it always depends on how you want to split your data.

SkafteNicki commented 4 years ago

I think we could provide two options:

  1. Either users provide a single train_dataloader that we split into K new dataloaders with non-overlapping subsets of data, and perform the cross validation from them
  2. Users provide K train_dataloaders and K test_dataloaders and we run cross validation on them (basically calling trainer.fit iteratively)
justusschock commented 4 years ago

@SkafteNicki I think this would be a good idea to start.

However, we might also want to have some stratified splitting and not just random splitting, which may become more difficult, since we would have to assume things (like structure, dtype etc.) about these batches.

In general, we should also keep in mind, that we may not want to only split for train and test but also for validation sets/data loaders

SkafteNicki commented 4 years ago

@justusschock completely agree, I think that v1 of this feature should be very simple just random splitting. My proposed option 2. would allow the user to provide their own stratified dataloaders.

In v2 we can begin to figure out how to do more advance stuff/better integration. The main problem (in my view), is that we are working with dataloaders and not datasets, so to get dataset statistics (like class balance for stratified splitting) we need to explicit run over the dataset and enforce a lot of structure in the batches (as you mention).

astenuz commented 4 years ago

Hi! Is there an update on this issue? Due to the ubiquity of the cross val strategy it could be a quite significant addition to pl

SkafteNicki commented 4 years ago

@astenuz so we currently have a freeze on new features until the v1.0 release, since we want to focus on getting a very stable release. After v1.0 this is definitely something we would like to be a part of lightning.

ananyahjha93 commented 3 years ago

@SkafteNicki should this be a DataModule feature, as mentioned in #4287 ? Like the DataModule itself provides k dataloaders like you mentioned here.

williamFalcon commented 3 years ago

cc @edenafek

let’s pick this back up now

SkafteNicki commented 3 years ago

@ananyahjha93 the first question is how it should be integrated in lightning: 1) should trainer have a k_fold init argument? 2) should fit have a k_fold argument? 3) should trainer have a new method (cross_validate) 4) should this be a plugin? 5) should this be a completely new object wrapping around trainer (CV(Trainer(...)))?

justusschock commented 3 years ago

I actually like the idea of having a separate class (CV) and some function in the data module for that. This way we would still have the trainer to train separate networks, but don't further bloat it's state.

However I'd prefer the interface to have the CV construct trainers internally by passed args. So something like this:

class CV:
    def __init__(self, *args, **kwargs):
        self.trainer_args = args
        self.trainer_kwargs = kwargs

    def fit(model, data_module):
        for loaders in data_module.get_kfold():
            fold_model = deepcopy(model)
            yield Trainer(*self.trainer_args, **self.trainer_kwargs).fit(model, loaders)
SkafteNicki commented 3 years ago

I am also in favor of a new separate class. Another thing is that the CV object probably will have some parameters of its own: 1) should the fitting be done in parallel (then we need to figure out how to map individual fit to each device) 2) should the cv be stratified (maybe not in v1 of this feature) 3) ...

gaceladri commented 3 years ago

I think that integration with optuna cross-validation would be a great match.

williamFalcon commented 3 years ago

that’s already supported today. i think they tutorials about it as well no?

but generally we want to make sure we build general tools that support any option like optuna.

gaceladri commented 3 years ago

I have not seen tutorials doing cross validation with pytorch-lightning neither pytorch-lightning + Optuna cross-val.

I agree with you that the feature should be general.

justusschock commented 3 years ago

@SkafteNicki I think for v1 the folds could run sequentially and the data_module could have a method which creates the loader (probably without stratification in v1, but can be overwritten by user). Also it is not possible to stratify every kind of training :D

jaimergp commented 3 years ago

Any specific plans on this? I have been trying to implement something like https://github.com/PyTorchLightning/pytorch-lightning/issues/839#issuecomment-714273956 but I am running into some rough edges like managing the loggers across folds, or checkpoints. There's also open questions about how to deal with the test parts.

I'd be happy to work on a PR given some guidance on how you'd like this implemented!

marcosfelt commented 3 years ago

Any specific plans on this? I have been trying to implement something like https://github.com/PyTorchLightning/pytorch-lightning/issues/839#issuecomment-714273956 but I am running into some rough edges like managing the loggers across folds, or checkpoints. There's also open questions about how to deal with the test parts.

I'd be happy to work on a PR given some guidance on how you'd like this implemented!

Same here!

Shubhamai commented 3 years ago

Any specific plans on this? I have been trying to implement something like #839 (comment) but I am running into some rough edges like managing the loggers across folds, or checkpoints. There's also open questions about how to deal with the test parts.

I'd be happy to work on a PR given some guidance on how you'd like this implemented!

Same!

Svito-zar commented 3 years ago

Looking forward to seeing this feature!

appleparan commented 3 years ago

I support 2nd approach from @SkafteNicki .

  1. Users provide K train_dataloaders and K test_dataloaders and we run cross validation on them (basically calling trainer.fit iteratively)

There are some other CV methods such as Blocked Cross Validation for time series forecasting. Providing dataloaders for well-known CV method not only gives convenience but also a lot of customization to users.

If you are the someone needs K-Fold CV, you might implement custom dataset and dataloaders, then concatenating k-fold dataset by ConcatDataset in torch.utils.data [Ref] and providing to your trainer solve your problem.

jbschiratti commented 3 years ago

I'm also interested in this feature (I would use it on a regular basis). Starting from the computer vision example in the pl_examples folder, I wrote an example of K-Fold CV with Pytorch-Lightning. It's certainly not perfect but it's working.

from copy import deepcopy
from pathlib import Path

from sklearn.model_selection import KFold, StratifiedKFold
from torch import nn, sigmoid, optim
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.utils.data import ConcatDataset, Subset, DataLoader
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_and_extract_archive

from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import NeptuneLogger, LoggerCollection

DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"

class KFoldHelper:
    """Split data for (Stratified) K-Fold Cross-Validation."""
    def __init__(self,
                 n_splits=5,
                 stratify=False):
        super().__init__()
        self.n_splits = n_splits
        self.stratify = stratify

    def __call__(self, data):
        data.prepare_data()

        if self.stratify:
            labels = data.get_data_labels()
            splitter = StratifiedKFold(n_splits=self.n_splits)
        else:
            labels = None
            splitter = KFold(n_splits=self.n_splits)

        dataset = data.get_dataset()
        n_samples = len(dataset)
        for train_idx, val_idx in splitter.split(X=range(n_samples), y=labels):

            _train = Subset(dataset, train_idx)
            train_dataset = _WrappedDataset(_train, data.train_transform)
            train_loader = DataLoader(dataset=train_dataset,
                                      batch_size=data.batch_size,
                                      shuffle=True,
                                      num_workers=data.num_workers)

            _val = Subset(dataset, val_idx)
            val_dataset = _WrappedDataset(_val, data.val_transform)
            val_loader = DataLoader(dataset=val_dataset,
                                    batch_size=data.batch_size,
                                    shuffle=False,
                                    num_workers=data.num_workers)

            yield train_loader, val_loader

class _WrappedDataset:
    """Allows to add transforms to a given Dataset."""
    def __init__(self,
                 dataset,
                 transform=None):
        super().__init__()
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample, label = self.dataset[idx]
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, label

class CV:
    """(Stratified) K-Fold Cross-validation wrapper for a Trainer."""
    def __init__(self,
                 trainer,
                 n_splits=5,
                 stratify=False):
        super().__init__()
        self.trainer = trainer
        self.n_splits = n_splits
        self.stratify = stratify

    @staticmethod
    def _update_logger(logger, fold_idx):
        if hasattr(logger, 'experiment_name'):
            logger_key = 'experiment_name'
        elif hasattr(logger, 'name'):
            logger_key = 'name'
        else:
            raise AttributeError('The logger associated with the trainer '
                                 'should have an `experiment_name` or `name` '
                                 'attribute.')
        new_experiment_name = getattr(logger, logger_key) + f'/{fold_idx}'
        setattr(logger, logger_key, new_experiment_name)

    @staticmethod
    def update_modelcheckpoint(model_ckpt_callback, fold_idx):
        _default_filename = '{epoch}-{step}'
        _suffix = f'_fold{fold_idx}'
        if model_ckpt_callback.filename is None:
            new_filename = _default_filename + _suffix
        else:
            new_filename = model_ckpt_callback.filename + _suffix
        setattr(model_ckpt_callback, 'filename', new_filename)

    def update_logger(self, trainer, fold_idx):
        if not isinstance(trainer.logger, LoggerCollection):
            _loggers = [trainer.logger]
        else:
            _loggers = trainer.logger

        # Update loggers:
        for _logger in _loggers:
            self._update_logger(_logger, fold_idx)

    def fit(self, model, data):
        split_func = KFoldHelper(n_splits=self.n_splits, stratify=self.stratify)
        cv_data = split_func(data)
        for fold_idx, loaders in enumerate(cv_data):
            # Clone model & trainer:
            _model = deepcopy(model)
            _trainer = deepcopy(self.trainer)

            # Update loggers and callbacks:
            self.update_logger(_trainer, fold_idx)
            for callback in _trainer.callbacks:
                if isinstance(callback, ModelCheckpoint):
                    self.update_modelcheckpoint(callback, fold_idx)

            # Fit:
            _trainer.fit(_model, *loaders)

class CatsDogsData:
    """Cats & dogs toy dataset."""

    def __init__(self,
                 data_dir,
                 num_workers: int = 16,
                 batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.num_workers = num_workers
        self.batch_size = batch_size

    def prepare_data(self):
        """Download the raw data."""
        download_and_extract_archive(url=DATA_URL,
                                     download_root=self.data_dir,
                                     remove_finished=True)

    @property
    def normalize_transform(self):
        return transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

    @property
    def train_transform(self):
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            self.normalize_transform,
        ])

    @property
    def val_transform(self):
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            self.normalize_transform
        ])

    def get_dataset(self):
        """Create the complete dataset."""
        train_data_path = Path(self.data_dir).joinpath('cats_and_dogs_filtered', 'train')
        train_dataset = ImageFolder(root=train_data_path)
        valid_data_path = Path(self.data_dir).joinpath('cats_and_dogs_filtered', 'validation')
        valid_dataset = ImageFolder(root=valid_data_path)
        return ConcatDataset([train_dataset, valid_dataset])

    def get_data_labels(self):
        dataset = self.get_dataset()
        return [int(sample[1]) for sample in dataset]

class MyCustomModel(LightningModule):
    """Custom classification model."""

    def __init__(self, lr=1e-3):
        super().__init__()
        self.lr = lr

        self.__build_model()

    def __build_model(self):
        # Classifier:
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=0, stride=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=0, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=0, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.fc1 = nn.Linear(3 * 3 * 64, 10)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(10, 1)
        self.relu = nn.ReLU()

        # Loss:
        self.loss = binary_cross_entropy_with_logits

        # Metrics:
        self.train_acc = Accuracy()
        self.valid_acc = Accuracy()

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(out.size(0), -1)
        out = self.relu(self.fc1(out))
        out = self.fc2(out)
        return out

    def loss(self, logits, labels):
        return self.loss_func(input=logits, target=labels)

    def training_step(self, batch, batch_idx):
        # 1. Forward pass:
        x, y = batch
        y_logits = self.forward(x)
        y_true = y.view((-1, 1)).type_as(x)

        # 2. Compute loss
        train_loss = self.loss(y_logits, y_true)

        # 3. Compute accuracy:
        train_accuracy = self.train_acc(sigmoid(y_logits), y_true.int())
        self.log("train_acc", train_accuracy, prog_bar=True)

        return train_loss

    def validation_step(self, batch, batch_idx):
        # 1. Forward pass:
        x, y = batch
        y_logits = self.forward(x)
        y_true = y.view((-1, 1)).type_as(x)

        # 2. Compute loss
        self.log("val_loss", self.loss(y_logits, y_true), prog_bar=True)

        # 3. Compute accuracy:
        valid_accuracy = self.valid_acc(sigmoid(y_logits), y_true.int())
        self.log("val_acc", valid_accuracy, prog_bar=True)

    def configure_optimizers(self):
        parameters = list(self.parameters())
        trainable_parameters = list(filter(lambda p: p.requires_grad, parameters))
        optimizer = optim.Adam(trainable_parameters, lr=self.lr)
        return optimizer

if __name__ == '__main__':

    # Trainer
    neptune_logger = NeptuneLogger(project_name=NEPTUNE_PROJECT_NAME,
                                   experiment_name=NEPTUNE_EXPERIMENT_NAME)

    model_checkpoint = ModelCheckpoint(dirpath=MODEL_CHECKPOINT_DIR_PATH,
                                       monitor='val_acc',
                                       save_top_k=1,
                                       mode='max',
                                       filename='custom_model_{epoch}',)

    pl_trainer = Trainer(weights_summary=None,
                         progress_bar_refresh_rate=1,
                         num_sanity_val_steps=0,
                         gpus=[0],
                         max_epochs=10,
                         logger=neptune_logger,
                         callbacks=[model_checkpoint])

    # LightningModule
    clf = MyCustomModel(lr=1e-3)

    # Run a 5-fold cross-validation experiment:
    image_data = CatsDogsData(data_dir=DATA_DIR)

    cv = CV(trainer=pl_trainer,
            n_splits=5,
            stratify=False)

    cv.fit(clf, image_data)

The main ingredients are:

What do you think about this example?

cc @SkafteNicki @Borda

jbschiratti commented 3 years ago

Up ⬆️ :-)

justusschock commented 3 years ago

@jbschiratti Thanks for coming up with this.

I see some points, where we probably need to improve a bit:

1.) Your example only runs the models sequentially, but I feel that there should be an option to also do this in parallel (can be added later as discussed above, just mentioning it here)

2.) You only construct model and transforms once. I feel we should recreate them instead of deepcopy in case there are some dependencies on the data for initialization of transforms and model

3.) Should we really pass the trainer or just the arguments so that the trainer will also be created every time?

4.) This is only one possible way to o a kfold. We need to sort out which other versions there are and whether we want to support them/which of them we want to support.

5.) Your helper class should be part of the CV class, so that I can simply overwrite the parts necessary. Currently I have to overwrite both classes since the Helper class is kind of hardcoded there.

But first we should really discuss whether in general we want to add this here.

cc @tchaton @Borda @carmocca @ananthsub @SkafteNicki

marcelschilling commented 3 years ago

Very nice, thanks a lot @jbschiratti. One question: Do you also have a cuda memory leak issues when calling trainer.fit(...) several times? Seems to me that some subprocesses won't get killed (garbage collector as well as delete/torch.cuda.empty_cache() did not help), thus allocated gpu increases each trainer.fit(...) call? Thanks in advance?

SkafteNicki commented 3 years ago

1.) Your example only runs the models sequentially, but I feel that there should be an option to also do this in parallel (can be added later as discussed above, just mentioning it here)

Agree, that this should be added, but not in the first version. We should allow both running in parallel on different devices but also on same device (sometimes multiple models can be fit on same gpu)

2.) You only construct model and transforms once. I feel we should recreate them instead of deepcopy in case there are some dependencies on the data for initialization of transforms and model

Agree

3.) Should we really pass the trainer or just the arguments so that the trainer will also be created every time?

IMO we should pass the arguments. Much is going on in the trainer, and if we do not correctly reset it between runs it may screw things up.

4.) This is only one possible way to o a kfold. We need to sort out which other versions there are and whether we want to support them/which of them we want to support.

IMO v1 should be similar to https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_validate.html: pass in trainer, model and single train dataloader and it will be split into K folds.

5.) Your helper class should be part of the CV class, so that I can simply overwrite the parts necessary. Currently I have to overwrite both classes since the Helper class is kind of hardcoded there.

Agree

But first we should really discuss whether in general we want to add this here.

Just based on the number of thumbs up, this is probably our most requested feature (also one of the oldest). The clear argument for having this in lightning is that it reduces boilerplate (which is one of our core values). The argument against is having to maintain an additional feature.

jbschiratti commented 3 years ago

@marcelschilling I took a quick look and it does not seem like there are CUDA memory leaks during the training but it should be investigated more thoroughly.

@justusschock @SkafteNicki Thank you for the feedback. Although the example I proposed is far from perfect, I'm glad it triggered this discussion. I agree with @SkafteNicki that a lot of people requested this feature. Personally, I would use it on a regular (daily?) basis. As a next step, shall I address your comments and initiate a PR (for further discussions)?

jbschiratti commented 3 years ago

You only construct model and transforms once. I feel we should recreate them instead of deepcopy in case there are some dependencies on the data for initialization of transforms and model @justusschock do you have an example (where data transforms would need to be re-created at each split)?

justusschock commented 3 years ago

@jbschiratti E.g. in medical image processing you have transforms depending on the spacing of the training data (voxel size and distance). ANd since the train data changes here, we would have to recreate the transforms as well.

But this could probably be done with #6776 more easily

evancasey commented 3 years ago

@jbschiratti another, more widely used case is z-score normalization that's based on the train split statistics. In your example, a constant mean/std is used for every split but I could see cases where we want to calculate mean/std individually for each split

jbschiratti commented 3 years ago

I took your comments into account and addressed the points raised by @marcelschilling and @SkafteNicki.

Your example only runs the models sequentially

For now, the cross-validation is done sequentially. Let's keep things simple for now. Once a first version of this feature is implemented, we may think about making things more complicated :-)

Should we really pass the trainer or just the arguments so that the trainer will also be created every time?

The trainer arguments are passed instead of the trainer itself. A new trainer is created for each CV split.

You only construct model and transforms once

In the example below, the model and transforms are created once. However, I use the example of z-score normalization proposed by @evancasey to show how transforms can be updated with each data split. In this example, the mean and std of the images in the dataset is recomputed to allow for a different normalization transform each time.

Your helper class should be part of the CV class

It's now part of the data class. Users may overwrite the get_splits method of this class to do more elaborate stuff than (stratified) K-Fold. By default, K-Fold is used for the cross-validation.

Just based on the number of thumbs up, this is probably our most requested feature (also one of the oldest)

@SkafteNicki @Borda Shall we move on with this?

from copy import deepcopy
from pathlib import Path
from typing import Union, Optional, Callable

from sklearn.model_selection import KFold, StratifiedKFold
from torch import nn, sigmoid, optim
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.utils.data import ConcatDataset, Subset, DataLoader, Dataset
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_and_extract_archive

from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import NeptuneLogger, LoggerCollection

DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"

class _WrappedDataset:
    """Allows to add transforms to a given Dataset."""
    def __init__(self,
                 dataset: Dataset,
                 transform: Optional[Callable] = None):
        super().__init__()
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx: int):
        sample, label = self.dataset[idx]
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, label

class CatsDogsDataCV:
    """Cats & dogs toy dataset for cross-validation."""
    def __init__(self,
                 data_dir: Union[str, Path],
                 num_workers: int = 16,
                 batch_size: int = 32,
                 n_splits: int = 5,
                 stratify: bool = False):
        super().__init__()
        self.data_dir = data_dir
        self.num_workers = num_workers
        self.batch_size = batch_size

        # Cross-validation
        self.n_splits = n_splits
        self.stratify = stratify

        # Data normalization
        self._mean = [0.485, 0.456, 0.406]
        self._std = [0.229, 0.224, 0.225]

    def prepare_data(self):
        """Download the raw data."""
        download_and_extract_archive(url=DATA_URL,
                                     download_root=str(self.data_dir),
                                     remove_finished=True)

    @property
    def normalize_transform(self):
        return transforms.Normalize(mean=self._mean, std=self._std)

    @property
    def train_transform(self):
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            self.normalize_transform,
        ])

    @property
    def val_transform(self):
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            self.normalize_transform
        ])

    def get_splits(self):
        if self.stratify:
            labels = self.get_data_labels()
            cv_ = StratifiedKFold(n_splits=self.n_splits)
        else:
            labels = None
            cv_ = KFold(n_splits=self.n_splits)

        dataset = self.get_dataset()
        n_samples = len(dataset)
        for train_idx, val_idx in cv_.split(X=range(n_samples), y=labels):
            _train = Subset(dataset, train_idx)
            self._update_mean_std(dataset=_train)
            train_dataset = _WrappedDataset(_train, self.train_transform)
            train_loader = DataLoader(dataset=train_dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=self.num_workers)

            _val = Subset(dataset, val_idx)
            val_dataset = _WrappedDataset(_val, self.val_transform)
            val_loader = DataLoader(dataset=val_dataset,
                                    batch_size=self.batch_size,
                                    shuffle=False,
                                    num_workers=self.num_workers)

            yield train_loader, val_loader

    def _update_mean_std(self, dataset):
        """Computes the mean and std of the given (image) dataset.

        Instantiates a dataloader to compute the mean and std from batches.
        """
        _dataset = _WrappedDataset(dataset=dataset,
                                   transform=transforms.Compose([transforms.Resize((224, 224)),
                                                                 transforms.ToTensor()]))
        _dataloader = DataLoader(dataset=_dataset,
                                 batch_size=self.batch_size,
                                 shuffle=False,
                                 num_workers=self.num_workers)
        mean, std, n_samples = 0., 0., 0.
        for images, _ in _dataloader:
            batch_samples = images.size(0)
            data = images.view(batch_samples, images.size(1), -1)
            mean += data.mean(2).sum(0)
            std += data.std(2).sum(0)
            n_samples += batch_samples
        self._mean = mean / n_samples
        self._std = std / n_samples

    def get_dataset(self):
        """Creates and returns the complete dataset."""
        train_data_path = Path(self.data_dir).joinpath('cats_and_dogs_filtered', 'train')
        train_dataset = ImageFolder(root=train_data_path)
        valid_data_path = Path(self.data_dir).joinpath('cats_and_dogs_filtered', 'validation')
        valid_dataset = ImageFolder(root=valid_data_path)
        return ConcatDataset([train_dataset, valid_dataset])

    def get_data_labels(self):
        dataset = self.get_dataset()
        return [int(sample[1]) for sample in dataset]

class CV:
    """Cross-validation with a LightningModule."""
    def __init__(self,
                 *trainer_args,
                 **trainer_kwargs):
        super().__init__()
        self.trainer_args = trainer_args
        self.trainer_kwargs = trainer_kwargs

    @staticmethod
    def _update_logger(logger, fold_idx: int):
        if hasattr(logger, 'experiment_name'):
            logger_key = 'experiment_name'
        elif hasattr(logger, 'name'):
            logger_key = 'name'
        else:
            raise AttributeError('The logger associated with the trainer '
                                 'should have an `experiment_name` or `name` '
                                 'attribute.')
        new_experiment_name = getattr(logger, logger_key) + f'/{fold_idx}'
        setattr(logger, logger_key, new_experiment_name)

    @staticmethod
    def update_modelcheckpoint(model_ckpt_callback, fold_idx):
        _default_filename = '{epoch}-{step}'
        _suffix = f'_fold{fold_idx}'
        if model_ckpt_callback.filename is None:
            new_filename = _default_filename + _suffix
        else:
            new_filename = model_ckpt_callback.filename + _suffix
        setattr(model_ckpt_callback, 'filename', new_filename)

    def update_logger(self, trainer: Trainer, fold_idx: int):
        if not isinstance(trainer.logger, LoggerCollection):
            _loggers = [trainer.logger]
        else:
            _loggers = trainer.logger

        # Update loggers:
        for _logger in _loggers:
            self._update_logger(_logger, fold_idx)

    def fit(self, model: LightningModule, data: CatsDogsDataCV):
        splits = data.get_splits()
        for fold_idx, loaders in enumerate(splits):

            # Clone model & instantiate a new trainer:
            _model = deepcopy(model)
            trainer = Trainer(*self.trainer_args, **self.trainer_kwargs)

            # Update loggers and callbacks:
            self.update_logger(trainer, fold_idx)
            for callback in trainer.callbacks:
                if isinstance(callback, ModelCheckpoint):
                    self.update_modelcheckpoint(callback, fold_idx)

            # Fit:
            trainer.fit(_model, *loaders)

class MyCustomModel(LightningModule):
    """Custom classification model."""

    def __init__(self, lr=1e-3):
        super().__init__()
        self.lr = lr

        self.__build_model()

    def __build_model(self):
        # Classifier:
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=0, stride=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=0, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=0, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.fc1 = nn.Linear(3 * 3 * 64, 10)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(10, 1)
        self.relu = nn.ReLU()

        # Loss:
        self.loss = binary_cross_entropy_with_logits

        # Metrics:
        self.train_acc = Accuracy()
        self.valid_acc = Accuracy()

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(out.size(0), -1)
        out = self.relu(self.fc1(out))
        out = self.fc2(out)
        return out

    def loss(self, logits, labels):
        return self.loss_func(input=logits, target=labels)

    def training_step(self, batch, batch_idx):
        # 1. Forward pass:
        x, y = batch
        y_logits = self.forward(x)
        y_true = y.view((-1, 1)).type_as(x)

        # 2. Compute loss
        train_loss = self.loss(y_logits, y_true)

        # 3. Compute accuracy:
        train_accuracy = self.train_acc(sigmoid(y_logits), y_true.int())
        self.log("train_acc", train_accuracy, prog_bar=True)

        return train_loss

    def validation_step(self, batch, batch_idx):
        # 1. Forward pass:
        x, y = batch
        y_logits = self.forward(x)
        y_true = y.view((-1, 1)).type_as(x)

        # 2. Compute loss
        self.log("val_loss", self.loss(y_logits, y_true), prog_bar=True)

        # 3. Compute accuracy:
        valid_accuracy = self.valid_acc(sigmoid(y_logits), y_true.int())
        self.log("val_acc", valid_accuracy, prog_bar=True)

    def configure_optimizers(self):
        parameters = list(self.parameters())
        trainable_parameters = list(filter(lambda p: p.requires_grad, parameters))
        optimizer = optim.Adam(trainable_parameters, lr=self.lr)
        return optimizer

if __name__ == '__main__':

    # Trainer
    neptune_logger = NeptuneLogger(project_name=NEPTUNE_PROJECT_NAME,
                                   experiment_name=NEPTUNE_EXPERIMENT_NAME)

    model_checkpoint = ModelCheckpoint(dirpath=MODEL_CHECKPOINT_DIR_PATH,
                                       monitor='val_acc',
                                       save_top_k=1,
                                       mode='max',
                                       filename='custom_model_{epoch}',)

    trainer_kwargs_ = {'weights_summary': None,
                       'progress_bar_refresh_rate': 1,
                       'num_sanity_val_steps': 0,
                       'gpus': [0],
                       'max_epochs': 10,
                       'logger': neptune_logger,
                       'callbacks': [model_checkpoint]}

    cv = CV(**trainer_kwargs_)

    # LightningModule
    clf = MyCustomModel(lr=1e-3)

    # Run a 5-fold cross-validation experiment:
    image_data = CatsDogsDataCV(data_dir=DATA_DIR, n_splits=5, stratify=False)

    cv.fit(clf, image_data)
jbschiratti commented 3 years ago

Sadly, people seem to have lost interest in this issue...

Borda commented 3 years ago

@jbschiratti are you interested in bringing it up and implementing it? :raccoon:

jbschiratti commented 3 years ago

@Borda Sure, I am! If you think that this feature should be added to lightning, of course!

CarlosUziel commented 3 years ago

I am also very interested in this feature. As it has been argued before in this thread, CV really adds value to research, becoming a sort of standard to give credibility and robustness to any ML results. Having said that, I have tried to implement my own version, heavily based on @jbschiratti great contribution.

For me it was more intuitive to create an abstraction for CV to be applied to any data modules (which is how I decided to structure my data, since I really liked this idea from the library). I have made a quick and dirty UML diagram to show how I imagine cross-validation could be implemented, keeping in mind what I found to be intuitive for me.

UML Diagram (incomplete, focus on class relationships)

image

As you can see above, a CVTrainer takes an already-initialized Trainer (serving as base trainer that can then deep-copied in each kfold iteration). To fit a CVTrainer, one needs a LightningModule and a LightningCVDataModule that can provide each train/val split. With this solution, one can easily switch between single-model training (LightningDataModule) and k-fold training (LightningCVDataModule) without making any changes to the classes already built for each data set.

My current research is on tabular data, so keep in mind that my current assumptions might not work for other data types. Please do feel free to point that out to work towards a more generic solution.

Simple CVDataModule implementation:

"""
    Cross validation for Pytorch Lightning Data Modules
"""

import os
from abc import abstractmethod, ABC
from typing import Tuple

import pytorch_lightning as pl
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, ConcatDataset, Subset

class CVDataModule(ABC):

    def __init__(self,
                 data_module: pl.LightningDataModule,
                 n_splits: int = 10,
                 shuffle: bool = True):
        self.data_module = data_module
        self._n_splits = n_splits
        self._shuffle = shuffle

    @abstractmethod
    def split(self):
        pass

class KFoldCVDataModule(CVDataModule):
    """
        K-fold cross-validation data module

    Args:
        data_module: data module containing data to be split
        n_splits: number of k-fold iterations/data splits
    """

    def __init__(self,
                 data_module: pl.LightningDataModule,
                 n_splits: int = 10):
        super().__init__(data_module, n_splits)
        self._k_fold = KFold(n_splits=self._n_splits, shuffle=self._shuffle)

        # set dataloader kwargs if not available in data module (as in the default one)
        self.dataloader_kwargs = data_module.__getattribute__('dataloader_kwargs') or {}

        # set important defaults if not present
        self.dataloader_kwargs['batch_size'] = self.dataloader_kwargs.get('batch_size', 32)
        self.dataloader_kwargs['num_workers'] = self.dataloader_kwargs.get('num_workers', os.cpu_count())
        self.dataloader_kwargs['shuffle'] = self.dataloader_kwargs.get('shuffle', True)

    def get_data(self):
        """
            Extract and concatenate training and validation datasets from data module.
        """
        self.data_module.setup()
        train_ds = self.data_module.train_dataloader().dataset
        val_ds = self.data_module.val_dataloader().dataset
        return ConcatDataset([train_ds, val_ds])

    def split(self) -> Tuple[DataLoader, DataLoader]:
        """
            Split data into k-folds and yield each pair
        """
        # 0. Get data to split
        data = self.get_data()

        # 1. Iterate through splits
        for train_idx, val_idx in self._k_fold.split(range(len(data))):
            train_dl = DataLoader(Subset(data, train_idx),
                                  **self.dataloader_kwargs)
            val_dl = DataLoader(Subset(data, val_idx),
                                **self.dataloader_kwargs)

            yield train_dl, val_dl

Simple CVTrainer implementation:

"""
    todo
"""
from copy import deepcopy

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import LoggerCollection, LightningLoggerBase

from data.cv_modules import CVDataModule

class CVTrainer:

    def __init__(self, trainer: Trainer):
        super().__init__()
        self._trainer = trainer

    @staticmethod
    def _update_logger(logger: LightningLoggerBase, fold_idx: int):
        """
            Change a model logger parameters to log new fold
        Args:
            logger: Logger to update
            fold_idx: Fold ID
        """
        if hasattr(logger, 'experiment_name'):
            logger_key = 'experiment_name'
        elif hasattr(logger, 'name'):
            logger_key = 'name'
        else:
            raise AttributeError('The logger associated with the trainer '
                                 'should have an `experiment_name` or `name` '
                                 'attribute.')
        new_experiment_name = getattr(logger, logger_key) + f'/fold_{fold_idx}'
        setattr(logger, logger_key, new_experiment_name)

    @staticmethod
    def update_modelcheckpoint(model_ckpt_callback: ModelCheckpoint, fold_idx: int):
        """
            Update model checkpoint object with fold information
        Args:
            model_ckpt_callback: Model checkpoint object
            fold_idx: Fold ID
        """
        _default_filename = '{epoch}-{step}'
        _suffix = f'_fold{fold_idx}'
        if model_ckpt_callback.filename is None:
            new_filename = _default_filename + _suffix
        else:
            new_filename = model_ckpt_callback.filename + _suffix
        setattr(model_ckpt_callback, 'filename', new_filename)

    def update_loggers(self, trainer: Trainer, fold_idx: int):
        """
            Change model's loggers parameters to log new fold
        Args:
            trainer: Trainer whose logger to update
            fold_idx: Fold ID
        """
        if not isinstance(trainer.logger, LoggerCollection):
            _loggers = [trainer.logger]
        else:
            _loggers = trainer.logger

        # Update loggers:
        for _logger in _loggers:
            self._update_logger(_logger, fold_idx)

    def fit(self, model: pl.LightningModule, data: CVDataModule):
        for fold_idx, loaders in enumerate(data.split()):

            # Clone model & trainer:
            _model = deepcopy(model)
            _trainer = deepcopy(self._trainer)

            # Update loggers and callbacks:
            #self.update_loggers(_trainer, fold_idx)
            for callback in _trainer.callbacks:
                if isinstance(callback, ModelCheckpoint):
                    self.update_modelcheckpoint(callback, fold_idx)

            # fit
            _trainer.fit(_model, *loaders)

Again, heavily based on @jbschiratti contribution (thank you!). A few points I changed:

  1. I thought passing the trainer arguments directly to build it later on was very unintuitive, since it is hidden from the user what these arguments should actually be and when/what are they used for. I find it better to just pass a Trainer (the same you would use to train your model a single time and thus a more natural approach). Then this trainer is actually never used, but used as a "base trainer" that is deep-copied for each training iteration. Do you find any situation where this implementation could lead to problems? Do you have a better suggestion?
  2. I was not able to update the logger on each iteration since the name attribute is a property without a setter, at least on v.1.2.8. How did you do it @jbschiratti? What am I missing?
  3. How to copy the model needs more attention, because when I have just enough memory to run one training iteration, I get a CUDA out of memory error in the second one, maybe because one needs to delete the previous model first before the next iteration?

Please let me know what you think, I look forward to your feedback and suggestions!

jbschiratti commented 3 years ago

@CarlosUziel Thank you for your post. Overall what you propose is quite similar to what I proposed. If I understood correctly, the main differences are:

@Borda I think that we're close to having a proof of concept. Shall we formalize it and start a PR?

Borda commented 3 years ago

I think that we're close to having a proof of concept. Shall we formalize it and start a PR?

yeah, great talk to you today, go ahead!

hengee commented 3 years ago

So any update?