facebookresearch / mae

PyTorch implementation of MAE https//arxiv.org/abs/2111.06377
Other
6.93k stars 1.17k forks source link

Error in loading pretrained weight for 'mae_vit_base_patch16' #180

Open nightrain-vampire opened 8 months ago

nightrain-vampire commented 8 months ago

I try to use mae_vit_base in the demo, but it reports:

RuntimeError                              Traceback (most recent call last)
/data/user3/zspace/Mcm/demo/mae_visualize.ipynb 单元格 9 line 8
      [5] get_ipython().system('wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_base.pth')
      [7] chkpt_dir = 'mae_visualize_vit_base.pth'
----> [8] model_mae = prepare_model(chkpt_dir, 'mae_vit_base_patch16')
      [9] print('Model loaded.')

/data/user3/zspace/Mcm/demo/mae_visualize.ipynb 单元格 9 line 1
     [17] # load model
     [18] checkpoint = torch.load(chkpt_dir, map_location='cpu')
---> [19] msg = model.load_state_dict(checkpoint['model'], strict=False)
     [20] print(msg)
     [21] return model

File [~/miniconda3/envs/mae/lib/python3.8/site-packages/torch/nn/modules/module.py:1671](https://vscode-remote+ssh-002dremote-002b10-002e176-002e62-002e136.vscode-resource.vscode-cdn.net/data/user3/zspace/Mcm/demo/~/miniconda3/envs/mae/lib/python3.8/site-packages/torch/nn/modules/module.py:1671), in Module.load_state_dict(self, state_dict, strict)
   1666         error_msgs.insert(
   1667             0, 'Missing key(s) in state_dict: {}. '.format(
   1668                 ', '.join('"{}"'.format(k) for k in missing_keys)))
   1670 if len(error_msgs) > 0:
-> 1671     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1672                        self.__class__.__name__, "\n\t".join(error_msgs)))
   1673 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for MaskedAutoencoderViT:
    size mismatch for pos_embed: copying a param with shape torch.Size([1, 197, 768]) from checkpoint, the shape in current model is torch.Size([1, 4097, 768]).
    size mismatch for decoder_pos_embed: copying a param with shape torch.Size([1, 197, 512]) from checkpoint, the shape in current model is torch.Size([1, 4097, 512])

My code is below:

# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)

# download checkpoint if not exist
# !wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth
!wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_base.pth

chkpt_dir = 'mae_visualize_vit_base.pth'
model_mae = prepare_model(chkpt_dir, 'mae_vit_base_patch16')
print('Model loaded.')

What's the matter with the pretrained model? I also tried 'mae_pretrain_vit_base_full.pth', but it reports the same error. Can anyone help?

MakoOfficial commented 8 months ago

According to the error report,it seems that something is wrong with your model parameters. The checkpoint's image_size should be 224, and it's patch_size is 16, so the shape of pos_embed is [1, 197, 768]. But your init-model pos_embed is [1, 4097, 768].So I guess that your change the parameter "imgae_size" from 224 into 1024 with the "patch_size" remained 16.Maybe restoring the super-parameter would be a solution.

sALTaccount commented 6 months ago

I'm also unable to get the weights to load :/