HobbitLong / SupContrast

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)
BSD 2-Clause "Simplified" License
3.08k stars 528 forks source link

Purpose of logits_mask? #137

Closed adv010 closed 10 months ago

adv010 commented 11 months ago

Hi @HobbitLong , thank you for the code! It has been 3 years since you released it, and yet I find this really tidy and readable. (something that newer papers fail to maintain ;) )!

My question -Is the purpose of logits_mask to just mask away diagonal elements and keep off-diagonal elements as 1?

I ran for a few scenarios, and this is what I understood. I am working on the anchor_count=all scenario: Perhaps, Instead of logits_mask = torch.scatter( torch.ones_like(mask),1,torch.arange(batch_size * anchor_count).view(-1, 1).to(device),0)

I would have preferred something like this:

` # logits_mask = 1 - torch.eye(labels.shape[0])

logits_mask = logits_mask.repeat(batch_size, anchor_count).to(device)`

adv010 commented 11 months ago

hi @HobbitLong , could you respond to this please? thanks!

adv010 commented 10 months ago

This is the purpose. The tiling prevents from executing what I had asked for.

Can confirm : logits_mask is used to zero out self_contrasiive cases

Aikoin commented 1 month ago

This is the purpose. The tiling prevents from executing what I had asked for.

Can confirm : logits_mask is used to zero out self_contrasiive cases

Hi there! I noticed the same issue as well. It seems that the code only excludes self-contrastive for negative samples. However, I’m curious about whether samples from the same class as the anchor are also considered negative samples. That seems a bit counterintuitive to me. What do you think? I'd love to hear your thoughts on this!