AntreasAntoniou / HowToTrainYourMAMLPytorch

The original code for the paper "How to train your MAML" along with a replication of the original "Model Agnostic Meta Learning" (MAML) paper in Pytorch.
https://arxiv.org/abs/1810.09502
Other
773 stars 137 forks source link

argument `training` hard set to `True` in call to F.batch_norm #3

Closed simonguiroy closed 5 years ago

simonguiroy commented 5 years ago

It should be training=training instead of training=True, since here even at inference time, the BN function will use data statistics instead of moving statistics.

https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/blob/09dbacc7415bb0659f84f40d5c264179d80a9bd6/meta_neural_network_architectures.py#L269

AntreasAntoniou commented 5 years ago

In Pytorch, using training=True, does not use the data stats, but actually updates the moving stats using the data stats and uses those updated moving stats to apply BN.

Now there is another question to answer. Isn't it cheating for me to store info that came from validation data? Yes, it could be. But it is done in a way that prevents me from cheating. In meta learning, you want to maximize the information you extract from a given task, to do well in that task. If one were to only use the saved running stats without using the summary statistics from the data of a particular task, they would be preventing their model from potentially using that information to perform better. So I instead allow evaluation data stats to be used to update the running stats that are then used when applying batch norm on the task's target set. However, once a task has been completed (i.e. inner loop steps + inference on target task), I reset the running stats their state before any validation data was used to update them. Using this little trick, I allow my model to use data-distribution information from a new task at inference time without 'contaminating' the saved stats with validation stats.

namsan96 commented 5 years ago

Hi, I also thought the training argument should not be set to always True. If the referred line line is only used to save data stats, then where is the part that actually use it?

jfb54 commented 4 years ago

In Pytorch, using training=True, does not use the data stats, but actually updates the moving stats using the data stats and uses those updated moving stats to apply BN.

This is not correct. When calling F.batch_norm with training=True, the batch statistics will be used to normalize, not the running statistics. The running statistics will only be used to normalize when training=False. You are correct in saying that setting training=True will update the running stats, but it won't apply them unless training=False. Thus, the code as written will not utilize Per-Step Batch Normalization Running Statistics (BNRS) as described in the paper.