facebookresearch / moco

PyTorch implementation of MoCo: https://arxiv.org/abs/1911.05722
MIT License
4.83k stars 794 forks source link

Buffer of BN in EMA update #70

Closed LeeDoYup closed 4 years ago

LeeDoYup commented 4 years ago

Hello. Thanks for the awesome project ! I have a question.

I wonder why the EMA update doesn't track the running mean and variance of BN.

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

I think below codes are right, because the ema model has poor performance when the running variables in BN are not tracked.

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

        # buffer update
        for buffer_q, buffer_k in zip(self.encoder_q.buffers(), self.encoder_k.buffers()):
            buffer_k.copy_(buffer_q)
ppwwyyxx commented 4 years ago

Because the running variables in the momentum encoder are never used in training or fine-tuning in this codebase.

If you need to use them for your project, you can include them in the update.

LeeDoYup commented 4 years ago

Thanks for the quick reply ! I want to check the below thinkgs.

ResNet-50 uses batch normalization and the default setting uses the running varibles. So the momentum encoder uses them. However, i think it also calculate when the momentum encoder uses forward(im_k).