Closed guarin closed 1 year ago
Hi @guarin, I think I might have some suggestions to update the benchmarking script 😃 In short, I think you can simply use on_validation_epoch_start
as you suggested, then manually collect the outputs of the validation_step
, and finally use the on_validation_epoch_end
method. As per the Lightning 2.0 release notes, manual collection of outputs is handled like so (see release notes here):
import lightning as L
class LitModel(L.LightningModule):
def training_step(self, batch, batch_idx):
...
return {"loss": loss, "banana": banana}
# `outputs` is a list of all bananas returned in the epoch
def training_epoch_end(self, outputs):
avg_banana = torch.cat(out["banana"] for out in outputs).mean()
import lightning as L
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
# 1. Create a list to hold the outputs of `*_step`
self.bananas = []
def training_step(self, batch, batch_idx):
...
# 2. Add the outputs to the list
# You should be aware of the implications on memory usage
self.bananas.append(banana)
return loss
# 3. Rename the hook to `on_*_epoch_end`
def on_train_epoch_end(self):
# 4. Do something with all outputs
avg_banana = torch.cat(self.bananas).mean()
# Don't forget to clear the memory for the next epoch!
self.bananas.clear()
I've modified the benchmarking module in the past for my own use, and I've made a few more adjustments to make it work with Lightning 2.0. It also uses torchmetrics instead of manually computing performance metrics, but otherwise, it mostly follows BenchmarkModule
. I can confirm that this works with the latest versions of PyTorch Lightning and lightly 😁
# Modified from https://github.com/lightly-ai/lightly/blob/master/lightly/utils/benchmarking.py
# See https://arxiv.org/abs/1805.01978 for more details on kNN feature evaluation
class KNNBenchmarkModule(pl.LightningModule):
"""A PyTorch Lightning Module for automated kNN callback with support for torchmetrics.
Modified from https://github.com/lightly-ai/lightly/blob/master/lightly/utils/benchmarking.py
At the end of every training epoch we create a feature bank by feeding the
`dataloader_kNN` passed to the module through the backbone.
At every validation step we predict features on the validation data.
After all predictions on validation data, we evaluate the predictions on a
kNN classifier on the validation data using the feature_bank features from
the train data.
Attributes:
backbone:
The backbone model used for kNN validation. Make sure that you set the
backbone when inheriting from `BenchmarkModule`.
max_accuracy:
Maximum test accuracy the benchmarked model has achieved.
max_f1:
Maximum test f1 score the benchmarked model has achieved.
dataloader_kNN:
Dataloader to be used after each training epoch to create feature bank.
num_classes:
Number of classes. E.g. for cifar10 we have 10 classes. (default: 10)
knn_k:
Number of nearest neighbors for kNN (default: 25)
knn_t:
Temperature parameter for kNN (default: 0.1)
"""
def __init__(
self,
dataloader_kNN: DataLoader,
num_classes: int,
knn_k: int = 25, # TODO: find a good default value, 200 is too high for class imbalance
knn_t: float = 0.1,
):
super().__init__()
self.backbone = nn.Module()
self.max_accuracy = 0.0
self.max_f1 = 0.0
self.dataloader_kNN = dataloader_kNN
self.num_classes = num_classes
self.knn_k = knn_k
self.knn_t = knn_t
# Initialize metrics for validation; use macro averages for imbalanced classes
self.val_accuracy = MulticlassAccuracy(num_classes=num_classes, average="macro")
self.val_f1 = MulticlassF1Score(num_classes=num_classes, average="macro")
# Dummy param tracks the device the model is using
self.dummy_param = nn.Parameter(torch.empty(0))
# `*_epoch_end` hooks were removed; you'll need to manually store outputs of `on_*_epoch_end`
self.all_preds = []
self.all_targets = []
# Previously, we used the `training_epoch_end` hook to update the feature bank
def on_validation_epoch_start(self):
# Note that we don't need to use self.eval() or torch.no_grad() here
# Lightning uses on_validation_model_eval() and on_validation_model_train()
self.feature_bank = []
self.targets_bank = []
for data in self.dataloader_kNN:
img, target, _ = data
img = img.to(self.dummy_param.device)
target = target.to(self.dummy_param.device)
feature = self.backbone(img).squeeze()
feature = F.normalize(feature, dim=1)
self.feature_bank.append(feature)
self.targets_bank.append(target)
self.feature_bank = torch.cat(self.feature_bank, dim=0).t().contiguous()
self.targets_bank = torch.cat(self.targets_bank, dim=0).t().contiguous()
# We'll need to manually store the outputs of the validation step to our lists
def validation_step(self, batch, batch_idx):
images, targets, _ = batch
feature = self.backbone(images).squeeze()
feature = F.normalize(feature, dim=1)
pred_labels = knn_predict(
feature,
self.feature_bank,
self.targets_bank,
self.num_classes,
self.knn_k,
self.knn_t,
)
preds = pred_labels[:, 0]
self.all_preds.append(preds)
self.all_targets.append(targets)
# Previously, we used `validation_epoch_end(self, outputs)` to compute the metrics
def on_validation_epoch_end(self):
# Concatenate all predictions and targets
all_preds = torch.cat(self.all_preds, dim=0)
all_targets = torch.cat(self.all_targets, dim=0)
# Update metrics
self.val_accuracy(all_preds, all_targets)
self.val_f1(all_preds, all_targets)
accuracy = self.val_accuracy.compute().item()
f1 = self.val_f1.compute().item()
# Update maxima
if accuracy > self.max_accuracy:
self.max_accuracy = accuracy
if f1 > self.max_f1:
self.max_f1 = f1
# Log metrics
self.log("knn_accuracy", self.val_accuracy, on_epoch=True, prog_bar=True)
self.log("knn_f1", self.val_f1, on_epoch=True, prog_bar=True)
# Remember to clear the predictions and targets once we finish the validation epoch!
self.all_preds.clear()
self.all_targets.clear()
def predict_step(self, batch, batch_idx):
images, _, _ = batch
return self.backbone(images)
This looks awesome! And thanks a lot for the pointers!
Partially completed in #1136 Only the LARS optimizer remains incompatible with PyTorch Lightning 2.0.
We added LARS optimizer recently and Lightly should now be fully compatible with PyTorch Lightning 2.0
Follow-up from #1112
The
training_epoch_end
andvalidation_epoch_end
hooks which we used inBenchmarkModule
were removed in PyTorch Lightning 2.0.We can replace
training_epoch_end
withon_validation_epoch_start
. But replacingvalidation_epoch_end
will be more effort as we use theoutputs
to calculate the top1 scores.