trying to mask actions for an environment with dict observation and multidiscrete action space #149

Open zbenmo opened 1 year ago

zbenmo commented 1 year ago

Describe the bug Trying to mask actions for an environment with dict observation and multidiscrete action space.

Code example I have "pull request" a failing test.

from typing import Tuple

import gym
import numpy as np
from sb3_contrib import MaskablePPO
from stable_baselines3.common.env_checker import check_env

class MyBreakingCodeEnv(gym.Env):
    def __init__(self):
        obs_dict = dict(
            board=gym.spaces.Box(low=0, high=1, shape=(8 * 8,), dtype=bool),
            player=gym.spaces.MultiDiscrete([8, 8]),
        self.observation_space = gym.spaces.Dict(obs_dict)
        self.action_space = gym.spaces.MultiDiscrete([8, 8])

    def reset(self):
        return self._obs()

    def _obs(self):
        return {"board": np.zeros(shape=(8, 8), dtype=bool).flatten(), "player": (2, 3)}

    def step(self, action: Tuple[int, int]):
        reward = 0.2
        done = False
        info = {}
        return self._obs(), reward, done, info

    def render(self):

    def action_masks(self) -> np.ndarray:
        masks = np.zeros(shape=(8, 8), dtype=bool)
        for i in range(8):
            masks[i, i] = True
        return masks

env = MyBreakingCodeEnv()

check_env(env, warn=True)

model = MaskablePPO("MultiInputPolicy", env, n_steps=32, seed=8)
 RuntimeError: shape '[1, 8]' is invalid for input of size 32

sb3 - 1.8.0a2 (also with 1.7.0)

zbenmo commented 1 year ago

The same issue as #148

araffin commented 1 year ago

If I use the env checker I get:

Traceback (most recent call last):
  File "/home/raff_an/USERDIR/projects/torchy-baselines/stable_baselines3/common/", line 219, in _check_returned_values
    _check_obs(obs[key], observation_space.spaces[key], "reset")
  File "/home/raff_an/USERDIR/projects/torchy-baselines/stable_baselines3/common/", line 164, in _check_obs
    ), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple"
AssertionError: The observation returned by the `reset()` method should be a single value, not a tuple

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "", line 42, in <module>
    check_env(env, warn=True)
  File "/home/raff_an/USERDIR/projects/torchy-baselines/stable_baselines3/common/", line 377, in check_env
    _check_returned_values(env, observation_space, action_space)
  File "/home/raff_an/USERDIR/projects/torchy-baselines/stable_baselines3/common/", line 221, in _check_returned_values
    raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
AssertionError: Error while checking key=player: The observation returned by the `reset()` method should be a single value, not a tuple

with the correct observation:

Traceback (most recent call last):
  File "", line 46, in <module>
  File "/sb3_contrib/sb3_contrib/ppo_mask/", line 521, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
  File "/sb3_contrib/sb3_contrib/ppo_mask/", line 298, in collect_rollouts
    actions, values, log_probs = self.policy(obs_tensor, action_masks=action_masks)
  File "/volume/USERSTORE/raff_an/mambaforge/envs/th/lib/python3.7/site-packages/torch/nn/modules/", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/sb3_contrib/sb3_contrib/common/maskable/", line 140, in forward
  File "/sb3_contrib/sb3_contrib/common/maskable/", line 246, in apply_masking
  File "/sb3_contrib/sb3_contrib/common/maskable/", line 58, in apply_masking
    self.masks = th.as_tensor(masks, dtype=th.bool, device=device).reshape(self.logits.shape)
RuntimeError: shape '[1, 8]' is invalid for input of size 32

~My guess is that we don't support multi-dimensional multi discrete space for masks.~

Please take a look at the built-in multi discrete env.

araffin commented 1 year ago

Probably a duplicate of and

Maxxxel commented 5 months ago

I get the same error for my env with Multidiscrete(3, 10, 10), my action mask is a 300 bool values sized array: shape '[-1, 23]' is invalid for input of size 300 But i think the problem comes from the Dict Observation Space. Im using FlattenObservation Wrapper for it and i guess this doesnt work with MaskablePPO..

when i run it without the FlattenObservation Wrapper the .learn method instead returns 'dict' object has no attribute 'flatten'

when i use the MaskableMultiInputActorCriticPolicy i also get the same error about the shape MaskablePPO(MaskableMultiInputActorCriticPolicy, env, verbose=2)

The problem is that masks_tensor = masks_tensor.view(-1, sum(self.action_dims)) inside of takes a sum of dims. But my dims are (3, 10, 10) so it expects 23 but recevied 300!

For me its pressing 1 of 3 buttons on a 10 x 10 grid.

I 'fixed' it by switching to a Discrete representation of Size 300, where the first digit is the button and the 2nd and 3rd digit are the xy coords..