Closed NorbertZheng closed 1 year ago
A comprehensive guide to four contrastive loss functions for contrastive learning.
In a previous post, I wrote about contrastive learning in supervised classification and performed some experiments on MNIST dataset and alike to find that the two-stage method proposed in the Khosla et al. 2020 paper indeed shows significant improvement for supervised classification task by learning meaningful embeddings with contrastive loss. Later I found my experiments actually used a different contrastive loss function than Khosla et al. proposed. Although sharing the same intuition of explicitly contrasting examples against each other with respect to their labels, different contrastive loss functions can have their own nuances. In this post, I will review a series of contrastive loss functions and compare their performances in supervised classification task.
Contrastive loss functions were invented for metric learning, which intends to
In the context of classification, the desired metric would render a pair of examples with the same label more similar than a pair of examples with different labels. Deep metric learning involves deep neural networks to embed data points to a lower-dimensional space with nonlinearity, then using contrastive loss function to optimize the parameters in the neural networks. Recent research projects have applied deep metric learning to self-supervised learning, supervised learning, and even reinforcement learning, for example, Contrastively-trained Structured World Models (C-SWMs).
To review different contrastive loss functions in the context of deep metric learning, I use the following formalization. Let $x$ be the input feature vector and $y$ be its label. Let $f(\cdot)$ be an encoder network mapping the input space to the embedding space and let $z=f(x)$ be the embedding vector.
Here I review four contrastive loss functions in chronological order. I slightly changed the names of a few functions to highlight their distinctive characteristics.
Max margin contrastive loss function takes a pair of embedding vectors $z{i}$ and $z{j}$ as inputs. It essentially equates the Euclidean distance between them if they have the same label ( $y{i}=y{j}$ ) and is otherwise equivalent to hinge loss. It has a margin parameter $m>0$ to impose a lower bound on the distance between a pair of samples with different labels.
Max margin contrastive loss should already be enough to train PEs of TEM alone?
Triplet loss operates on a triplet of vectors whose labels follow $y{i}=y{j}$ and $y{i}\neq y{k}$. That is to say two of the three $(z{i},z{j})$ shared the same label and a third vector $z{k}$ has a different label. In triplet learning literatures, they are termed anchor $(z{i})$, positive $(z{j})$, and negative $(z{k})$, respectively. Triplet loss is defined as:
where $m$ again is a margin parameter that requires the delta distances between anchor-positive and anchor-negative to be larger than $m$. The intuition for this loss function is to push the negative sample outside of the neighborhood by a margin while keeping positive samples within the neighborhood. Here is a great graphical demonstration showing the effect of triplet loss from the original paper: Before and after training using triplet loss (from Weinberger et al. 2005).
Triplet mining Based on the definition of the triplet loss, a triplet may have the following three scenarios before any training:
Triplet loss has been used to learn embeddings for faces in the FaceNet (Schroff et al. 2015) paper. Schroff et al. argued that triplet mining is crucial for model performance and convergence. They also found that
Multi-class N-pair loss is a generalization of triplet loss allowing joint comparison among more than one negative samples. When applied on a pair of positive samples $z{i}$ and $z{j}$ sharing the same label $(y{i}=y{j})$ from a mini-batch with $2N$ samples, it is calculated as: where $z{i}z{j}$ is the inner product, which is equivalent to cosine similarity when both vectors have unit norm.
As the figure below shows, $N$-pair loss pushes $2(N-1)$ negative samples away simultaneously instead of one at a time: Triplet loss (left) and its extension $(N+1)$-triplet loss (right) (from Sohn 2016)
With some algebraic manipulations, multi-class $N$-pair loss can be written as the following: This form of multi-class $N$-pair loss helps us introduce the next loss function.
Let’s first look at the self-supervised version of NT-Xent loss. NT-Xent is coined by Chen et al. 2020 in the SimCLR paper and is short for “normalized temperature-scaled cross-entropy loss”. It is a modification of the multi-class N-pair loss with the addition of the temperature parameter $(\tau)$ to scale the cosine similarities: Self-supervised NT-Xent loss.
Chen et al. found that
In addition, they showed that the optimal temperature differs in different
Khosla et al. later extended NT-Xent loss for supervised learning: Supervised NT-Xent loss.
Next I assess the whether these contrastive loss functions can help the encoder network to learn meaningful representations of the data to aid the classification task. Following the exact same experimental settings from my previous post, with small batch size (32) and low learning rate (0.001), I found
Performance (accuracy) on the hold-out test sets of MNIST and Fashion MNIST datasets (Results from triplet with hard negative mining not shown).
These results confirmed the benefit of using contrastive loss function in the pre-training the encoder part of the network for the subsequent classification. It also underscored the importance of triplet mining for triplet loss. Specifically,
semi-hard mining works the best on these experiments, which is consistent with the FaceNet paper.
Both Chen et al. (SimCLR) and Khosla et al. use very large batch sizes and higher learning rates for the NT-Xent loss to achieve greater performances.
I next experimented with different batch sizes 32, 256 and 2048 with learning rates of 0.001, 0.01, and 0.2, respectively.
The results show that the performances diminish as the batch size increases for all loss functions. Although triplet loss with semi-hard negative mining performs very well on small/medium batches, it is very memory intensive and my 16G RAM is impossible to handle it with a batch size of 2048. Supervised NT-Xent loss does turn to perform relatively better on larger batch size compared to its counterparts.
Next, I checked the PCA projections of the embeddings learned using contrastive loss functions to see if they learn any informative representations during the pre-training stage. PCA projections of the embeddings learned by encoder networks with different contrastive loss functions and batch sizes on MNIST dataset. From left to right: projections learned by 1) max margin loss; 2) triplet loss with semi-hard mining; 3) multi-class N-pair loss; 4) supervised NT-Xent loss. From top to bottom: batch sizes of 32, 256, 2048.
Joint plots showing the densities in the PCA projections learned by models on MNIST dataset. From left to right: projections learned by 1) max margin loss; 2) triplet loss with semi-hard mining; 3) multi-class N-pair loss; 4) supervised NT-Xent loss. From top to bottom: batch sizes of 32, 256, 2048.
Judging from the colored PCA projections and their densities, we can see both max margin and supervised NT-Xent learn tighter clusters for each class, whereas clusters from triplet loss with semi-hard mining are most dilated but still distinctive. As the batch size increase, the representation qualities degenerate in multi-class $N$-pair loss and max-margin loss, but
Below are the PCA projections of the learned representation on a more difficult Fashion MNIST dataset. Overall it shows similar observations with MNIST. PCA projections of the embeddings learned by encoder networks with different contrastive loss functions and batch sizes on Fashion MNIST dataset. From left to right: projections learned by 1) max margin loss; 2) triplet loss with semi-hard mining; 3) multi-class N-pair loss; 4) supervised NT-Xent loss. From top to bottom: batch sizes of 32, 256, 2048.
Contrastive loss functions are extremely helpful for improving supervised classification tasks by learning useful representations. Max margin and supervised NT-Xent loss are the top performers in the datasets experimented (MNIST and Fashion MNIST). Additionally, NT-Xent loss is robust to large batch sizes.
Of note, all the contrastive loss functions reviewed here have hyperparameters e.g. margin, temperature, similarity/distance metrics for input vectors. These hyperparameter may affect the results drastically as suggested by other studies and should potentially be optimized for different datasets.
Codes used for these experiments are available here: https://github.com/wangz10/contrastive_loss.
Zichen Wang. Contrasting contrastive loss functions.