KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.94k stars 657 forks source link

Another ntxent loss question #705

Closed Stas384 closed 1 month ago

Stas384 commented 1 month ago

I have labeled pairs of images, where 1 mean same class, and 0 otherwise. This pairs don’t have any anchors between each other. So, I want to use ntxent loss for this type of labeled data. How I can use this loss?

KevinMusgrave commented 1 month ago

If there is no overlap between samples in your positive pairs and negative pairs, then I don't think NTXentLoss will work. It depends on there being samples that have at least 1 positive and 1 negative in the batch.

You could try ContrastiveLoss. Here's the logic, but you'll want something more efficient:

from pytorch_metric_learning.losses import ContrastiveLoss

loss_fn = ContrastiveLoss(neg_margin=0.05)

a1, p, a2, n = [], [], [], []
embeddings = []
for idx, (label, pair) in enumerate(batch):
    if label == 1:
        a1.append(idx)
        p.append(idx+1)
    else:
        a2.append(idx)
        n.append(idx+1)

    embeddings.append(model(pair[0]))
    embeddings.append(model(pair[1]))

indices_tuple = torch.tensor(a1), torch.tensor(p), torch.tensor(a2), torch.tensor(n)
embeddings = torch.cat(embeddings, dim=0)
loss =  loss_fn(embeddings, indices_tuple=indices_tuple)

In the above, a1, p are the indices of the positive pairs, and a2, n are the indices of the negative pairs.

Stas384 commented 1 month ago

thank you very much for the answer! And if I move away from marking pairs as 0 and 1, and start assigning each picture its own class, and the positive pair is the one whose label 1 and label 2 are equal, will this loss also not work? In fact, this is what is written in “how is ntxent loss calculated?”, right? But then I have a problem, there will be a lot of classes, and if the dataset is shuffled, it is not at all a fact that positive pairs will appear.

KevinMusgrave commented 1 month ago

If you assign labels to samples, then all samples with different labels are assumed to be negative pairs. So the problem I mentioned above wouldn't occur.

You can use MPerClassSampler to ensure there are positive pairs in each batch.

Stas384 commented 1 month ago

Thanks!