Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
https://sb3-contrib.readthedocs.io
MIT License
461 stars 169 forks source link

[BUG] action masking does not work with VecEnv and MultiDiscrete action space #74

Open clotodex opened 2 years ago

clotodex commented 2 years ago

Describe the bug I am aware of https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/49#issuecomment-957629253 - but it still does not work. I have investigated the code and this is what I found:

When having more than one environment, each using their own ActionMasker, the masks get collected in batch form, thus splitting the masks across the distributions does not work. This feels to me like a VecEnv bug, however, I followed the advice in the documentation and comments on how to set up the action masker on an env-individual basis.

https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/75b2de139927da26d5871aef9fd839632f73b296/sb3_contrib/common/maskable/distributions.py#L234

My ActionSpace is for example Multidiscrete([5]*72). I am spinning up 128 environments. (Fyi: 5*72 = 360) When investigating the MaskableMultiCategoricalDistribution it actually creates 72 MaskableCategorical distributions, as it should. BUT: the shape of the mask is not (360,) or (1,360) but instead it is (128, 360). This way the masks get split weirdly. and the above-mentioned line as well as the distributions are not built for it AFAIK. When tracking invalid actions taken in my environment, there are a ton instead of the expected 0.

System Info Describe the characteristic of your environment:

Am I doing something wrong or are there further ways I can debug this?

araffin commented 2 years ago

Hello, best is to start with a working example: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/75b2de139927da26d5871aef9fd839632f73b296/sb3_contrib/common/envs/invalid_actions_env.py#L39

that being said, there might be a bug too. Tagging @kronion and @vwxyzjn as they actually worked with it.

kronion commented 2 years ago

It might be a bug, but it's hard to say from the description. Could you share the code to reproduce? And could you show an example of how the mask is being split weirdly? My initial impression is that the (128, 360) shape is intended because each row corresponds to an env in the vecenv.

araffin commented 2 years ago

BUT: the shape of the mask is not (360,) or (1,360) but instead it is (128, 360)

this actually looks good to me, we need to retrieve one mask per env. Does it produce an error? if so, please provide a minimal example to reproduce the issue and provide the traceback.

(fyi I think that we expect 1D mask from the env even for multi discrete (see https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/80#issuecomment-1159013599), it will be reshaped by the algorithm afterward)