wei-tim / YOWO

You Only Watch Once: A Unified CNN Architecture for Real-Time Spatiotemporal Action Localization
840 stars 161 forks source link

Is it possible to finetune with new classes of a custom UCF101 dataset? #59

Open chuong opened 3 years ago

chuong commented 3 years ago

Currently an error occurs if I try to finetune YOWO with new classes (3 in this case) with a custom UCF101 dataset:

$ python train.py --dataset ucf101-24 --data_cfg cfg/ucf24.data --cfg_file cfg/ucf24.cfg --n_classes 3 --backbone_3d resnext101 --backbone_3d_weights weights/resnext-101-kinetics.pth --backbone_2d darknet --backbone_2d_weights weights/yolo.weights --resume_path weights/yowo_ucf101-24_16f_best.pth 
Namespace(backbone_2d='darknet', backbone_2d_weights='weights/yolo.weights', backbone_3d='resnext101', backbone_3d_weights='weights/resnext-101-kinetics.pth', begin_epoch=1, cfg_file='cfg/ucf24.cfg', data_cfg='cfg/ucf24.data', dataset='ucf101-24', end_epoch=25, evaluate=False, freeze_backbone_2d=False, freeze_backbone_3d=False, n_classes=3, resume_path='weights/yowo_ucf101-24_16f_best.pth')
DataParallel(
  (module): YOWO(
    (backbone_2d): Darknet(
...................
    (conv_final): Conv2d(1024, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
)
===================================================================
loading checkpoint weights/yowo_ucf101-24_16f_best.pth
Traceback (most recent call last):
  File "train.py", line 109, in <module>
    model.load_state_dict(checkpoint['state_dict'])
  File "/apps/pytorch/1.4.0-py36-cuda10/lib/python3.6/site-packages/torch/nn/modules/module.py", line 830, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DataParallel:
    size mismatch for module.conv_final.weight: copying a param with shape torch.Size([145, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([40, 1024, 1, 1]).

If trained without --resume_path weights/yowo_ucf101-24_16f_best.pth option, the model fails to converge. Is there a better way to deal with this problem? Thanks