KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.94k stars 657 forks source link

Suggestion on loss function and other combinations #663

Closed jainrahulsethi closed 10 months ago

jainrahulsethi commented 11 months ago

I want to learn an embedding model so that it can be used to compare if two images are of the same object possibly taken from a different angle. I have a dataset of several images of several objects and I want to learn this model. Which loss function/distance/reducer combination etc. should be used here? If someone can throw some light, I would really appreciate.

KevinMusgrave commented 10 months ago

Apologies for the late response. I don't think any loss function will have a particular advantage in this case. I would just start with ContrastiveLoss or NTXentLoss.

However, your dataset sounds very small, so it might be difficult to obtain a useful model.

Maybe feature matching could help. Here's an example using the kornia library: https://github.com/kornia/kornia-examples/blob/master/image-matching-example.ipynb

You'd have to figure out how to compute the confidence of the matching, to determine if the images are a match.

jainrahulsethi commented 10 months ago

Firstly, Thanks for the awesome work!! I have a dataset of nearly 70k images and each object has nearly 5-6 images. Can I not learn a Resnet based feature extractor with a combination of triplet loss/contrastive loss/arc face loss etc... to eventually be able to learn a good embedding model for my use case. Also, if I intend to use the arcface loss, how do i decide on the number of classes.

KevinMusgrave commented 10 months ago

Oh that's a big enough dataset. I would start with NTXentLoss with the default settings, because it should give reasonable results.

For ArcFaceLoss, the num_classes should be the number of distinct objects in your dataset. For example, if every object has exactly 5 images, then you must have 70k/5 objects which is 14000 objects, so you would set num_classes=14000.

jainrahulsethi commented 10 months ago

Thanks.. However, if I set the number of classes to 14000, will it generalize to new objects that it has not seen during training? If yes, then why exactly the number 14k was important in the first place? Also, NTXentLoss - is it supposed to give me the best accuracy? For me to be able to compare 2 images and find out if they represent the exact same object with a very high accuracy is important

KevinMusgrave commented 10 months ago

Thanks.. However, if I set the number of classes to 14000, will it generalize to new objects that it has not seen during training? If yes, then why exactly the number 14k was important in the first place?

I can't say for sure if it will generalize. But it was used very successfully to generalize to new faces for the task of face recognition.

The exact number of classes is important because it's basically a modified classification loss. But it's a classification loss that's designed to be good at separating classes in the embedding space.

Also, NTXentLoss - is it supposed to give me the best accuracy? For me to be able to compare 2 images and find out if they represent the exact same object with a very high accuracy is important

It's a very widely-used loss function. I don't know whether it will give you the best accuracy. That will require experimentation. In my experience MultiSimilarityLoss also gives good performance.