AntreasAntoniou / HowToTrainYourMAMLPytorch

The original code for the paper "How to train your MAML" along with a replication of the original "Model Agnostic Meta Learning" (MAML) paper in Pytorch.
https://arxiv.org/abs/1810.09502
Other
759 stars 137 forks source link

line 244 in meta_neural_network_architectures.py #24

Open RuohanW opened 4 years ago

RuohanW commented 4 years ago

Thank you for releasing the code.

I notice that the function def forward(self, input, num_step, params=None, training=False, backup_running_statistics=False)

has a training indicator. However, within the function (line 244):

output = F.batch_norm(input, running_mean, running_var, weight, bias,
                              training=True, momentum=momentum, eps=self.eps)

should the training be always set to true? Does this affect the reported results in the original paper, as batch norm per step appears to be an important trick for improving maml from the paper?

Many thanks.

jfb54 commented 4 years ago

This is the same as issue: https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/3

My understanding is that this will affect the results reported in the paper. The code as written will always use the batch statistics, not a running average accumulated per step.

AntreasAntoniou commented 4 years ago

What you stated was what I thought was the case. However, after doing a few tests I found that what I stated previously was the right way to go. Check for yourself.

On Fri, 10 Apr 2020 at 16:11, jfb54 notifications@github.com wrote:

This is the same as issue #3 https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/3:

3 https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/3

My understanding is that this will affect the results reported in the paper. The code as written will always use the batch statistics, not a running average accumulated per step.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/24#issuecomment-612071685, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACSK4NVZNHHH3IKFGAQK4XDRL4ZJZANCNFSM4I6H2MBQ .

jfb54 commented 4 years ago

Thanks for the quick response! The following is a short script that demonstrates my assertion. If you have tests that show otherwise, it would be great to see them.

import torch
import torch.nn.functional as F

N = 64  # batch size
C = 16  # number of channels
H = 32  # image height
W = 32  # image width
eps = 1e-05

input = 10 * torch.randn(N, C, H, W)  # create a random input

running_mean = torch.zeros(C)  # set the running mean for all channels to be 0
running_var = torch.ones(C)  # set the running var for all channels to be 1

# Call batch norm with training=False. Expect that the input is normalized with the running mean and running variance
output = F.batch_norm(input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05)

# Assert that the output is equal to the input
assert torch.allclose(input, output)

# Call batch norm with training=True. Expect that the input is normalized with batch statistics of the input.
output_bn = F.batch_norm(input, running_mean, running_var, weight=None, bias=None, training=True, momentum=0.1, eps=eps)

# Normalize the input manually
batch_mean = torch.mean(input, dim=(0, 2, 3), keepdim=True)
batch_var = torch.var(input, dim=(0, 2, 3), keepdim=True)
output_manual = (input - batch_mean) / torch.sqrt(batch_var + eps)

# Assert that output_bn equals output_manual
assert torch.allclose(output_bn, output_manual)
AntreasAntoniou commented 4 years ago

I can definitely confirm that it was the case back in 2018. I'll need to reconfirm with the latest versions of pytorch. Will come back to you soon.

On Fri, 10 Apr 2020 at 17:43, jfb54 notifications@github.com wrote:

Thanks for the quick response! The following is a short script that demonstrates my assertion. If you have tests that show otherwise, it would be great to see them.

import torch import torch.nn.functional as F

N = 64 # batch size C = 16 # number of channels H = 32 # image height W = 32 # image width eps = 1e-05

input = 10 * torch.randn(N, C, H, W) # create a random input

running_mean = torch.zeros(C) # set the running mean for all channels to be 0 running_var = torch.ones(C) # set the running mean for all channels to be 1

Call batch norm with training=False. Expect that the input is normalized with the running mean and running variance

output = F.batch_norm(input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05)

Assert that the output is equal to the input

assert torch.allclose(input, output)

Call batch norm with training=True. Expect that the input is normalized with batch statistics of the input.

output_bn = F.batch_norm(input, running_mean, running_var, weight=None, bias=None, training=True, momentum=0.1, eps=eps)

Normalize the input manually

batch_mean = torch.mean(input, dim=(0, 2, 3), keepdim=True) batch_var = torch.var(input, dim=(0, 2, 3), keepdim=True) output_manual = (input - batch_mean) / torch.sqrt(batch_var + eps)

Assert that output_bn equals output_manual

assert torch.allclose(output_bn, output_manual)

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/24#issuecomment-612114523, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACSK4NVCK3KIKFFIOQ65YHTRL5EE5ANCNFSM4I6H2MBQ .