KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
5.95k stars 657 forks source link

Performing one-shot process after training #343

Closed GCidd closed 3 years ago

GCidd commented 3 years ago

Greetings,

I have read the documentation and trained my model using the provided objects and so far so good. What I'm having trouble understanding from the Inference notebook is the following:

After training the model, how do you actually perform zero-shot or few-shot? I understand that the InferenceModel provides functions for finding matches etc. but it's confusing me regarding the step where you train the model's indexer (train_indexer). After training my model using MetricLossOnly class on dataset1, how do I perform a zero-shot procedure on dataset2 and evaluate the model's performance? Do I have to train the indexer of InferenceModel with dataset1 and then use the is_match method? or use the get_nearest_neighbors to make a prediction? In any case, how will I evaluate the model's performance on dataset2 using performance metrics such as precision, recall or f1-score where they need the prediction class and not just the is_match output?

On the other hand, after training my model with the TrainWithClassifier class on dataset1, the model performs very poorly during zero/few-shot process (f1-score performance is lower than 40% with 3 classes).

Regarding the models, the trunk is a pre-trained resnet18 and the embedder a two-layer MLP that is provided in the example notebooks.

KevinMusgrave commented 3 years ago

Here are some pointers. I'll write a more complete response later.

After training my model using MetricLossOnly class on dataset1, how do I perform a zero-shot procedure on dataset2 and evaluate the model's performance? Do I have to train the indexer of InferenceModel with dataset1 and then use the is_match method? or use the get_nearest_neighbors to make a prediction?

This issue might be related: https://github.com/KevinMusgrave/pytorch-metric-learning/issues/342

In any case, how will I evaluate the model's performance on dataset2 using performance metrics such as precision, recall or f1-score where they need the prediction class and not just the is_match output?

Have you looked at the testers and AccuracyCalculator?

And a related issue: https://github.com/KevinMusgrave/pytorch-metric-learning/issues/320

On the other hand, after training my model with the TrainWithClassifier class on dataset1, the model performs very poorly during zero/few-shot process (f1-score performance is lower than 40% with 3 classes).

Hard to know what could be causing low accuracy (it could just be a difficult dataset?). I would suggest starting with the simplest possible baseline, which would be MetricLossOnly + ContrastiveLoss, and tune the hyperparameters a bit.

GCidd commented 3 years ago

Thank you for replying!

This issue might be related: #342 Have you looked at the testers and AccuracyCalculator? And a related issue: #320

I should have mentioned that I have already made use of the AccuracyCalculator. Both of the cases come back to my original question regarding the get_nearest_neighbors usage. Does the choice of k in the specific method "rely" on me and how I choose which class the model predicted? For example taking as the prediction which to which class the majority of the nearest neighbors belong to (and as a result having to choose which k is the best)?

Also, I should mention that, if I train and validate the model only on dataset2, it performs better (80% accuracy). So, does this mean that dataset2 is not actually hard or that dataset1 is not "similar" to dataset2? (by similar meaning for example texture classification in my case)

KevinMusgrave commented 3 years ago

After training my model using MetricLossOnly class on dataset1, how do I perform a zero-shot procedure on dataset2 and evaluate the model's performance? Do I have to train the indexer of InferenceModel with dataset1 and then use the is_match method?

You should train the indexer on dataset1. Then for each image in dataset2, you can call get_nearest_neighbors. This will return the indices of the images in dataset1 that are most similar to the dataset2 image.

I should have mentioned that I have already made use of the AccuracyCalculator. Both of the cases come back to my original question regarding the get_nearest_neighbors usage. Does the choice of k in the specific method "rely" on me and how I choose which class the model predicted? For example taking as the prediction which to which class the majority of the nearest neighbors belong to (and as a result having to choose which k is the best)?

Currently, the only thing get_nearest_neighbors does is return the indices of the most similar images. It's up to you to use that information to predict the class of the query image. The value of k is also something you'll have to adjust to get the best performance.

It would be nice to have the class prediction built into InferenceModel, so I have created an issue for that (#344)

Also, I should mention that, if I train and validate the model only on dataset2, it performs better (80% accuracy). So, does this mean that dataset2 is not actually hard or that dataset1 is not "similar" to dataset2? (by similar meaning for example texture classification in my case)

If you train on dataset1 and then evaluate on dataset2, and the performance is 40%, then yes I would guess that dataset1 and dataset2 are not very similar. You could try training and evaluating only on dataset1, to see if dataset1 is difficult.

GCidd commented 3 years ago

Excuse me for my late reply, I was performing various tests with the MetricLossOnly trainer you suggested. From the results I have gathered it performs worse than the TrainWithClassifier case. Even during the training, the loss does not go below the margin value of 0.5 I have provided.

Dataset2 (which I am trying to one/fewshot by training on various similar datasets1) is quite small (3 unbalanced classes with total 900 samples) and the reason I am trying to perform a one/fewshot process on it by firstly training on similar datasets. If I train the model with a simple trainer and loss like you suggested, it overfits and for this reason I have been using (in both trainer cases) ZeroMeanRegularizer, a small learning rate with weight decay and an MPerClassSampler along with augmentation on the training dataset (only horizontal/vertical flips as I am doing texture analysis and I can't mess with the pixels). The training is done with TripletMarginLoss with CosineSimilarity and TripletMarginMiner using hard triplets. For the classifier (in the TrainWithClassifier case) I use a simple CrossEntropyLoss. All this in a 10fold cross validation technique.

By training and validating with TrainWithClassifier on dataset2 only, the best results I have gotten are about 80% accuracy, but when I train on dataset1 and oneshot/fewshot on dataset2 it performs poorly (40% mentioned before). This basically beats my original purpose of training on a similar, larger dataset and then performing one/fewshot.

Having said that, I have the following questions:

I understand that some questions are basic or hard to answer. I am trying to make sure I understand the whole processes of metric learning and one/fewshot correctly, that I am not doing anything wrong in the training process and that my conclusions based on the results are correct.

KevinMusgrave commented 3 years ago
  • In the case of TrainWithClassifier, can you perform a one/fewshot process? If yes, if my classifier has 3 outputs (as many as the classes I originally trained it on), do I have to train it following the same process as in transfer learning to do a fewshot process on dataset2 (re-train it on a small part of the dataset)?

What is your use-case?

  • Does dataset1 and dataset2 have to be similar to perform one/fewshot? If yes, how do I determine if two dataset are similar? Do they both have to be texture datasets? Or texture datasets of similar objects I am trying to predict?

Well the more similar they are, the easier it will be. I believe measuring dataset similarity is not a solved problem. You could try something simple like measuring the distance between the mean embedding of dataset1 and dataset2.

  • Does training and validating on dataset2 beat the purpose of Siamese Networks to solve a classification problem? By purpose i mean training on a dataset1 and then the model can perform good on dataset2.

If you're validating on dataset2, then you're using the images and labels from dataset2. From what I understand, you want to see how a model trained on dataset1 performs on a totally unseen dataset2, so validating on dataset2 is basically "cheating".

  • I have previously performed a transfer learning process for classification with larger models (Densenet169, Resnet152 and others) on the same dataset with the best model performing 82% f1-score. Does this mean that the dataset is hard? Even if it is hard, is the 40% f1 score performance on one/fewshot expected?

I don't have enough experience with this type of problem to know if that drop in accuracy is expected.

GCidd commented 3 years ago

Basically what I'm trying to do is overcome the issue I have with the small imbalanced dataset (without augmentation techniques). That's why I am trying to take advantage of Siamese networks, essentially train on a similar dataset1 and predict on dataset2. If it works, train on both datasets (or just the one) and then predict every new sample that comes.

My use-case eventually will be having dataset2 and predict every new sample that comes without a label. Does PyRetri you suggested fits my purpose?

If you're validating on dataset2, then you're using the images and labels from dataset2. From what I understand, you want to see how a model trained on dataset1 performs on a totally unseen dataset2, so validating on dataset2 is basically "cheating".

What I wanted to say is train on dataset1 and predict and evaluate the model on dataset2.

KevinMusgrave commented 3 years ago

My use-case eventually will be having dataset2 and predict every new sample that comes without a label. Does PyRetri you suggested fits my purpose?

What do you mean by "having dataset2"? Do you mean that the "new samples" will come from dataset2?

I haven't used PyRetri but it looks like a general purpose image retrieval toolbox.

What I wanted to say is train on dataset1 and predict and evaluate the model on dataset2.

If dataset2 is supposed to be totally unknown, then you should set it aside as a test set. You need a train/val/test split. You should evaluate on the val split to pick your best model, and then see how the best model performs on the test set.

GCidd commented 3 years ago

What do you mean by "having dataset2"? Do you mean that the "new samples" will come from dataset2?

Yes, the new samples will come from dataset2.

If dataset2 is supposed to be totally unknown, then you should set it aside as a test set. You need a train/val/test split. You should evaluate on the val split to pick your best model, and then see how the best model performs on the test set.

That's what I am doing essentially. If I'm not mistaken, the trainer classes that are provided do not expect a validation set (do they internally split the dataset into training and validation?). So I am training on dataset1 and testing on dataset2 with the results mentioned in my previous replies.

KevinMusgrave commented 3 years ago

That's what I am doing essentially. If I'm not mistaken, the trainer classes that are provided do not expect a validation set (do they internally split the dataset into training and validation?). So I am training on dataset1 and testing on dataset2 with the results mentioned in my previous replies.

Re: validation set with this libary's trainers, you have to specify that separately. (Take a look at this section in the MetricLossOnly example notebook)

Anyway, if you weren't using dataset2 for validation, then your setup seems fine. So training on dataset1 doesn't transfer well to dataset2, and I think you need to either: