HobbitLong / SupContrast

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)
BSD 2-Clause "Simplified" License
3.04k stars 525 forks source link

Regarding the "for numerical stability" block #65

Open AlexMoreo opened 3 years ago

AlexMoreo commented 3 years ago

I got confused with:

       # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

Since the max across each row in anchor_dot_contrast will be attained by what you later refer to as "self-contrast cases". These will always be equal to 1/temperature since the features are expected to be l2-normalized. Would it not be better to perform this operation after the mask has been applied? Thanks for the great work.

yash0307 commented 2 years ago

@AlexMoreo, did you try this? Does it help?

AlexMoreo commented 2 years ago

Hi @yash0307! No, sorry, I don't remember to have ever tried this. Let us know if this happens to work better!

TonyWangX commented 2 years ago

(new comer to this field, and I don't assume the feature is l2-normalized)

A toy example: consider logits [0, 0, 300, 500], logits_mask [1, 1, 1, 0].

implementation 1: No measure to ensure stability, obviously will not work

>> torch.exp(torch.tensor([0, 0, 300, 500])) * torch.tensor([1, 1, 1, 0])
tensor([1., 1., inf, nan])

implementation 2: Take the logits_max before applying logits_mask, in which case logits_max = 500

>> torch.exp(torch.tensor([0-500, 0-500, 300-500, (500-500)])) * torch.tensor([1, 1, 1, 0])
tensor([0., 0., 0., 0.])

This approach ensures that the exp(logits) before masking will not explode, but logits that will not be masked may become too small.

implementation 3: I guess what @AlexMoreo suggested will give logits_max = 300, then

>> torch.exp(torch.tensor([0-300, 0-300, 300-300, (500-300)])) * torch.tensor([1, 1, 1, 0])
tensor([0., 0., 1., nan])

We get nan here.
exp() is applied to all the logits, including those to be masked. If we don't get logit_max over all the logits, we may see the nan.

implementation 4: For me, this seems to work:

  1. get log_mask after applying mask
  2. apply mask inside exp(logits * mask)
    torch.exp(torch.tensor([0-300, 0-300, 300-300, (500-300)*0])) * torch.tensor([1, 1, 1, 0])
    tensor([0., 0., 1., 0.])

This toy example is an extreme case.

I think when features are l2-normalized, such an extreme case may never happen.

adv010 commented 10 months ago

@TonyWangX great explanation! Have you considered the reason for normalizing the projection features and the encoder features? From the codebase I'm uncertain whether the projections features do get normalized or not.