HobbitLong / SupContrast

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)
BSD 2-Clause "Simplified" License
2.98k stars 525 forks source link

Question about logits_mask #110

Open Alva-2020 opened 2 years ago

Alva-2020 commented 2 years ago

Thanks very much for your contributions. But I have some trouble when I read the code and paper. ` # mask-out self-contrast cases logits_mask = torch.scatter( torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0 )

    mask = mask * logits_mask

    # compute log_prob
    exp_logits = torch.exp(logits) * logits_mask
    #
    log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))`

In the above code, is the logits_mask used to filter out the negative pairs? Why is it a square matrix whose diagonal is 0 and 1 of other elements? I think the logits_mask = ~ mask, which is to flag which pair is negative? Thank you for your reply.

QishengL commented 2 years ago

0 because you don't want to use the inner product of itself. 1 means other positions are good to use. After mask * logits_mask. You get all the positions that are inner products of positive samples and exclude the sample itself.

What I do not understand is log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)). Why do we use logits to minus the log of the denominator of the equation in the paper? Did you figure it out?

fanchi commented 2 years ago

log(a/b) = log(a) - log(b)

shuaiNJU commented 1 year ago

0 because you don't want to use the inner product of itself. 1 means other positions are good to use. After mask * logits_mask. You get all the positions that are inner products of positive samples and exclude the sample itself.

What I do not understand is log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)). Why do we use logits to minus the log of the denominator of the equation in the paper? Did you figure it out?

hello,you said that "After mask * logits_mask. You get all the positions that are inner products of positive samples and exclude the sample itself.", however, each row of the mask matrix is a one-hot coding vector, that's means only an augmentation sample that is the same origin as anchor is used as positive, and no other augmentation samples that belong to the same label as anchor are used. So can you tell me which code shows other augmentation samples that belong to the same label as anchor are used as positives? Thanks a lot!

Arsiuuu commented 1 year ago

0 because you don't want to use the inner product of itself. 1 means other positions are good to use. After mask * logits_mask. You get all the positions that are inner products of positive samples and exclude the sample itself. What I do not understand is log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)). Why do we use logits to minus the log of the denominator of the equation in the paper? Did you figure it out?

hello,you said that "After mask * logits_mask. You get all the positions that are inner products of positive samples and exclude the sample itself.", however, each row of the mask matrix is a one-hot coding vector, that's means only an augmentation sample that is the same origin as anchor is used as positive, and no other augmentation samples that belong to the same label as anchor are used. So can you tell me which code shows other augmentation samples that belong to the same label as anchor are used as positives? Thanks a lot!

Did you resolve it? Looking forward to your reply.

jlliRUC commented 1 year ago

hello,you said that "After mask * logits_mask. You get all the positions that are inner products of positive samples and exclude the sample itself.", however, each row of the mask matrix is a one-hot coding vector, that's means only an augmentation sample that is the same origin as anchor is used as positive, and no other augmentation samples that belong to the same label as anchor are used. So can you tell me which code shows other augmentation samples that belong to the same label as anchor are used as positives? Thanks a lot!

For multi augmentions in unsupervised scenario or multi samples with same labels in supervised scenario, the mask matrix is not one-hot vector. It contains the label info of all positives.