Zhehui-Huang / quad-swarm-rl

Additional environments compatible with OpenAI gym
98 stars 39 forks source link

Model size mismatch #51

Closed CowFromSpace closed 2 months ago

CowFromSpace commented 6 months ago

"I downloaded a model (Multi drone without obstacles) from the following URL for testing: https://huggingface.co/andrewzhang505/quad-swarm-rl-multi-drone-no-obstacles/tree/main.

When I executed the following code : python -m swarm_rl.enjoy --algo=APPO --env=quadrotor_multi --replay_buffer_sample_prob=0 --quads_use_numba=False --train_dir="./train_dir/quad_multi_baseline/baseline_multidrone" --experiment="01_baseline_multi_drone_see_1111" --quads_view_mode "global" --quads_render=True --quads_num_agents=8

I encountered an error. The error is as follows: Traceback (most recent call last): File "/home/handsome/anaconda3/envs/swarm-rl/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/handsome/anaconda3/envs/swarm-rl/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/handsome/Desktop/drome/quad-swarm-rl/swarm_rl/enjoy.py", line 17, in sys.exit(main()) File "/home/handsome/Desktop/drome/quad-swarm-rl/swarm_rl/enjoy.py", line 12, in main status = enjoy(cfg) File "/home/handsome/anaconda3/envs/swarm-rl/lib/python3.8/site-packages/sample_factory/enjoy.py", line 126, in enjoy actor_critic.load_state_dict(checkpoint_dict["model"]) File "/home/handsome/anaconda3/envs/swarm-rl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for ActorCriticSeparateWeights: Unexpected key(s) in state_dict: "actor_encoder.neighbor_encoder.embedding_mlp.0.weight", "actor_encoder.neighbor_encoder.embedding_mlp.0.bias", "actor_encoder.neighbor_encoder.embedding_mlp.2.weight", "actor_encoder.neighbor_encoder.embedding_mlp.2.bias", "actor_encoder.neighbor_encoder.neighbor_value_mlp.0.weight", "actor_encoder.neighbor_encoder.neighbor_value_mlp.0.bias", "actor_encoder.neighbor_encoder.neighbor_value_mlp.2.weight", "actor_encoder.neighbor_encoder.neighbor_value_mlp.2.bias", "actor_encoder.neighbor_encoder.attention_mlp.0.weight", "actor_encoder.neighbor_encoder.attention_mlp.0.bias", "actor_encoder.neighbor_encoder.attention_mlp.2.weight", "actor_encoder.neighbor_encoder.attention_mlp.2.bias", "actor_encoder.neighbor_encoder.attention_mlp.4.weight", "actor_encoder.neighbor_encoder.attention_mlp.4.bias", "critic_encoder.neighbor_encoder.embedding_mlp.0.weight", "critic_encoder.neighbor_encoder.embedding_mlp.0.bias", "critic_encoder.neighbor_encoder.embedding_mlp.2.weight", "critic_encoder.neighbor_encoder.embedding_mlp.2.bias", "critic_encoder.neighbor_encoder.neighbor_value_mlp.0.weight", "critic_encoder.neighbor_encoder.neighbor_value_mlp.0.bias", "critic_encoder.neighbor_encoder.neighbor_value_mlp.2.weight", "critic_encoder.neighbor_encoder.neighbor_value_mlp.2.bias", "critic_encoder.neighbor_encoder.attention_mlp.0.weight", "critic_encoder.neighbor_encoder.attention_mlp.0.bias", "critic_encoder.neighbor_encoder.attention_mlp.2.weight", "critic_encoder.neighbor_encoder.attention_mlp.2.bias", "critic_encoder.neighbor_encoder.attention_mlp.4.weight", "critic_encoder.neighbor_encoder.attention_mlp.4.bias". size mismatch for obs_normalizer.running_mean_std.running_mean_std.obs.running_mean: copying a param with shape torch.Size([54]) from checkpoint, the shape in current model is torch.Size([18]). size mismatch for obs_normalizer.running_mean_std.running_mean_std.obs.running_var: copying a param with shape torch.Size([54]) from checkpoint, the shape in current model is torch.Size([18]). size mismatch for actor_encoder.feed_forward.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([512, 256]). size mismatch for critic_encoder.feed_forward.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([512, 256]).

How can I resolve this issue?

Zhehui-Huang commented 6 months ago

In the config file: https://huggingface.co/andrewzhang505/quad-swarm-rl-multi-drone-no-obstacles/blob/main/cfg.json. They use repository. "git_hash": "819afd374748be2bf5f9336ad8651ee215470c84", "git_repo_name": "https://github.com/andrewzhang505/sample-factory.git",

The easiest way is using that repo with the specified commit.