mx-mark / VideoTransformer-pytorch

PyTorch implementation of a collections of scalable Video Transformer Benchmarks.
272 stars 34 forks source link

Errors when loading pretrained weights -pretrain_pth 'vivit_model.pth' -weights_from 'kinetics' #27

Open desti-nation opened 1 year ago

desti-nation commented 1 year ago

When I want to finetune my dataset based on pretrained kinetics vivit model, the errors occured. I am new to pytorch, may I know How could solve the following errors? Thanks.

command

python model_pretrain.py \
    -lr 0.001 -epoch 100 -batch_size 32 -num_workers 4  -frame_interval 16  \
    -arch 'vivit' -attention_type 'fact_encoder' -optim_type 'sgd' -lr_schedule 'cosine' \
    -objective 'supervised' -root_dir ./ \
    -gpus 0 -num_class 2 -img_size 50 -num_frames 13 \
    -warmup_epochs 5 \
    -pretrain_pth 'vivit_model.pth' -weights_from 'kinetics'

Errors:

RuntimeError: Error(s) in loading state_dict for ViViT:
File "/home/VideoTransformer-pytorch/weight_init.py", line 319, in init_from_kinetics_pretrain_
    msg = module.load_state_dict(state_dict, strict=False)
  File "/home/anaconda3/envs/pytorchvideo/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1407, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
        size mismatch for pos_embed: copying a param with shape torch.Size([1, 197, 768]) from checkpoint, the shape in current model is torch.Size([1, 10, 768]).
        size mismatch for time_embed: copying a param with shape torch.Size([1, 9, 768]) from checkpoint, the shape in current model is torch.Size([1, 7, 768]).
kennyckk commented 1 year ago

From the error, I can see that the mismatch of the pos embedding shape is not matching the one from pretrained paramters. And from your input the img size is 50. I guess you need to resize your video width and height to 224 which is where the K400 Vivit weights are trained on.