vturrisi / solo-learn

solo-learn: a library of self-supervised methods for visual representation learning powered by Pytorch Lightning
MIT License
1.38k stars 181 forks source link

class loss in ssl method #354

Closed cheliu-computation closed 1 year ago

cheliu-computation commented 1 year ago

Hi vturrisi,

First thanks for your effort for this amazing repo of ssl.

In some of methods, such as barlowtwins, byol, I find two loss during pretrain stage: https://github.com/vturrisi/solo-learn/blob/3ee84db28d8760a6eb2bba922cad8de35ee1d2b6/solo/methods/barlow_twins.py#L129

In this line, it shows two losses are returned for updating the pretrained model (I am not sure both loss are used). If the model is pretrained with class_loss, does it mean the model is not fully SSL? Because in the line: https://github.com/vturrisi/solo-learn/blob/3ee84db28d8760a6eb2bba922cad8de35ee1d2b6/solo/methods/base.py#L471

The model has one classifier to do classification task with label. If the model is optmized with barlowtwins and class_loss bothly, Does it introduce more information in pretraining?

gcanat commented 1 year ago

class_loss is only computed for data for which there is a known target class. It is used for online evaluation. It is not backpropagated to the backbone because the classifier head is detached from the graph: logits = self.classifier(feats.detach()). All data that is strictly for SSL have a class index of -1 which is ignored in the class_loss.

cheliu-computation commented 1 year ago

Many thanks for your reply! In the online evaluation, is it the 'train acc' on wandb? Does it use cifar100 test set or still train set? Because I also find the 'val_acc' on wandb, their results seems similar, but I am not sure they use same dataset for evaluation.

Another question: In the pretrain, is the linear classifier is random initilized in every evaluation stage or updated every step(isolated with backbone)?

gcanat commented 1 year ago

1/ pretty sure train_acc is computed on training data, and val_acc is computed on validation data, as their names suggests. 2/ from what I remember, the linear classifier is learned during the whole training procedure. But I will let @vturrisi confirm/correct me if needed.

vturrisi commented 1 year ago

@gcanat exactly. Both the backbone and linear classifier are trained together during training, with the classification loss not updating the backbone.