Farama-Foundation / SuperSuit

A collection of wrappers for Gymnasium and PettingZoo environments (being merged into gymnasium.wrappers and pettingzoo.wrappers
Other
441 stars 56 forks source link

StableBaselines3 and PettingZoo: Unrecognized type of observation <class 'tuple'> #222

Closed adam-crowther closed 1 year ago

adam-crowther commented 1 year ago

Hi,

I'm having a similar issue to Issue #220, however mine is not being resolved by using ss.flatten_v0().

I'm using the latest current versions of stable-baselines3==2.1.0a0, pettingzoo==1.23.1, supersuit==3.8.1 and gymnasium==0.28.1. I have also tried with stable-baselines3==2.0.0 and had the same issue.

I have adapted @PieroMacaluso's dummy project from Issue #169 to reproduce the issue: https://github.com/adam-crowther/test-supersuit-baseline3-pettingzoo-parallel-env

The ParallelEnv looks like this:

import random
from typing import Dict

import numpy as np
from gymnasium import spaces
from gymnasium.utils import EzPickle
from pettingzoo import ParallelEnv
from pettingzoo.utils.env import ObsDict, ActionDict

class DummyParallelEnv(ParallelEnv, EzPickle):
    metadata = {'render_modes': ['ansi'], "name": "TestParallelEnv-v0"}

    def __init__(self, n_agents: int = 20, new_step_api: bool = True) -> None:
        EzPickle.__init__(
            self,
            n_agents,
            new_step_api
        )

        self._terminated = False
        self.current_step = 0

        self.n_agents = n_agents
        self.possible_agents = [f"player_{idx}" for idx in range(n_agents)]
        self.agents = self.possible_agents[:]

        self.agent_name_mapping = dict(
            zip(self.possible_agents, list(range(len(self.possible_agents))))
        )

        self.observation_spaces = {
            agent: spaces.Box(shape=(len(self.agents),), dtype=np.float64, low=0.0, high=1.0)
            for agent in self.possible_agents
        }

        self.action_spaces = {
            agent: spaces.Discrete(4) for agent in self.possible_agents}

    def observation_space(self, agent):
        return self.observation_spaces[agent]

    def action_space(self, agent):
        return self.action_spaces[agent]

    def step(self, actions: ActionDict) \
            -> tuple[ObsDict, dict[str, float], dict[str, bool], dict[str, bool], dict[str, dict]]:
        self.current_step += 1
        self._terminated = self.current_step >= 100

        observations = self.__calculate_observations()
        rewards = {
            self.agents[agent]: random.randint(0, 100) for agent in range(len(self.agents))
        }
        terminated = {agent: self._terminated for agent in self.agents}
        truncated = {agent: False for agent in self.agents}
        infos = {agent: {} for agent in self.agents}

        if self._terminated:
            self.agents = []

        return observations, rewards, terminated, truncated, infos

    def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[ObsDict, dict[str, dict]]:
        self.agents = self.possible_agents[:]
        self._terminated = False
        self.current_step = 0
        observations = self.__calculate_observations()
        infos = {agent: {} for agent in self.agents}

        return observations, infos

    def __calculate_observations(self) -> Dict[str, np.ndarray]:
        return {
            agent: self.observation_space(agent).sample() for agent in self.agents
        }

And is executed like this:

import supersuit as ss
from pettingzoo.test import parallel_api_test
from stable_baselines3 import PPO

from dummy_env import dummy

if __name__ == '__main__':
    env_parallel = dummy.DummyParallelEnv()
    parallel_api_test(env_parallel)

    # env_parallel = ss.flatten_v0(env_parallel)
    env_parallel = ss.pettingzoo_env_to_vec_env_v1(env_parallel)
    env_parallel = ss.concat_vec_envs_v1(env_parallel, 1, base_class="stable_baselines3")

    model = PPO("MlpPolicy", env_parallel, verbose=1)

    model.learn(total_timesteps=10_000)

When I execute I get this error:

Traceback (most recent call last):
  File "C:\dev\repo\test-supersuit-baseline3-pettingzoo-parallel-env\main_dummy.py", line 17, in <module>
    model.learn(total_timesteps=10_000)
  File "C:\Users\adamcc\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\ppo\ppo.py", line 308, in learn
    return super().learn(
           ^^^^^^^^^^^^^^
  File "C:\Users\adamcc\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\common\on_policy_algorithm.py", line 259, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\adamcc\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\common\on_policy_algorithm.py", line 168, in collect_rollouts
    obs_tensor = obs_as_tensor(self._last_obs, self.device)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\adamcc\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\common\utils.py", line 487, in obs_as_tensor
    raise Exception(f"Unrecognized type of observation {type(obs)}")
Exception: Unrecognized type of observation <class 'tuple'>

Process finished with exit code 1

Exception: Unrecognized type of observation <class 'tuple'>

If I set a breakpoint in stable_baselines3 class on_policy_algorithm.py at line 168, I see that self._last_obs is being set to the tuple of observation and info that is being returned by the ParallelEnv reset() method. obs_as_tensor() is expecting a np.ndarray.

Have I got something wrong or is there a compatibility issue here somewhere?

Thanks,

Adam

adam-crowther commented 1 year ago

I created a workaround using a shim that wraps the SB3VecEnvWrapper:

class Sb3ShimWrapper(VecEnvWrapper):
    metadata = {'render_modes': ['human', 'files', 'none'], "name": "Sb3ShimWrapper-v0"}

    def __init__(self, venv):
        super().__init__(venv)

    def reset(self, seed=None, options=None):
        return self.venv.reset()[0]

    def step_wait(self) -> VecEnvStepReturn:
        return self.venv.step_wait()

As you can see it overrides the reset method and returns the first element of the tuple.

I integrate it like this:

if __name__ == '__main__':
    env_parallel = dummy.DummyParallelEnv()
    parallel_api_test(env_parallel)

    # env_parallel = ss.flatten_v0(env_parallel)
    env_parallel = ss.pettingzoo_env_to_vec_env_v1(env_parallel)
    env_parallel = ss.concat_vec_envs_v1(env_parallel, 1, base_class="stable_baselines3")
    env_parallel = Sb3ShimWrapper(env_parallel)

    model = PPO("MlpPolicy", env_parallel, verbose=1)

    model.learn(total_timesteps=10_000)

I will push this change to my demo repo.

Now I have a new problem with render(), which I will document in a new Issue.

elliottower commented 1 year ago

To my knowledge, this has been fixed with https://github.com/Farama-Foundation/SuperSuit/pull/226 (I was getting the same issue, it's because they expect only an observation whereas by default PettingZoo and Gymnasium return an observation and info)