learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.61k stars 350 forks source link

F1 Score Curve During Meta Train/Validation #377

Closed furkanpala closed 1 year ago

furkanpala commented 1 year ago

Hi, I am training VGG11 on a custom image dataset for 3-way 5-shot image classification using MAML. I am encapsulating the whole VGG11 model with MAML, i.e., not just the classification head. My hyperparameters are as follows:

During the training, I noticed that after taking the first outer-loop optimization step, i.e., AdamW.step(), loss skyrockets to very large values, like ten thousands. Is this normal? Also, I am measuring the micro F1 score as accuracy metric of which curve for meta training/validation is as follows: image

It is fluctuating too much in my opinion, is this normal?

Thanks.

furkanpala commented 1 year ago

I figured it out. I was using VGG11 with vanilla BatchNorm layers from PyTorch which was not working properly in meta training setup. I removed BatchNorm layers and now it works as expected. Thanks...