AntreasAntoniou / HowToTrainYourMAMLPytorch

The original code for the paper "How to train your MAML" along with a replication of the original "Model Agnostic Meta Learning" (MAML) paper in Pytorch.
https://arxiv.org/abs/1810.09502
Other
759 stars 137 forks source link

Backup running statistics is incorrect #42

Open denizetkar opened 2 years ago

denizetkar commented 2 years ago

When MetaBatchNormLayer is called forward with backup_running_statistics=True, the running statistics are meant to be copied into the backup variables by copying:

https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/blob/86964701651c06e490271365fedb8789a72ef3f0/meta_neural_network_architectures.py#L241

However, this is not what happens and the underlying data ends up being tied with each other such that when the running statistics get updated, so does the backup. Here is a short code snippet that minimally reproduces the same behavior:

import torch as th
from copy import copy

t = th.tensor([1, 2, 3])
t2 = th.empty_like(t)
t2.data = copy(t.data)
t[0] = 5
assert t2[0] == 5

Is my understanding of backing up the running statistics wrong or is this a bug that needs fixing?

AntreasAntoniou commented 2 years ago

I'll need some time to look into this, but I do believe that the behaviour I currently have coded is intentional.

I want the running statistics to be updated within a given episode, but then scrapped at the end if this isn't a training iteration.

DubiousCactus commented 2 years ago

@denizetkar is right, that's not your intended behaviour and the backup is overwritten by the validation pass. You simply need to use deepcopy instead.