DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.79k stars 1.67k forks source link

Dict Observation Space with MultiDiscrete Action Space Issue #1713

Open lilmrmagoo opened 11 months ago

lilmrmagoo commented 11 months ago

🐛 Bug

Using a Dict or Tuple observation space and a MultiDiscrete action space together causes PPO or A2C models to fail. If you swap either space to be a box the issue is resolved. The environment will pass the checker without any issue but fail when setting up the model.

Code example

import gymnasium as gym
import numpy as np
from gymnasium import spaces

from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env

class CustomEnv(gym.Env):

    def __init__(self):
        super().__init__()
        self.observation_space = spaces.Dict({"hand":spaces.Box(low=-1, high=1, shape=(6,))})
        self.action_space = spaces.MultiDiscrete([[1,2]])

    def reset(self, seed=None, options=None):
        return self.observation_space.sample(), {}

    def step(self, action):
        obs = self.observation_space.sample()
        reward = 1.0
        terminated = False
        truncated = False
        info = {}
        return obs, reward, terminated, truncated, info

env = CustomEnv()
check_env(env)

model = PPO("MultiInputPolicy", env, verbose=1).learn(1000)

Relevant log output / Error message

TypeError                                 Traceback (most recent call last)
Cell In[16], line 30
     27 env = CustomEnv()
     28 check_env(env)
---> 30 model = PPO("MultiInputPolicy", env, verbose=1).learn(1000)

File ~\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\ppo\ppo.py:164, in PPO.__init__(self, policy, env, learning_rate, n_steps, batch_size, n_epochs, gamma, gae_lambda, clip_range, clip_range_vf, normalize_advantage, ent_coef, vf_coef, max_grad_norm, use_sde, sde_sample_freq, target_kl, stats_window_size, tensorboard_log, policy_kwargs, verbose, seed, device, _init_setup_model)
    161 self.target_kl = target_kl
    163 if _init_setup_model:
--> 164     self._setup_model()

File ~\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\ppo\ppo.py:167, in PPO._setup_model(self)
    166 def _setup_model(self) -> None:
--> 167     super()._setup_model()
    169     # Initialize schedules for policy/value clipping
    170     self.clip_range = get_schedule_fn(self.clip_range)

File ~\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\common\on_policy_algorithm.py:123, in OnPolicyAlgorithm._setup_model(self)
    113 self.rollout_buffer = buffer_cls(
    114     self.n_steps,
    115     self.observation_space,
   (...)
    120     n_envs=self.n_envs,
    121 )
    122 # pytype:disable=not-instantiable
--> 123 self.policy = self.policy_class(  # type: ignore[assignment]
    124     self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
    125 )
    126 # pytype:enable=not-instantiable
    127 self.policy = self.policy.to(self.device)

File ~\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\common\policies.py:853, in MultiInputActorCriticPolicy.__init__(self, observation_space, action_space, lr_schedule, net_arch, activation_fn, ortho_init, use_sde, log_std_init, full_std, use_expln, squash_output, features_extractor_class, features_extractor_kwargs, share_features_extractor, normalize_images, optimizer_class, optimizer_kwargs)
    833 def __init__(
    834     self,
    835     observation_space: spaces.Dict,
   (...)
    851     optimizer_kwargs: Optional[Dict[str, Any]] = None,
    852 ):
--> 853     super().__init__(
    854         observation_space,
    855         action_space,
    856         lr_schedule,
    857         net_arch,
    858         activation_fn,
    859         ortho_init,
    860         use_sde,
    861         log_std_init,
    862         full_std,
    863         use_expln,
    864         squash_output,
    865         features_extractor_class,
    866         features_extractor_kwargs,
    867         share_features_extractor,
    868         normalize_images,
    869         optimizer_class,
    870         optimizer_kwargs,
    871     )

File ~\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\common\policies.py:507, in ActorCriticPolicy.__init__(self, observation_space, action_space, lr_schedule, net_arch, activation_fn, ortho_init, use_sde, log_std_init, full_std, use_expln, squash_output, features_extractor_class, features_extractor_kwargs, share_features_extractor, normalize_images, optimizer_class, optimizer_kwargs)
    504 # Action distribution
    505 self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
--> 507 self._build(lr_schedule)

File ~\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\common\policies.py:577, in ActorCriticPolicy._build(self, lr_schedule)
    573     self.action_net, self.log_std = self.action_dist.proba_distribution_net(
    574         latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init
    575     )
    576 elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)):
--> 577     self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
    578 else:
    579     raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.")

File ~\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\common\distributions.py:336, in MultiCategoricalDistribution.proba_distribution_net(self, latent_dim)
    325 def proba_distribution_net(self, latent_dim: int) -> nn.Module:
    326     """
    327     Create the layer that represents the distribution:
    328     it will be the logits (flattened) of the MultiCategorical distribution.
   (...)
    333     :return:
    334     """
--> 336     action_logits = nn.Linear(latent_dim, sum(self.action_dims))
    337     return action_logits

File ~\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\linear.py:96, in Linear.__init__(self, in_features, out_features, bias, device, dtype)
     94 self.in_features = in_features
     95 self.out_features = out_features
---> 96 self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
     97 if bias:
     98     self.bias = Parameter(torch.empty(out_features, **factory_kwargs))

TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, torch.memory_format memory_format, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)

System Info

Checklist

araffin commented 11 months ago

Hello, I think you need to specify spaces.MultiDiscrete([1, 2]) instead of spaces.MultiDiscrete([[1,2]])

lilmrmagoo commented 11 months ago

that does seem to be related to the issue although an example given in the gymnasium docs for MultiDiscrete is

observation_space = MultiDiscrete(np.array([[1, 2], [3, 4]]), seed=42)
observation_space.sample()
array([[0, 0],
       [2, 2]])

which causes the issue as well, specifically I wish to use MultiDiscrete(np.array([[13, 13], [9, 8]])) for my environment and I had just simplified it for the example.

araffin commented 11 months ago

n example given in the gymnasium docs for MultiDiscrete is

we do not support everything that Gymnasium allows (a warning is indeed missing in the env checker). What you have should be equivalent to MultiDiscrete([13, 13, 9, 8]).

lilmrmagoo commented 11 months ago

I see, would be great if that was mentioned somewhere or was part of the env check. Not sure if that constitutes leaving this issue open or not.

araffin commented 11 months ago

was part of the env check.

I would be happy to receive a PR that adds a warning in the env checker ;).

IsYang23 commented 3 weeks ago

Hey,when trying to run my highway script on multi-agent settings, I run into this error: " File ~.conda\envs\spyder\Lib\site-packages\stable_baselines3\common\base_class.py:180 in init assert isinstance(self.action_space, supported_action_spaces), (

AssertionError: The algorithm only supports (<class 'gymnasium.spaces.discrete.Discrete'>,) as action spaces but Tuple(Discrete(5), Discrete(5)) was provided"

How to solve the issue ? Here is my env config: config= {"action": { "type": "MultiAgentAction", "action_config":{ "type":"DiscreteMetaAction", "longitudinal": True, "lateral": True, "target_speeds": [50, 60, 70, 80], },

         },

    "observation":{
        "type":"MultiAgentObservation",
        "observation_config":{
            "type": "Kinematics",
            "vehicles_count": 8,
        "features": [
            "presence",
            "x",
            "y",
            "vx",
            "vy",
            "cos_h",
            "sin_h"
        ],
        "absolute": False                
            },
        },
    "lanes_count": 3, "vehicles_count": 10, "controlled_vehicles": 2, "collision_reward": -1, "right_lane_reward": 0, "high_speed_reward": 1, "lane_change_reward": 0.1, "reward_speed_range": [20, 30]},render_mode="rgb_array")