AidenDurrant / MoCo-Pytorch

An unofficial Pytorch implementation of "Improved Baselines with Momentum Contrastive Learning" (MoCoV2) - X. Chen, et al.
68 stars 10 forks source link

Hello, i am wondering the accuarcy on cifar-10 #1

Open vmmm123 opened 3 years ago

vmmm123 commented 3 years ago

It's strange that when i rerun the official version of moco I can't obtain a great result. It usually obtains around 73% accuarcy at the last epoch, and even the highest accuarcy is just 76%. It didn't make sense because simclr or byol can obtain at least 90% accuarcy. So i am wonderring your results. Could you please post it firstly?

a411919924 commented 3 years ago

I hypothesize the criticial issue is the inconsistency of negative keys caused by large queue size. Different from the huge ImageNet dataset, the training size of CIFAR10 is only 50,000. With default K=65536, the maintained queue in moco contains 65536/50000 = 1.31 epoch of training data. Meaning that a negative key will be dequeued after 1.31 epoch, that is too later if the key is out-of-date. While for ImageNet, the number is about 65536/14,000,000 = 0.0005, relatively more quickly for an out-of-date negative key to be replaced by a new one. Hence, the inconsistency might be alleviated by decreasing the K (like 4096) and the learning rate.

Hope this help.

k-stacke commented 3 years ago

There seem to be an error in your code when running on multiple GPUs (have not tried on distributed system). Loading the pretrained weigths in load_moco() assumes that the parameters to use when training base_model starts with "encoder_q". This is not the case when using nn.DataParallel, which adds the prefix "module" to all parameter names. Therefore, none of the trained parameters will be used when finetuning, only randomly initialized weights. Changing the code in load_moco() to

 # rename moco pre-trained keys
    state_dict = checkpoint['moco']
    for k in list(state_dict.keys()):
        # retain only encoder_q up to before the embedding layer
        if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
            # remove prefix
            state_dict[f'module.{k[len("module.encoder_q."):]}'] = state_dict[k]
        # delete renamed or unused k
        del state_dict[k]

will make it work in the multiGPU case (will however break the code in the single GPU setting). Perhaps with solves your problem @vmmm123 ?