facebookresearch / mae

PyTorch implementation of MAE https//arxiv.org/abs/2111.06377
Other
7.33k stars 1.22k forks source link

Importing finetuned MAE into another project #27

Open TizianThieringer opened 2 years ago

TizianThieringer commented 2 years ago

I am trying to load a MAE-model into another project. I use the Vision Transformer class from https://github.com/facebookresearch/mae/blob/main/models_vit.py#L56 (I also tried the mae class from https://github.com/facebookresearch/mae/blob/main/models_mae.py#L223).

I tried loading the model using the class and then loading the weights from https://github.com/facebookresearch/mae/blob/main/FINETUNE.md

However with my code it does not seem to load the weights from the checkpoint. When printing the model before and after loading the checkpoint the weights haven't changed. Does someone have an idea why this happens?

Here is my implementation:

#Create model
model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
                      norm_layer=partial(nn.LayerNorm, eps=1e-6))
#Load checkpoint
checkpoint = torch.load('PATH/TO/mae_finetuned_vit_base.pth', map_location='cpu')
checkpoint_model = checkpoint['model']

#load state dict

model.load_state_dict(checkpoint_model, strict=False)`

Edit: simplified code and clarified my error

Songloading commented 2 years ago

Try to use MAE to load check point. I got pretrained weights.

TizianThieringer commented 2 years ago

Thanks for your reply. I thought I was loading the pretrained weights using torch.load() am I missing something here? I use the pretrained weights from https://github.com/facebookresearch/mae/blob/main/FINETUNE.md as stated above.

allezsyh commented 2 years ago

Thanks for your reply. I thought I was loading the pretrained weights using torch.load() am I missing something here? I use the pretrained weights from https://github.com/facebookresearch/mae/blob/main/FINETUNE.md as stated above.

I solved this issue in following two methods: (1) `model.load_state_dict(checkpoint_model, strict=False)`` (2) Using the visualize checkpoints, e.g., mae_visualize_vit_base.pth for mae_vit_large_patch16