Farama-Foundation / SuperSuit

A collection of wrappers for Gymnasium and PettingZoo environments (being merged into gymnasium.wrappers and pettingzoo.wrappers
Other
446 stars 56 forks source link

Possible problem related to `pettingzoo_env_to_vec_env_v1` and `reset()` function #169

Closed PieroMacaluso closed 2 years ago

PieroMacaluso commented 2 years ago

I am having trouble in making things work with a Custom ParallelEnv I wrote by using PettingZoo. I am using SuperSuit's ss.pettingzoo_env_to_vec_env_v1(env) as a wrapper to Vectorize the environment and make it work with Stable-Baseline3 as documented here.

You can find attached a summary of the most relevant part of the code:

from typing import Optional
from gym import spaces
import numpy as np
from pettingzoo import ParallelEnv
import supersuit as ss

def env(**kwargs):
    env_ = parallel_env(**kwargs)
    env_ = ss.pettingzoo_env_to_vec_env_v1(env_)
    return env_

petting_zoo = env

class parallel_env(ParallelEnv):
    metadata = {'render_modes': ['ansi'], "name": "PlayerEnv-v0"}

    def __init__(self, n_agents, new_step_api: bool = True) -> None:
        # [...]
        self.possible_agents = [
            f"player_{idx}" for idx in range(n_agents)]

        self.agents = self.possible_agents[:]

        self.agent_name_mapping = dict(
            zip(self.possible_agents, list(range(len(self.possible_agents))))
        )

        self.observation_spaces = spaces.Dict(
            {agent: spaces.Box(shape=(20,), dtype=np.float64, low=0.0, high=1.0)
             for agent in self.possible_agents}
        )

        self.action_spaces = spaces.Dict(
            {agent: spaces.Box(low=0, high=4, shape=(1,), dtype=np.int32)
             for agent in self.possible_agents}
        )

    def observation_space(self, agent):
        return self.observation_spaces[agent]

    def action_space(self, agent):
        return self.action_spaces[agent]

    def __calculate_observation(self, idx_player: int) -> np.ndarray:
        # Calculate the observation for the given player (just an example)
        observation = np.zeros(20)
        return observation

    def __calculate_observations(self) -> np.ndarray:
        """
        This method returns the observations for all players.
        """

        observations = {
            agent: self.__calculate_observation(
                idx_player=self.agent_name_mapping[agent])
            for agent in self.agents
        }
        return observations

    def observe(self, agent):
        i = self.agent_name_mapping[agent]
        return self.__calculate_observation(idx_player=i)

    def step(self, actions):
        observations = self.__calculate_observations()
        rewards = self.__calculate_rewards()  # As example
        self._episode_ended = self.__check_episode_ended()  # As example

        if self._episode_ended:
            infos = {agent: {} for agent in self.agents}
            dones = {agent: self._episode_ended for agent in self.agents}
            rewards = {
                self.agents[i]: rewards[i]
                for i in range(len(self.agents))
            }
            self.agents = {}  # To satisfy `set(par_env.agents) == live_agents`

        else:
            infos = {agent: {"discount": 1.0} for agent in self.agents}
            dones = {agent: self._episode_ended for agent in self.agents}
            rewards = {
                self.agents[i]: rewards[i]
                for i in range(len(self.agents))
            }

        return observations, rewards, dones, infos

    def reset(self,
              seed: Optional[int] = None,
              return_info: bool = False,
              options: Optional[dict] = None,):
        # Reset the environment and get observation from each player.
        observations = self.__calculate_observations()
        return observations

Unfortunately when I try to test with the following main procedure:

import gym
from player_env import player_env
from stable_baselines3.common.env_checker import check_env
from pettingzoo.test import parallel_api_test

if __name__ == '__main__':
    # Environment initialization
    env = player_env.petting_zoo(agents=10)
    parallel_api_test(env)  # Works
    check_env(env)  # Throws error

I get the following error:

AssertionError: The observation returned by the `reset()` method does not match the given observation space

It seems like that ss.pettingzoo_env_to_vec_env_v1(env) is capable of splitting the parallel environment in multiple vectorized ones, but not for the reset() function.

Does anyone know how to fix this problem?

jjshoots commented 2 years ago

For some pretty subtle reasons, when using the env_to_vec_env_v1, you must follow up by wrapping it in concat_vec_env_v1.

PieroMacaluso commented 2 years ago

Thanks for the answer @jjshoots! I tried like this:

def env(**kwargs):
    env_ = parallel_env(**kwargs)
    env_ = ss.pettingzoo_env_to_vec_env_v1(env_)
    env_ = ss.concat_vec_envs_v1(env_, 1)
    return env_

petting_zoo = env

class parallel_env(ParallelEnv):
[...]

But I get this error:

cannot pickle 'SwigPyObject' object
jjshoots commented 2 years ago

You may need to use the EzPickle trick, see the kaz init for an example. Essentially, your base environment class needs to inherit from the EzPickle class.

PieroMacaluso commented 2 years ago

Thanks! The problem related to pickle is solved, but there is still the initial error.

def env(**kwargs):
    env_ = parallel_env(**kwargs)
    env_ = ss.pettingzoo_env_to_vec_env_v1(env_)
    env_ = ss.concat_vec_envs_v1(env_, 1)
    return env_

petting_zoo = env

class parallel_env(ParallelEnv, EzPickle):
[...]

I still get:

The observation returned by the `reset()` method does not match the given observation space
jjshoots commented 2 years ago

You can probably remove the api check, it's only valid for unwrapped classes. :)

PieroMacaluso commented 2 years ago

Removing the line with check_env(env), the problem is still in the reset function. The error is:

Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/user/.vscode/extensions/ms-python.python-2022.10.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/user/.vscode/extensions/ms-python.python-2022.10.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/user/.vscode/extensions/ms-python.python-2022.10.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/user/.vscode/extensions/ms-python.python-2022.10.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/user/.vscode/extensions/ms-python.python-2022.10.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/user/.vscode/extensions/ms-python.python-2022.10.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/home/user/dev/project/player-env/main_multi.py", line 23, in <module>
    model.learn(total_timesteps=10_000)
  File "/home/user/dev/project/player-env/.venv/lib/python3.8/site-packages/stable_baselines3/ppo/ppo.py", line 310, in learn
    return super().learn(
  File "/home/user/dev/project/player-env/.venv/lib/python3.8/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 239, in learn
    total_timesteps, callback = self._setup_learn(
  File "/home/user/dev/project/player-env/.venv/lib/python3.8/site-packages/stable_baselines3/common/base_class.py", line 446, in _setup_learn
    self._last_obs = self.env.reset()  # pytype: disable=annotation-type-mismatch
  File "/home/user/dev/project/player-env/.venv/lib/python3.8/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 64, in reset
    self._save_obs(env_idx, obs)
  File "/home/user/dev/project/player-env/.venv/lib/python3.8/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 94, in _save_obs
    self.buf_obs[key][env_idx] = obs
ValueError: could not broadcast input array from shape (20,20) into shape (20,)

The reset() function is still returning an array which is the concatenation of the observation for all the players (20).

jjshoots commented 2 years ago

Would if be ok for you to send me the files that you're working on? I'll have a poke around and see what I can find out. I think that is more productive than iterative bug squashing over this issues channel.

PieroMacaluso commented 2 years ago

I agree @jjshoots, thanks again for your time! I've just prepared a dummy repo with a very simple environment to reproduce the error. You can find it here: https://github.com/PieroMacaluso/dummy-env

jjshoots commented 2 years ago

Hi @PieroMacaluso, apologies for the delay, I finally got the time today to look at this.

I believe the fix to your problem is this:

  1. Update Supersuit to the latest version.
  2. You need to give the argument base_class="stable_baselines3" to concat_vec_envs_v1, apologies for not stating this earlier.

That said, the correct wrapper order for your environment should be:

    env_parallel = dummy.parallel_env()
    parallel_api_test(env_parallel)  # This works!

    env_parallel = ss.pettingzoo_env_to_vec_env_v1(env_parallel)
    env_parallel = ss.concat_vec_envs_v1(env_parallel, 1, base_class="stable_baselines3")

Let me know if you encounter any other issues. :)

PieroMacaluso commented 2 years ago

Hi @jjshoots! Finally good news! :partying_face: I followed your guidelines and these are the outcomes.

In the end, I think that the argument base_class="stable_baselines3" made the difference. Only the small problem on check_env remains to be reported, but I think it can be considered as trivial if the training works.

Thanks a lot for your help and support!