Chiaraplizz / ST-TR

Spatial Temporal Transformer Network for Skeleton-Based Activity Recognition
MIT License
294 stars 57 forks source link

Problems loading checkpoints #17

Closed bszczapa closed 3 years ago

bszczapa commented 3 years ago

Hi Chiara,

I am trying to load the file "checkpoint_ST-TR/ntu120_xsub_spatial.pt" using the following command:

model.load_state_dict(weights)

and I get an error telling me that I have missing keys and unexpected keys in state_dict. From the trace, I see that all the expected keys start with "module.key_name", but when I load and print the keys in the file I don't have this "module." in front of each keys. All the keys look correct but I don't know where this "module." come from. Have you encounter this problem before and do you know how to load this model ?

Thanks in advance Benjamin

Here is the full trace:

RuntimeError: Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict: "module.data_bn.weight", "module.data_bn.bias", "module.data_bn.running_mean", "module.data_bn.running_var", "module.backbone.0.gcn1.conv_list.0.weight", "module.backbone.0.gcn1.conv_list.0.bias", "module.backbone.0.gcn1.conv_list.1.weight", "module.backbone.0.gcn1.conv_list.1.bias", "module.backbone.0.gcn1.conv_list.2.weight", "module.backbone.0.gcn1.conv_list.2.bias", "module.backbone.0.gcn1.bn.weight", "module.backbone.0.gcn1.bn.bias", "module.backbone.0.gcn1.bn.running_mean", "module.backbone.0.gcn1.bn.running_var", "module.backbone.0.tcn1.conv.weight", "module.backbone.0.tcn1.conv.bias", "module.backbone.0.tcn1.bn.weight", "module.backbone.0.tcn1.bn.bias", "module.backbone.0.tcn1.bn.running_mean", "module.backbone.0.tcn1.bn.running_var", "module.backbone.1.gcn1.conv_list.0.weight", "module.backbone.1.gcn1.conv_list.0.bias", "module.backbone.1.gcn1.conv_list.1.weight", "module.backbone.1.gcn1.conv_list.1.bias", "module.backbone.1.gcn1.conv_list.2.weight", "module.backbone.1.gcn1.conv_list.2.bias", "module.backbone.1.gcn1.bn.weight", "module.backbone.1.gcn1.bn.bias", "module.backbone.1.gcn1.bn.running_mean", "module.backbone.1.gcn1.bn.running_var", "module.backbone.1.tcn1.conv.weight", "module.backbone.1.tcn1.conv.bias", "module.backbone.1.tcn1.bn.weight", "module.backbone.1.tcn1.bn.bias", "module.backbone.1.tcn1.bn.running_mean", "module.backbone.1.tcn1.bn.running_var", "module.backbone.2.gcn1.conv_list.0.weight", "module.backbone.2.gcn1.conv_list.0.bias", "module.backbone.2.gcn1.conv_list.1.weight", "module.backbone.2.gcn1.conv_list.1.bias", "module.backbone.2.gcn1.conv_list.2.weight", "module.backbone.2.gcn1.conv_list.2.bias", "module.backbone.2.gcn1.bn.weight", "module.backbone.2.gcn1.bn.bias", "module.backbone.2.gcn1.bn.running_mean", "module.backbone.2.gcn1.bn.running_var", "module.backbone.2.tcn1.conv.weight", "module.backbone.2.tcn1.conv.bias", "module.backbone.2.tcn1.bn.weight", "module.backbone.2.tcn1.bn.bias", "module.backbone.2.tcn1.bn.running_mean", "module.backbone.2.tcn1.bn.running_var", "module.backbone.3.gcn1.data_bn.weight", "module.backbone.3.gcn1.data_bn.bias", "module.backbone.3.gcn1.data_bn.running_mean", "module.backbone.3.gcn1.data_bn.running_var", "module.backbone.3.gcn1.bn.weight", "module.backbone.3.gcn1.bn.bias", "module.backbone.3.gcn1.bn.running_mean", "module.backbone.3.gcn1.bn.running_var", "module.backbone.3.gcn1.attention_conv.qkv_conv.weight", "module.backbone.3.gcn1.attention_conv.qkv_conv.bias", "module.backbone.3.gcn1.attention_conv.attn_out.weight", "module.backbone.3.gcn1.attention_conv.attn_out.bias", "module.backbone.3.tcn1.conv.weight", "module.backbone.3.tcn1.conv.bias", "module.backbone.3.tcn1.bn.weight", "module.backbone.3.tcn1.bn.bias", "module.backbone.3.tcn1.bn.running_mean", "module.backbone.3.tcn1.bn.running_var", "module.backbone.3.down1.conv.weight", "module.backbone.3.down1.conv.bias", "module.backbone.3.down1.bn.weight", "module.backbone.3.down1.bn.bias", "module.backbone.3.down1.bn.running_mean", "module.backbone.3.down1.bn.running_var", "module.backbone.4.gcn1.data_bn.weight", "module.backbone.4.gcn1.data_bn.bias", "module.backbone.4.gcn1.data_bn.running_mean", "module.backbone.4.gcn1.data_bn.running_var", "module.backbone.4.gcn1.bn.weight", "module.backbone.4.gcn1.bn.bias", "module.backbone.4.gcn1.bn.running_mean", "module.backbone.4.gcn1.bn.running_var", "module.backbone.4.gcn1.attention_conv.qkv_conv.weight", "module.backbone.4.gcn1.attention_conv.qkv_conv.bias", "module.backbone.4.gcn1.attention_conv.attn_out.weight", "module.backbone.4.gcn1.attention_conv.attn_out.bias", "module.backbone.4.tcn1.conv.weight", "module.backbone.4.tcn1.conv.bias", "module.backbone.4.tcn1.bn.weight", "module.backbone.4.tcn1.bn.bias", "module.backbone.4.tcn1.bn.running_mean", "module.backbone.4.tcn1.bn.running_var", "module.backbone.5.gcn1.data_bn.weight", "module.backbone.5.gcn1.data_bn.bias", "module.backbone.5.gcn1.data_bn.running_mean", "module.backbone.5.gcn1.data_bn.running_var", "module.backbone.5.gcn1.bn.weight", "module.backbone.5.gcn1.bn.bias", "module.backbone.5.gcn1.bn.running_mean", "module.backbone.5.gcn1.bn.running_var", "module.backbone.5.gcn1.attention_conv.qkv_conv.weight", "module.backbone.5.gcn1.attention_conv.qkv_conv.bias", "module.backbone.5.gcn1.attention_conv.attn_out.weight", "module.backbone.5.gcn1.attention_conv.attn_out.bias", "module.backbone.5.tcn1.conv.weight", "module.backbone.5.tcn1.conv.bias", "module.backbone.5.tcn1.bn.weight", "module.backbone.5.tcn1.bn.bias", "module.backbone.5.tcn1.bn.running_mean", "module.backbone.5.tcn1.bn.running_var", "module.backbone.6.gcn1.data_bn.weight", "module.backbone.6.gcn1.data_bn.bias", "module.backbone.6.gcn1.data_bn.running_mean", "module.backbone.6.gcn1.data_bn.running_var", "module.backbone.6.gcn1.bn.weight", "module.backbone.6.gcn1.bn.bias", "module.backbone.6.gcn1.bn.running_mean", "module.backbone.6.gcn1.bn.running_var", "module.backbone.6.gcn1.attention_conv.qkv_conv.weight", "module.backbone.6.gcn1.attention_conv.qkv_conv.bias", "module.backbone.6.gcn1.attention_conv.attn_out.weight", "module.backbone.6.gcn1.attention_conv.attn_out.bias", "module.backbone.6.tcn1.conv.weight", "module.backbone.6.tcn1.conv.bias", "module.backbone.6.tcn1.bn.weight", "module.backbone.6.tcn1.bn.bias", "module.backbone.6.tcn1.bn.running_mean", "module.backbone.6.tcn1.bn.running_var", "module.backbone.6.down1.conv.weight", "module.backbone.6.down1.conv.bias", "module.backbone.6.down1.bn.weight", "module.backbone.6.down1.bn.bias", "module.backbone.6.down1.bn.running_mean", "module.backbone.6.down1.bn.running_var", "module.backbone.7.gcn1.data_bn.weight", "module.backbone.7.gcn1.data_bn.bias", "module.backbone.7.gcn1.data_bn.running_mean", "module.backbone.7.gcn1.data_bn.running_var", "module.backbone.7.gcn1.bn.weight", "module.backbone.7.gcn1.bn.bias", "module.backbone.7.gcn1.bn.running_mean", "module.backbone.7.gcn1.bn.running_var", "module.backbone.7.gcn1.attention_conv.qkv_conv.weight", "module.backbone.7.gcn1.attention_conv.qkv_conv.bias", "module.backbone.7.gcn1.attention_conv.attn_out.weight", "module.backbone.7.gcn1.attention_conv.attn_out.bias", "module.backbone.7.tcn1.conv.weight", "module.backbone.7.tcn1.conv.bias", "module.backbone.7.tcn1.bn.weight", "module.backbone.7.tcn1.bn.bias", "module.backbone.7.tcn1.bn.running_mean", "module.backbone.7.tcn1.bn.running_var", "module.backbone.8.gcn1.data_bn.weight", "module.backbone.8.gcn1.data_bn.bias", "module.backbone.8.gcn1.data_bn.running_mean", "module.backbone.8.gcn1.data_bn.running_var", "module.backbone.8.gcn1.bn.weight", "module.backbone.8.gcn1.bn.bias", "module.backbone.8.gcn1.bn.running_mean", "module.backbone.8.gcn1.bn.running_var", "module.backbone.8.gcn1.attention_conv.qkv_conv.weight", "module.backbone.8.gcn1.attention_conv.qkv_conv.bias", "module.backbone.8.gcn1.attention_conv.attn_out.weight", "module.backbone.8.gcn1.attention_conv.attn_out.bias", "module.backbone.8.tcn1.conv.weight", "module.backbone.8.tcn1.conv.bias", "module.backbone.8.tcn1.bn.weight", "module.backbone.8.tcn1.bn.bias", "module.backbone.8.tcn1.bn.running_mean", "module.backbone.8.tcn1.bn.running_var", "module.gcn0.conv_list.0.weight", "module.gcn0.conv_list.0.bias", "module.gcn0.conv_list.1.weight", "module.gcn0.conv_list.1.bias", "module.gcn0.conv_list.2.weight", "module.gcn0.conv_list.2.bias", "module.gcn0.bn.weight", "module.gcn0.bn.bias", "module.gcn0.bn.running_mean", "module.gcn0.bn.running_var", "module.tcn0.conv.weight", "module.tcn0.conv.bias", "module.tcn0.bn.weight", "module.tcn0.bn.bias", "module.tcn0.bn.running_mean", "module.tcn0.bn.running_var", "module.person_bn.weight", "module.person_bn.bias", "module.person_bn.running_mean", "module.person_bn.running_var", "module.fcn.weight", "module.fcn.bias". 

Unexpected key(s) in state_dict: "data_bn.weight", "data_bn.bias", "data_bn.running_mean", "data_bn.running_var", "data_bn.num_batches_tracked", "backbone.0.gcn1.mask", "backbone.0.gcn1.conv_list.0.weight", "backbone.0.gcn1.conv_list.0.bias", "backbone.0.gcn1.conv_list.1.weight", "backbone.0.gcn1.conv_list.1.bias", "backbone.0.gcn1.conv_list.2.weight", "backbone.0.gcn1.conv_list.2.bias", "backbone.0.gcn1.bn.weight", "backbone.0.gcn1.bn.bias", "backbone.0.gcn1.bn.running_mean", "backbone.0.gcn1.bn.running_var", "backbone.0.gcn1.bn.num_batches_tracked", "backbone.0.tcn1.conv.weight", "backbone.0.tcn1.conv.bias", "backbone.0.tcn1.bn.weight", "backbone.0.tcn1.bn.bias", "backbone.0.tcn1.bn.running_mean", "backbone.0.tcn1.bn.running_var", "backbone.0.tcn1.bn.num_batches_tracked", "backbone.1.gcn1.mask", "backbone.1.gcn1.conv_list.0.weight", "backbone.1.gcn1.conv_list.0.bias", "backbone.1.gcn1.conv_list.1.weight", "backbone.1.gcn1.conv_list.1.bias", "backbone.1.gcn1.conv_list.2.weight", "backbone.1.gcn1.conv_list.2.bias", "backbone.1.gcn1.bn.weight", "backbone.1.gcn1.bn.bias", "backbone.1.gcn1.bn.running_mean", "backbone.1.gcn1.bn.running_var", "backbone.1.gcn1.bn.num_batches_tracked", "backbone.1.tcn1.conv.weight", "backbone.1.tcn1.conv.bias", "backbone.1.tcn1.bn.weight", "backbone.1.tcn1.bn.bias", "backbone.1.tcn1.bn.running_mean", "backbone.1.tcn1.bn.running_var", "backbone.1.tcn1.bn.num_batches_tracked", "backbone.2.gcn1.mask", "backbone.2.gcn1.conv_list.0.weight", "backbone.2.gcn1.conv_list.0.bias", "backbone.2.gcn1.conv_list.1.weight", "backbone.2.gcn1.conv_list.1.bias", "backbone.2.gcn1.conv_list.2.weight", "backbone.2.gcn1.conv_list.2.bias", "backbone.2.gcn1.bn.weight", "backbone.2.gcn1.bn.bias", "backbone.2.gcn1.bn.running_mean", "backbone.2.gcn1.bn.running_var", "backbone.2.gcn1.bn.num_batches_tracked", "backbone.2.tcn1.conv.weight", "backbone.2.tcn1.conv.bias", "backbone.2.tcn1.bn.weight", "backbone.2.tcn1.bn.bias", "backbone.2.tcn1.bn.running_mean", "backbone.2.tcn1.bn.running_var", "backbone.2.tcn1.bn.num_batches_tracked", "backbone.3.gcn1.data_bn.weight", "backbone.3.gcn1.data_bn.bias", "backbone.3.gcn1.data_bn.running_mean", "backbone.3.gcn1.data_bn.running_var", "backbone.3.gcn1.data_bn.num_batches_tracked", "backbone.3.gcn1.bn.weight", "backbone.3.gcn1.bn.bias", "backbone.3.gcn1.bn.running_mean", "backbone.3.gcn1.bn.running_var", "backbone.3.gcn1.bn.num_batches_tracked", "backbone.3.gcn1.attention_conv.qkv_conv.weight", "backbone.3.gcn1.attention_conv.qkv_conv.bias", "backbone.3.gcn1.attention_conv.attn_out.weight", "backbone.3.gcn1.attention_conv.attn_out.bias", "backbone.3.tcn1.conv.weight", "backbone.3.tcn1.conv.bias", "backbone.3.tcn1.bn.weight", "backbone.3.tcn1.bn.bias", "backbone.3.tcn1.bn.running_mean", "backbone.3.tcn1.bn.running_var", "backbone.3.tcn1.bn.num_batches_tracked", "backbone.3.down1.conv.weight", "backbone.3.down1.conv.bias", "backbone.3.down1.bn.weight", "backbone.3.down1.bn.bias", "backbone.3.down1.bn.running_mean", "backbone.3.down1.bn.running_var", "backbone.3.down1.bn.num_batches_tracked", "backbone.4.gcn1.data_bn.weight", "backbone.4.gcn1.data_bn.bias", "backbone.4.gcn1.data_bn.running_mean", "backbone.4.gcn1.data_bn.running_var", "backbone.4.gcn1.data_bn.num_batches_tracked", "backbone.4.gcn1.bn.weight", "backbone.4.gcn1.bn.bias", "backbone.4.gcn1.bn.running_mean", "backbone.4.gcn1.bn.running_var", "backbone.4.gcn1.bn.num_batches_tracked", "backbone.4.gcn1.attention_conv.qkv_conv.weight", "backbone.4.gcn1.attention_conv.qkv_conv.bias", "backbone.4.gcn1.attention_conv.attn_out.weight", "backbone.4.gcn1.attention_conv.attn_out.bias", "backbone.4.tcn1.conv.weight", "backbone.4.tcn1.conv.bias", "backbone.4.tcn1.bn.weight", "backbone.4.tcn1.bn.bias", "backbone.4.tcn1.bn.running_mean", "backbone.4.tcn1.bn.running_var", "backbone.4.tcn1.bn.num_batches_tracked", "backbone.5.gcn1.data_bn.weight", "backbone.5.gcn1.data_bn.bias", "backbone.5.gcn1.data_bn.running_mean", "backbone.5.gcn1.data_bn.running_var", "backbone.5.gcn1.data_bn.num_batches_tracked", "backbone.5.gcn1.bn.weight", "backbone.5.gcn1.bn.bias", "backbone.5.gcn1.bn.running_mean", "backbone.5.gcn1.bn.running_var", "backbone.5.gcn1.bn.num_batches_tracked", "backbone.5.gcn1.attention_conv.qkv_conv.weight", "backbone.5.gcn1.attention_conv.qkv_conv.bias", "backbone.5.gcn1.attention_conv.attn_out.weight", "backbone.5.gcn1.attention_conv.attn_out.bias", "backbone.5.tcn1.conv.weight", "backbone.5.tcn1.conv.bias", "backbone.5.tcn1.bn.weight", "backbone.5.tcn1.bn.bias", "backbone.5.tcn1.bn.running_mean", "backbone.5.tcn1.bn.running_var", "backbone.5.tcn1.bn.num_batches_tracked", "backbone.6.gcn1.data_bn.weight", "backbone.6.gcn1.data_bn.bias", "backbone.6.gcn1.data_bn.running_mean", "backbone.6.gcn1.data_bn.running_var", "backbone.6.gcn1.data_bn.num_batches_tracked", "backbone.6.gcn1.bn.weight", "backbone.6.gcn1.bn.bias", "backbone.6.gcn1.bn.running_mean", "backbone.6.gcn1.bn.running_var", "backbone.6.gcn1.bn.num_batches_tracked", "backbone.6.gcn1.attention_conv.qkv_conv.weight", "backbone.6.gcn1.attention_conv.qkv_conv.bias", "backbone.6.gcn1.attention_conv.attn_out.weight", "backbone.6.gcn1.attention_conv.attn_out.bias", "backbone.6.tcn1.conv.weight", "backbone.6.tcn1.conv.bias", "backbone.6.tcn1.bn.weight", "backbone.6.tcn1.bn.bias", "backbone.6.tcn1.bn.running_mean", "backbone.6.tcn1.bn.running_var", "backbone.6.tcn1.bn.num_batches_tracked", "backbone.6.down1.conv.weight", "backbone.6.down1.conv.bias", "backbone.6.down1.bn.weight", "backbone.6.down1.bn.bias", "backbone.6.down1.bn.running_mean", "backbone.6.down1.bn.running_var", "backbone.6.down1.bn.num_batches_tracked", "backbone.7.gcn1.data_bn.weight", "backbone.7.gcn1.data_bn.bias", "backbone.7.gcn1.data_bn.running_mean", "backbone.7.gcn1.data_bn.running_var", "backbone.7.gcn1.data_bn.num_batches_tracked", "backbone.7.gcn1.bn.weight", "backbone.7.gcn1.bn.bias", "backbone.7.gcn1.bn.running_mean", "backbone.7.gcn1.bn.running_var", "backbone.7.gcn1.bn.num_batches_tracked", "backbone.7.gcn1.attention_conv.qkv_conv.weight", "backbone.7.gcn1.attention_conv.qkv_conv.bias", "backbone.7.gcn1.attention_conv.attn_out.weight", "backbone.7.gcn1.attention_conv.attn_out.bias", "backbone.7.tcn1.conv.weight", "backbone.7.tcn1.conv.bias", "backbone.7.tcn1.bn.weight", "backbone.7.tcn1.bn.bias", "backbone.7.tcn1.bn.running_mean", "backbone.7.tcn1.bn.running_var", "backbone.7.tcn1.bn.num_batches_tracked", "backbone.8.gcn1.data_bn.weight", "backbone.8.gcn1.data_bn.bias", "backbone.8.gcn1.data_bn.running_mean", "backbone.8.gcn1.data_bn.running_var", "backbone.8.gcn1.data_bn.num_batches_tracked", "backbone.8.gcn1.bn.weight", "backbone.8.gcn1.bn.bias", "backbone.8.gcn1.bn.running_mean", "backbone.8.gcn1.bn.running_var", "backbone.8.gcn1.bn.num_batches_tracked", "backbone.8.gcn1.attention_conv.qkv_conv.weight", "backbone.8.gcn1.attention_conv.qkv_conv.bias", "backbone.8.gcn1.attention_conv.attn_out.weight", "backbone.8.gcn1.attention_conv.attn_out.bias", "backbone.8.tcn1.conv.weight", "backbone.8.tcn1.conv.bias", "backbone.8.tcn1.bn.weight", "backbone.8.tcn1.bn.bias", "backbone.8.tcn1.bn.running_mean", "backbone.8.tcn1.bn.running_var", "backbone.8.tcn1.bn.num_batches_tracked", "gcn0.mask", "gcn0.conv_list.0.weight", "gcn0.conv_list.0.bias", "gcn0.conv_list.1.weight", "gcn0.conv_list.1.bias", "gcn0.conv_list.2.weight", "gcn0.conv_list.2.bias", "gcn0.bn.weight", "gcn0.bn.bias", "gcn0.bn.running_mean", "gcn0.bn.running_var", "gcn0.bn.num_batches_tracked", "tcn0.conv.weight", "tcn0.conv.bias", "tcn0.bn.weight", "tcn0.bn.bias", "tcn0.bn.running_mean", "tcn0.bn.running_var", "tcn0.bn.num_batches_tracked", "person_bn.weight", "person_bn.bias", "person_bn.running_mean", "person_bn.running_var", "person_bn.num_batches_tracked", "fcn.weight", "fcn.bias". 
Chiaraplizz commented 3 years ago

Hi :)

The module prefix is added by DataParallel. You should make sure you are using DataParallel (l. 357) AFTER loading the weights.

Chiara

bszczapa commented 3 years ago

Hi,

Ok, I never thought of that, thank you. I have another question, not related to this one: how do you generate the heatmaps to visualize the attention ? I search in the project and I don't see a code to do that. I would like to generate the same heatmaps for a dataset I use.

Thanks. Benjamin

Chiaraplizz commented 3 years ago

Hi,

Ok, I never thought of that, thank you. I have another question, not related to this one: how do you generate the heatmaps to visualize the attention ? I search in the project and I don't see a code to do that. I would like to generate the same heatmaps for a dataset I use.

Thanks. Benjamin

Dear Benjamin, can you send me an e-mail? So that I can share the code with you.

Chiara