Closed araffin closed 2 years ago
because I cannot step in a particular env... (env.send() does exist but env.recv() does not garantee the result to be from the same env).
Not quite sure your approach, but now envpool supports this feature when num_envs == batch_size
https://github.com/sail-sg/envpool/blob/3375b13a1a33d08758c9c26927abd29ab2914190/envpool/atari/atari_envpool_test.py#L80-L108
https://github.com/sail-sg/envpool/blob/3375b13a1a33d08758c9c26927abd29ab2914190/examples/env_step.py#L30
Indeed this feature lacks documentation, I'll add later...
I can also make a PR if you think it makes sense to integrate it directly into envpool (would make it easier for people already using gym / SB3 to adopt envpool ;))
Awesome! Looking forward to that.
Thanks for the heads up =) My updated code, I'll try to make a PR tomorrow ;). Do you want it inside envpool package or in the example folder?
from typing import Optional
import envpool
import gym
import numpy as np
import torch as th
from envpool.python.protocol import EnvPool
from gym.envs.registration import EnvSpec
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecEnvWrapper, VecMonitor
from stable_baselines3.common.vec_env.base_vec_env import (
VecEnv,
VecEnvObs,
VecEnvStepReturn,
VecEnvWrapper,
)
from stable_baselines3.common.evaluation import evaluate_policy
# Force PyTorch to use only one threads
# make things faster for simple envs
th.set_num_threads(1)
num_envs = 4
env_id = "Pendulum-v0"
seed = 0
use_env_pool = True
class VecAdapter(VecEnvWrapper):
"""
Convert EnvPool object to a Stable-Baselines3 (SB3) VecEnv.
:param venv: The envpool object.
"""
def __init__(self, venv: EnvPool):
venv.num_envs = venv.spec.config.num_envs
super().__init__(venv=venv)
def step_async(self, actions: np.ndarray) -> None:
self.actions = actions
def reset(self) -> VecEnvObs:
return self.venv.reset()
def seed(self, seed: Optional[int] = None) -> None:
# You can only seed EnvPool env by calling envpool.make()
pass
def step_wait(self) -> VecEnvStepReturn:
obs, rewards, dones, info_dict = self.venv.step(self.actions)
infos = []
# Convert dict to list of dict
# and add terminal observation
for i in range(self.num_envs):
infos.append(
{
key: info_dict[key][i]
for key in info_dict.keys()
if isinstance(info_dict[key], np.ndarray)
}
)
if dones[i]:
infos[i]["terminal_observation"] = obs[i]
obs[i] = self.venv.reset(np.array([i]))
return obs, rewards, dones, infos
if use_env_pool:
env = envpool.make(env_id, env_type="gym", num_envs=num_envs, seed=seed)
env.spec.id = env_id
env = VecAdapter(env)
env = VecMonitor(env)
else:
env = make_vec_env(env_id, n_envs=num_envs)
# Tuned hyperparams for Pendulum-v0
model = PPO(
"MlpPolicy",
env,
n_steps=1024,
learning_rate=1e-3,
use_sde=True,
sde_sample_freq=4,
gae_lambda=0.95,
gamma=0.9,
verbose=1,
seed=seed,
)
# model = PPO(
# "MlpPolicy",
# env,
# learning_rate=1e-3,
# gae_lambda=0.95,
# gamma=0.9,
# verbose=1,
# seed=seed,
# )
try:
model.learn(100_000)
except KeyboardInterrupt:
pass
# Agent trained on envpool version should also perform well on regular Gym env
test_env = gym.make(env_id)
# Test with EnvPool
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20)
print(f"EnvPool - {env_id}")
print(f"Mean Reward: {mean_reward:.2f} +/- {std_reward:.2f}")
# Test with Gym
mean_reward, std_reward = evaluate_policy(model, test_env, n_eval_episodes=20, warn=False)
print(f"Gym - {env_id}")
print(f"Mean Reward: {mean_reward:.2f} +/- {std_reward:.2f}")
Do you want it inside envpool package or in the example folder?
I think it should be in example
. I don't want to put tianshou's code in envpool/
either, so that the library's code is clean enough.
Motivation
Related to #33
I was trying to make env pool work with SB3 and I noticed different inconsistencies with classic gym envs / gym vector envs. I wrote a wrapper but currently there is no way to properly handle terminal observations (as mentioned in #33 ) because I cannot step in a particular env... (
env.send()
does exist butenv.recv()
does not garantee the result to be from the same env).Solution
My current solution, I can also make a PR if you think it makes sense to integrate it directly into envpool (would make it easier for people already using gym / SB3 to adopt envpool ;))
Alternative
A better alternative would be to fix those inconsistencies directly in the c++ code.
Checklist