DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.97k stars 1.68k forks source link

[Bug]: Issue with SubprocVecEnv when observation space is a Dict #1428

Closed BDonnot closed 1 year ago

BDonnot commented 1 year ago

πŸ› Bug

When I try to use an environment with a gym Dict observation space, everything works pretty well for single process file.

I wanted to speed up computations using SubprocVecEnv but it appears it's not working anymore.

I used a complicated gym environment when noticing the bug (hence the need to SubprocVecEnv) but I managed to reproduce it with a very simple one (see below). What the environment does is pretty irrelevant for this bug (in reality it's for managing powergrid, for the example it just increments two things and return +1 if the first one reaches 100 before the other...)

To Reproduce

import numpy as np
import gym
from gym.spaces import Discrete, Box, Dict
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv

class CustomGym(gym.Env):
    """A simple gym environment to illustrate the bug, will work with
    any gym environment with Dict observation space
    """
    def __init__(self) -> None:
        self.action_space = Discrete(2)
        self.observation_space = Dict(spaces={"one": Box(low=np.zeros(1, dtype=np.float32),
                                                         high=np.zeros(1, dtype=np.float32) + 100,
                                                         dtype=np.float32),
                                              "two": Box(low=np.zeros(1, dtype=np.float32),
                                                         high=np.zeros(1, dtype=np.float32) + 100,
                                                         dtype=np.float32)
                                             }
                                      )
        self.reward_range = (-1, 1)
        self._internal_state = [0, 0]

    def reset(self):
        self._internal_state = [0, 0]
        return {"one": self._internal_state[0], "two": self._internal_state[1]}

    def seed(self, seed):
        # nothing to do here
        super().seed(seed)

    def step(self, action):
        self._internal_state[action] += 1

        obs = {"one": self._internal_state[0], "two": self._internal_state[1]}
        reward = 0
        done = False
        info = {}
        if self._internal_state[0] == 100:
            reward = 1
            done = True
        elif self._internal_state[1] == 100:
            reward = -1
            done = True
        return obs, reward, done, info

def make_env():
    return CustomGym()

if __name__ == '__main__':
    # env = make_env()  # single process: it works fine
    env = make_vec_env(make_env, n_envs=1, vec_env_cls=SubprocVecEnv)  # "multi" process it does not
    model = PPO("MultiInputPolicy", env, verbose=1)
    model.learn(total_timesteps=250)

Relevant log output / Error message

Using cuda device
Traceback (most recent call last):
  File "bug_sb3.py", line 56, in <module>
    model.learn(total_timesteps=250)
  File "/usr/local/lib/python3.8/dist-packages/stable_baselines3/ppo/ppo.py", line 299, in learn
    return super(PPO, self).learn(
  File "/usr/local/lib/python3.8/dist-packages/stable_baselines3/common/on_policy_algorithm.py", line 250, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/usr/local/lib/python3.8/dist-packages/stable_baselines3/common/on_policy_algorithm.py", line 169, in collect_rollouts
    actions, values, log_probs = self.policy.forward(obs_tensor)
  File "/usr/local/lib/python3.8/dist-packages/stable_baselines3/common/policies.py", line 588, in forward
    features = self.extract_features(obs)
  File "/usr/local/lib/python3.8/dist-packages/stable_baselines3/common/policies.py", line 129, in extract_features
    return self.features_extractor(preprocessed_obs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/stable_baselines3/common/torch_layers.py", line 274, in forward
    encoded_tensor_list.append(extractor(observations[key]))
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/flatten.py", line 42, in forward
    return input.flatten(self.start_dim, self.end_dim)
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

System Info

OS: Linux-5.15.0-69-generic-x86_64-with-glibc2.29 #76~20.04.1-Ubuntu SMP Mon Mar 20 15:54:19 UTC 2023 Python: 3.8.10 Stable-Baselines3: 1.4.0 PyTorch: 1.10.1+cu102 GPU Enabled: True Numpy: 1.20.3 Gym: 0.18.0

Checklist

araffin commented 1 year ago

Hello, what are the ouput/warnings from the env checker ? (see custom env issue template)

BDonnot commented 1 year ago

Hello,

Thanks for the quick reply.

Sorry for missing the specific issue template for custom environment. I did not saw it before posting.

For the env I used as an example indeed I made a mistake and the env checker failed. I modified it and after modification;

        obs = {"one": np.array([self._internal_state[0]]),
                   "two": np.array([self._internal_state[1]])}

(instead of simply returning python integer) and it does not fail anymore.

However I am pretty sure in the env for my initial problem (which is not this one) the env_checker worked. I will try to modify the "CustomGym" so that it ewhibits the same problem, if I can. Thanks for the help

BDonnot commented 1 year ago

Good news: bug (like most bugs) was between the chair and the keyboard...

With the error spotted on this simple cases, I manage to run the initial code which did not.

Thanks for your help and sorry for this error. Stable baselines works well on this case :-)