EGO4D / social-interactions

MIT License
45 stars 8 forks source link

LAM: failed to load pretrained model - lam_gaze360_best.pth #27

Open foolhard opened 5 months ago

foolhard commented 5 months ago

Hello,

I try to test with your pretrained model "lam_gaze360_best.pth" but failed to load the parameters.

I run script python run.py --eval --checkpoint ckpts/lam-gaze360-best.pth --exp_path output/ --model GazeLSTM

But got an error that missing keys in state_dict.

I found in model.py, GazeLSTM model has two MLP layers for classification head, which is same as BaselineLSTM, but different from pretrained weight.

How can I test your pretrained model?

Below is the detailed error.

(lam) root@workspace:~# python run.py --eval --checkpoint ckpts/lam-gaze360-best.pth --exp_path output/ --model GazeLSTM
Namespace(backend='nccl', batch_size=64, checkpoint='ckpts/lam-gaze360-best.pth', device_id=0, dist=False, epochs=40, eval=True, exp_path='output/', gt_path='data/result_LAM', init_method=None, json_path='data/json_original', lr=0.0005, model='GazeLSTM', num_workers=16, rank=0, source_path='data/video_imgs', start_rank=0, test_path='data/videos_challenge', test_stride=1, train_file='data/split/train.list', train_stride=13, val_file='data/split/val.list', val_stride=13, weights=[0.136, 0.864], world_size=None)
Model: GazeLSTM
loading checkpoint ckpts/lam-gaze360-best.pth
Traceback (most recent call last):
  File "run.py", line 165, in <module>
    run()
  File "run.py", line 161, in run
    main(args)
  File "run.py", line 33, in main
    model = eval(args.model)(args)
  File "/dfs/data/workspace/social-interactions/model/model.py", line 71, in __init__
    self.load_checkpoint()
  File "/dfs/data/workspace/social-interactions/model/model.py", line 95, in load_checkpoint
    self.load_state_dict(state_dict, strict=self.args.eval)
  File "/dfs/data/miniconda/envs/lam/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GazeLSTM:
        Missing key(s) in state_dict: "last_layer1.weight", "last_layer1.bias", "last_layer2.weight", "last_layer2.bias".