bmartacho / UniPose

We propose UniPose, a unified framework for human pose estimation, based on our “Waterfall” Atrous Spatial Pooling architecture, that achieves state-of-art-results on several pose estimation metrics. Current pose estimation methods utilizing standard CNN architectures heavily rely on statistical postprocessing or predefined anchor poses for joint localization. UniPose incorporates contextual seg- mentation and joint localization to estimate the human pose in a single stage, with high accuracy, without relying on statistical postprocessing methods. The Waterfall module in UniPose leverages the efficiency of progressive filter- ing in the cascade architecture, while maintaining multi- scale fields-of-view comparable to spatial pyramid config- urations. Additionally, our method is extended to UniPose- LSTM for multi-frame processing and achieves state-of-the- art results for temporal pose estimation in Video. Our re- sults on multiple datasets demonstrate that UniPose, with a ResNet backbone and Waterfall module, is a robust and efficient architecture for pose estimation obtaining state-of- the-art results in single person pose detection for both sin- gle images and videos.
Other
211 stars 44 forks source link

Test pretrained model #6

Closed NguyenVanThanhHust closed 3 years ago

NguyenVanThanhHust commented 3 years ago

Thank for your work.

I'm using your pretrained models to test with my dataset. I encounter an error. When i use model for LSP dataset, i run command: CUDA_VISIBLE_DEVICES=1 python3.7 test.py --dataset LSP --pretrained UniPose_LSP.tar --img_folder my_dataset I get error:

Traceback (most recent call last):
  File "test_hand.py", line 166, in <module>
    trainer = Trainer(args)
  File "test_hand.py", line 89, in __init__
    self.model.load_state_dict(state_dict)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 830, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for unipose:
    size mismatch for decoder.last_conv.8.weight: copying a param with shape torch.Size([15, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([20, 256, 1, 1]).
    size mismatch for decoder.last_conv.8.bias: copying a param with shape torch.Size([15]) from checkpoint, the shape in current model is torch.Size([20]

I change to MPII dataset with command: CUDA_VISIBLE_DEVICES=1 python3.7 test.py --dataset MPII --pretrained UniPose_MPII.tar --img_folder my_dataset then i get:

Traceback (most recent call last):
  File "test_hand.py", line 167, in <module>
    trainer = Trainer(args)
  File "test_hand.py", line 89, in __init__
    self.model.load_state_dict(state_dict)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 830, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for unipose:
    size mismatch for decoder.last_conv.8.weight: copying a param with shape torch.Size([17, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([22, 256, 1, 1]).
    size mismatch for decoder.last_conv.8.bias: copying a param with shape torch.Size([17]) from checkpoint, the shape in current model is torch.Size([22]).

My test.py is just unipose.py with new args and remove train, validatation code.

Can you check your pretrained models or maybe what did i do wrong?

bmartacho commented 3 years ago

Thank you for the feedback and interest on our work.

When loading weights from the pretrained model, dimensions should match. In case of applying our model to other dataset with different number of classes, the last layer should be ignored when copying weights: https://pytorch.org/tutorials/beginner/saving_loading_models.html.

E.g.

checkpoint = torch.load(checkpoint_file)
model_state_dict = model.state_dict()
new_model_state_dict = `{}`
for k in model_state_dict:
    if k in checkpoint[‘state_dict’] and model_state_dict[k].size() == checkpoint[‘state_dict’][k].size():
        new_model_state_dict[k] = checkpoint[‘state_dict’][k]
    else:
        print(‘Skipped loading parameter {}’.format(k))
model.load_state_dict(new_model_state_dict, strict=False)