Eclectic-Sheep / sheeprl

Distributed Reinforcement Learning accelerated by Lightning Fabric
https://eclecticsheep.ai
Apache License 2.0
303 stars 31 forks source link

`sheeprl_eval` loading model with different keys #305

Open belerico opened 3 months ago

belerico commented 3 months ago

I cannot sheeprl-eval my trained model, since the keys in the world model's state_dict have different names:

Stacktrace

Error executing job with overrides: ['checkpoint_path=/home/drt/Desktop/sheeprl/sheeprl/logs/runs/dreamer_v3/PyFlyt/2024-06-23_19-34-31_dreamer_v3_PyFlyt_42/version_0/checkpoint/ckpt_730000_0.ckpt', 'fabric.accelerator=gpu', 'env.capture_video=True', 'seed=52'] Traceback (most recent call last): File "/home/drt/miniconda3/envs/sheeprl/lib/python3.10/site-packages/sheeprl/cli.py", line 404, in evaluation eval_algorithm(ckpt_cfg) File "/home/drt/miniconda3/envs/sheeprl/lib/python3.10/site-packages/sheeprl/cli.py", line 267, in eval_algorithm fabric.launch(command, cfg, state) File "/home/drt/miniconda3/envs/sheeprl/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 839, in launch return self._wrap_and_launch(function, self, *args, **kwargs) File "/home/drt/miniconda3/envs/sheeprl/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 925, in _wrap_and_launch return to_run(*args, **kwargs) File "/home/drt/miniconda3/envs/sheeprl/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 930, in _wrap_with_setup return to_run(*args, **kwargs) File "/home/drt/miniconda3/envs/sheeprl/lib/python3.10/site-packages/sheeprl/cli.py", line 262, in wrapper return func(*args, **kwargs) File "/home/drt/miniconda3/envs/sheeprl/lib/python3.10/site-packages/sheeprl/algos/dreamer_v3/evaluate.py", line 47, in evaluate _, _, _, _, player = build_agent( File "/home/drt/miniconda3/envs/sheeprl/lib/python3.10/site-packages/sheeprl/algos/dreamer_v3/agent.py", line 1186, in build_agent world_model.load_state_dict(world_model_state) File "/home/drt/miniconda3/envs/sheeprl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for WorldModel: Missing key(s) in state_dict: "encoder.mlp_encoder.model._model.0.weight", "encoder.mlp_encoder.model._model.1.weight", "encoder.mlp_encoder.model._model.1.bias", "encoder.mlp_encoder.model._model.3.weight", "encoder.mlp_encoder.model._model.4.weight", "encoder.mlp_encoder.model._model.4.bias", "encoder.mlp_encoder.model._model.6.weight", "encoder.mlp_encoder.model._model.7.weight", "encoder.mlp_encoder.model._model.7.bias", "rssm.recurrent_model.mlp._model.0.weight", "rssm.recurrent_model.mlp._model.1.weight", "rssm.recurrent_model.mlp._model.1.bias", "rssm.recurrent_model.rnn.linear.weight", "rssm.recurrent_model.rnn.layer_norm.weight", "rssm.recurrent_model.rnn.layer_norm.bias", "rssm.representation_model._model.0.weight", "rssm.representation_model._model.1.weight", "rssm.representation_model._model.1.bias", "rssm.representation_model._model.3.weight", "rssm.representation_model._model.3.bias", "rssm.transition_model._model.0.weight", "rssm.transition_model._model.1.weight", "rssm.transition_model._model.1.bias", "rssm.transition_model._model.3.weight", "rssm.transition_model._model.3.bias", "observation_model.mlp_decoder.model._model.0.weight", "observation_model.mlp_decoder.model._model.1.weight", "observation_model.mlp_decoder.model._model.1.bias", "observation_model.mlp_decoder.model._model.3.weight", "observation_model.mlp_decoder.model._model.4.weight", "observation_model.mlp_decoder.model._model.4.bias", "observation_model.mlp_decoder.model._model.6.weight", "observation_model.mlp_decoder.model._model.7.weight", "observation_model.mlp_decoder.model._model.7.bias", "observation_model.mlp_decoder.heads.0.weight", "observation_model.mlp_decoder.heads.0.bias", "reward_model._model.0.weight", "reward_model._model.1.weight", "reward_model._model.1.bias", "reward_model._model.3.weight", "reward_model._model.4.weight", "reward_model._model.4.bias", "reward_model._model.6.weight", "reward_model._model.7.weight", "reward_model._model.7.bias", "reward_model._model.9.weight", "reward_model._model.9.bias". Unexpected key(s) in state_dict: "encoder._orig_mod.mlp_encoder.model._model.0.weight", "encoder._orig_mod.mlp_encoder.model._model.1.weight", "encoder._orig_mod.mlp_encoder.model._model.1.bias", "encoder._orig_mod.mlp_encoder.model._model.3.weight", "encoder._orig_mod.mlp_encoder.model._model.4.weight", "encoder._orig_mod.mlp_encoder.model._model.4.bias", "encoder._orig_mod.mlp_encoder.model._model.6.weight", "encoder._orig_mod.mlp_encoder.model._model.7.weight", "encoder._orig_mod.mlp_encoder.model._model.7.bias", "rssm.recurrent_model._orig_mod.mlp._model.0.weight", "rssm.recurrent_model._orig_mod.mlp._model.1.weight", "rssm.recurrent_model._orig_mod.mlp._model.1.bias", "rssm.recurrent_model._orig_mod.rnn.linear.weight", "rssm.recurrent_model._orig_mod.rnn.layer_norm.weight", "rssm.recurrent_model._orig_mod.rnn.layer_norm.bias", "rssm.representation_model._orig_mod._model.0.weight", "rssm.representation_model._orig_mod._model.1.weight", "rssm.representation_model._orig_mod._model.1.bias", "rssm.representation_model._orig_mod._model.3.weight", "rssm.representation_model._orig_mod._model.3.bias", "rssm.transition_model._orig_mod._model.0.weight", "rssm.transition_model._orig_mod._model.1.weight", "rssm.transition_model._orig_mod._model.1.bias", "rssm.transition_model._orig_mod._model.3.weight", "rssm.transition_model._orig_mod._model.3.bias", "observation_model._orig_mod.mlp_decoder.model._model.0.weight", "observation_model._orig_mod.mlp_decoder.model._model.1.weight", "observation_model._orig_mod.mlp_decoder.model._model.1.bias", "observation_model._orig_mod.mlp_decoder.model._model.3.weight", "observation_model._orig_mod.mlp_decoder.model._model.4.weight", "observation_model._orig_mod.mlp_decoder.model._model.4.bias", "observation_model._orig_mod.mlp_decoder.model._model.6.weight", "observation_model._orig_mod.mlp_decoder.model._model.7.weight", "observation_model._orig_mod.mlp_decoder.model._model.7.bias", "observation_model._orig_mod.mlp_decoder.heads.0.weight", "observation_model._orig_mod.mlp_decoder.heads.0.bias", "reward_model._orig_mod._model.0.weight", "reward_model._orig_mod._model.1.weight", "reward_model._orig_mod._model.1.bias", "reward_model._orig_mod._model.3.weight", "reward_model._orig_mod._model.4.weight", "reward_model._orig_mod._model.4.bias", "reward_model._orig_mod._model.6.weight", "reward_model._orig_mod._model.7.weight", "reward_model._orig_mod._model.7.bias", "reward_model._orig_mod._model.9.weight", "reward_model._orig_mod._model.9.bias".

Originally posted by @defrag-bambino in https://github.com/Eclectic-Sheep/sheeprl/issues/261#issuecomment-2188777312

belerico commented 3 months ago

Hi @defrag-bambino, could you please elaborate more on the issue? Which version of SheepRL are you using? Which steps have you run before encountering the error? Thank you

defrag-bambino commented 3 months ago

It is related to the feature/compile branch. I trained a model using it and afterwards cannot load its state_dict (still using this branch).

belerico commented 3 months ago

I'm trying but i'm not able to replicate: which torch version are you using?

defrag-bambino commented 3 months ago
pytorch-lightning        2.2.1
torch                    2.3.1
torchmetrics             1.3.2
torchvision              0.18.1
belerico commented 2 months ago

Hi @defrag-bambino, this is a screenshot where you can see that the world_model.encoder, which is a _FabricModule, when the state_dict function is called it returns the correct module:

image

I'm not able to reproduce. Have you maybe trained the model with an older version of sheeprl and/or lightning and you're now trying to resume it with a newer one?