fabro66 / GAST-Net-3DPoseEstimation

A Graph Attention Spatio-temporal Convolutional Networks for 3D Human Pose Estimation in Video (GAST-Net)
MIT License
312 stars 70 forks source link

Causal model? #40

Closed emredog closed 3 years ago

emredog commented 3 years ago

Hello,

Thanks for sharing this inspiring work :pray:

I'm interested in considering only the past N frames (receptive field) for predicting the current frame. I believe this corresponds to the causal model: Screenshot from 2021-04-20 17-12-40

And I see that in the code that there are models named with *_causal suffix.

But then again, when I create a SpatioTemporalModelOptimized1f object with causal=True, I can still load the "regular" checkpoint files, e.g. 27_frame_model_toe.bin

Question

  1. Are the weights 27_frame_model.bin and 27_frame_model_causal.bin identical, or at least interchangeable?
  2. If not, could you maybe share it?

Looking forward to your answer, many thanks in advance!

fabro66 commented 3 years ago

Hi~ Thank you for your interest in our works.

  1. The weights 27_frame_model.bin and 27_frame_model_causal.bin are not identical.

  2. I just uploaded 81_frame_model_causal.bin in GoogleDrive. Please check it.

emredog commented 3 years ago

Thank you for your quick response @fabro66 !

I downloaded the causal model, but can't seem to load it.

I'm using the reconstruction.py script with the baseball example.

The following configurations work flawlessly:

python reconstruction.py -w 27_frame_model_toe.bin -n 19 -k ./data/keypoints/baseball_wholebody.json -kf wholebody
python reconstruction.py -w 27_frame_model.bin -n 17 -k ./data/keypoints/baseball.json -kf coco
python reconstruction.py -f 81 -w 81_frame_model.bin -n 17 -k ./data/keypoints/baseball.json -kf coco

However, I wasn't able to set the initialize the model properly for the the checkpoint file that you kindly provided. I tried:

model_pos = SpatioTemporalModel(adj=adj, num_joints_in=17, in_features=2, num_joints_out=17, 
                      filter_widths=[3, 3, 3, 3], channels=64, dropout=0.05, causal=True)
checkpoint = torch.load('./checkpoint/gastnet/81_frame_model_causal.bin', map_location=lambda storage, loc: storage)
model_pos.load_state_dict(checkpoint['model_pos'])

but I got

Error(s) in loading state_dict for SpatioTemporalModel:
    Unexpected key(s) in state_dict: "layers_graph_conv.0.local_graph_layer.gcn_sym.bias", "layers_graph_conv.0.local_graph_layer.gcn_con.bias", "layers_graph_conv.1.local_graph_layer.gcn_sym.bias", "layers_graph_conv.1.local_graph_layer.gcn_con.bias", "layers_graph_conv.2.local_graph_layer.gcn_sym.bias", "layers_graph_conv.2.local_graph_layer.gcn_con.bias", "layers_graph_conv.3.local_graph_layer.gcn_sym.bias", "layers_graph_conv.3.local_graph_layer.gcn_con.bias". 
    size mismatch for layers_graph_conv.0.local_graph_layer.gcn_sym.e: copying a param with shape torch.Size([1, 29]) from checkpoint, the shape in current model is torch.Size([64, 29]).
    size mismatch for layers_graph_conv.0.local_graph_layer.gcn_con.e: copying a param with shape torch.Size([1, 54]) from checkpoint, the shape in current model is torch.Size([64, 54]).
    size mismatch for layers_graph_conv.1.local_graph_layer.gcn_sym.e: copying a param with shape torch.Size([1, 29]) from checkpoint, the shape in current model is torch.Size([128, 29]).
    size mismatch for layers_graph_conv.1.local_graph_layer.gcn_con.e: copying a param with shape torch.Size([1, 54]) from checkpoint, the shape in current model is torch.Size([128, 54]).
    size mismatch for layers_graph_conv.2.local_graph_layer.gcn_sym.e: copying a param with shape torch.Size([1, 29]) from checkpoint, the shape in current model is torch.Size([256, 29]).
    size mismatch for layers_graph_conv.2.local_graph_layer.gcn_con.e: copying a param with shape torch.Size([1, 54]) from checkpoint, the shape in current model is torch.Size([256, 54]).
    size mismatch for layers_graph_conv.3.local_graph_layer.gcn_sym.e: copying a param with shape torch.Size([1, 29]) from checkpoint, the shape in current model is torch.Size([512, 29]).
    size mismatch for layers_graph_conv.3.local_graph_layer.gcn_con.e: copying a param with shape torch.Size([1, 54]) from checkpoint, the shape in current model is torch.Size([512, 54]).
  File "/home/emredog/git/gastnet_lindera/reconstruction.py", line 263, in reconstruction
    model_pos.load_state_dict(checkpoint['model_pos'])
  File "/home/emredog/git/gastnet_lindera/reconstruction.py", line 300, in <module>
    reconstruction(args)

So it looked like a one frame / many frame issue, and I tried the same parameters with SpatioTemporalModelOptimized1f, but I got the same error.

I tried the same two, but this time setting causal=False, still didn't work.

I also tried to use it for the wholebody skeleton with 19 joints, just in case, but of course the errors got worse.

Looking at the error, I'm almost sure that I'm not setting the model right. But I can't figure it out.

Could you maybe help me initialize the model in the correct way so that it can load the causal weights that you provided?

Thanks in advance!

fabro66 commented 3 years ago

Hi~
I have uploaded 27_frame_model_causal.bin in Google drive. Please update the reconstruction.py file and reproduce the baseball sample by running the following commad:

python reconstruction.py -w 27_frame_model_causal.bin --causal 
emredog commented 3 years ago

Works perfect. Thank you very much for your help!