Closed NorbertZheng closed 1 year ago
We all know that
Researchers have been trying to develop methods that work with partially labeled data for years. There are quite a few techniques in Semi-supervised learning that works quite decently with partially labeled data. But still, most of them suffer significantly in the case of Deep learning. In this blog post, we are going to discuss a strategy that doesn’t require any labels, and it's called Contrastive Learning. So, without further ado let’s dive deep into the concept of contrastive learning.
Note: 80% of the time spent in a supervised learning ML project is invested in acquiring and cleaning the data for model training.
Contrastive Learning is a technique that is used generally in vision tasks lacking labeled data. By using the principle of contrasting samples against each other it learns
As the name suggests, samples are contrasted against each other, and those belonging to the same distribution or class are pulled together in the embedding space. In contrast, those belonging to different distributions are pushed against each other.
Therefore, contrastive learning is generally considered to be a form of self-supervised learning, because it does not require labeled data from external sources in order to train the model to predict the difference or relationship between two input items. It is often used for representation learning, where the goal is to learn useful and meaningful representations of the input data.
Image from Contrastive Self-Supervised Learning | Ankesh Anand.
In contrastive learning, the model is presented with pairs of items and is trained to predict whether the two items are related or not. For example, the model might be presented with pairs of images and asked to predict whether the images are of the same object or not. The model is then trained to minimize the error in its predictions by adjusting its internal representations of the input data.
In particular, contrastive learning can be used to learn features that are invariant to certain transformations, such as translation or rotation, which are important for recognizing objects in natural images.
Basically, contrastive learning tries to put similar things into the same basket and anything dissimilar not in that particular basket. This method is very similar to how humans understand the world. We don’t need to be shown every car in the world to identify a new car. We create some features associated with cars in our mind and anything that shows a similar feature is categorized as a car.
Positive and negative sample.
The basic principle behind contrastive learning is:
But how do we actually push and pull different samples? In this method,
For example, if we select an image of a human as the anchor, we can jitter the image or convert it to grayscale to use as the positive sample. The negative sample can be any other image in the dataset.
The framework of the instance discrimination-based contrastive learning.
Different types of image transformation:
This method breaks the single image into multiple patches of a fixed dimension (overlapping of patches is allowed). It uses the different parts of the same image as positive samples and other patches from different images are used as the negative samples.
This model is developed by Google Brain, it is a framework for contrastive learning of visual representations. Its basic working principle is to
The framework of the SimCLR method is shown below.
A simple framework for contrastive learning of visual representations. Two separate data augmentation operators are sampled from the same family of augmentations ($t \sim \mathcal{T}$ and $t' \sim \mathcal{T}$) and applied to each data example to obtain two correlated views. A base encoder network $f(\cdot)$ and a projection head $g(\cdot)$ are trained to maximize agreement using a contrastive loss. After training is completed, we throw away the projection head $g(\cdot)$ and use encoder $f(\cdot)$ and representation $h$ for downstream tasks.
A nicer illustration is as follows: Image from The Illustrated SimCLR Framework (amitness.com).
Data augmentation module: Transforms a given data sample (image) randomly to create two views of the same example ( $x{i}$ and $x{j}$ in the diagram above). These represent the positive pairs. The SimCLR framework applies the following three augmentations:
According to the results obtained by the authors, random cropping and color distortion are essential for achieving good performance.
The loss function used here is called Normalized Temperature-scaled Cross-Entropy or NT-Xent loss. It is a modification of the multi-class $N$-pair loss with an addition of the temperature ( $T$ ) parameter. In the multi-class N-pair loss sampling,
Image from The Illustrated SimCLR Framework (amitness.com).
It’s worth noting that NT-Xent is not directly related to the cosine transform. The cosine transform is often used in contrastive learning, a machine learning technique for training models to recognize similarities and differences between pairs of data points, but it is not a component of the NT-Xent loss function.
Image from The Illustrated SimCLR Framework (amitness.com).
Image from The Illustrated SimCLR Framework (amitness.com).
Let’s implement the contrastive learning to learn pixel-level features from the cifar10 dataset of images using TensorFlow:
import numpy as np
import tensorflow as tf
# Load the dataset of images
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# Preprocess the data
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
# Define the model architecture
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(32, 32, 3)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
# Define the contrastive loss function
def contrastive_loss(y_true, y_pred):
margin = 1
return tf.reduce_mean(y_true * tf.square(y_pred) +
(1 - y_true) * tf.square(tf.maximum(margin - y_pred, 0)))
# Compile the model with the contrastive loss function
model.compile(optimizer='adam', loss=contrastive_loss, metrics=['accuracy'])
# Define a function to generate pairs of images for training
def generate_pairs(x, y):
while True:
indices = np.random.permutation(len(x))
for i in range(0, len(x), 2):
a = x[indices[i]]
b = x[indices[i+1]]
yield ([a, b], [y[indices[i]], y[indices[i+1]]])
# Use the `fit` method to train the model on the generated pairs of images
model.fit(generate_pairs(x_train, y_train), epochs=5,
validation_data=generate_pairs(x_test, y_test))