AntreasAntoniou / HowToTrainYourMAMLPytorch

The original code for the paper "How to train your MAML" along with a replication of the original "Model Agnostic Meta Learning" (MAML) paper in Pytorch.
773 stars 137 forks source link

Bias in final layer #23

Open bkj opened 5 years ago

bkj commented 5 years ago

Hi --

This is a really great implementation and improvement of MAML!

I'm curious about whether it's actually a good idea to let the network meta-learn a nonzero bias initialization for the linear classifier head. Since a support/target set could be fed to the model under any permutation, I would think that it doesn't make sense to favor one of the output classes over another. Any thoughts on this?

In a very similar vein, I'm wondering whether it would make sense to tie the weights of the linear classifier head. That is, for a 20-way classification problem, instead of having:

classifier.layer_dict.linear.weights = random_init(64, 20)

you'd have something like

classifier.layer_dict.linear.weights = random_init(64, 1).repeat((1, 20))

I did a simple experiment where I took a trained Omniglot-20way-1shot MAML model and added the following code to MAMLFewShotClassifier.load_model

with torch.no_grad():
    # zero bias

    # replace linear weights w/ average linear weight
    w = state_dict_loaded['classifier.layer_dict.linear.weights'].clone()
    w = w.mean(dim=0, keepdims=True).expand_as(w)

This bumped accuracy from 0.943 to 0.953, which is neat. This is just anecdotal evidence for now though, since it's tested with only a single model.

I'm wondering whether you're able to send me all of your pretrained models somehow, so I can run this experiment across all of the various configurations that you report in the paper? It would be cool if this trick lead to improved performance across the board.

(Then, of course, the next step would be to train models end-to-end w/ these modifications and see if that gives a slight bump.)

Thanks! ~ Ben

AntreasAntoniou commented 5 years ago

You bring up a very interesting point. Initializing all bias params to zero accelerates training of MAML++, because it's closer to a good solution. Your idea lies in a similar vein. Let me know if this improves any of the Mini-ImageNet results.

Unfortunately, I no longer have the full pretrained models. But, I could rerun them if you want. That is, if you are lacking the compute. I could also provide some of the 'High-End MAML++' models as described in my Learning to Learn via Self-Critique paper.