KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
6k stars 658 forks source link

Question/Suggestions on Integrating Metric Learning to Similarity Problem #328

Closed Madhvi19 closed 2 years ago

Madhvi19 commented 3 years ago

Hi @KevinMusgrave I have gone through the documentation as well as some notebooks but I am still at loss.

Problem: Given two data points (p1, p2), I need to identify if they are similar or not. I have a list of N samples which I know are positive. One data point can be related to multiple data points: p1, p2 - 1 p1, p3 - 1 p2, p3 - 1 . .

If I build negative pairs offline, I can have pairs of <anchors, positive> with label 1 and <anchor, negatives> as label 0. Then I can use TripletMarginLoss but how would I use online or mix mining approaches? From documentation, following seem to be the flow where triplets/pairs are created within the batch based on the labels we pass.

for i, (data, labels) in enumerate(dataloader):
    optimizer.zero_grad()
    embeddings = model(data)
    hard_pairs = miner(embeddings, labels)
    loss = loss_func(embeddings, labels, hard_pairs)
    loss.backward()
    optimizer.step()

However, in my case, even though labels are 0/1, they don't represent the classes as in MNIST or CIFAR. So, how would triplets or pair be formed by miner here? More concretely, what do I pass to loss_func, miner and sampler?

I have used Siamese Networks to solve this problem but don't know how to integrate components of this library to it. Are there other approaches too? I think, I might have to use custom miner and/or NPairsLoss but again I am not able to figure out :confused:

Any suggestions and pointers on how to implement this would be really helpful. Thank you.

KevinMusgrave commented 3 years ago

This issue might be related: https://github.com/KevinMusgrave/pytorch-metric-learning/issues/263

Let me know if that helps or not.

KevinMusgrave commented 3 years ago

That issue (#263) has code for converting similarity labels (1s and 0s) to "class" labels that can be used by the loss functions and miners.