KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.96k stars 656 forks source link

Contrastive loss using only indices_tuple without labels. #681

Closed fdila closed 9 months ago

fdila commented 9 months ago

Hi! I'm trying to use the contrastive loss from this library. To be specific I have 2 embedding tensors (emb0 and emb1), and I know the indexes of positive pairs between emb0 and emb1.

From my understanding I should have something like this, to use the loss without labels:

emb0 = torch.Tensor(torch.randn(batch, 64))
emb1 = torch.Tensor(torch.randn(batch, 64))

triplets = ???

loss_fn = losses.ContrastiveLoss(pos_margin=0, neg_margin=1)
loss = loss_fn(emb0, indices_tuple=triplets, ref_emb=emb1)

However, I have not been able to find in the documentation how the triplets should be formatted. The only references I've found are here: https://kevinmusgrave.github.io/pytorch-metric-learning/losses/

and here: https://kevinmusgrave.github.io/pytorch-metric-learning/extend/losses/#using-indices_tuple

I am constructing a triplets tuple like this:


a = [0, 1, 2, 3, 4]

p = [[0, 1], [1, 3], [1, 4], [0, 3], [1, 3]]

n = [[3], [0, 2, 3], [0, 2], [1], [2, 4]]

triplets = (a, p, n)

but I'm the loss function is throwing errors about wrong shapes in the triplets.

KevinMusgrave commented 9 months ago

Sorry about the poor documentation.

Here's maybe the easiest way to understand the format of (a,p,n):

x = torch.tensor([[0,0,3], [0,1,3]])
a = x[:,0]
p = x[:,1]
n = x[:,2]

In the above code, I've made a torch tensor where each row is a triplet.

(a,p,n) corresponds to the columns of that tensor.

fdila commented 9 months ago

Thanks!