facebookresearch / vicreg

VICReg official code base
MIT License
516 stars 87 forks source link

Can't load the full checkpoint #6

Closed lkshrsch closed 2 years ago

lkshrsch commented 2 years ago

I downloaded the available checkpoint for ResNet-50 through the provided link: https://dl.fbaipublicfiles.com/vicreg/resnet50_fullckpt.pth

But upon loading the checkpoint, following error appears:

>>> torch.load(checkpoint_path)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/user/anaconda3/envs/pytorch_env/lib/python3.8/site-packages/torch/serialization.py", line 594, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/home/user/anaconda3/envs/pytorch_env/lib/python3.8/site-packages/torch/serialization.py", line 853, in _load
    result = unpickler.load()
AttributeError: Can't get attribute 'exclude_bias_and_norm' on <module '__main__' (built-in)>

Same error happens with the other checkpoints. Is there something I am doing wrong? Appreciate the help!

Using pytorch version 1.7.1 and torchvision 0.8.2

Adrien987k commented 2 years ago

Hi,

Are you trying to load the checkpoint from the file "main_vicreg.py" of this repository, or from somewhere else ? If it is from somewhere else you need to define the "exclude_bias_and_norm" as follows:

def exclude_bias_and_norm(p):
    return p.ndim == 1

Please tell me if that works.

lkshrsch commented 2 years ago

Thank you for the help, that did work! You can close this issue now ;)

I was just downloading the checkpoint, but using my own code, to load the model, not the repository.

Adrien987k commented 2 years ago

Thanks!