cambridge-mlg / cnaps

Code for: "Fast and Flexible Multi-Task Classification Using Conditional Neural Adaptive Processes" and "TaskNorm: Rethinking Batch Normalization for Meta-Learning"
MIT License
159 stars 22 forks source link

Meta-BN Implementation #16

Open dvtailor opened 10 months ago

dvtailor commented 10 months ago

Hi, is there a Torch implementation of Meta-BN available as introduced in https://arxiv.org/pdf/2003.03284.pdf

jfb54 commented 10 months ago

Here is a class that implements Meta-BN that is derived from the base class NormalizationLayer in https://github.com/cambridge-mlg/cnaps/blob/master/src/normalization_layers.py#L5.

Apologies, somehow the code formating below is wacy.

class MetaBN(NormalizationLayer): """MetaBN Normalization Layer""" def init(self, num_features): """ Initialize :param num_features: number of channels in the 2D convolutional layer """ super(MetaBN, self).init(num_features)

Variables to store the context moments to use for normalizing the target.

    self.context_batch_mean = torch.zeros((1, num_features, 1, 1), requires_grad=True)
    self.context_batch_var = torch.ones((1, num_features, 1, 1), requires_grad=True)

def forward(self, x):
    """
    Normalize activations.
    :param x: input activations
    :return: normalized activations
    """
    if self.training:  # normalize the context and save off the moments
        batch_mean, batch_var = self._compute_batch_moments(x)
        x = self._normalize(x, batch_mean, batch_var)
        self.context_batch_mean = batch_mean
        self.context_batch_var = batch_var
    else:  # normalize the target with the saved moments
        x = self._normalize(x, self.context_batch_mean, self.context_batch_var)

    return x
dvtailor commented 10 months ago

Thanks! I had an issue when using this in combination with copy.deepcopy(model) which is common in MAML implementations when performing finetuning. It leads to the following error:

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

I can get around this by deleting the batch moment attributes prior to a deepcopy, e.g.:

for m in model.modules():
    if isinstance(m, MetaBN):
        delattr(m, 'context_batch_mean')
        delattr(m, 'context_batch_var')

But this does make execution time longer. Also do you see any issues with this?

jfb54 commented 10 months ago

If you delete them before the copy, you would need to add them back afterwards as these parameters hold crucial state for normalizing the target set with the statistics of the context set.