mx-mark / VideoTransformer-pytorch

PyTorch implementation of a collections of scalable Video Transformer Benchmarks.
272 stars 34 forks source link

Missing keys in demo notebook #24

Closed Simcs closed 2 years ago

Simcs commented 2 years ago

Hi, thank you for sharing your work.

When I follow the instructions in the notebook file (VideoTransformer_demo.ipynb), I got trouble loading the pre-trained weights of ViViT model.

After downloading and placing the "./vivit_model.pth" file, I was able to instantiate the ViViT model. However, the log says that there are many missing keys in the given pth file.

Is it the desired behavior? or should I do some preprocessing to match the parameter name?

This is the output after parameter loading.

load model finished, the missing key of transformer is:['transformer_layers.0.layers.0.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.0.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.0.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.0.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.1.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.1.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.1.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.1.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.2.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.2.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.2.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.2.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.3.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.3.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.3.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.3.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.4.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.4.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.4.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.4.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.5.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.5.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.5.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.5.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.6.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.6.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.6.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.6.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.7.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.7.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.7.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.7.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.8.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.8.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.8.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.8.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.9.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.9.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.9.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.9.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.10.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.10.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.10.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.10.attentions.0.attn.proj.bias', 'transformer_layers.0.layers.11.attentions.0.attn.qkv.weight', 'transformer_layers.0.layers.11.attentions.0.attn.qkv.bias', 'transformer_layers.0.layers.11.attentions.0.attn.proj.weight', 'transformer_layers.0.layers.11.attentions.0.attn.proj.bias', 'transformer_layers.1.layers.0.attentions.0.attn.qkv.weight', 'transformer_layers.1.layers.0.attentions.0.attn.qkv.bias', 'transformer_layers.1.layers.0.attentions.0.attn.proj.weight', 'transformer_layers.1.layers.0.attentions.0.attn.proj.bias', 'transformer_layers.1.layers.1.attentions.0.attn.qkv.weight', 'transformer_layers.1.layers.1.attentions.0.attn.qkv.bias', 'transformer_layers.1.layers.1.attentions.0.attn.proj.weight', 'transformer_layers.1.layers.1.attentions.0.attn.proj.bias', 'transformer_layers.1.layers.2.attentions.0.attn.qkv.weight', 'transformer_layers.1.layers.2.attentions.0.attn.qkv.bias', 'transformer_layers.1.layers.2.attentions.0.attn.proj.weight', 'transformer_layers.1.layers.2.attentions.0.attn.proj.bias', 'transformer_layers.1.layers.3.attentions.0.attn.qkv.weight', 'transformer_layers.1.layers.3.attentions.0.attn.qkv.bias', 'transformer_layers.1.layers.3.attentions.0.attn.proj.weight', 'transformer_layers.1.layers.3.attentions.0.attn.proj.bias'], cls is:[]

Thank you in advance!

+edit) FYI, these are the unexpected keys from the load_state_dict(). transformer unexpected: ['cls_head.weight', 'cls_head.bias', 'transformer_layers.0.layers.0.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.0.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.0.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.0.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.1.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.1.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.1.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.1.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.2.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.2.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.2.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.2.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.3.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.3.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.3.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.3.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.4.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.4.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.4.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.4.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.5.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.5.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.5.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.5.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.6.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.6.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.6.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.6.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.7.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.7.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.7.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.7.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.8.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.8.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.8.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.8.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.9.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.9.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.9.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.9.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.10.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.10.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.10.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.10.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.11.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.11.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.11.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.11.attentions.0.attn.out_proj.bias', 'transformer_layers.1.layers.0.attentions.0.attn.in_proj_weight', 'transformer_layers.1.layers.0.attentions.0.attn.in_proj_bias', 'transformer_layers.1.layers.0.attentions.0.attn.out_proj.weight', 'transformer_layers.1.layers.0.attentions.0.attn.out_proj.bias', 'transformer_layers.1.layers.1.attentions.0.attn.in_proj_weight', 'transformer_layers.1.layers.1.attentions.0.attn.in_proj_bias', 'transformer_layers.1.layers.1.attentions.0.attn.out_proj.weight', 'transformer_layers.1.layers.1.attentions.0.attn.out_proj.bias', 'transformer_layers.1.layers.2.attentions.0.attn.in_proj_weight', 'transformer_layers.1.layers.2.attentions.0.attn.in_proj_bias', 'transformer_layers.1.layers.2.attentions.0.attn.out_proj.weight', 'transformer_layers.1.layers.2.attentions.0.attn.out_proj.bias', 'transformer_layers.1.layers.3.attentions.0.attn.in_proj_weight', 'transformer_layers.1.layers.3.attentions.0.attn.in_proj_bias', 'transformer_layers.1.layers.3.attentions.0.attn.out_proj.weight', 'transformer_layers.1.layers.3.attentions.0.attn.out_proj.bias']

classification head unexpected: ['cls_token', 'pos_embed', 'time_embed', 'patch_embed.projection.weight', 'patch_embed.projection.bias', 'transformer_layers.0.layers.0.attentions.0.norm.weight', 'transformer_layers.0.layers.0.attentions.0.norm.bias', 'transformer_layers.0.layers.0.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.0.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.0.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.0.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.0.ffns.0.norm.weight', 'transformer_layers.0.layers.0.ffns.0.norm.bias', 'transformer_layers.0.layers.0.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.0.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.0.ffns.0.layers.1.weight', 'transformer_layers.0.layers.0.ffns.0.layers.1.bias', 'transformer_layers.0.layers.1.attentions.0.norm.weight', 'transformer_layers.0.layers.1.attentions.0.norm.bias', 'transformer_layers.0.layers.1.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.1.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.1.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.1.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.1.ffns.0.norm.weight', 'transformer_layers.0.layers.1.ffns.0.norm.bias', 'transformer_layers.0.layers.1.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.1.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.1.ffns.0.layers.1.weight', 'transformer_layers.0.layers.1.ffns.0.layers.1.bias', 'transformer_layers.0.layers.2.attentions.0.norm.weight', 'transformer_layers.0.layers.2.attentions.0.norm.bias', 'transformer_layers.0.layers.2.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.2.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.2.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.2.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.2.ffns.0.norm.weight', 'transformer_layers.0.layers.2.ffns.0.norm.bias', 'transformer_layers.0.layers.2.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.2.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.2.ffns.0.layers.1.weight', 'transformer_layers.0.layers.2.ffns.0.layers.1.bias', 'transformer_layers.0.layers.3.attentions.0.norm.weight', 'transformer_layers.0.layers.3.attentions.0.norm.bias', 'transformer_layers.0.layers.3.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.3.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.3.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.3.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.3.ffns.0.norm.weight', 'transformer_layers.0.layers.3.ffns.0.norm.bias', 'transformer_layers.0.layers.3.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.3.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.3.ffns.0.layers.1.weight', 'transformer_layers.0.layers.3.ffns.0.layers.1.bias', 'transformer_layers.0.layers.4.attentions.0.norm.weight', 'transformer_layers.0.layers.4.attentions.0.norm.bias', 'transformer_layers.0.layers.4.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.4.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.4.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.4.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.4.ffns.0.norm.weight', 'transformer_layers.0.layers.4.ffns.0.norm.bias', 'transformer_layers.0.layers.4.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.4.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.4.ffns.0.layers.1.weight', 'transformer_layers.0.layers.4.ffns.0.layers.1.bias', 'transformer_layers.0.layers.5.attentions.0.norm.weight', 'transformer_layers.0.layers.5.attentions.0.norm.bias', 'transformer_layers.0.layers.5.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.5.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.5.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.5.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.5.ffns.0.norm.weight', 'transformer_layers.0.layers.5.ffns.0.norm.bias', 'transformer_layers.0.layers.5.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.5.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.5.ffns.0.layers.1.weight', 'transformer_layers.0.layers.5.ffns.0.layers.1.bias', 'transformer_layers.0.layers.6.attentions.0.norm.weight', 'transformer_layers.0.layers.6.attentions.0.norm.bias', 'transformer_layers.0.layers.6.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.6.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.6.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.6.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.6.ffns.0.norm.weight', 'transformer_layers.0.layers.6.ffns.0.norm.bias', 'transformer_layers.0.layers.6.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.6.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.6.ffns.0.layers.1.weight', 'transformer_layers.0.layers.6.ffns.0.layers.1.bias', 'transformer_layers.0.layers.7.attentions.0.norm.weight', 'transformer_layers.0.layers.7.attentions.0.norm.bias', 'transformer_layers.0.layers.7.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.7.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.7.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.7.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.7.ffns.0.norm.weight', 'transformer_layers.0.layers.7.ffns.0.norm.bias', 'transformer_layers.0.layers.7.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.7.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.7.ffns.0.layers.1.weight', 'transformer_layers.0.layers.7.ffns.0.layers.1.bias', 'transformer_layers.0.layers.8.attentions.0.norm.weight', 'transformer_layers.0.layers.8.attentions.0.norm.bias', 'transformer_layers.0.layers.8.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.8.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.8.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.8.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.8.ffns.0.norm.weight', 'transformer_layers.0.layers.8.ffns.0.norm.bias', 'transformer_layers.0.layers.8.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.8.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.8.ffns.0.layers.1.weight', 'transformer_layers.0.layers.8.ffns.0.layers.1.bias', 'transformer_layers.0.layers.9.attentions.0.norm.weight', 'transformer_layers.0.layers.9.attentions.0.norm.bias', 'transformer_layers.0.layers.9.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.9.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.9.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.9.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.9.ffns.0.norm.weight', 'transformer_layers.0.layers.9.ffns.0.norm.bias', 'transformer_layers.0.layers.9.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.9.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.9.ffns.0.layers.1.weight', 'transformer_layers.0.layers.9.ffns.0.layers.1.bias', 'transformer_layers.0.layers.10.attentions.0.norm.weight', 'transformer_layers.0.layers.10.attentions.0.norm.bias', 'transformer_layers.0.layers.10.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.10.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.10.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.10.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.10.ffns.0.norm.weight', 'transformer_layers.0.layers.10.ffns.0.norm.bias', 'transformer_layers.0.layers.10.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.10.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.10.ffns.0.layers.1.weight', 'transformer_layers.0.layers.10.ffns.0.layers.1.bias', 'transformer_layers.0.layers.11.attentions.0.norm.weight', 'transformer_layers.0.layers.11.attentions.0.norm.bias', 'transformer_layers.0.layers.11.attentions.0.attn.in_proj_weight', 'transformer_layers.0.layers.11.attentions.0.attn.in_proj_bias', 'transformer_layers.0.layers.11.attentions.0.attn.out_proj.weight', 'transformer_layers.0.layers.11.attentions.0.attn.out_proj.bias', 'transformer_layers.0.layers.11.ffns.0.norm.weight', 'transformer_layers.0.layers.11.ffns.0.norm.bias', 'transformer_layers.0.layers.11.ffns.0.layers.0.0.weight', 'transformer_layers.0.layers.11.ffns.0.layers.0.0.bias', 'transformer_layers.0.layers.11.ffns.0.layers.1.weight', 'transformer_layers.0.layers.11.ffns.0.layers.1.bias', 'transformer_layers.1.layers.0.attentions.0.norm.weight', 'transformer_layers.1.layers.0.attentions.0.norm.bias', 'transformer_layers.1.layers.0.attentions.0.attn.in_proj_weight', 'transformer_layers.1.layers.0.attentions.0.attn.in_proj_bias', 'transformer_layers.1.layers.0.attentions.0.attn.out_proj.weight', 'transformer_layers.1.layers.0.attentions.0.attn.out_proj.bias', 'transformer_layers.1.layers.0.ffns.0.norm.weight', 'transformer_layers.1.layers.0.ffns.0.norm.bias', 'transformer_layers.1.layers.0.ffns.0.layers.0.0.weight', 'transformer_layers.1.layers.0.ffns.0.layers.0.0.bias', 'transformer_layers.1.layers.0.ffns.0.layers.1.weight', 'transformer_layers.1.layers.0.ffns.0.layers.1.bias', 'transformer_layers.1.layers.1.attentions.0.norm.weight', 'transformer_layers.1.layers.1.attentions.0.norm.bias', 'transformer_layers.1.layers.1.attentions.0.attn.in_proj_weight', 'transformer_layers.1.layers.1.attentions.0.attn.in_proj_bias', 'transformer_layers.1.layers.1.attentions.0.attn.out_proj.weight', 'transformer_layers.1.layers.1.attentions.0.attn.out_proj.bias', 'transformer_layers.1.layers.1.ffns.0.norm.weight', 'transformer_layers.1.layers.1.ffns.0.norm.bias', 'transformer_layers.1.layers.1.ffns.0.layers.0.0.weight', 'transformer_layers.1.layers.1.ffns.0.layers.0.0.bias', 'transformer_layers.1.layers.1.ffns.0.layers.1.weight', 'transformer_layers.1.layers.1.ffns.0.layers.1.bias', 'transformer_layers.1.layers.2.attentions.0.norm.weight', 'transformer_layers.1.layers.2.attentions.0.norm.bias', 'transformer_layers.1.layers.2.attentions.0.attn.in_proj_weight', 'transformer_layers.1.layers.2.attentions.0.attn.in_proj_bias', 'transformer_layers.1.layers.2.attentions.0.attn.out_proj.weight', 'transformer_layers.1.layers.2.attentions.0.attn.out_proj.bias', 'transformer_layers.1.layers.2.ffns.0.norm.weight', 'transformer_layers.1.layers.2.ffns.0.norm.bias', 'transformer_layers.1.layers.2.ffns.0.layers.0.0.weight', 'transformer_layers.1.layers.2.ffns.0.layers.0.0.bias', 'transformer_layers.1.layers.2.ffns.0.layers.1.weight', 'transformer_layers.1.layers.2.ffns.0.layers.1.bias', 'transformer_layers.1.layers.3.attentions.0.norm.weight', 'transformer_layers.1.layers.3.attentions.0.norm.bias', 'transformer_layers.1.layers.3.attentions.0.attn.in_proj_weight', 'transformer_layers.1.layers.3.attentions.0.attn.in_proj_bias', 'transformer_layers.1.layers.3.attentions.0.attn.out_proj.weight', 'transformer_layers.1.layers.3.attentions.0.attn.out_proj.bias', 'transformer_layers.1.layers.3.ffns.0.norm.weight', 'transformer_layers.1.layers.3.ffns.0.norm.bias', 'transformer_layers.1.layers.3.ffns.0.layers.0.0.weight', 'transformer_layers.1.layers.3.ffns.0.layers.0.0.bias', 'transformer_layers.1.layers.3.ffns.0.layers.1.weight', 'transformer_layers.1.layers.3.ffns.0.layers.1.bias', 'norm.weight', 'norm.bias']

realgump commented 2 years ago

Have you set _weightsfrom? I have met the same error and fixed it after set _weightsfrom.

mx-mark commented 2 years ago

@Simcs Sorry for that, it should be the reason as @realgump said where we have updated the weights initialization function. But we have not keep path with it in notebook.

Simcs commented 2 years ago

Thank you for the answer! After following @realgump 's suggestion, I could load the ViViT parameters normally.

For those who want to know the detailed instructions, I did the following steps.

  1. import ViViT model from original repo
    from video transformer import ViViT
  2. set _weightsfrom param to 'kinetics' when initializing ViViT model
    pretrain_pth = './vivit_model.pth'
    num_frames = num_frames * 2
    frame_interval = frame_interval // 2
    model = ViViT(num_frames=num_frames,
                  img_size=224,
                  patch_size=16,
                  embed_dims=768,
                  in_channels=3,
                  attention_type='fact_encoder',
                  return_cls_token=True,
                  weights_from='kinetics',
                  pretrain_pth=pretrain_pth)
  3. remove the remaining initialization code
    # comment out init_from_pretrain_() method
    cls_head = ClassificationHead(num_classes=num_class, in_channels=768)
    # msg_trans = init_from_pretrain_(model, pretrain_pth, init_module='transformer')
    msg_cls = init_from_pretrain_(cls_head, pretrain_pth, init_module='cls_head')

Thanks again for the quick response!

zz0126 commented 5 months ago

我想问问您成功加载vivit-b预训练权重了吗?