Tobias-Fischer / rt_gene

RT-GENE: Real-Time Eye Gaze and Blink Estimation in Natural Environments
http://www.imperial.ac.uk/personal-robotics
Other
361 stars 67 forks source link

error when running estimate_gaze_standalone.py with the PyTorch model #83

Closed TranThanh96 closed 3 years ago

TranThanh96 commented 3 years ago

I am trying to run estimate_gaze_standalone.py with the PyTorch model "Model_allsubjects1_pytorch.model". And some errors occurred. It looks like something related to the model define: Traceback (most recent call last): File "estimate_gaze_standalone.py", line 241, in <module> gaze_estimator = GazeEstimator("cuda:0", args.models) File "../rt_gene/src/rt_gene/estimate_gaze_pytorch.py", line 32, in __init__ _model.load_state_dict(torch.load(ckpt)) File "/home/thanhtm/anaconda3/envs/eye_gaze/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1045, in load_state_dict self.__class__.__name__, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for GazeEstimationAllCNNModel: Missing key(s) in state_dict: "_left_module.0.weight", "_left_module.0.bias", "_left_module.2.weight", "_left_module.2.bias", "_left_module.4.weight", "_left_module.4.bias", "_left_module.6.weight", "_left_module.6.bias", "_left_module.8.weight", "_left_module.8.bias", "_left_module.10.weight", "_left_module.10.bias", "_left_module.12.weight", "_left_module.12.bias", "_left_module.14.weight", "_left_module.14.bias", "_left_module.16.weight", "_left_module.16.bias", "_left_module.18.weight", "_left_module.18.bias", "_left_module.20.weight", "_left_module.20.bias", "_left_module.22.weight", "_left_module.22.bias", "_right_module.0.weight", "_right_module.0.bias", "_right_module.2.weight", "_right_module.2.bias", "_right_module.4.weight", "_right_module.4.bias", "_right_module.6.weight", "_right_module.6.bias", "_right_module.8.weight", "_right_module.8.bias", "_right_module.10.weight", "_right_module.10.bias", "_right_module.12.weight", "_right_module.12.bias", "_right_module.14.weight", "_right_module.14.bias", "_right_module.16.weight", "_right_module.16.bias", "_right_module.18.weight", "_right_module.18.bias", "_right_module.20.weight", "_right_module.20.bias", "_right_module.22.weight", "_right_module.22.bias", "concat.2.weight", "concat.2.bias", "concat.4.weight", "concat.4.bias", "concat.5.weight", "concat.5.bias", "concat.7.weight", "concat.7.bias". Unexpected key(s) in state_dict: "left_features.0.weight", "left_features.0.bias", "left_features.2.weight", "left_features.2.bias", "left_features.5.weight", "left_features.5.bias", "left_features.7.weight", "left_features.7.bias", "left_features.10.weight", "left_features.10.bias", "left_features.12.weight", "left_features.12.bias", "left_features.14.weight", "left_features.14.bias", "left_features.17.weight", "left_features.17.bias", "left_features.19.weight", "left_features.19.bias", "left_features.21.weight", "left_features.21.bias", "left_features.24.weight", "left_features.24.bias", "left_features.26.weight", "left_features.26.bias", "left_features.28.weight", "left_features.28.bias", "right_features.0.weight", "right_features.0.bias", "right_features.2.weight", "right_features.2.bias", "right_features.5.weight", "right_features.5.bias", "right_features.7.weight", "right_features.7.bias", "right_features.10.weight", "right_features.10.bias", "right_features.12.weight", "right_features.12.bias", "right_features.14.weight", "right_features.14.bias", "right_features.17.weight", "right_features.17.bias", "right_features.19.weight", "right_features.19.bias", "right_features.21.weight", "right_features.21.bias", "right_features.24.weight", "right_features.24.bias", "right_features.26.weight", "right_features.26.bias", "right_features.28.weight", "right_features.28.bias", "xl.0.weight", "xl.0.bias", "xl.1.weight", "xl.1.bias", "xl.1.running_mean", "xl.1.running_var", "xl.1.num_batches_tracked", "xr.0.weight", "xr.0.bias", "xr.1.weight", "xr.1.bias", "xr.1.running_mean", "xr.1.running_var", "xr.1.num_batches_tracked", "fc.0.weight", "fc.0.bias", "fc.2.weight", "fc.2.bias", "concat.0.weight", "concat.0.bias", "concat.1.running_mean", "concat.1.running_var", "concat.1.num_batches_tracked". size mismatch for concat.1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048, 4098, 1, 1]). size mismatch for concat.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]).

Tobias-Fischer commented 3 years ago

Looks like something related to your recent changes @ahmed-alhindawi?

@TranThanh96 could you please try the latest release instead of the master branch? https://github.com/Tobias-Fischer/rt_gene/releases/tag/v4.0.1 (simply run git checkout v4.0.1 and then try again).

TranThanh96 commented 3 years ago

Thank you. now, I can run on PyTorch.

TranThanh96 commented 3 years ago

another question: have you try to benchmark between the Pytorch version and the TensorFlow version? which one gives the better result?

ahmed-alhindawi commented 3 years ago

Looks like something related to your recent changes @ahmed-alhindawi?

@TranThanh96 could you please try the latest release instead of the master branch? https://github.com/Tobias-Fischer/rt_gene/releases/tag/v4.0.1 (simply run git checkout v4.0.1 and then try again).

Quite right - sorry that shouldn't have been committed. Apologies. Edit: I've removed the latest commit so the repository should "just work" now...

Tobias-Fischer commented 3 years ago

Thanks Ahmed!