Closed Stas384 closed 1 month ago
If there is no overlap between samples in your positive pairs and negative pairs, then I don't think NTXentLoss will work. It depends on there being samples that have at least 1 positive and 1 negative in the batch.
You could try ContrastiveLoss. Here's the logic, but you'll want something more efficient:
from pytorch_metric_learning.losses import ContrastiveLoss
loss_fn = ContrastiveLoss(neg_margin=0.05)
a1, p, a2, n = [], [], [], []
embeddings = []
for idx, (label, pair) in enumerate(batch):
if label == 1:
a1.append(idx)
p.append(idx+1)
else:
a2.append(idx)
n.append(idx+1)
embeddings.append(model(pair[0]))
embeddings.append(model(pair[1]))
indices_tuple = torch.tensor(a1), torch.tensor(p), torch.tensor(a2), torch.tensor(n)
embeddings = torch.cat(embeddings, dim=0)
loss = loss_fn(embeddings, indices_tuple=indices_tuple)
In the above, a1, p
are the indices of the positive pairs, and a2, n
are the indices of the negative pairs.
thank you very much for the answer! And if I move away from marking pairs as 0 and 1, and start assigning each picture its own class, and the positive pair is the one whose label 1 and label 2 are equal, will this loss also not work? In fact, this is what is written in “how is ntxent loss calculated?”, right? But then I have a problem, there will be a lot of classes, and if the dataset is shuffled, it is not at all a fact that positive pairs will appear.
If you assign labels to samples, then all samples with different labels are assumed to be negative pairs. So the problem I mentioned above wouldn't occur.
You can use MPerClassSampler to ensure there are positive pairs in each batch.
Thanks!
I have labeled pairs of images, where 1 mean same class, and 0 otherwise. This pairs don’t have any anchors between each other. So, I want to use ntxent loss for this type of labeled data. How I can use this loss?