Closed adam-crowther closed 1 year ago
I created a workaround using a shim that wraps the SB3VecEnvWrapper
:
class Sb3ShimWrapper(VecEnvWrapper):
metadata = {'render_modes': ['human', 'files', 'none'], "name": "Sb3ShimWrapper-v0"}
def __init__(self, venv):
super().__init__(venv)
def reset(self, seed=None, options=None):
return self.venv.reset()[0]
def step_wait(self) -> VecEnvStepReturn:
return self.venv.step_wait()
As you can see it overrides the reset method and returns the first element of the tuple.
I integrate it like this:
if __name__ == '__main__':
env_parallel = dummy.DummyParallelEnv()
parallel_api_test(env_parallel)
# env_parallel = ss.flatten_v0(env_parallel)
env_parallel = ss.pettingzoo_env_to_vec_env_v1(env_parallel)
env_parallel = ss.concat_vec_envs_v1(env_parallel, 1, base_class="stable_baselines3")
env_parallel = Sb3ShimWrapper(env_parallel)
model = PPO("MlpPolicy", env_parallel, verbose=1)
model.learn(total_timesteps=10_000)
I will push this change to my demo repo.
Now I have a new problem with render(), which I will document in a new Issue.
To my knowledge, this has been fixed with https://github.com/Farama-Foundation/SuperSuit/pull/226 (I was getting the same issue, it's because they expect only an observation whereas by default PettingZoo and Gymnasium return an observation and info)
Hi,
I'm having a similar issue to Issue #220, however mine is not being resolved by using ss.flatten_v0().
I'm using the latest current versions of stable-baselines3==2.1.0a0, pettingzoo==1.23.1, supersuit==3.8.1 and gymnasium==0.28.1. I have also tried with stable-baselines3==2.0.0 and had the same issue.
I have adapted @PieroMacaluso's dummy project from Issue #169 to reproduce the issue: https://github.com/adam-crowther/test-supersuit-baseline3-pettingzoo-parallel-env
The ParallelEnv looks like this:
And is executed like this:
When I execute I get this error:
If I set a breakpoint in stable_baselines3 class
on_policy_algorithm.py
at line 168, I see thatself._last_obs
is being set to the tuple of observation and info that is being returned by the ParallelEnv reset() method.obs_as_tensor()
is expecting anp.ndarray
.Have I got something wrong or is there a compatibility issue here somewhere?
Thanks,
Adam