HobbitLong / SupContrast

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)
BSD 2-Clause "Simplified" License
2.98k stars 525 forks source link

Question about potentially incomplete usage of 'logits_mask' #131

Open ruohuali opened 1 year ago

ruohuali commented 1 year ago

Hi @HobbitLong, Thanks for this great work! My question is that in these two lines of loss https://github.com/HobbitLong/SupContrast/blob/331aab5921c1def3395918c05b214320bef54815/losses.py#L88 https://github.com/HobbitLong/SupContrast/blob/331aab5921c1def3395918c05b214320bef54815/losses.py#L89 the logit_mask are only applied to exp_logits and not logits themselves. I cannot figure out the reason for this mathematically, can you please shed some light?

forgotton-wind commented 1 year ago

I have also been studying the code for this loss today, and I have come to understand that it works like this: logits_mask is used to get the denominator(positives and negatives). mask is used to get the numerator(positives). You can see that in line 92, mask is applied to log_prob. https://github.com/HobbitLong/SupContrast/blob/331aab5921c1def3395918c05b214320bef54815/losses.py#L92