Walter0807 / MotionBERT

[ICCV 2023] PyTorch Implementation of "MotionBERT: A Unified Perspective on Learning Human Motion Representations"
Apache License 2.0
1.02k stars 123 forks source link

Infer_wild.py fails with the provided checkpoint. #75

Closed aarshp closed 1 year ago

aarshp commented 1 year ago

Hi I was trying to run the Infer_wild as per https://github.com/Walter0807/MotionBERT/blob/main/docs/inference.md and I ran into some kind of model mismatch error with the provided checkpoint for 3D Pose. I got the following error,

Loading checkpoint checkpoint/pose3d/FT_MB_lite_MB_ft_h36m_global_lite/best_epoch.bin Traceback (most recent call last): File "infer_wild.py", line 38, in <module> model_backbone.load_state_dict(checkpoint['model_pos'], strict=True) File "C:\Users\aarsagar\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\modules\module.py", line 2041, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for DSTformer: Missing key(s) in state_dict: "temp_embed", "pos_embed", "joints_embed.weight", "joints_embed.bias", "blocks_st.0.norm1_s.weight", "blocks_st.0.norm1_s.bias", "blocks_st.0.norm1_t.weight", "blocks_st.0.norm1_t.bias", "blocks_st.0.attn_s.proj.weight", "blocks_st.0.attn_s.proj.bias", "blocks_st.0.attn_s.qkv.weight", "blocks_st.0.attn_s.qkv.bias", "blocks_st.0.attn_t.proj.weight", "blocks_st.0.attn_t.proj.bias", "blocks_st.0.attn_t.qkv.weight", "blocks_st.0.attn_t.qkv.bias", "blocks_st.0.norm2_s.weight", "blocks_st.0.norm2_s.bias", "blocks_st.0.norm2_t.weight", "blocks_st.0.norm2_t.bias", "blocks_st.0.mlp_s.fc1.weight", "blocks_st.0.mlp_s.fc1.bias", "blocks_st.0.mlp_s.fc2.weight", "blocks_st.0.mlp_s.fc2.bias", "blocks_st.0.mlp_t.fc1.weight", "blocks_st.0.mlp_t.fc1.bias", "blocks_st.0.mlp_t.fc2.weight", "blocks_st.0.mlp_t.fc2.bias", "blocks_st.1.norm1_s.weight", "blocks_st.1.norm1_s.bias", "blocks_st.1.norm1_t.weight", "blocks_st.1.norm1_t.bias", "blocks_st.1.attn_s.proj.weight", "blocks_st.1.attn_s.proj.bias", "blocks_st.1.attn_s.qkv.weight", "blocks_st.1.attn_s.qkv.bias", "blocks_st.1.attn_t.proj.weight", "blocks_st.1.attn_t.proj.bias", "blocks_st.1.attn_t.qkv.weight", "blocks_st.1.attn_t.qkv.bias", "blocks_st.1.norm2_s.weight", "blocks_st.1.norm2_s.bias", "blocks_st.1.norm2_t.weight", "blocks_st.1.norm2_t.bias", "blocks_st.1.mlp_s.fc1.weight", "blocks_st.1.mlp_s.fc1.bias", "blocks_st.1.mlp_s.fc2.weight", "blocks_st.1.mlp_s.fc2.bias", "blocks_st.1.mlp_t.fc1.weight", "blocks_st.1.mlp_t.fc1.bias", "blocks_st.1.mlp_t.fc2.weight", "blocks_st.1.mlp_t.fc2.bias", "blocks_st.2.norm1_s.weight", "blocks_st.2.norm1_s.bias", "blocks_st.2.norm1_t.weight", "blocks_st.2.norm1_t.bias", "blocks_st.2.attn_s.proj.weight", "blocks_st.2.attn_s.proj.bias", "blocks_st.2.attn_s.qkv.weight", "blocks_st.2.attn_s.qkv.bias", "blocks_st.2.attn_t.proj.weight", "blocks_st.2.attn_t.proj.bias", "blocks_st.2.attn_t.qkv.weight", "blocks_st.2.attn_t.qkv.bias", "blocks_st.2.norm2_s.weight", "blocks_st.2.norm2_s.bias", "blocks_st.2.norm2_t.weight", "blocks_st.2.norm2_t.bias", "blocks_st.2.mlp_s.fc1.weight", "blocks_st.2.mlp_s.fc1.bias", "blocks_st.2.mlp_s.fc2.weight", "blocks_st.2.mlp_s.fc2.bias", "blocks_st.2.mlp_t.fc1.weight", "blocks_st.2.mlp_t.fc1.bias", "blocks_st.2.mlp_t.fc2.weight", "blocks_st.2.mlp_t.fc2.bias", "blocks_st.3.norm1_s.weight", "blocks_st.3.norm1_s.bias", "blocks_st.3.norm1_t.weight", "blocks_st.3.norm1_t.bias", "blocks_st.3.attn_s.proj.weight", "blocks_st.3.attn_s.proj.bias", "blocks_st.3.attn_s.qkv.weight", "blocks_st.3.attn_s.qkv.bias", "blocks_st.3.attn_t.proj.weight", "blocks_st.3.attn_t.proj.bias", "blocks_st.3.attn_t.qkv.weight", "blocks_st.3.attn_t.qkv.bias", "blocks_st.3.norm2_s.weight", "blocks_st.3.norm2_s.bias", "blocks_st.3.norm2_t.weight", "blocks_st.3.norm2_t.bias", "blocks_st.3.mlp_s.fc1.weight", "blocks_st.3.mlp_s.fc1.bias", "blocks_st.3.mlp_s.fc2.weight", "blocks_st.3.mlp_s.fc2.bias", "blocks_st.3.mlp_t.fc1.weight", "blocks_st.3.mlp_t.fc1.bias", "blocks_st.3.mlp_t.fc2.weight", "blocks_st.3.mlp_t.fc2.bias", "blocks_st.4.norm1_s.weight", "blocks_st.4.norm1_s.bias", "blocks_st.4.norm1_t.weight", "blocks_st.4.norm1_t.bias", "blocks_st.4.attn_s.proj.weight", "blocks_st.4.attn_s.proj.bias", "blocks_st.4.attn_s.qkv.weight", "blocks_st.4.attn_s.qkv.bias", "blocks_st.4.attn_t.proj.weight", "blocks_st.4.attn_t.proj.bias", "blocks_st.4.attn_t.qkv.weight", "blocks_st.4.attn_t.qkv.bias", "blocks_st.4.norm2_s.weight", "blocks_st.4.norm2_s.bias", "blocks_st.4.norm2_t.weight", "blocks_st.4.norm2_t.bias", "blocks_st.4.mlp_s.fc1.weight", "blocks_st.4.mlp_s.fc1.bias", "blocks_st.4.mlp_s.fc2.weight", "blocks_st.4.mlp_s.fc2.bias", "blocks_st.4.mlp_t.fc1.weight", "blocks_st.4.mlp_t.fc1.bias", "blocks_st.4.mlp_t.fc2.weight", "blocks_st.4.mlp_t.fc2.bias", "blocks_ts.0.norm1_s.weight", "blocks_ts.0.norm1_s.bias", "blocks_ts.0.norm1_t.weight", "blocks_ts.0.norm1_t.bias", "blocks_ts.0.attn_s.proj.weight", "blocks_ts.0.attn_s.proj.bias", "blocks_ts.0.attn_s.qkv.weight", "blocks_ts.0.attn_s.qkv.bias", "blocks_ts.0.attn_t.proj.weight", "blocks_ts.0.attn_t.proj.bias", "blocks_ts.0.attn_t.qkv.weight", "blocks_ts.0.attn_t.qkv.bias", "blocks_ts.0.norm2_s.weight", "blocks_ts.0.norm2_s.bias", "blocks_ts.0.norm2_t.weight", "blocks_ts.0.norm2_t.bias", "blocks_ts.0.mlp_s.fc1.weight", "blocks_ts.0.mlp_s.fc1.bias", "blocks_ts.0.mlp_s.fc2.weight", "blocks_ts.0.mlp_s.fc2.bias", "blocks_ts.0.mlp_t.fc1.weight", "blocks_ts.0.mlp_t.fc1.bias", "blocks_ts.0.mlp_t.fc2.weight", "blocks_ts.0.mlp_t.fc2.bias", "blocks_ts.1.norm1_s.weight", "blocks_ts.1.norm1_s.bias", "blocks_ts.1.norm1_t.weight", "blocks_ts.1.norm1_t.bias", "blocks_ts.1.attn_s.proj.weight", "blocks_ts.1.attn_s.proj.bias", "blocks_ts.1.attn_s.qkv.weight", "blocks_ts.1.attn_s.qkv.bias", "blocks_ts.1.attn_t.proj.weight", "blocks_ts.1.attn_t.proj.bias", "blocks_ts.1.attn_t.qkv.weight", "blocks_ts.1.attn_t.qkv.bias", "blocks_ts.1.norm2_s.weight", "blocks_ts.1.norm2_s.bias", "blocks_ts.1.norm2_t.weight", "blocks_ts.1.norm2_t.bias", "blocks_ts.1.mlp_s.fc1.weight", "blocks_ts.1.mlp_s.fc1.bias", "blocks_ts.1.mlp_s.fc2.weight", "blocks_ts.1.mlp_s.fc2.bias", "blocks_ts.1.mlp_t.fc1.weight", "blocks_ts.1.mlp_t.fc1.bias", "blocks_ts.1.mlp_t.fc2.weight", "blocks_ts.1.mlp_t.fc2.bias", "blocks_ts.2.norm1_s.weight", "blocks_ts.2.norm1_s.bias", "blocks_ts.2.norm1_t.weight", "blocks_ts.2.norm1_t.bias", "blocks_ts.2.attn_s.proj.weight", "blocks_ts.2.attn_s.proj.bias", "blocks_ts.2.attn_s.qkv.weight", "blocks_ts.2.attn_s.qkv.bias", "blocks_ts.2.attn_t.proj.weight", "blocks_ts.2.attn_t.proj.bias", "blocks_ts.2.attn_t.qkv.weight", "blocks_ts.2.attn_t.qkv.bias", "blocks_ts.2.norm2_s.weight", "blocks_ts.2.norm2_s.bias", "blocks_ts.2.norm2_t.weight", "blocks_ts.2.norm2_t.bias", "blocks_ts.2.mlp_s.fc1.weight", "blocks_ts.2.mlp_s.fc1.bias", "blocks_ts.2.mlp_s.fc2.weight", "blocks_ts.2.mlp_s.fc2.bias", "blocks_ts.2.mlp_t.fc1.weight", "blocks_ts.2.mlp_t.fc1.bias", "blocks_ts.2.mlp_t.fc2.weight", "blocks_ts.2.mlp_t.fc2.bias", "blocks_ts.3.norm1_s.weight", "blocks_ts.3.norm1_s.bias", "blocks_ts.3.norm1_t.weight", "blocks_ts.3.norm1_t.bias", "blocks_ts.3.attn_s.proj.weight", "blocks_ts.3.attn_s.proj.bias", "blocks_ts.3.attn_s.qkv.weight", "blocks_ts.3.attn_s.qkv.bias", "blocks_ts.3.attn_t.proj.weight", "blocks_ts.3.attn_t.proj.bias", "blocks_ts.3.attn_t.qkv.weight", "blocks_ts.3.attn_t.qkv.bias", "blocks_ts.3.norm2_s.weight", "blocks_ts.3.norm2_s.bias", "blocks_ts.3.norm2_t.weight", "blocks_ts.3.norm2_t.bias", "blocks_ts.3.mlp_s.fc1.weight", "blocks_ts.3.mlp_s.fc1.bias", "blocks_ts.3.mlp_s.fc2.weight", "blocks_ts.3.mlp_s.fc2.bias", "blocks_ts.3.mlp_t.fc1.weight", "blocks_ts.3.mlp_t.fc1.bias", "blocks_ts.3.mlp_t.fc2.weight", "blocks_ts.3.mlp_t.fc2.bias", "blocks_ts.4.norm1_s.weight", "blocks_ts.4.norm1_s.bias", "blocks_ts.4.norm1_t.weight", "blocks_ts.4.norm1_t.bias", "blocks_ts.4.attn_s.proj.weight", "blocks_ts.4.attn_s.proj.bias", "blocks_ts.4.attn_s.qkv.weight", "blocks_ts.4.attn_s.qkv.bias", "blocks_ts.4.attn_t.proj.weight", "blocks_ts.4.attn_t.proj.bias", "blocks_ts.4.attn_t.qkv.weight", "blocks_ts.4.attn_t.qkv.bias", "blocks_ts.4.norm2_s.weight", "blocks_ts.4.norm2_s.bias", "blocks_ts.4.norm2_t.weight", "blocks_ts.4.norm2_t.bias", "blocks_ts.4.mlp_s.fc1.weight", "blocks_ts.4.mlp_s.fc1.bias", "blocks_ts.4.mlp_s.fc2.weight", "blocks_ts.4.mlp_s.fc2.bias", "blocks_ts.4.mlp_t.fc1.weight", "blocks_ts.4.mlp_t.fc1.bias", "blocks_ts.4.mlp_t.fc2.weight", "blocks_ts.4.mlp_t.fc2.bias", "norm.weight", "norm.bias", "pre_logits.fc.weight", "pre_logits.fc.bias", "head.weight", "head.bias", "ts_attn.0.weight", "ts_attn.0.bias", "ts_attn.1.weight", "ts_attn.1.bias", "ts_attn.2.weight", "ts_attn.2.bias", "ts_attn.3.weight", "ts_attn.3.bias", "ts_attn.4.weight", "ts_attn.4.bias".

Can you pls help and check if the expected checkpoint was provided? Thanks

Walter0807 commented 1 year ago

https://github.com/Walter0807/MotionBERT/issues/24

aarshp commented 1 year ago

Thanks for pointing it out!

abhaydoke09 commented 5 months ago

Here is a fix for https://github.com/Walter0807/MotionBERT/blob/1839f099ce9f128342c8f5499478ace328c0df4a/infer_wild.py#L37-L38

original_state_dict = checkpoint['model_pos']

# Remove the 'module.' prefix
new_state_dict = {k[len("module."):]: v for k, v in original_state_dict.items()}

# Load new_state_dict instead of checkpoint['model_pos']
model_backbone.load_state_dict(new_state_dict, strict=True)