Farama-Foundation / Gymnasium

An API standard for single-agent reinforcement learning environments, with popular reference environments and related utilities (formerly Gym)
https://gymnasium.farama.org
MIT License
7.39k stars 836 forks source link

[Bug Report] AsyncVectorEnv observation not batched when observation space is not Box #1164

Closed pkrack closed 2 months ago

pkrack commented 2 months ago

The AsyncVectorEnv class does not batch the observations when the underlying environment has an observation space of type Dict. Instead the reset and step functions return a tuple of individual observations.

I think this is a bug because:

The reason for this behaviour is that, when using use_shared_mem=True (the default), the observations are read only once at the beginning using the read_from_shared_memory functions and the dict version of that function returns tuples instead of returning the batched observations. The _read_base_from_shared_memory for example, reshapes the observation to fit the batched observation space.

In the AsyncVectorEnv lines 306 and 397 there is then a case distinction: if not shared_memory... and the observations are not batched when using shared_memory.

A possible fix for the error in the code sample (on tag v1.0.0.a2) is:

@@ -162,13 +162,10 @@ def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1):

 @read_from_shared_memory.register(Dict)
 def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
-    subspace_samples = {
+    return {
         key: read_from_shared_memory(subspace, shared_memory[key], n=n)
         for (key, subspace) in space.spaces.items()
     }
-    return tuple(
-        {key: subspace_samples[key][i] for key in space.keys()} for i in range(n)
-    )

Is there a reason for returning a tuple instead of returning the batched observation (subspace_samples is already correctly batched in the code example below)?

EDIT Similar issue with the Tuple action space: subspace_samples is batched, then tuple(zip(*subspace_samples)) is returned, which "unbatches" the observations in the Tuple obs are for example Box and returns non-sense if the observations in the tuple are for example Dict observations (*subspace_samples then unpacks to the key names of the individual dicts, zipping them together gives you pairs of key names).

Similar fix:

@@ -153,22 +153,18 @@ def _read_base_from_shared_memory(

 @read_from_shared_memory.register(Tuple)
 def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1):
-    subspace_samples = tuple(
+    return tuple(
         read_from_shared_memory(subspace, memory, n=n)
         for (memory, subspace) in zip(shared_memory, space.spaces)
     )
-    return tuple(zip(*subspace_samples))

Code example

from functools import partial
from traceback import print_exception

import numpy as np

import gymnasium as gym

def main():
    wrappers = [
        partial(
            gym.wrappers.TransformObservation,
            func=lambda x: {"key1": x[0:1], "key2": x[1:]},
            observation_space=gym.spaces.Dict(
                {
                    "key1": gym.spaces.Box(
                        low=-np.inf, high=np.inf, shape=(1,)
                    ),
                    "key2": gym.spaces.Box(
                        low=-np.inf, high=np.inf, shape=(3,)
                    ),
                }
            )
        )
    ]
    env = gym.make_vec(
        "CartPole-v1",
        num_envs=2,
        vectorization_mode=gym.VectorizeMode.ASYNC,
        wrappers=wrappers
    )
    obs, info = env.reset()
    print(f"{obs in env.observation_space=}")
    print(f"{env.observation_space=}")
    print(f"{type(obs)=}")
    print(f"{obs=}")
    env = gym.wrappers.vector.VectorizeTransformObservation(
        env=env,
        wrapper=gym.wrappers.TransformObservation,
        func=lambda x: np.concatenate((x["key1"], x["key2"])),
        observation_space=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4,))
    )
    try:
        _ = env.reset()
    except TypeError as exc:
        print_exception(exc)

if __name__ == "__main__":
    main()

$ python reproducible_bug.py
obs in env.observation_space=False
env.observation_space=Dict('key1': Box(-inf, inf, (2, 1), float32), 'key2': Box(-inf, inf, (2, 3), float32))
type(obs)=<class 'tuple'>
obs=({'key1': array([-0.04952637], dtype=float32), 'key2': array([-0.01968242, -0.0389384 ,  0.01207159], dtype=float32)}, {'key1': array([-0.03940215], dtype=float32), 'key2': array([-0.00562053,  0.04703879,  0.00567271], dtype=float32)})
Traceback (most recent call last):
  File "/home/pkrack/Desktop/Gymnasium/reproducible_bug.py", line 44, in main
    _ = env.reset()
        ^^^^^^^^^^^
  File "/home/pkrack/Desktop/Gymnasium/gymnasium/vector/vector_env.py", line 518, in reset
    return self.observations(observations), infos
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pkrack/Desktop/Gymnasium/gymnasium/wrappers/vector/vectorize_observation.py", line 166, in observations
    tuple(
  File "/home/pkrack/Desktop/Gymnasium/gymnasium/wrappers/vector/vectorize_observation.py", line 166, in <genexpr>
    tuple(
         ^
  File "/home/pkrack/Desktop/Gymnasium/gymnasium/vector/utils/space_utils.py", line 222, in _iterate_dict
    *[
     ^
  File "/home/pkrack/Desktop/Gymnasium/gymnasium/vector/utils/space_utils.py", line 223, in <listcomp>
    (key, iterate(subspace, items[key]))
                            ~~~~~^^^^^
TypeError: tuple indices must be integers or slices, not str

System info

Additional context

No response

Checklist

pseudo-rnd-thoughts commented 2 months ago

@pkrack Thanks for the detailed issue Could you make a PR with your suggested fix along with related tests that confirms the fix

pseudo-rnd-thoughts commented 2 months ago

Ironically, I caused this bug by fixing a test that was actually wrong but never realised, https://github.com/Farama-Foundation/Gymnasium/pull/941 By reverting the shared_memory code to what you suggested (this was the original code) and fixing the test related to all of this.

pkrack commented 2 months ago

Thanks for the quick answer and fix.