Usually we use model.save(path) to save a zip file.
But policy also has a save method. If one call that method with gSDE enabled, it will produce an error.
To Reproduce
import pybullet_envs
from stable_baselines3 import PPO
model = PPO('MlpPolicy', 'HopperBulletEnv-v0', use_sde=True)
model.learn(100)
model.policy.save("/tmp/sde_policy")
Traceback (most recent call last):
File "try_sde_with_policy_save.py", line 6, in <module>
model.policy.save("/tmp/sde_policy")
File "/home/liusida/code/code_trysb3/stable-baselines3/stable_baselines3/common/policies.py", line 152, in save
th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
File "/home/liusida/code/code_trysb3/stable-baselines3/stable_baselines3/common/policies.py", line 438, in _get_constructor_parameters
sde_net_arch=default_none_kwargs["sde_net_arch"],
KeyError: 'sde_net_arch'
Expected behavior
The policy (the PyTorch module) should be saved as a file.
System Info
Describe the characteristic of your environment:
tested with pip install and also commit: 5d47296b8d85355765e2aa1293b68c9a01f3779f
GPU models and configuration: No GPU
Python version: 3.8
PyTorch version: 1.8.1
Gym version: 0.18.0
Checklist
[x] I have checked that there is no similar issue in the repo (required)
🐛 Bug
Usually we use
model.save(path)
to save a zip file. Butpolicy
also has asave
method. If one call that method with gSDE enabled, it will produce an error.To Reproduce
Expected behavior
The policy (the PyTorch module) should be saved as a file.
System Info
Describe the characteristic of your environment:
Checklist