Chiaraplizz / ST-TR

Spatial Temporal Transformer Network for Skeleton-Based Activity Recognition
MIT License
294 stars 57 forks source link

Cannot test checkpoint with AGCN #24

Closed zhengthomastang closed 2 years ago

zhengthomastang commented 3 years ago

There are 5 pre-trained models for each split (xub vs. xview) on NTU60: ntu60_xsub_spatial.pt, ntu60_xsub_temporal.pt, ntu60_xsub_bones_spatial.pt, ntu60_xsub_bones_temporal.pt and ntu60_xsub_bones_temporal-agcn.pt. I was able to successfully test the first 4 by editing the configs in train.yaml, but cannot figure out the last one with AGCN. My configs are listed as follows (compared to ntu60_xsub_bones_temporal.pt the only thing I changed was agcn: True):

# feeder
feeder: st_gcn.feeder.Feeder
feeder_augmented: st_gcn.feeder.FeederAugmented
train_feeder_args:
  data_path: ../data/NTU-RGB-D_Processed/xsub/train_data_joint_bones.npy
  label_path: ../data/NTU-RGB-D_Processed/xsub/train_label.pkl
  random_choose: False
  random_shift: False
  random_move: False
  window_size: -1
  normalization: False
  mirroring: False
test_feeder_args:
  data_path: ../data/NTU-RGB-D_Processed/xsub/val_data_joint_bones.npy
  label_path: ../data/NTU-RGB-D_Processed/xsub/val_label.pkl
# model
model: st_gcn.net.ST_GCN
training: True
model_args:
  num_class: 60
  channel: 6
  window_size: 300
  num_point: 25
  num_person: 2
  mask_learning: True
  use_data_bn: True
  attention: False
  only_attention: True
  tcn_attention: True
  data_normalization: True
  skip_conn: True
  weight_matrix: 2
  only_temporal_attention: True
  bn_flag: True
  attention_3: False
  kernel_temporal: 9
  more_channels: False
  double_channel: True
  drop_connect: True
  concat_original: True
  all_layers: False
  adjacency: False
  agcn: True
  dv: 0.25
  dk: 0.25
  Nh: 8
  n: 4
  dim_block1: 10
  dim_block2: 30
  dim_block3: 75
  relative: False
  graph: st_gcn.graph.NTU_RGB_D
  visualization: False
  graph_args:
    labeling_mode: 'spatial'
  #optical_flow: True
# optim
#0: old one, 1: new one
scheduler: 1
weight_decay: 0.0001
base_lr: 0.1
step: [60,90]
# training
device: [0,1,2,3]
batch_size: 2
test_batch_size: 1
num_epoch: 120
nesterov: True

Here is the command I ran: python main.py --config config/st_gcn/nturgbd/train.yaml --phase test --weights ../checkpoint_ST-TR/ntu60_xsub_bones_temporal-agcn.pt

And here is the traceback of error messages:

Traceback (most recent call last):
  File "main.py", line 966, in <module>
    processor = Processor(arg)
  File "main.py", line 234, in __init__
    self.load_model()
  File "main.py", line 354, in load_model
    self.model.load_state_dict(state)
  File "/localhome/local-thtang/anaconda3/envs/st-tr/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1406, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Model:
    size mismatch for fcn.weight: copying a param with shape torch.Size([120, 512, 1]) from checkpoint, the shape in current model is torch.Size([60, 512, 1]).
    size mismatch for fcn.bias: copying a param with shape torch.Size([120]) from checkpoint, the shape in current model is torch.Size([60]).
Chiaraplizz commented 3 years ago

Hi,

It seems it is expecting 120 classes instead of 60. Maybe I inverted it when uploading them. I'll check and correct it!

Chiara