我在导入自己用pytorch训练的6_6_4模型时遇到了如下报错,请问大佬该怎么解决呀,我感觉自己改一些地方改错了
root@autodl-container-464d11b752-0064c9d8:~/autodl-fs/gomoku/AlphaZero_Gomoku# python human_play.py
Traceback (most recent call last):
File "human_play.py", line 88, in
run()
File "human_play.py", line 60, in run
best_policy = PolicyValueNet(width, height, model_file)
File "/root/autodl-fs/gomoku/AlphaZero_Gomoku/policy_value_net_pytorch.py", line 78, in init
self.policy_value_net.load_state_dict(net_params)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Net:
size mismatch for act_fc1.weight: copying a param with shape torch.Size([49, 196]) from checkpoint, the shape in current model is torch.Size([36, 144]).
size mismatch for act_fc1.bias: copying a param with shape torch.Size([49]) from checkpoint, the shape in current model is torch.Size([36]).
size mismatch for val_fc1.weight: copying a param with shape torch.Size([64, 98]) from checkpoint, the shape in current model is torch.Size([64, 72]).
我在导入自己用pytorch训练的6_6_4模型时遇到了如下报错,请问大佬该怎么解决呀,我感觉自己改一些地方改错了 root@autodl-container-464d11b752-0064c9d8:~/autodl-fs/gomoku/AlphaZero_Gomoku# python human_play.py Traceback (most recent call last): File "human_play.py", line 88, in
run()
File "human_play.py", line 60, in run
best_policy = PolicyValueNet(width, height, model_file)
File "/root/autodl-fs/gomoku/AlphaZero_Gomoku/policy_value_net_pytorch.py", line 78, in init
self.policy_value_net.load_state_dict(net_params)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Net:
size mismatch for act_fc1.weight: copying a param with shape torch.Size([49, 196]) from checkpoint, the shape in current model is torch.Size([36, 144]).
size mismatch for act_fc1.bias: copying a param with shape torch.Size([49]) from checkpoint, the shape in current model is torch.Size([36]).
size mismatch for val_fc1.weight: copying a param with shape torch.Size([64, 98]) from checkpoint, the shape in current model is torch.Size([64, 72]).