facebookresearch / ijepa

Official codebase for I-JEPA, the Image-based Joint-Embedding Predictive Architecture. First outlined in the CVPR paper, "Self-supervised learning from images with a joint-embedding predictive architecture."
Other
2.83k stars 358 forks source link

Loading pre-trained model: state_dict key mismatch #34

Closed HuFY-dev closed 1 year ago

HuFY-dev commented 1 year ago

Firstly, thanks for the amazing work! I implemented my own code loading your pre-trained model, IN1K-vit.h.16-448px-300e.pth, and encountered this issue:

RuntimeError: Error(s) in loading state_dict for VisionTransformer:
    Missing key(s) in state_dict: "pos_embed", "patch_embed.proj.weight", "patch_embed.proj.bias", "blocks.0.norm1.weight", ......
        Unexpected key(s) in state_dict: "module.pos_embed", "module.patch_embed.proj.weight", "module.patch_embed.proj.bias", "module.blocks.0.norm1.weight", ......

I used the exact same model architecture in your vision_transformer.py file, and the problem was solved after I added this line before loading:

pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}

I wonder if there are any issues in your released model weights. Did you forget to update to the newest version?

MidoAssran commented 1 year ago

Hi @HuFY-dev, since the model was trained with PyTorch DistributedDataParallel (DDP) it implicitly wraps the model in extra module. parameter. Just putting that line before your data loading is the correct way to load the model without DDP.

HuFY-dev commented 1 year ago

Thank you!