Open winggggZ opened 1 year ago
new_policy_dict = {}
for key, value in model_cp['policy_dict'].items():
new_key = key
if "lin_r" in new_key:
new_key = new_key.replace("lin_r", "lin_root")
if "lin_l" in new_key:
new_key = new_key.replace("lin_l", "lin_rel")
new_policy_dict[new_key] = value
new_value_dict = {}
for key, value in model_cp['value_dict'].items():
new_key = key
if "lin_r" in new_key:
new_key = new_key.replace("lin_r", "lin_root")
if "lin_l" in new_key:
new_key = new_key.replace("lin_l", "lin_rel")
new_value_dict[new_key] = value
self.policy_net.load_state_dict(new_policy_dict)
self.value_net.load_state_dict(new_value_dict)
loading model from checkpoint: results/swimmer/models/best.p Traceback (most recent call last): File "/home/dzyao/shiyuexin/shiyan/transform2act/design_opt/eval.py", line 28, in
agent = Transform2ActAgent(cfg=cfg, dtype=dtype, device=device, seed=cfg.seed, num_threads=1, training=False, checkpoint=epoch)
File "/home/dzyao/shiyuexin/shiyan/transform2act/design_opt/agents/transform2act_agent.py", line 38, in init
self.load_checkpoint(checkpoint)
File "/home/dzyao/shiyuexin/shiyan/transform2act/design_opt/agents/transform2act_agent.py", line 162, in load_checkpoint
self.policy_net.load_state_dict(model_cp['policy_dict'])
File "/home/dzyao/anaconda3/envs/pytorch2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Transform2ActPolicy:
Missing key(s) in state_dict: "skel_gnn.gconv_layers.0.lin_rel.weight", "skel_gnn.gconv_layers.0.lin_rel.bias", "skel_gnn.gconv_layers.0.lin_root.weight", "skel_gnn.gconv_layers.1.lin_rel.weight", "skel_gnn.gconv_layers.1.lin_rel.bias", "skel_gnn.gconv_layers.1.lin_root.weight", "skel_gnn.gconv_layers.2.lin_rel.weight", "skel_gnn.gconv_layers.2.lin_rel.bias", "skel_gnn.gconv_layers.2.lin_root.weight", "attr_gnn.gconv_layers.0.lin_rel.weight", "attr_gnn.gconv_layers.0.lin_rel.bias", "attr_gnn.gconv_layers.0.lin_root.weight", "attr_gnn.gconv_layers.1.lin_rel.weight", "attr_gnn.gconv_layers.1.lin_rel.bias", "attr_gnn.gconv_layers.1.lin_root.weight", "attr_gnn.gconv_layers.2.lin_rel.weight", "attr_gnn.gconv_layers.2.lin_rel.bias", "attr_gnn.gconv_layers.2.lin_root.weight", "control_gnn.gconv_layers.0.lin_rel.weight", "control_gnn.gconv_layers.0.lin_rel.bias", "control_gnn.gconv_layers.0.lin_root.weight", "control_gnn.gconv_layers.1.lin_rel.weight", "control_gnn.gconv_layers.1.lin_rel.bias", "control_gnn.gconv_layers.1.lin_root.weight", "control_gnn.gconv_layers.2.lin_rel.weight", "control_gnn.gconv_layers.2.lin_rel.bias", "control_gnn.gconv_layers.2.lin_root.weight".
Unexpected key(s) in state_dict: "skel_gnn.gconv_layers.0.lin_l.weight", "skel_gnn.gconv_layers.0.lin_l.bias", "skel_gnn.gconv_layers.0.lin_r.weight", "skel_gnn.gconv_layers.1.lin_l.weight", "skel_gnn.gconv_layers.1.lin_l.bias", "skel_gnn.gconv_layers.1.lin_r.weight", "skel_gnn.gconv_layers.2.lin_l.weight", "skel_gnn.gconv_layers.2.lin_l.bias", "skel_gnn.gconv_layers.2.lin_r.weight", "attr_gnn.gconv_layers.0.lin_l.weight", "attr_gnn.gconv_layers.0.lin_l.bias", "attr_gnn.gconv_layers.0.lin_r.weight", "attr_gnn.gconv_layers.1.lin_l.weight", "attr_gnn.gconv_layers.1.lin_l.bias", "attr_gnn.gconv_layers.1.lin_r.weight", "attr_gnn.gconv_layers.2.lin_l.weight", "attr_gnn.gconv_layers.2.lin_l.bias", "attr_gnn.gconv_layers.2.lin_r.weight", "control_gnn.gconv_layers.0.lin_l.weight", "control_gnn.gconv_layers.0.lin_l.bias", "control_gnn.gconv_layers.0.lin_r.weight", "control_gnn.gconv_layers.1.lin_l.weight", "control_gnn.gconv_layers.1.lin_l.bias", "control_gnn.gconv_layers.1.lin_r.weight", "control_gnn.gconv_layers.2.lin_l.weight", "control_gnn.gconv_layers.2.lin_l.bias", "control_gnn.gconv_layers.2.lin_r.weight".