nicklashansen / tdmpc2

Code for "TD-MPC2: Scalable, Robust World Models for Continuous Control"
https://www.tdmpc2.com
MIT License
272 stars 49 forks source link

RuntimeError When Loading State_Dict for WorldModel #11

Closed jc-bao closed 6 months ago

jc-bao commented 6 months ago

I am trying to run the model that was downloaded from the website using the following command:

python evaluate.py task=mt30 model_size=48 checkpoint="/home/pcy/Research/code/tdmpc2/models/mt30-48M.pt"

Unfortunately, an error appears during execution as follows:

Creating multi-task environment with tasks: [...] (list of tasks omitted for clarity)
Error executing job with overrides: ['task=mt30', 'model_size=1', 'checkpoint=/home/pcy/Research/code/tdmpc2/models/mt30-1M.pt']
Traceback (most recent call last):
  File "/home/pcy/Research/code/tdmpc2/tdmpc2/evaluate.py", line 59, in evaluate
    agent.load(cfg.checkpoint)
  File "/home/pcy/Research/code/tdmpc2/tdmpc2/tdmpc2.py", line 68, in load
    self.model.load_state_dict(state_dict["model"])
  File "/home/pcy/mambaforge/envs/tdmpc/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for WorldModel:
        Missing key(s) in state_dict: "_target_Qs.params.0", "_target_Qs.params.1", "_target_Qs.params.2", "_target_Qs.params.3", "_target_Qs.params.4", "_target_Qs.params.5", "_target_Qs.params.6", "_target_Qs.params.7", "_target_Qs.params.8", "_target_Qs.params.9". 

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

I am wondering if anyone else has come across the same issue. I am suspecting potential compatibility issues between the released model and the latest version of the code. Any information will be highly appreciated!

nicklashansen commented 6 months ago

Hi @jc-bao, the models hosted on google drive are for an anonymized version of our codebase. Our public release hosts models on huggingface and are listed here: https://nicklashansen.github.io/td-mpc2/models Weights are identical but internal modules have been restructured which is the reason for your error. Sorry for the confusion!

jc-bao commented 6 months ago

Got it. I believe that will resolve the issue. Thanks for the prompt reply!

nicklashansen commented 6 months ago

Great -- let me know if there's anything else!