HobbitLong / SupContrast

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)
BSD 2-Clause "Simplified" License
2.98k stars 525 forks source link

About the loss. Sincerely, I would like to ask: #116

Open Struggle-Forever opened 1 year ago

Struggle-Forever commented 1 year ago

The purpose of contrast loss is to minimize the positive sample distance while maximizing the negative sample distance. However, I only find minimizing the distance of positive samples in this loss, and I don't see maximizing the distance of negative samples? Can you tell me which codes achieve the maximum negative sample distance?

HobbitLong commented 1 year ago

I think this line does it.

Struggle-Forever commented 1 year ago

I think this line does it.

log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
The logits denote all samples' distances and torch.log(exp_logits.sum(1, keepdim=True)) denote the negative samples' distances .

The log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) denote the positive samples' distances and the mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) denotes the average positive loss .

What confuses me is that this feels like it only minimizes the positive sample distance. The loss of maximizing negative samples is not in the final loss.

I feel like I'm not understanding something, can you help me?

Struggle-Forever commented 1 year ago

I think this line does it. I still don't understand it. Please help me, thanks.

Struggle-Forever commented 1 year ago

I think this line does it. I still don't understand it. Please help me, thanks.

he log_prob denotes all samples' distances and the mask * log_prob can obtain the positive sample. This means let all the sample distances do the numerator/denominator , then get the loss of positive samples by mask. This time my understanding should be correct.

RizhaoCai commented 1 year ago

I also wonder that, at line image

Isn't that the distance between positive pairs are being the denominator because of exp_logits.sum(1, keepdim=True)?