DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
9.05k stars 1.7k forks source link

[Bug]: PyTorch 2.0 compile results in bad state_dict keys and policy fails to load #1438

Closed Pythoniasm closed 10 months ago

Pythoniasm commented 1 year ago

🐛 Bug

Can be fixed easily, see #1439

To Reproduce

env = gym.make("Pendulum-v1")

model = SAC("MlpPolicy", env, verbose=1)
model.policy = th.compile(model.policy)  # Compile the model
model.save("sac_pendulum")

del model # remove to demonstrate saving and loading

SAC.load("sac_pendulum")

Relevant log output / Error message

RuntimeError: Error(s) in loading state_dict for SACPolicy:
        Missing key(s) in state_dict: "actor.latent_pi.0.weight", "actor.latent_pi.0.bias", "actor.latent_pi.2.weight", "actor.latent_pi.2.bias", "actor.mu.weight", "actor.mu.bias", "actor.log_std.weight", "actor.log_std.bias", "critic.qf0.0.weight", "critic.qf0.0.bias", "critic.qf0.2.weight", "critic.qf0.2.bias", "critic.qf0.4.weight", "critic.qf0.4.bias", "critic.qf1.0.weight", "critic.qf1.0.bias", "critic.qf1.2.weight", "critic.qf1.2.bias", "critic.qf1.4.weight", "critic.qf1.4.bias", "critic_target.qf0.0.weight", "critic_target.qf0.0.bias", "critic_target.qf0.2.weight", "critic_target.qf0.2.bias", "critic_target.qf0.4.weight", "critic_target.qf0.4.bias", "critic_target.qf1.0.weight", "critic_target.qf1.0.bias", "critic_target.qf1.2.weight", "critic_target.qf1.2.bias", "critic_target.qf1.4.weight", "critic_target.qf1.4.bias". 
        Unexpected key(s) in state_dict: "_orig_mod.actor.latent_pi.0.weight", "_orig_mod.actor.latent_pi.0.bias", "_orig_mod.actor.latent_pi.2.weight", "_orig_mod.actor.latent_pi.2.bias", "_orig_mod.actor.mu.weight", "_orig_mod.actor.mu.bias", "_orig_mod.actor.log_std.weight", "_orig_mod.actor.log_std.bias", "_orig_mod.critic.qf0.0.weight", "_orig_mod.critic.qf0.0.bias", "_orig_mod.critic.qf0.2.weight", "_orig_mod.critic.qf0.2.bias", "_orig_mod.critic.qf0.4.weight", "_orig_mod.critic.qf0.4.bias", "_orig_mod.critic.qf1.0.weight", "_orig_mod.critic.qf1.0.bias", "_orig_mod.critic.qf1.2.weight", "_orig_mod.critic.qf1.2.bias", "_orig_mod.critic.qf1.4.weight", "_orig_mod.critic.qf1.4.bias", "_orig_mod.critic_target.qf0.0.weight", "_orig_mod.critic_target.qf0.0.bias", "_orig_mod.critic_target.qf0.2.weight", "_orig_mod.critic_target.qf0.2.bias", "_orig_mod.critic_target.qf0.4.weight", "_orig_mod.critic_target.qf0.4.bias", "_orig_mod.critic_target.qf1.0.weight", "_orig_mod.critic_target.qf1.0.bias", "_orig_mod.critic_target.qf1.2.weight", "_orig_mod.critic_target.qf1.2.bias", "_orig_mod.critic_target.qf1.4.weight", "_orig_mod.critic_target.qf1.4.bias".

System Info

Tested on torch.compile compatible OS (Linux, WSL) with CUDA and CPU.

Checklist

araffin commented 1 year ago

Hello, thanks for reporting the issue. Pytorch doesn't provide any API/Helper to deal with that issue?

vmoens commented 10 months ago

Hello! I can run this code without any issue with the latest pytorch (nightly release)

import gymnasium as gym
from stable_baselines3 import SAC
import torch as th

env = gym.make("Pendulum-v1")

model = SAC("MlpPolicy", env, verbose=1)
model.policy = th.compile(model.policy)  # Compile the model
model.save("sac_pendulum")

del model # remove to demonstrate saving and loading

SAC.load("sac_pendulum")

Therefore, I think this issue can be closed (?)

araffin commented 10 months ago

Thanks for trying out =) I can reproduce the issue with PyTorch cpu 2.1.1 but it is fixed with torch-2.2.0.dev20231212+cpu.

araffin commented 10 months ago

Fixed two days ago, see https://github.com/pytorch/pytorch/issues/94575 and commit https://github.com/pytorch/pytorch/commit/38f890341df7a83decf6b6a7eed74786ce1ab866