vturrisi / solo-learn

solo-learn: a library of self-supervised methods for visual representation learning powered by Pytorch Lightning
MIT License
1.43k stars 186 forks source link

Error in loading state_dict for VisionTransformer #218

Closed ChintanTrivedi closed 2 years ago

ChintanTrivedi commented 2 years ago

I ran the pretraining code on custom dataset with vit_small backbone with other input arguments same as in the cifar example.

Now, I'm trying to load the saved checkpoint using the following:-

ckpt_path = './byol-vit-carla-ep=49.ckpt'

from solo.utils.backbones import vit_small
backbone = vit_small()
state = torch.load(ckpt_path)["state_dict"]
for k in list(state.keys()):
    if "backbone" in k:
        state[k.replace("backbone.", "")] = state[k]
    del state[k]
backbone.load_state_dict(state, strict=False)

However, I'm facing this issue in loading back the weights:-

    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for VisionTransformer:
    size mismatch for pos_embed: copying a param with shape torch.Size([1, 5, 384]) from checkpoint, the shape in current model is torch.Size([1, 197, 384]).

Do you know where the pos_embed dimension mismatch is arising from? Is it the input size of the image (crop_size=32)?

vturrisi commented 2 years ago

Hey @ChintanTrivedi, I think we never tried ViT with cifar, it's probably some argument that's missing or some extra parsing that's needed. I'll check it between today and tomorrow and get back to you.

vturrisi commented 2 years ago

@ChintanTrivedi just tried it, you are missing --crop_size 32 in your linear script. Our different backbones are thin wrappers around timm. For ViT, we need to provide it with the img_size (we used --crop_size as parameter for this because it's more consistent with solo). Two other things, if you try the default linear evaluation (100 epochs, etc.) you will likely get worse performance than the online linear eval for cifar datasets. The other thing is that you should probably play around with --patch_size as the default of 16 is probably too big for cifar.

If you encounter any more issues, please re-open the issue.

ChintanTrivedi commented 2 years ago

Yes, that was the issue. I was passing crop_size=32 as the input parameter to vit_small() instead of img_size=32. Thanks!