htdt / hyp_metric

Hyperbolic Vision Transformers: Combining Improvements in Metric Learning | Official repository
https://arxiv.org/abs/2203.10833
MIT License
184 stars 20 forks source link

What is the reason for concatenating two logit matrices in contrastive loss? #4

Closed ghost closed 2 years ago

ghost commented 2 years ago

Hi,

Thanks for such an interesting technique. I have a question about Contrastive Loss. Could you please explain the rationale behind the following code snippet -

eye_mask = torch.eye(bsize).cuda() * 1e9
logits00 = dist_f(x0, x0) / tau - eye_mask
logits01 = dist_f(x0, x1) / tau
logits = torch.cat([logits01, logits00], dim=1)
logits -= logits.max(1, keepdim=True)[0].detach()

I understand this resembles InfoNCE Loss but I am unable to understand why would you concat two logits and then take max of the resultant matrix.

Thanks

htdt commented 2 years ago

Hi, the batch consists of two subsets - x0 and x1. Adding logits00 after logits01 means using pairs from x0 as negatives. The same approach you can see here https://github.com/google-research/simclr/blob/master/objective.py#L83 Subtraction of the max on the last line does not change the result of cross_entropy, it's only for numerical stability.