lightly-ai / lightly

A python library for self-supervised learning on images.
https://docs.lightly.ai/self-supervised-learning/
MIT License
3.16k stars 284 forks source link

Make benchmarks compatible with PyTorch Lighting 2.0 #1113

Closed guarin closed 1 year ago

guarin commented 1 year ago

Follow-up from #1112

The training_epoch_end and validation_epoch_end hooks which we used in BenchmarkModule were removed in PyTorch Lightning 2.0.

We can replace training_epoch_end with on_validation_epoch_start. But replacing validation_epoch_end will be more effort as we use the outputs to calculate the top1 scores.

faris-k commented 1 year ago

Hi @guarin, I think I might have some suggestions to update the benchmarking script 😃 In short, I think you can simply use on_validation_epoch_start as you suggested, then manually collect the outputs of the validation_step, and finally use the on_validation_epoch_end method. As per the Lightning 2.0 release notes, manual collection of outputs is handled like so (see release notes here):

Before

import lightning as L

class LitModel(L.LightningModule):

    def training_step(self, batch, batch_idx):
        ...
        return {"loss": loss, "banana": banana}

    # `outputs` is a list of all bananas returned in the epoch
    def training_epoch_end(self, outputs):
        avg_banana = torch.cat(out["banana"] for out in outputs).mean()  

Now

import lightning as L

class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        # 1. Create a list to hold the outputs of `*_step`
        self.bananas = []

    def training_step(self, batch, batch_idx):
        ...
        # 2. Add the outputs to the list
        # You should be aware of the implications on memory usage
        self.bananas.append(banana)
        return loss

    # 3. Rename the hook to `on_*_epoch_end`
    def on_train_epoch_end(self):
        # 4. Do something with all outputs
        avg_banana = torch.cat(self.bananas).mean()
        # Don't forget to clear the memory for the next epoch!
        self.bananas.clear()

I've modified the benchmarking module in the past for my own use, and I've made a few more adjustments to make it work with Lightning 2.0. It also uses torchmetrics instead of manually computing performance metrics, but otherwise, it mostly follows BenchmarkModule. I can confirm that this works with the latest versions of PyTorch Lightning and lightly 😁

# Modified from https://github.com/lightly-ai/lightly/blob/master/lightly/utils/benchmarking.py
# See https://arxiv.org/abs/1805.01978 for more details on kNN feature evaluation
class KNNBenchmarkModule(pl.LightningModule):
    """A PyTorch Lightning Module for automated kNN callback with support for torchmetrics.

    Modified from https://github.com/lightly-ai/lightly/blob/master/lightly/utils/benchmarking.py

    At the end of every training epoch we create a feature bank by feeding the
    `dataloader_kNN` passed to the module through the backbone.
    At every validation step we predict features on the validation data.
    After all predictions on validation data, we evaluate the predictions on a 
    kNN classifier on the validation data using the feature_bank features from 
    the train data.

    Attributes:
        backbone:
            The backbone model used for kNN validation. Make sure that you set the
            backbone when inheriting from `BenchmarkModule`.
        max_accuracy:
            Maximum test accuracy the benchmarked model has achieved.
        max_f1:
            Maximum test f1 score the benchmarked model has achieved.
        dataloader_kNN:
            Dataloader to be used after each training epoch to create feature bank.
        num_classes:
            Number of classes. E.g. for cifar10 we have 10 classes. (default: 10)
        knn_k:
            Number of nearest neighbors for kNN (default: 25)
        knn_t:
            Temperature parameter for kNN (default: 0.1)
    """

    def __init__(
        self,
        dataloader_kNN: DataLoader,
        num_classes: int,
        knn_k: int = 25,  # TODO: find a good default value, 200 is too high for class imbalance
        knn_t: float = 0.1,
    ):
        super().__init__()
        self.backbone = nn.Module()
        self.max_accuracy = 0.0
        self.max_f1 = 0.0
        self.dataloader_kNN = dataloader_kNN
        self.num_classes = num_classes
        self.knn_k = knn_k
        self.knn_t = knn_t

        # Initialize metrics for validation; use macro averages for imbalanced classes
        self.val_accuracy = MulticlassAccuracy(num_classes=num_classes, average="macro")
        self.val_f1 = MulticlassF1Score(num_classes=num_classes, average="macro")

        # Dummy param tracks the device the model is using
        self.dummy_param = nn.Parameter(torch.empty(0))

        # `*_epoch_end` hooks were removed; you'll need to manually store outputs of `on_*_epoch_end`
        self.all_preds = []
        self.all_targets = []

    # Previously, we used the `training_epoch_end` hook to update the feature bank
    def on_validation_epoch_start(self):
        # Note that we don't need to use self.eval() or torch.no_grad() here
        # Lightning uses on_validation_model_eval() and on_validation_model_train()
        self.feature_bank = []
        self.targets_bank = []
        for data in self.dataloader_kNN:
            img, target, _ = data
            img = img.to(self.dummy_param.device)
            target = target.to(self.dummy_param.device)
            feature = self.backbone(img).squeeze()
            feature = F.normalize(feature, dim=1)
            self.feature_bank.append(feature)
            self.targets_bank.append(target)
        self.feature_bank = torch.cat(self.feature_bank, dim=0).t().contiguous()
        self.targets_bank = torch.cat(self.targets_bank, dim=0).t().contiguous()

    # We'll need to manually store the outputs of the validation step to our lists
    def validation_step(self, batch, batch_idx):
        images, targets, _ = batch
        feature = self.backbone(images).squeeze()
        feature = F.normalize(feature, dim=1)
        pred_labels = knn_predict(
            feature,
            self.feature_bank,
            self.targets_bank,
            self.num_classes,
            self.knn_k,
            self.knn_t,
        )
        preds = pred_labels[:, 0]
        self.all_preds.append(preds)
        self.all_targets.append(targets)

    # Previously, we used `validation_epoch_end(self, outputs)` to compute the metrics
    def on_validation_epoch_end(self):
        # Concatenate all predictions and targets
        all_preds = torch.cat(self.all_preds, dim=0)
        all_targets = torch.cat(self.all_targets, dim=0)

        # Update metrics
        self.val_accuracy(all_preds, all_targets)
        self.val_f1(all_preds, all_targets)
        accuracy = self.val_accuracy.compute().item()
        f1 = self.val_f1.compute().item()

        # Update maxima
        if accuracy > self.max_accuracy:
            self.max_accuracy = accuracy
        if f1 > self.max_f1:
            self.max_f1 = f1

        # Log metrics
        self.log("knn_accuracy", self.val_accuracy, on_epoch=True, prog_bar=True)
        self.log("knn_f1", self.val_f1, on_epoch=True, prog_bar=True)

        # Remember to clear the predictions and targets once we finish the validation epoch!
        self.all_preds.clear()
        self.all_targets.clear()

    def predict_step(self, batch, batch_idx):
        images, _, _ = batch
        return self.backbone(images)
guarin commented 1 year ago

This looks awesome! And thanks a lot for the pointers!

guarin commented 1 year ago

Partially completed in #1136 Only the LARS optimizer remains incompatible with PyTorch Lightning 2.0.

guarin commented 1 year ago

We added LARS optimizer recently and Lightly should now be fully compatible with PyTorch Lightning 2.0