DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
9.09k stars 1.7k forks source link

PPO doesn't work with MultiDiscrete observation space #1836

Closed elisavio closed 1 month ago

elisavio commented 9 months ago

🐛 Bug

I am implementing a simple custom environment for using PPO with MultiDiscrete observation space. It works if I use MultiDiscrete([ 5, 2, 2 ]), but when it becomes a multidimensional array it fails. In the code I attach I am using the MultiDiscrete observation given as example in https://gymnasium.farama.org/api/spaces/fundamental/#gymnasium.spaces.MultiDiscrete .

Code example

import numpy as np
import gymnasium as gym
from gymnasium.spaces import MultiDiscrete
from stable_baselines3 import PPO

class CustomEnv(gym.Env):
    def __init__(self):
        self.observation_space = MultiDiscrete(np.array([[1,2], [3,4]]), seed=42)  # Example multi-discrete observation space

        self.action_space = MultiDiscrete(np.array([3, 4, 3, 4]), seed=42)
        self.reset()

    def reset(self, seed=None, options=None):
        self.state = self.observation_space.sample()
        return self.state, {}

    def step(self, action):
        self.state = self.observation_space.sample()
        reward = 1    # Example reward function
        done = False  # Example termination condition
        info = {}     # Additional information (optional)
        return self.state, reward, done, False, info

env = CustomEnv()
model = PPO('MlpPolicy', env, verbose=1)

Relevant log output / Error message

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[5], line 25
     22         return self.state, reward, done, False, info
     24 env = CustomEnv()
---> 25 model = PPO('MlpPolicy', env, verbose=1)

File c:\Users\Elisa\anaconda3\envs\RL_env\lib\site-packages\stable_baselines3\ppo\ppo.py:171, in PPO.__init__(self, policy, env, learning_rate, n_steps, batch_size, n_epochs, gamma, gae_lambda, clip_range, clip_range_vf, normalize_advantage, ent_coef, vf_coef, max_grad_norm, use_sde, sde_sample_freq, rollout_buffer_class, rollout_buffer_kwargs, target_kl, stats_window_size, tensorboard_log, policy_kwargs, verbose, seed, device, _init_setup_model)
    168 self.target_kl = target_kl
    170 if _init_setup_model:
--> 171     self._setup_model()

File c:\Users\Elisa\anaconda3\envs\RL_env\lib\site-packages\stable_baselines3\ppo\ppo.py:174, in PPO._setup_model(self)
    173 def _setup_model(self) -> None:
--> 174     super()._setup_model()
    176     # Initialize schedules for policy/value clipping
    177     self.clip_range = get_schedule_fn(self.clip_range)

File c:\Users\Elisa\anaconda3\envs\RL_env\lib\site-packages\stable_baselines3\common\on_policy_algorithm.py:133, in OnPolicyAlgorithm._setup_model(self)
    121         self.rollout_buffer_class = RolloutBuffer
    123 self.rollout_buffer = self.rollout_buffer_class(
    124     self.n_steps,
    125     self.observation_space,  # type: ignore[arg-type]
   (...)
    131     **self.rollout_buffer_kwargs,
    132 )
--> 133 self.policy = self.policy_class(  # type: ignore[assignment]
    134     self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
    135 )
    136 self.policy = self.policy.to(self.device)

File c:\Users\Elisa\anaconda3\envs\RL_env\lib\site-packages\stable_baselines3\common\policies.py:505, in ActorCriticPolicy.__init__(self, observation_space, action_space, lr_schedule, net_arch, activation_fn, ortho_init, use_sde, log_std_init, full_std, use_expln, squash_output, features_extractor_class, features_extractor_kwargs, share_features_extractor, normalize_images, optimizer_class, optimizer_kwargs)
    502 self.ortho_init = ortho_init
    504 self.share_features_extractor = share_features_extractor
--> 505 self.features_extractor = self.make_features_extractor()
    506 self.features_dim = self.features_extractor.features_dim
    507 if self.share_features_extractor:

File c:\Users\Elisa\anaconda3\envs\RL_env\lib\site-packages\stable_baselines3\common\policies.py:120, in BaseModel.make_features_extractor(self)
    118 def make_features_extractor(self) -> BaseFeaturesExtractor:
    119     """Helper method to create a features extractor."""
--> 120     return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)

File c:\Users\Elisa\anaconda3\envs\RL_env\lib\site-packages\stable_baselines3\common\torch_layers.py:41, in FlattenExtractor.__init__(self, observation_space)
     40 def __init__(self, observation_space: gym.Space) -> None:
---> 41     super().__init__(observation_space, get_flattened_obs_dim(observation_space))
     42     self.flatten = nn.Flatten()

File c:\Users\Elisa\anaconda3\envs\RL_env\lib\site-packages\stable_baselines3\common\torch_layers.py:23, in BaseFeaturesExtractor.__init__(self, observation_space, features_dim)
     21 def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None:
     22     super().__init__()
---> 23     assert features_dim > 0
     24     self._observation_space = observation_space
     25     self._features_dim = features_dim

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

System Info

Checklist

qgallouedec commented 9 months ago

The simplest way around is to flatten the observation space.

from gymnasium.wrappers import FlattenObservation

env = FlattenObservation(CustomEnv())
elisavio commented 9 months ago

Thank you very much for your answer. If I try your command with the above example and then sample a random observation I get something totally different from what I want

env1 = CustomEnv()
env1.observation_space.shape
env1.observation_space.sample()
env2 = FlattenObservation(CustomEnv())
env2.observation_space.shape
env2.observation_space.sample()

The two shapes and the results of the samples are different: in the case of env1 we have a shape of (2,2), in the case of env2 we have (10,).

A question naturally arises: are there differences in the performance of an algorithm depending on the way I represent the observation (in this case, a flattened or not flattened observation) ?

qgallouedec commented 9 months ago

Indeed, it's different from what I expected too. It seems that flatten in the multi-discrete case works in a very counter-intuitive way (at least for me).

As far as I can see, there's no wrapper that allows this, so you'll have to create your own wrapper:

from gymnasium import ObservationWrapper

class FlattenMultiDiscrete(ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = MultiDiscrete(env.observation_space.nvec.flatten())

    def observation(self, observation):
        return observation.flatten()

env = FlattenMultiDiscrete(CustomEnv())
araffin commented 9 months ago

Note: the env checker must be updated to warn users that we don't support multi-dim multi discrete and propose a fix (the one from @qgallouedec ).

elisavio commented 9 months ago

Thank you very much for the answer. Tell me If I should close the issue, or I can leave it open until the bug is fixed.

qgallouedec commented 9 months ago

Please let it open until the env checker is updated :)