DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.98k stars 1.68k forks source link

[Bug] with gSDE, calling policy.save() causes error #399

Closed liusida closed 3 years ago

liusida commented 3 years ago

🐛 Bug

Usually we use model.save(path) to save a zip file. But policy also has a save method. If one call that method with gSDE enabled, it will produce an error.

To Reproduce

import pybullet_envs
from stable_baselines3 import PPO

model = PPO('MlpPolicy', 'HopperBulletEnv-v0', use_sde=True)
model.learn(100)
model.policy.save("/tmp/sde_policy")
Traceback (most recent call last):
  File "try_sde_with_policy_save.py", line 6, in <module>
    model.policy.save("/tmp/sde_policy")
  File "/home/liusida/code/code_trysb3/stable-baselines3/stable_baselines3/common/policies.py", line 152, in save
    th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
  File "/home/liusida/code/code_trysb3/stable-baselines3/stable_baselines3/common/policies.py", line 438, in _get_constructor_parameters
    sde_net_arch=default_none_kwargs["sde_net_arch"],
KeyError: 'sde_net_arch'

Expected behavior

The policy (the PyTorch module) should be saved as a file.

 System Info

Describe the characteristic of your environment:

Checklist

araffin commented 3 years ago

Hello, thanks for reporting the bug, will be fixed in https://github.com/DLR-RM/stable-baselines3/pull/401 ;)