sail-sg / envpool

C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments.
https://envpool.readthedocs.io
Apache License 2.0
1.08k stars 99 forks source link

[Feature Request] Compatibility with gym and SB3 wrapper #38

Closed araffin closed 2 years ago

araffin commented 2 years ago

Motivation

Related to #33

I was trying to make env pool work with SB3 and I noticed different inconsistencies with classic gym envs / gym vector envs. I wrote a wrapper but currently there is no way to properly handle terminal observations (as mentioned in #33 ) because I cannot step in a particular env... (env.send() does exist but env.recv() does not garantee the result to be from the same env).

Solution

My current solution, I can also make a PR if you think it makes sense to integrate it directly into envpool (would make it easier for people already using gym / SB3 to adopt envpool ;))

import gym
import envpool
from gym.envs.registration import EnvSpec

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecEnvWrapper, VecMonitor
from stable_baselines3.common.env_util import make_vec_env

import numpy as np

from stable_baselines3.common.vec_env.base_vec_env import (
    VecEnv,
    VecEnvStepReturn,
    VecEnvWrapper,
)

num_envs = 4
env_id = "Pendulum-v0"
seed = 0
use_env_pool = True

class VecAdapter(VecEnvWrapper):
    def __init__(self, venv):
        venv.num_envs = venv.spec.config.num_envs
        super().__init__(venv=venv)

    def step_async(self, actions: np.ndarray) -> None:
        self.actions = actions

    def reset(self):
        return self.venv.reset()

    def step_wait(self):
        # TODO: handle terminal obs
        obs, reward, done, info = self.venv.step(self.actions)
        infos = []
        # convert to list
        for i in range(self.num_envs):
            infos.append(
                {
                    key: info[key][i]
                    for key in info.keys()
                    if isinstance(info[key], np.ndarray)
                }
            )
        return obs, reward, done, infos

if use_env_pool:
    env = envpool.make(env_id, env_type="gym", num_envs=num_envs, seed=seed)
    env.spec.id = env_id
    env = VecAdapter(env)
    env = VecMonitor(env)
else:
    env = make_vec_env(env_id, n_envs=num_envs)

model = PPO(
    "MlpPolicy",
    env,
    n_steps=1024,
    learning_rate=1e-3,
    use_sde=True,
    sde_sample_freq=4,
    gae_lambda=0.95,
    gamma=0.9,
    verbose=1,
    seed=seed,
)
try:
    model.learn(100_000)
except KeyboardInterrupt:
    pass

Alternative

A better alternative would be to fix those inconsistencies directly in the c++ code.

Checklist

Trinkle23897 commented 2 years ago

because I cannot step in a particular env... (env.send() does exist but env.recv() does not garantee the result to be from the same env).

Not quite sure your approach, but now envpool supports this feature when num_envs == batch_size https://github.com/sail-sg/envpool/blob/3375b13a1a33d08758c9c26927abd29ab2914190/envpool/atari/atari_envpool_test.py#L80-L108 https://github.com/sail-sg/envpool/blob/3375b13a1a33d08758c9c26927abd29ab2914190/examples/env_step.py#L30

Indeed this feature lacks documentation, I'll add later...

I can also make a PR if you think it makes sense to integrate it directly into envpool (would make it easier for people already using gym / SB3 to adopt envpool ;))

Awesome! Looking forward to that.

araffin commented 2 years ago

Thanks for the heads up =) My updated code, I'll try to make a PR tomorrow ;). Do you want it inside envpool package or in the example folder?

from typing import Optional

import envpool
import gym
import numpy as np
import torch as th
from envpool.python.protocol import EnvPool
from gym.envs.registration import EnvSpec

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecEnvWrapper, VecMonitor
from stable_baselines3.common.vec_env.base_vec_env import (
    VecEnv,
    VecEnvObs,
    VecEnvStepReturn,
    VecEnvWrapper,
)
from stable_baselines3.common.evaluation import evaluate_policy

# Force PyTorch to use only one threads
# make things faster for simple envs
th.set_num_threads(1)

num_envs = 4
env_id = "Pendulum-v0"
seed = 0
use_env_pool = True

class VecAdapter(VecEnvWrapper):
    """
    Convert EnvPool object to a Stable-Baselines3 (SB3) VecEnv.

    :param venv: The envpool object.
    """
    def __init__(self, venv: EnvPool):
        venv.num_envs = venv.spec.config.num_envs
        super().__init__(venv=venv)

    def step_async(self, actions: np.ndarray) -> None:
        self.actions = actions

    def reset(self) -> VecEnvObs:
        return self.venv.reset()

    def seed(self, seed: Optional[int] = None) -> None:
        # You can only seed EnvPool env by calling envpool.make()
        pass

    def step_wait(self) -> VecEnvStepReturn:
        obs, rewards, dones, info_dict = self.venv.step(self.actions)
        infos = []
        # Convert dict to list of dict
        # and add terminal observation
        for i in range(self.num_envs):
            infos.append(
                {
                    key: info_dict[key][i]
                    for key in info_dict.keys()
                    if isinstance(info_dict[key], np.ndarray)
                }
            )
            if dones[i]:
                infos[i]["terminal_observation"] = obs[i]
                obs[i] = self.venv.reset(np.array([i]))

        return obs, rewards, dones, infos

if use_env_pool:
    env = envpool.make(env_id, env_type="gym", num_envs=num_envs, seed=seed)
    env.spec.id = env_id
    env = VecAdapter(env)
    env = VecMonitor(env)
else:
    env = make_vec_env(env_id, n_envs=num_envs)

# Tuned hyperparams for Pendulum-v0
model = PPO(
    "MlpPolicy",
    env,
    n_steps=1024,
    learning_rate=1e-3,
    use_sde=True,
    sde_sample_freq=4,
    gae_lambda=0.95,
    gamma=0.9,
    verbose=1,
    seed=seed,
)
# model = PPO(
#     "MlpPolicy",
#     env,
#     learning_rate=1e-3,
#     gae_lambda=0.95,
#     gamma=0.9,
#     verbose=1,
#     seed=seed,
# )
try:
    model.learn(100_000)
except KeyboardInterrupt:
    pass

# Agent trained on envpool version should also perform well on regular Gym env
test_env = gym.make(env_id)

# Test with EnvPool
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20)
print(f"EnvPool - {env_id}")
print(f"Mean Reward: {mean_reward:.2f} +/- {std_reward:.2f}")

# Test with Gym
mean_reward, std_reward = evaluate_policy(model, test_env, n_eval_episodes=20, warn=False)
print(f"Gym - {env_id}")
print(f"Mean Reward: {mean_reward:.2f} +/- {std_reward:.2f}")
Trinkle23897 commented 2 years ago

Do you want it inside envpool package or in the example folder?

I think it should be in example. I don't want to put tianshou's code in envpool/ either, so that the library's code is clean enough.