yaoyao-liu / meta-transfer-learning

TensorFlow and PyTorch implementation of "Meta-Transfer Learning for Few-Shot Learning" (CVPR2019)
https://lyy.mpi-inf.mpg.de/mtl/
MIT License
746 stars 149 forks source link

Question in pre-train phase: Why use meta-val set to select pre-train model? #28

Open Errorfinder opened 4 years ago

Errorfinder commented 4 years ago

Hello, thanks for your very enlightening work. However i found a spot that quiet confusing to me in your code, i describe it as follows:

It's clear said in paper, that in pre-train phase you want to train a feature extractor which has backbone of resnet on the training set of 64 classes in case of Mini-ImageNet. and i think that is basically a standard training procedure that we sample some of training set to train on and use the rest to perform validation, where here both training and val set are drawn from the training set of 64 classes in this case, as you said in paper for each class 550 for train and the rest 50 for validation.

However, in your code PreTrainer class, the validation part is done by 'n-way-k-class' tasks. This makes me very confusing, since i can't find any description to this part in the paper. And in my opinion, the pre-train phase is actually independent with later meta-training part. so what is the reason to do so? or i've missed some important info ? pls let me know, and thanks in advance!

yaoyao-liu commented 4 years ago

Hi @Errorfinder,

Thanks for your interest in our work. In pre-train phase, we aim to find the best model to initialize the backbone model for meta-train phase. The best model intuitively means the model with the highest meta-validation accuracy (we cannot access the meta-test set during pre-train phase, so we use meta-validation set). In this way, we don't need to split the 64-class images to train and validation sets. We can use all the 64-classes images for pre-training.

Besides, please note that the PyTorch version is not the code we used in the published paper. If you hope to reproduce the results for the paper, you may use the TensorFlow version.

Errorfinder commented 4 years ago

@yaoyao-liu Oh i see, that makes sense then. Thanks for your quick reply.

And i am really interested in this work, but i am not familiar with tensorflow. Since you mentioned that you use the results generated by tensorflow. Could you pls point out what are the majoy differences between both? I mean, are they just same content but built on different platform?

yaoyao-liu commented 4 years ago

Hi @Errorfinder,

The PyTorch version is built on FEAT, for some details (e.g. backbone, learning rate, optimizer, dataloader) we directly follow FEAT.

If you have any further questions, feel free to add more comments.

Sword-keeper commented 4 years ago

oh, I also confused about this when I tried to pretrain my model. This issue solved my question. Besides, I wonder that which methods' result is better?(standard validation or meta validation). I tried to modify your pytorch code. When I used the standard validation by pytorch, the best accuracy is 61%. I think the result is not very good. Maybe some problems in my code. Could you tell me the validation accuracy when you use the standard validation(tensorflow code)?

yaoyao-liu commented 4 years ago

Hi @Sword-keeper, The best meta-validation accuracy for our PyTorch implementation is around 63%. You may tried my best pre-trained model here: Google Drive.

For the sota methods, I recommand you check the public papers of recent machine learning and computer vision conferences (e.g. ICLR 2020, CVPR 2020) and the paperwithcode website.

Sword-keeper commented 4 years ago

Oh! Thank you very much! So this is the model's ( ResNetMtl(mtl=False) ) '.pth' file, right?

Another question, I found that the common radio of dataset is 8:2(trainset : testset) in the Internet. And with the increase of data, the trainset's radio is increase. As for fsl, is there any theory about this? In the TF code, the radio is 550:50.

yaoyao-liu commented 4 years ago

@Sword-keeper

  1. This is the final model for the pre-train phase.
  2. For the validation of the pre-train phase, I suggest you follow our PyTorch implementation. I think using meta-validation to select pre-train model is more reasonable.
Sword-keeper commented 4 years ago

I got it. Thank you very much!

Sword-keeper commented 4 years ago

Hi I'm so sry to bother you .Beacuse I got some problems when I loaded your .pth file. image And the code is error: image

I think that your model has changed . My code is different with you.

Sword-keeper commented 4 years ago

oh I found that you updated the 'ResNetMtl' 23days ago. It cause this problem

yaoyao-liu commented 4 years ago

@Sword-keeper,

Sorry. I'll check the pre-trained checkpoint file again to make sure that it matches with the current repo.