facebookresearch / suncet

Code to reproduce the results in the FAIR research papers "Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples" https://arxiv.org/abs/2104.13963 and "Supervision Accelerates Pre-training in Contrastive Semi-Supervised Learning of Visual Representations" https://arxiv.org/abs/2006.10803
MIT License
486 stars 67 forks source link

difference between suncet loss and supervised contrastive loss #18

Closed mingkai-zheng closed 3 years ago

mingkai-zheng commented 3 years ago

What is the difference between the SuNCEt loss and supervised contrastive loss (L_{in} version) ? It looks like they are the same thing right ?

MidoAssran commented 3 years ago

Hi @KyleZheng1997

They are similar, but the L_{in} loss has an extra 1/P(i) term inside the logarithm that re-weights each class based on the number of instances belonging to that class.

mingkai-zheng commented 3 years ago

ummmm, interesting, I'm a bit confused why this loss function can be well optimized

L_{SuNCEt} = -log[ sum(exp(S_pos)) / ( sum(exp(S_neg)) + sum(exp(S_pos)) ) ] = log[ 1 + sum(exp(S_neg)) / sum(exp(S_pos)) ] = log[ 1 + exp( log[sum(exp(S_neg))] - log[sum(exp(S_pos))] ) ] ≈ [LogSumExp(S_neg) - LogSumExp(S_pos)] ≈ [max(S_neg) - max(S_pos)]

S_pos and S_neg denotes the cosine similarity between the positive pairs and negative pairs As we can see the loss is actually trying to optimize [max(S_neg) - max(Spos)], which means the similarity of the most similar positive pair has to be greater than the most similar negative pair. This is a typical "easy negative mining". From my understanding, the L{SuNCEt} should quickly become ineffective since not all S_pos have to be greater than max(S_neg).

On the other hand, "hard negative mining" is always considered to be much more effective in metric learning. For example, the circle loss is trying to optimize [max(S_neg) - min(S_pos)], where it requires the similarity of the least similar positive pair has to be greater than the most similar negative pair, which sounds much more reasonable.

I'm not quite sure if there is any misunderstanding, would you like to give some more explanation about how this loss works. I will be really appreciated it.

MidoAssran commented 3 years ago

Hi @KyleZheng1997 ,

I agree with much of your intuition, but let me clarify a few points:

For the SuNCEt loss, you can think of the arguments to the log function as a probability distribution over the classes. A sample is classified correctly if its representation is close to at least one positive sample (where closeness here is measured relative to the negative samples), which I think is somewhat similar to your intuition about the most similar positive pair being the important component in the loss. As you have surmised as well, this actually makes the task easier, which means the loss is easier to optimize. It is also worth noting that to completely minimize the loss, as you can see before your last approximation step, the model needs to make the query point similar to many positive samples, not just the max, so you still get something similar to the hard-negative mining behaviour you were describing.

Of course one can construct other contrastive losses using similar ideas to the one you mentioned, that make the task harder, and there could be value in doing so, for example, the hard-negative vs easy-negative mining argument is compelling. On the flip side, this would also make the loss harder to optimizer and potentially less robust to noisy or misclassified examples/outliers in S_pos.