OpenGVLab / unmasked_teacher

[ICCV2023 Oral] Unmasked Teacher: Towards Training-Efficient Video Foundation Models
https://arxiv.org/abs/2303.16058
MIT License
285 stars 15 forks source link

'pos_embed' key not found in downloadable checkpoints. #6

Closed zyoNoob closed 1 year ago

zyoNoob commented 1 year ago

I have been trying to use the finetuned model weights available in the repo to do inference on some videos. While trying to load the state_dict to the model, I encountered a log output : Weights of VisionTransformer not initialized from pretrained model: ['pos_embed']

So I checked the downloaded checkpoints' keys, to find out if pos_embed exsits, and it apparently doesn't. Is this the desired behavior, if yes then does the model still work fine without the pos_embed, or do I have to follow a separate procedure to get the pos_embed weights.

Any help is appreciated, thanks.

Andy1621 commented 1 year ago

Thanks for your question and sorry for the late response.

For your question, you should set use_learnable_pos_emb=False to use sine-cosine positional embeddings. Thus when you load the model weights, the Weights of VisionTransformer not initialized from pretrained model: ['pos_embed'] should not exist.

mustafahalimeh commented 1 week ago

I am also facing the same error _RuntimeError: Error(s) in loading state_dict for VisionTransformer: Missing key(s) in state_dict: "posembed" Althoug I passed use_learnable_pos_emb=False to create_model()


_model = create_model(
    'vit_large_patch16_224',
    pretrained=False, tubelet_size=1, use_learnable_pos_emb=False,
    num_classes=700, use_mean_pooling=True, all_frames=16)
_state = torch.load('l16_ptk710_ftk710_ftk700_f16_res224.pth')
_model = _model.load_state_dict(_state, strict=True)

Also not working when setting use_learnable_pos_emb=True.

mustafahalimeh commented 1 week ago

The error is gone by not using model.load_state_dict(_state) but the function in the utils of the repo: load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index") and use_learnable_pos_emb=True when creating the passed model to this function.

Sorry for missing that out.