Open dvtailor opened 10 months ago
Here is a class that implements Meta-BN that is derived from the base class NormalizationLayer in https://github.com/cambridge-mlg/cnaps/blob/master/src/normalization_layers.py#L5.
Apologies, somehow the code formating below is wacy.
class MetaBN(NormalizationLayer): """MetaBN Normalization Layer""" def init(self, num_features): """ Initialize :param num_features: number of channels in the 2D convolutional layer """ super(MetaBN, self).init(num_features)
self.context_batch_mean = torch.zeros((1, num_features, 1, 1), requires_grad=True)
self.context_batch_var = torch.ones((1, num_features, 1, 1), requires_grad=True)
def forward(self, x):
"""
Normalize activations.
:param x: input activations
:return: normalized activations
"""
if self.training: # normalize the context and save off the moments
batch_mean, batch_var = self._compute_batch_moments(x)
x = self._normalize(x, batch_mean, batch_var)
self.context_batch_mean = batch_mean
self.context_batch_var = batch_var
else: # normalize the target with the saved moments
x = self._normalize(x, self.context_batch_mean, self.context_batch_var)
return x
Thanks! I had an issue when using this in combination with copy.deepcopy(model)
which is common in MAML implementations when performing finetuning.
It leads to the following error:
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
I can get around this by deleting the batch moment attributes prior to a deepcopy, e.g.:
for m in model.modules():
if isinstance(m, MetaBN):
delattr(m, 'context_batch_mean')
delattr(m, 'context_batch_var')
But this does make execution time longer. Also do you see any issues with this?
If you delete them before the copy, you would need to add them back afterwards as these parameters hold crucial state for normalizing the target set with the statistics of the context set.
Hi, is there a Torch implementation of Meta-BN available as introduced in https://arxiv.org/pdf/2003.03284.pdf