NorbertZheng / read-papers

My paper reading notes.
MIT License
8 stars 0 forks source link

Zichen Wang | Contrasting contrastive loss functions. #57

Closed NorbertZheng closed 1 year ago

NorbertZheng commented 1 year ago

Zichen Wang. Contrasting contrastive loss functions.

NorbertZheng commented 1 year ago

Overview

A comprehensive guide to four contrastive loss functions for contrastive learning.

In a previous post, I wrote about contrastive learning in supervised classification and performed some experiments on MNIST dataset and alike to find that the two-stage method proposed in the Khosla et al. 2020 paper indeed shows significant improvement for supervised classification task by learning meaningful embeddings with contrastive loss. Later I found my experiments actually used a different contrastive loss function than Khosla et al. proposed. Although sharing the same intuition of explicitly contrasting examples against each other with respect to their labels, different contrastive loss functions can have their own nuances. In this post, I will review a series of contrastive loss functions and compare their performances in supervised classification task.

NorbertZheng commented 1 year ago

Preliminary

Contrastive loss functions were invented for metric learning, which intends to

In the context of classification, the desired metric would render a pair of examples with the same label more similar than a pair of examples with different labels. Deep metric learning involves deep neural networks to embed data points to a lower-dimensional space with nonlinearity, then using contrastive loss function to optimize the parameters in the neural networks. Recent research projects have applied deep metric learning to self-supervised learning, supervised learning, and even reinforcement learning, for example, Contrastively-trained Structured World Models (C-SWMs).

To review different contrastive loss functions in the context of deep metric learning, I use the following formalization. Let $x$ be the input feature vector and $y$ be its label. Let $f(\cdot)$ be an encoder network mapping the input space to the embedding space and let $z=f(x)$ be the embedding vector.

NorbertZheng commented 1 year ago

Types of contrastive loss functions

Here I review four contrastive loss functions in chronological order. I slightly changed the names of a few functions to highlight their distinctive characteristics.

Max margin contrastive loss (Hadsell et al. 2006)

image Max margin contrastive loss function takes a pair of embedding vectors $z{i}$ and $z{j}$ as inputs. It essentially equates the Euclidean distance between them if they have the same label ( $y{i}=y{j}$ ) and is otherwise equivalent to hinge loss. It has a margin parameter $m>0$ to impose a lower bound on the distance between a pair of samples with different labels.

NorbertZheng commented 1 year ago

Max margin contrastive loss should already be enough to train PEs of TEM alone?

NorbertZheng commented 1 year ago

Triplet loss (Weinberger et al. 2006)

Triplet loss operates on a triplet of vectors whose labels follow $y{i}=y{j}$ and $y{i}\neq y{k}$. That is to say two of the three $(z{i},z{j})$ shared the same label and a third vector $z{k}$ has a different label. In triplet learning literatures, they are termed anchor $(z{i})$, positive $(z{j})$, and negative $(z{k})$, respectively. Triplet loss is defined as: image

where $m$ again is a margin parameter that requires the delta distances between anchor-positive and anchor-negative to be larger than $m$. The intuition for this loss function is to push the negative sample outside of the neighborhood by a margin while keeping positive samples within the neighborhood. Here is a great graphical demonstration showing the effect of triplet loss from the original paper: image Before and after training using triplet loss (from Weinberger et al. 2005).

Triplet mining Based on the definition of the triplet loss, a triplet may have the following three scenarios before any training:

Triplet loss has been used to learn embeddings for faces in the FaceNet (Schroff et al. 2015) paper. Schroff et al. argued that triplet mining is crucial for model performance and convergence. They also found that

NorbertZheng commented 1 year ago

Multi-class N-pair loss (Sohn 2016)

Multi-class N-pair loss is a generalization of triplet loss allowing joint comparison among more than one negative samples. When applied on a pair of positive samples $z{i}$ and $z{j}$ sharing the same label $(y{i}=y{j})$ from a mini-batch with $2N$ samples, it is calculated as: image where $z{i}z{j}$ is the inner product, which is equivalent to cosine similarity when both vectors have unit norm.

As the figure below shows, $N$-pair loss pushes $2(N-1)$ negative samples away simultaneously instead of one at a time: image Triplet loss (left) and its extension $(N+1)$-triplet loss (right) (from Sohn 2016)

With some algebraic manipulations, multi-class $N$-pair loss can be written as the following: image This form of multi-class $N$-pair loss helps us introduce the next loss function.

NorbertZheng commented 1 year ago

Supervised NT-Xent loss (Khosla et al. 2020)

Let’s first look at the self-supervised version of NT-Xent loss. NT-Xent is coined by Chen et al. 2020 in the SimCLR paper and is short for “normalized temperature-scaled cross-entropy loss”. It is a modification of the multi-class N-pair loss with the addition of the temperature parameter $(\tau)$ to scale the cosine similarities: image Self-supervised NT-Xent loss.

Chen et al. found that

In addition, they showed that the optimal temperature differs in different

Khosla et al. later extended NT-Xent loss for supervised learning: image Supervised NT-Xent loss.

NorbertZheng commented 1 year ago

Experimental results

Next I assess the whether these contrastive loss functions can help the encoder network to learn meaningful representations of the data to aid the classification task. Following the exact same experimental settings from my previous post, with small batch size (32) and low learning rate (0.001), I found

image Performance (accuracy) on the hold-out test sets of MNIST and Fashion MNIST datasets (Results from triplet with hard negative mining not shown).

These results confirmed the benefit of using contrastive loss function in the pre-training the encoder part of the network for the subsequent classification. It also underscored the importance of triplet mining for triplet loss. Specifically,

I next experimented with different batch sizes 32, 256 and 2048 with learning rates of 0.001, 0.01, and 0.2, respectively. image

The results show that the performances diminish as the batch size increases for all loss functions. Although triplet loss with semi-hard negative mining performs very well on small/medium batches, it is very memory intensive and my 16G RAM is impossible to handle it with a batch size of 2048. Supervised NT-Xent loss does turn to perform relatively better on larger batch size compared to its counterparts.

Next, I checked the PCA projections of the embeddings learned using contrastive loss functions to see if they learn any informative representations during the pre-training stage. image PCA projections of the embeddings learned by encoder networks with different contrastive loss functions and batch sizes on MNIST dataset. From left to right: projections learned by 1) max margin loss; 2) triplet loss with semi-hard mining; 3) multi-class N-pair loss; 4) supervised NT-Xent loss. From top to bottom: batch sizes of 32, 256, 2048.

image Joint plots showing the densities in the PCA projections learned by models on MNIST dataset. From left to right: projections learned by 1) max margin loss; 2) triplet loss with semi-hard mining; 3) multi-class N-pair loss; 4) supervised NT-Xent loss. From top to bottom: batch sizes of 32, 256, 2048.

Judging from the colored PCA projections and their densities, we can see both max margin and supervised NT-Xent learn tighter clusters for each class, whereas clusters from triplet loss with semi-hard mining are most dilated but still distinctive. As the batch size increase, the representation qualities degenerate in multi-class $N$-pair loss and max-margin loss, but

Below are the PCA projections of the learned representation on a more difficult Fashion MNIST dataset. Overall it shows similar observations with MNIST. image PCA projections of the embeddings learned by encoder networks with different contrastive loss functions and batch sizes on Fashion MNIST dataset. From left to right: projections learned by 1) max margin loss; 2) triplet loss with semi-hard mining; 3) multi-class N-pair loss; 4) supervised NT-Xent loss. From top to bottom: batch sizes of 32, 256, 2048.

NorbertZheng commented 1 year ago

Summary

Contrastive loss functions are extremely helpful for improving supervised classification tasks by learning useful representations. Max margin and supervised NT-Xent loss are the top performers in the datasets experimented (MNIST and Fashion MNIST). Additionally, NT-Xent loss is robust to large batch sizes.

Of note, all the contrastive loss functions reviewed here have hyperparameters e.g. margin, temperature, similarity/distance metrics for input vectors. These hyperparameter may affect the results drastically as suggested by other studies and should potentially be optimized for different datasets.

NorbertZheng commented 1 year ago

Code

Codes used for these experiments are available here: https://github.com/wangz10/contrastive_loss.

NorbertZheng commented 1 year ago

References