NorbertZheng / read-papers

My paper reading notes.
MIT License
7 stars 0 forks source link

Zichen Wang | Contrastive loss for supervised classification. #56

Closed NorbertZheng closed 1 year ago

NorbertZheng commented 1 year ago

Zichen Wang. Contrastive loss for supervised classification.

NorbertZheng commented 1 year ago

Overview

Contrasting cross-entropy loss and contrastive loss.

Recently, researchers (Khosla et al.) from Google Research and MIT published a paper entitled “Supervised Contrastive Learning”. The paper presented a new loss function, namely “contrastive loss”, to train supervised deep networks, based on contrastive learning. They demonstrated that contrastive loss performs significantly better than the conventional cross-entropy loss for classification across a range of neural architectures and data augmentation regimes on the ImageNet dataset.

I found this paper to be particularly interesting as the idea is very concise and could potentially be an easy and universal improvement for supervised classification tasks, or even regressions. Before getting into the details, let’s first review the basics of loss functions (aka, objective functions) used for supervised classification.

NorbertZheng commented 1 year ago

Cross entropy loss

Cross entropy, also known as log loss, or logistic loss, is arguably the most commonly used loss function for classifications. As the name suggests, it came from information theory, which measures the mutual entropy between two probability distributions, $p$ and $q$. It is also closely related to Kullback–Leibler divergence, and can be written as the sum of entropy $H(p)$ and the KL divergence from $p$ to $q$:

$$ H(p,q)=H(p)+D_{KL}(p||q). $$

When used as a loss function for classification, one can use cross entropy measure the differences between the ground truth class distribution and the distribution of the predicted classes:

$$ L{CE}=-\sum{c=1}^{M}y_{c}logp(y=c|x), $$

where $M$ is the number of classes $c$ and $y_{c}$ is a binary indicator if the class label is $c$ and $p(y=c|x)$ is what the classifier thinks should be the probability of the label being $c$ given the input feature vector $x$.

NorbertZheng commented 1 year ago

Contrastive loss

Contrastive loss is widely-used in unsupervised and self-supervised learning. Originally developed by Hadsell et al. in 2016 from Yann LeCun’s group,

It defines a binary indicator $Y$ for each pair of samples stating whether they should be deemed dissimilar ($Y=0$ if $x{1}, x{2}$ are deemed similar; $Y=1$ otherwise), and a learnable distance function $D{W}(x{1},x{2})$ between a pair of samples $x{1},x_{2}$, parameterized by the weights $W$ in the neural network. The contrastive loss is defined as:

$$ L{contrast}=\frac{1}{2}\left[(1-Y)\cdot (D{W})^{2}+Y\cdot \max(0, m-D_{W})^{2}\right], $$

where $m>0$ is a margin. The margin defines a radius around the embedding space of a sample so that dissimilar pairs of samples only contribute to the contrastive loss function if the distance $D_{W}$ is within the margin.

Intuitively, this loss function encourages the neural network to learn an embedding to place samples with the same labels close to each other, while distancing the samples with different labels in the embedding space.

NorbertZheng commented 1 year ago

Contrastive loss for self-supervised and supervised learning

In a self-supervised setting where labels are unavailable and the goal is to learn a useful embedding for the data, the contrastive loss is used

Just like the illustration in the Khosla paper where different augmentations of the Bichon pictures should be close to each other, and far from any other pictures, be it other dogs or cats. However, as the authors point out,

because of the introduction of false negatives: an embedding of a Bichon should also be in the same neighborhood as embeddings of other dogs, be it of a Corgi or a Spaniel.

To adapt contrastive loss to supervised learning, Khosla and colleagues developed a two-stage procedure to combine the use of labels and contrastive loss:

The experiments in the original study were performed on ImageNet. The authors also evaluated different data augmentation options and various types of state-of-the-art architectures such as VGG-19 and ResNets.

NorbertZheng commented 1 year ago

Experiments on MNIST and Fashion MNIST datasets

Not being a computer vision expert, I am more interested in whether the proposed supervised contrastive loss is superior than cross entropy for general purpose classification problems. To run the experiments comparing the effectiveness of canonical loss and contrastive loss functions on my laptop without burning my laps, I chose MNIST and Fashion MNIST datasets which are much smaller.

I also ignored the spatial relationships among the pixels and avoided using any convolutional layers just to simulate tabular datasets. The architecture I used for the following experiments is rather simple, two dense hidden layers with 256 neurons each, connected by leaky ReLU activations. The output from the last hidden layers was normalized to position the output vector in the unit hypersphere. This architecture is identical to the encoder trained with contrastive loss and the MLP baseline. The last classification layer is composed of 10 units corresponding to the 10 classes.

Following the same procedure proposed by the Khosla paper: the encoder was pre-trained in stage 1 using contrastive loss, then frozen during stage 2 to just train the classifier. As the results are shown below, the model with supervised contrastive loss indeed outperforms the MLP baseline with cross-entropy loss, with exactly the same architecture, on both datasets. The improvement is higher on Fashion MNIST, a more difficult task than MNIST.

image Performance (accuracy) on the hold-out test sets of MNIST and Fashion MNIST datasets.

Another interesting observation is that the learning curves are smoother for models with supervised contrastive loss, perhaps due to a while pre-trained representation.

NorbertZheng commented 1 year ago

Maybe we can use such a training paradigm to improve the performance of the MEG decoding task?

NorbertZheng commented 1 year ago

To confirm the performance improvement was due to a better embedding space, I checked the PCA projections of the embeddings from the MLP baseline with cross-entropy, and from that learned using contrastive loss. As a negative control, I also checked the PCA projection of the original data space without any embeddings.

image PCA projections of the embeddings (or lack of thereof) learned by models. From left to right: raw data space; last hidden layer of the baseline MLP; projections learned using contrastive loss.

As the scatter plots show, both the MLP model and the contrastive model clusters the samples with the same labels better than the original data, but the clusters in the contrastive embedding are much tighter. To make this more salient, I plotted the densities of the PCA projections and one can clearly see 10 distinctive clusters from the contrastive embedding:

image Joint plots showing the densities in the PCA projections learned by models. From left to right: raw data space; last hidden layer of the baseline MLP; projections learned using contrastive loss.

Above plots are produced on the MNIST data, similar observations can also persist on the Fashion MNIST dataset, although a few of the classes are mingled together.

image PCA of the representations learned by different models on Fashion MNIST dataset.

To conclude, my experiments clearly reproduced the findings by the original study, that

The fact that it also works on much smaller datasets without data augmentation, and without using domain-specific architectures such as convolution, is very encouraging. It hints that supervised contrastive loss can be used as a universal technique for any supervised tasks. I look forward to applying it to future ML tasks and exploring its applicability in regressions.

NorbertZheng commented 1 year ago

Code

Codes used for these experiments are available here: https://github.com/wangz10/contrastive_loss.

NorbertZheng commented 1 year ago

References