Open zbenmo opened 1 year ago
The same issue as #148
If I use the env checker I get:
Traceback (most recent call last):
File "/home/raff_an/USERDIR/projects/torchy-baselines/stable_baselines3/common/env_checker.py", line 219, in _check_returned_values
_check_obs(obs[key], observation_space.spaces[key], "reset")
File "/home/raff_an/USERDIR/projects/torchy-baselines/stable_baselines3/common/env_checker.py", line 164, in _check_obs
), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple"
AssertionError: The observation returned by the `reset()` method should be a single value, not a tuple
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "test.py", line 42, in <module>
check_env(env, warn=True)
File "/home/raff_an/USERDIR/projects/torchy-baselines/stable_baselines3/common/env_checker.py", line 377, in check_env
_check_returned_values(env, observation_space, action_space)
File "/home/raff_an/USERDIR/projects/torchy-baselines/stable_baselines3/common/env_checker.py", line 221, in _check_returned_values
raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
AssertionError: Error while checking key=player: The observation returned by the `reset()` method should be a single value, not a tuple
with the correct observation:
Traceback (most recent call last):
File "test.py", line 46, in <module>
model.learn(32)
File "/sb3_contrib/sb3_contrib/ppo_mask/ppo_mask.py", line 521, in learn
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
File "/sb3_contrib/sb3_contrib/ppo_mask/ppo_mask.py", line 298, in collect_rollouts
actions, values, log_probs = self.policy(obs_tensor, action_masks=action_masks)
File "/volume/USERSTORE/raff_an/mambaforge/envs/th/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/sb3_contrib/sb3_contrib/common/maskable/policies.py", line 140, in forward
distribution.apply_masking(action_masks)
File "/sb3_contrib/sb3_contrib/common/maskable/distributions.py", line 246, in apply_masking
distribution.apply_masking(mask)
File "/sb3_contrib/sb3_contrib/common/maskable/distributions.py", line 58, in apply_masking
self.masks = th.as_tensor(masks, dtype=th.bool, device=device).reshape(self.logits.shape)
RuntimeError: shape '[1, 8]' is invalid for input of size 32
~My guess is that we don't support multi-dimensional multi discrete space for masks.~
Please take a look at the built-in multi discrete env.
I get the same error for my env with Multidiscrete(3, 10, 10), my action mask is a 300 bool values sized array:
shape '[-1, 23]' is invalid for input of size 300
But i think the problem comes from the Dict Observation Space. Im using FlattenObservation Wrapper for it and i guess this doesnt work with MaskablePPO..
when i run it without the FlattenObservation Wrapper the .learn method instead returns
'dict' object has no attribute 'flatten'
when i use the MaskableMultiInputActorCriticPolicy i also get the same error about the shape
MaskablePPO(MaskableMultiInputActorCriticPolicy, env, verbose=2)
The problem is that
masks_tensor = masks_tensor.view(-1, sum(self.action_dims))
inside of distributions.py takes a sum of dims.
But my dims are (3, 10, 10) so it expects 23 but recevied 300!
For me its pressing 1 of 3 buttons on a 10 x 10 grid.
I 'fixed' it by switching to a Discrete representation of Size 300, where the first digit is the button and the 2nd and 3rd digit are the xy coords..
Important Note: We do not do technical support, nor consulting and don't answer personal questions per email. Please post your question on reddit or stack overflow in that case.
If you have any questions, feel free to create an issue with the tag [question].
If you wish to suggest an enhancement or feature request, add the tag [feature request].
If you are submitting a bug report, please fill in the following details.
Describe the bug Trying to mask actions for an environment with dict observation and multidiscrete action space.
Code example I have "pull request" a failing test.
Please use the markdown code blocks for both code and stack traces.
System Info Describe the characteristic of your environment:
Describe how the library was installed (pip, docker, source, ...)
Stable-Baselines3 and sb3-contrib versions
GPU models and configuration
Python version Python 3.8.10
PyTorch version 1.13.1
Gym version 0.21.0
Versions of any other relevant libraries
sb3 - 1.8.0a2 (also with 1.7.0)
You can use
sb3.get_system_info()
to print relevant packages info:Additional context Add any other context about the problem here.