KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.96k stars 656 forks source link

What loss is suitable for one anchor, multiple positive and multiple negative? #693

Closed ImmortalSdm closed 4 months ago

KevinMusgrave commented 5 months ago

Apologies for the late reply.

You can use the concept of ref_emb to separate anchors from positives and negatives.

For example, using ContrastiveLoss:

from pytorch_metric_learning.losses import ContrastiveLoss

loss_fn = ContrastiveLoss()

# anchors has shape NxD
# anchor_labels has shape N
# ref_emb has shape MxD
# ref_labels has shape M
loss = loss_fn(anchors, anchor_labels, ref_emb=ref_emb, ref_labels=ref_labels)

Positive pairs will be formed by embeddings in anchors and ref_emb that have the same label. Negative pairs will be formed by embeddings in anchors and ref_emb that have different labels.

You can have multiple positive pairs and negative pairs for any of the embeddings in anchors. In the extreme case, you could have a single embedding in anchors (shape 1xD), and many positive and negative embeddings in ref_emb.