nicklashansen / tdmpc2

Code for "TD-MPC2: Scalable, Robust World Models for Continuous Control"
https://www.tdmpc2.com
MIT License
272 stars 49 forks source link

RuntimeError When Loading State_Dict for Single-task Models #23

Open Zzl35 opened 3 months ago

Zzl35 commented 3 months ago

I am trying to run the model that was downloaded from huggingface using the following command:

python evaluate.py task=humanoid-run checkpoint=/home/ubuntu/zzl/tdmpc2/models/humanoid-run-3.pt save_video=true model_size=5

Unfortunately, an error appears during execution as follows:

RuntimeError: Error(s) in loading state_dict for WorldModel:
        Missing key(s) in state_dict: "_encoder.state.0.weight", "_encoder.state.0.bias", "_encoder.state.0.ln.weight", "_encoder.state.0.ln.bias", "_encoder.state.1.ln.weight", "_encoder.state.1.ln.bias", "_dynamics.0.weight", "_dynamics.0.bias", "_dynamics.0.ln.weight", "_dynamics.0.ln.bias", "_dynamics.1.ln.weight", "_dynamics.1.ln.bias", "_dynamics.2.weight", "_dynamics.2.bias", "_dynamics.2.ln.weight", "_dynamics.2.ln.bias", "_reward.0.ln.weight", "_reward.0.ln.bias", "_reward.1.ln.weight", "_reward.1.ln.bias", "_reward.2.weight", "_reward.2.bias", "_pi.0.ln.weight", "_pi.0.ln.bias", "_pi.1.ln.weight", "_pi.1.ln.bias", "_pi.2.weight", "_pi.2.bias". 
        Unexpected key(s) in state_dict: "_encoder.state.2.weight", "_encoder.state.2.bias", "_encoder.state.4.weight", "_encoder.state.4.bias", "_encoder.state.5.weight", "_encoder.state.5.bias", "_dynamics.0.0.weight", "_dynamics.0.0.bias", "_dynamics.0.1.weight", "_dynamics.0.1.bias", "_dynamics.0.3.weight", "_dynamics.0.3.bias", "_dynamics.0.4.weight", "_dynamics.0.4.bias", "_dynamics.0.6.weight", "_dynamics.0.6.bias", "_reward.3.weight", "_reward.3.bias", "_reward.4.weight", "_reward.4.bias", "_reward.6.weight", "_reward.6.bias", "_pi.3.weight", "_pi.3.bias", "_pi.4.weight", "_pi.4.bias", "_pi.6.weight", "_pi.6.bias". 
        size mismatch for _encoder.state.1.weight: copying a param with shape torch.Size([256, 67]) from checkpoint, the shape in current model is torch.Size([512, 256]).
        size mismatch for _encoder.state.1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for _dynamics.1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
        size mismatch for _reward.1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
        size mismatch for _pi.1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([512, 512]).

It seems that the structure of the world model created in the code is different from the one downloaded. I've also tried previous code versions but still couldn't fix the issue. So I'd like to know how to use a pre-trained single-task model. Thank you!

nicklashansen commented 3 months ago

Thanks for reporting this! I'll see if I can reproduce this error and get back to you soon.

nicklashansen commented 3 months ago

@Zzl35 Hi again, I was able to reproduce your error. This seems to be due to some code restructuring that happened around the time of release. I have not quite figured out a solution yet, but it should be possible to simply rename keys from the old format to the new one, i.e.:

    def load(self, fp):
        """
        Load a saved state dict from filepath (or dictionary) into current agent.

        Args:
            fp (str or dict): Filepath or state dict to load.
        """
        state_dict = fp if isinstance(fp, dict) else torch.load(fp)
        if 'model' in state_dict:
            state_dict = state_dict['model']
        key_mapping = {
            '_encoder.state.1.weight': '_encoder.state.0.weight',
            '_encoder.state.1.bias': '_encoder.state.0.bias',
            '_encoder.state.2.weight': '_encoder.state.0.ln.weight',
            '_encoder.state.2.bias': '_encoder.state.0.ln.bias',
            '_encoder.state.4.weight': '_encoder.state.1.weight',
            '_encoder.state.4.bias': '_encoder.state.1.bias',
            '_encoder.state.5.weight': '_encoder.state.1.ln.weight',
            '_encoder.state.5.bias': '_encoder.state.1.ln.bias',
            '_dynamics.0.0.weight': '_dynamics.0.weight',
            '_dynamics.0.0.bias': '_dynamics.0.bias',
            '_dynamics.0.1.weight': '_dynamics.0.ln.weight',
            '_dynamics.0.1.bias': '_dynamics.0.ln.bias',
            '_dynamics.0.3.weight': '_dynamics.1.weight',
            '_dynamics.0.3.bias': '_dynamics.1.bias',
            '_dynamics.0.4.weight': '_dynamics.1.ln.weight',
            '_dynamics.0.4.bias': '_dynamics.1.ln.bias',
            '_dynamics.0.6.weight': '_dynamics.2.weight',
            '_dynamics.0.6.bias': '_dynamics.2.bias',
            '_dynamics.1.weight': '_dynamics.2.ln.weight',
            '_dynamics.1.bias': '_dynamics.2.ln.bias',
            '_reward.0.weight': '_reward.0.weight',
            '_reward.0.bias': '_reward.0.bias',
            '_reward.1.weight': '_reward.0.ln.weight',
            '_reward.1.bias': '_reward.0.ln.bias',
            '_reward.3.weight': '_reward.1.weight',
            '_reward.3.bias': '_reward.1.bias',
            '_reward.4.weight': '_reward.1.ln.weight',
            '_reward.4.bias': '_reward.1.ln.bias',
            '_reward.6.weight': '_reward.2.weight',
            '_reward.6.bias': '_reward.2.bias',
            '_pi.0.weight': '_pi.0.weight',
            '_pi.0.bias': '_pi.0.bias',
            '_pi.1.weight': '_pi.0.ln.weight',
            '_pi.1.bias': '_pi.0.ln.bias',
            '_pi.3.weight': '_pi.1.weight',
            '_pi.3.bias': '_pi.1.bias',
            '_pi.4.weight': '_pi.1.ln.weight',
            '_pi.4.bias': '_pi.1.ln.bias',
            '_pi.6.weight': '_pi.2.weight',
            '_pi.6.bias': '_pi.2.bias',
        }
        new_state_dict = {}
        for k, v in state_dict.items():
            if k in key_mapping:
                new_state_dict[key_mapping[k]] = v
            else:
                new_state_dict[k] = v
        self.model.load_state_dict(new_state_dict)

I have verified that I am able to run the humanoid-run-3.pt checkpoint using above code, but it does not reproduce the original policy performance at the moment, so I need to dig into it a bit more. The multi-task checkpoints are running fine so this seems to strictly be an issue with single-task checkpoints. I'll let you know once I have found a solution!

Zzl35 commented 3 months ago

Thank you for your reply, I successfully loaded the checkpoint!

nicklashansen commented 3 months ago

@Zzl35 Sounds great. Are you able to reproduce the original policy performance using the loaded checkpoints?

Zzl35 commented 3 months ago

I modified the code according to the instructions above and found that humanoid_run can indeed load the checkpoint, but it seems unable to reproduce the results in the paper. Additionally, I tested dog-run and found that the original code can load the checkpoint without any modifications and can reproduce the performance in the paper. So, I wonder if some models trained in single-task environments are from previous versions but not updated. Finally, thank you again for your help!