sicara / easy-few-shot-learning

Ready-to-use code and tutorial notebooks to boost your way into few-shot learning for image classification.
MIT License
1.06k stars 144 forks source link

Question on meta-training in the tutorial notebook #16

Closed karndeepsingh closed 2 years ago

karndeepsingh commented 3 years ago

Hi, Thanks for making such a simple and beautiful library for Few-Shot Learning. I have a query when we run a particular cell from your notebook for training meta-learning model, does it also train the ResNet18 Model on the given Dataset for generating a better representation of Feature Images like we do it while we do transfer learning when we train classifier model on our custom dataset using Imagenet pre-trained parameters or Does it only trains Prototype network?

Please, clarify this doubt. Thanks again.

ebennequin commented 3 years ago

Hi!

The only trainable parameters in Prototypical Networks are those of its backbone (a ResNet18 in the notebook), so training the Prototypical Networks is equivalent to training the ResNet.

In the notebook, when we execute the meta-training, we start from a ResNet18 pretrained on ImageNet:

We will keep the same model. So our weights will be pre-trained on ImageNet. If you want to start a training from scratch, feel free to set pretrained=False in the definition of the ResNet.

The difference with transfer learning is that the classes of the Omniglot dataset that are seen during meta-training are not the same as the classes on which we will test our meta-trained model.

Did I answer your question?

karndeepsingh commented 3 years ago

@ebennequin Thanks for anwering. I changed the architecture from Resnet18 to Resnet 152 and I started getting CUDA memory Error. Below is the image.

Please clarify one more thing, If have already trained Classifier model on my usecase can I use that inplace of Resnet 18 ? MicrosoftTeams-image (2)

ebennequin commented 3 years ago

You can use your own nn.Module object as the backbone argument when initialiazing a PrototypicalNetworks object (source code). The only additional requirement (if I'm not mistaken) is that the output of the backbone for a single image must be 1-dimensional. I guess it must be the case with your custom classifier.

As for your error: keep in mind that when you're meta-training, your batch size is a function of the number of classes, the number of support images per class and the number of query images per class (batch_size = n_way * (n_support + n_query)). In the notebook, this gives a batch size of 75. Your GPU is too small to train a 60M parameter model like the ResNet152 with a batch size of 75. You can choose a smaller model, and/or reduce the batch size (via n_support or n_query) and/or your image size.