HobbitLong / SupContrast

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)
BSD 2-Clause "Simplified" License
3.1k stars 532 forks source link

About args n_views #95

Open xianxuan-z opened 2 years ago

xianxuan-z commented 2 years ago

I am really confused about : features: hidden vector of shape [bsz, n_views, ...]. What does n_views mean? May I get a example to understand it? Many Thanks to everybody!

towzeur commented 2 years ago

Hi, n_views is the number of "view" used in the batch. Contrastive methods based on joint-embeding use a sample image multiple time (in fact n_view times) . So if you have bsz image X, you will generate n_view (2) augmented batch using the same augmentation pipeline : X_a1, X_a2.

Then you feed forward X_a1, X_a2 into your model to produce the two views Z_a1 and Z_a2. shape of Z_a1 and Z_a2 : (bsz, f_dim)

if you stack them in axis=1 you will have a tensor of shape (bsz, n_views, ...)

Contrastive learning require positive pairs and negative pairs. Negative: "and negative pairs are formed by the anchor and randomly chosen samples from the minibatch"

Positive: In the supervised setting (supcon), you have access to ground truth labels; You can form positive pair by taking two view from the same semantic class.

In the Self-supervised setting (unsupervised) positive pairs are taken from two view coming from the same image. ie.:

          X[i]
       /        \
X_a1[i]      X_a2[i]
   |              |
Z_a1[i]      Z_a2[i]

(Z_a1[i], Z_a2[i]) is a positive pair