RElbers / info-nce-pytorch

PyTorch implementation of the InfoNCE loss for self-supervised learning.
MIT License
445 stars 37 forks source link

Why do you use cross entropy ? #9

Closed shamanez closed 1 year ago

shamanez commented 1 year ago

Is it same as the following INFo-NCE loss?

image

RElbers commented 1 year ago

Yes, its the same (see https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html). The cross entropy function is usually used for classification problems, where you want your model to predict the correct class out of K classes for each sample. The InfoNCE loss can also be interpreted as a kind of classification, where you have K-1 negative classes and 1 positive class.

shamanez commented 1 year ago

Ok got it.

Just one more thing, I guess this loss is the same as the NT-Xet loss or MultipleNegative Ranking loss in Sentece Transformers right?

RElbers commented 1 year ago

I'm not familiar with MultipleNegativesRankingLoss, but it looks the same. I often see the NT-Xent loss and Info-NCE being used interchangeably, but I believe technically the difference is that NT-Xent excludes the positive sample from the denominator whereas Info-NCE includes it. In the papers there is some variation in how they get the embeddings and which (combination of) samples they use for the loss, but functionally it essentially comes down to doing softmax on embeddings.

shamanez commented 1 year ago

Thanks a lot. In one tutorial I’ve seen the difference in NT-dent is we use softmax scores and a temperature.

https://youtu.be/iqzJybIk4Go?t=1408