lightly-ai / lightly

A python library for self-supervised learning on images.
https://docs.lightly.ai/self-supervised-learning/
MIT License
2.93k stars 253 forks source link

SSL Online Evaluator Callback with Custom Data (i.e., different than pre-training data) #977

Closed aqibsaeed closed 1 year ago

aqibsaeed commented 1 year ago

Hi,

Is it possible to use online evaluation of SSL encoder with a linear classifier to run on different data then the one being used for pre-training?

guarin commented 1 year ago

Hi @aqibsaeed,

Yes, it is common to evaluate a model on a different dataset than the dataset that was used for pretraining. In this case, you usually also have to fine-tune (train the classifier) on the dataset used for evaluation.

guarin commented 1 year ago

Sorry, it looks like I didn't understand your question correctly.

We currently only have online KNN evaluation implemented as BenchmarkModule. This module supports training and evaluation on different datasets. You can use it as follows:

class SimCLR(BenchmarkModule):
    def __init__(self, dataloader_knn_train, num_classes, knn_k, knn_t):
        super().__init__(dataloader_kNN=dataloader_knn_train, num_classes=num_classes, knn_k=knn_k, knn_t=knn_t)
        ...

dataloader_train = ... # dataset for training SSL model
dataloader_knn_train = ... # dataset for generating KNN features
dataloader_knn_val = ... # dataset on which KNN is evaluated
model = SimCLR(dataloader_knn_train=dataloader_knn_train, num_classes=100, knn_k=20, knn_t=0.1)
trainer = pytorch_lightning.Trainer(...)
trainer.fit(model, train_dataloaders=dataloader_train, val_dataloaders=dataloader_knn_val)

If you want to do online linear classification you could try something like this:

# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import torch
from torch import nn
import torchvision
from torchvision import transforms as T

from lightly.data import LightlyDataset
from lightly.data import SimCLRCollateFunction
from lightly.data.collate import imagenet_normalize
from lightly.loss import NTXentLoss
from lightly.models.modules import SimCLRProjectionHead

class SimCLR(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimCLRProjectionHead(512, 512, 128)
        self.classifier = nn.Linear(512, 10) # 10 classes

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z

    def forward_classifier(self, x):
        with torch.no_grad():
            # do not backprop to backbone
            x = self.backbone(x).flatten(start_dim=1)
        return self.classifier(x)

resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = SimCLR(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# setup ssl data

cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
dataset = LightlyDataset.from_torch_dataset(cifar10)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

collate_fn = SimCLRCollateFunction(
    input_size=32,
    gaussian_blur=0.,
)

dataloader_train_ssl = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

# setup classifier data
classifier_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=imagenet_normalize['mean'],
        std=imagenet_normalize['std'],
    )
])

cifar10_train = torchvision.datasets.CIFAR10("datasets/cifar10", train=True, download=True, transform=classifier_transform)
cifar10_val = torchvision.datasets.CIFAR10("datasets/cifar10", train=False, download=True, transform=classifier_transform)

dataloader_train_classifier = torch.utils.data.DataLoader(
    cifar10_train,
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

dataloader_val_classifier = torch.utils.data.DataLoader(
    cifar10_val,
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

# setup optimizers
criterion_ssl = NTXentLoss()
criterion_classifier = nn.CrossEntropyLoss()
optimizer_ssl = torch.optim.SGD(
    [
        *list(model.backbone.parameters()), 
        *list(model.projection_head.parameters())
    ], 
    lr=0.06,
    momentum=0.9,
)
optimizer_classifier = torch.optim.SGD(model.classifier.parameters(), lr=0.01, momentum=0.9)

print("Starting Training")
for epoch in range(10):
    model.train()

    # train ssl model
    total_ssl_loss = 0
    for (x0, x1), _, _ in dataloader_train_ssl:
        x0 = x0.to(device)
        x1 = x1.to(device)
        z0 = model(x0)
        z1 = model(x1)
        loss = criterion_ssl(z0, z1)
        total_ssl_loss += loss.detach()
        loss.backward()
        optimizer_ssl.step()
        optimizer_ssl.zero_grad()
    avg_ssl_loss = total_ssl_loss / len(dataloader_train_ssl)

    # train classifier
    total_classifier_train_loss = 0
    for x, labels in dataloader_train_classifier:
        x = x.to(device)
        labels = labels.to(device)
        pred = model.forward_classifier(x)
        loss = criterion_classifier(pred, labels)
        total_classifier_train_loss += loss.detach()
        loss.backward()
        optimizer_classifier.step()
        optimizer_classifier.zero_grad()
    avg_classifier_train_loss = total_classifier_train_loss / len(dataloader_train_classifier)

    # eval classifier
    total_classifier_val_loss = 0
    model.eval()
    for x, labels in dataloader_val_classifier:
        x = x.to(device)
        labels = labels.to(device)
        pred = model.forward_classifier(x)
        loss = criterion_classifier(pred, labels)
        total_classifier_val_loss += loss.detach()
    avg_classifier_val_loss = total_classifier_val_loss / len(dataloader_val_classifier)

    print(f"epoch: {epoch:>02}, ssl_loss: {avg_ssl_loss:.5f}, cls_train_loss: {avg_classifier_train_loss:.5f}, cls_val_loss: {avg_classifier_val_loss:.5f}")

If your ssl_train and classifier_train datasets are the same you can combine the SSL train and classifier train loops into a single loop and only forward the images once through the backbone. Is this what you are trying to do?