Khrylx / Transform2Act

[ICLR 2022 Oral] Official PyTorch Implementation of "Transform2Act: Learning a Transform-and-Control Policy for Efficient Agent Design".
https://sites.google.com/view/transform2act
MIT License
54 stars 14 forks source link

I cant use the pretrain model for the state_dict error #4

Open winggggZ opened 1 year ago

winggggZ commented 1 year ago
python design_opt/eval.py --cfg swimmer --save_video

loading model from checkpoint: results/swimmer/models/best.p Traceback (most recent call last): File "/home/dzyao/shiyuexin/shiyan/transform2act/design_opt/eval.py", line 28, in agent = Transform2ActAgent(cfg=cfg, dtype=dtype, device=device, seed=cfg.seed, num_threads=1, training=False, checkpoint=epoch) File "/home/dzyao/shiyuexin/shiyan/transform2act/design_opt/agents/transform2act_agent.py", line 38, in init self.load_checkpoint(checkpoint) File "/home/dzyao/shiyuexin/shiyan/transform2act/design_opt/agents/transform2act_agent.py", line 162, in load_checkpoint self.policy_net.load_state_dict(model_cp['policy_dict']) File "/home/dzyao/anaconda3/envs/pytorch2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Transform2ActPolicy: Missing key(s) in state_dict: "skel_gnn.gconv_layers.0.lin_rel.weight", "skel_gnn.gconv_layers.0.lin_rel.bias", "skel_gnn.gconv_layers.0.lin_root.weight", "skel_gnn.gconv_layers.1.lin_rel.weight", "skel_gnn.gconv_layers.1.lin_rel.bias", "skel_gnn.gconv_layers.1.lin_root.weight", "skel_gnn.gconv_layers.2.lin_rel.weight", "skel_gnn.gconv_layers.2.lin_rel.bias", "skel_gnn.gconv_layers.2.lin_root.weight", "attr_gnn.gconv_layers.0.lin_rel.weight", "attr_gnn.gconv_layers.0.lin_rel.bias", "attr_gnn.gconv_layers.0.lin_root.weight", "attr_gnn.gconv_layers.1.lin_rel.weight", "attr_gnn.gconv_layers.1.lin_rel.bias", "attr_gnn.gconv_layers.1.lin_root.weight", "attr_gnn.gconv_layers.2.lin_rel.weight", "attr_gnn.gconv_layers.2.lin_rel.bias", "attr_gnn.gconv_layers.2.lin_root.weight", "control_gnn.gconv_layers.0.lin_rel.weight", "control_gnn.gconv_layers.0.lin_rel.bias", "control_gnn.gconv_layers.0.lin_root.weight", "control_gnn.gconv_layers.1.lin_rel.weight", "control_gnn.gconv_layers.1.lin_rel.bias", "control_gnn.gconv_layers.1.lin_root.weight", "control_gnn.gconv_layers.2.lin_rel.weight", "control_gnn.gconv_layers.2.lin_rel.bias", "control_gnn.gconv_layers.2.lin_root.weight". Unexpected key(s) in state_dict: "skel_gnn.gconv_layers.0.lin_l.weight", "skel_gnn.gconv_layers.0.lin_l.bias", "skel_gnn.gconv_layers.0.lin_r.weight", "skel_gnn.gconv_layers.1.lin_l.weight", "skel_gnn.gconv_layers.1.lin_l.bias", "skel_gnn.gconv_layers.1.lin_r.weight", "skel_gnn.gconv_layers.2.lin_l.weight", "skel_gnn.gconv_layers.2.lin_l.bias", "skel_gnn.gconv_layers.2.lin_r.weight", "attr_gnn.gconv_layers.0.lin_l.weight", "attr_gnn.gconv_layers.0.lin_l.bias", "attr_gnn.gconv_layers.0.lin_r.weight", "attr_gnn.gconv_layers.1.lin_l.weight", "attr_gnn.gconv_layers.1.lin_l.bias", "attr_gnn.gconv_layers.1.lin_r.weight", "attr_gnn.gconv_layers.2.lin_l.weight", "attr_gnn.gconv_layers.2.lin_l.bias", "attr_gnn.gconv_layers.2.lin_r.weight", "control_gnn.gconv_layers.0.lin_l.weight", "control_gnn.gconv_layers.0.lin_l.bias", "control_gnn.gconv_layers.0.lin_r.weight", "control_gnn.gconv_layers.1.lin_l.weight", "control_gnn.gconv_layers.1.lin_l.bias", "control_gnn.gconv_layers.1.lin_r.weight", "control_gnn.gconv_layers.2.lin_l.weight", "control_gnn.gconv_layers.2.lin_l.bias", "control_gnn.gconv_layers.2.lin_r.weight".

Caiyishuai commented 11 months ago
    new_policy_dict = {}
    for key, value in model_cp['policy_dict'].items():
        new_key = key
        if "lin_r" in new_key:
            new_key = new_key.replace("lin_r", "lin_root")
        if  "lin_l" in new_key:
            new_key = new_key.replace("lin_l", "lin_rel")
        new_policy_dict[new_key] = value

    new_value_dict = {}
    for key, value in model_cp['value_dict'].items():
        new_key = key
        if "lin_r" in new_key:
            new_key = new_key.replace("lin_r", "lin_root")
        if  "lin_l" in new_key:
            new_key = new_key.replace("lin_l", "lin_rel")
        new_value_dict[new_key] = value

    self.policy_net.load_state_dict(new_policy_dict)
    self.value_net.load_state_dict(new_value_dict)