DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
9.06k stars 1.7k forks source link

Cannot load in Custom Environment #1342

Closed anilkurkcu closed 1 year ago

anilkurkcu commented 1 year ago

🐛 Bug

I am able to save my policy with model.save(), but cannot load it with SAC.load()

== CURRENT SYSTEM INFO == OS: Linux-5.4.0-126-generic-x86_64-with-debian-bullseye-sid #142-Ubuntu SMP Fri Aug 26 12:12:57 UTC 2022 Python: 3.7.0 Stable-Baselines3: 1.6.2 PyTorch: 1.12.1 GPU Enabled: True Numpy: 1.21.5 Gym: 0.21.0

== SAVED MODEL SYSTEM INFO == OS: Linux-5.4.0-126-generic-x86_64-with-debian-bullseye-sid #142-Ubuntu SMP Fri Aug 26 12:12:57 UTC 2022 Python: 3.7.0 Stable-Baselines3: 1.6.2 PyTorch: 1.12.1 GPU Enabled: True Numpy: 1.21.5 Gym: 0.21.0

Code example

import gym
import numpy as np
from gym import spaces

from stable_baselines3 import A2C
from stable_baselines3.common.env_checker import check_env

class CustomEnv(gym.Env):

    def __init__(self):
        super().__init__()
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
        self.action_space = spaces.Box(low=-1, high=1, shape=(6,))

    def reset(self):
        return self.observation_space.sample()

    def step(self, action):
        obs = self.observation_space.sample()
        reward = 1.0
        done = False
        info = {}
        return obs, reward, done, info

env = CustomEnv()
check_env(env)

model = A2C("MlpPolicy", env, verbose=1).learn(1000)

Relevant log output / Error message

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/kurkcua/miniconda3/envs/venv/lib/python3.7/site-packages/stable_baselines3/common/base_class.py", line 761, in load
    print_system_info=print_system_info,
  File "/home/kurkcua/miniconda3/envs/venv/lib/python3.7/site-packages/stable_baselines3/common/save_util.py", line 419, in load_from_zip_file
    data = json_to_data(json_data, custom_objects=custom_objects)
  File "/home/kurkcua/miniconda3/envs/venv/lib/python3.7/site-packages/stable_baselines3/common/save_util.py", line 164, in json_to_data
    deserialized_object = cloudpickle.loads(base64_object)
UnicodeDecodeError: 'ascii' codec can't decode byte 0xc3 in position 3: ordinal not in range(128)

System Info

No response

Checklist

anilkurkcu commented 1 year ago

Ok I think I found a solution for this, just putting it here in case it would be helpful for anyone :)

In file stable_baselines3/common/save_util.py, I had to change line 164 to:

deserialized_object = cloudpickle.loads(base64_object, encoding='latin1')

araffin commented 1 year ago

Hello,

this type of error usually occurs when something was saved in python 2 and then loaded in python 3: https://stackoverflow.com/questions/11305790/pickle-incompatibility-of-numpy-arrays-between-python-2-and-3

I also cannot reproduce the error using the provided code (you can try in a google colab).

I found a solution for this

Good to hear =)