facebookresearch / swav

PyTorch implementation of SwAV https//arxiv.org/abs/2006.09882
Other
2.01k stars 280 forks source link

How to download a model with projection head #44

Closed bransGl closed 3 years ago

bransGl commented 3 years ago

Hellow. Thanks for the work. As far as I can see pretrained models in the end have fully counted dim =1000 . Shouldn't the projection head = 128 be there? I want to get embeddings , hiw to do this? Do you have appropriate pretraibed model with 128 x projection head?

mathildecaron31 commented 3 years ago

Hi @bransGl , The models here https://github.com/facebookresearch/swav#model-zoo are given with projection head. You can look at the list of keys by doing:

import torch
ckp = torch.hub.load_state_dict_from_url(MODEL_URL)

for k in ckp.keys():
    if "projection_head" in k or "prototypes" in k:
        print(k, ckp[k].shape)

where MODEL_URL can be "https://dl.fbaipublicfiles.com/deepcluster/swav_400ep_pretrain.pth.tar" for example.

RylanSchaeffer commented 2 years ago

@mathildecaron31 is there a way to get the projection heads with the recommended torch.hub.load?

Currently, when I do

model = torch.hub.load('facebookresearch/swav:main', model='resnet50')

the model appears to have nothing related to the projection head.

RylanSchaeffer commented 2 years ago

@bransGl how did you actually load the checkpoint into the model?