sail-sg / envpool

C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments.
https://envpool.readthedocs.io
Apache License 2.0
1.05k stars 97 forks source link

[Feature Request] Auto reset between gym and envpool #257

Open hankknight opened 1 year ago

hankknight commented 1 year ago

Motivation

After reading the auto-reset document, I made some simple attempts between gym and envpool, noticed that the results given by envpool are different from those of gym.

In my test code, when action=1, terminated is set to true and obs is reset to 0. Otherwise, terminated is set to false and obs is incremented by 1 each time.

envpool:

User Call Actual Elapsed Misc
env.reset() env.reset() 0 obs=0, terminated=false
env.step(0) env.step(0) 1 obs=1, terminated=false
env.step(1) env.step(1) 2 obs=2, terminated=true
env.step(0) env.reset() 3 obs=0, terminated=false(discard action 0)
env.step(0) env.step(0) 4 obs=1, terminated=false

gym(SyncVectorEnv):

User Call Actual Elapsed Misc
env.reset() env.reset() 0 obs=0, terminated=false
env.step(0) env.step(0) 1 obs=1, terminated=false
env.step(1) env.reset() 2 obs=0, terminated=true
env.step(0) env.step(0) 3 obs=1, terminated=false
env.step(0) env.step(0) 4 obs=2, terminated=false

What should I do to align the auto-reset functionality of envpool with that of gym?In my actual use sernario, I don't want to discard any actions because they are all meaningful.

Additional context

gym == 0.26.2 envpool == 0.8.1 numpy == 1.24.2

hankknight commented 1 year ago

Envpool

C++

#include <iostream>

#include "envpool/core/env.h"
#include "envpool/core/py_envpool.h"
#include "envpool/core/async_envpool.h"

using namespace std;

namespace cust {

class CustEnvFns {
public:
    static decltype(auto) DefaultConfig() { return MakeDict(); }

    template <typename Config>
    static decltype(auto) StateSpec(const Config& conf) {
        return MakeDict("obs"_.Bind(Spec<float>({-1})));
    }

    template <typename Config>
    static decltype(auto) ActionSpec(const Config& conf) {
        return MakeDict("action"_.Bind(Spec<int>({-1}, {0, 1})));
    }
};

using CustEnvSpec = EnvSpec<CustEnvFns>;

class CustEnv : public Env<CustEnvSpec> {
protected:
    bool done_{true};
    float obs_{0.0F};

public:
    CustEnv(const Spec& spec, int env_id) : Env<CustEnvSpec>(spec, env_id) {}

    bool IsDone() override { return done_; }

    void Reset() override {
        obs_  = 0.0F;
        done_ = false;

        // debug message
        cout << "env " << env_id_ << " reset." << endl;

        WriteState(0.0);
    }

    void Step(const Action& action) override {
        ++obs_;
        int act = action["action"_];

        if (act == 1) {
            done_ = true;
        }

        WriteState(1.0);
    }

private:
    void WriteState(float reward) {
        auto state       = Allocate();
        state["reward"_] = reward;
        state["obs"_]    = obs_;
    }
};

using CustEnvPool = AsyncEnvPool<CustEnv>;

}  // namespace cust

using CustEnvSpec = PyEnvSpec<cust::CustEnvSpec>;
using CustEnvPool = PyEnvPool<cust::CustEnvPool>;

PYBIND11_MODULE(cust_envpool, m) { REGISTER(m, CustEnvSpec, CustEnvPool) }

Python

import numpy as np

from envpool.python.api import py_env
from cust_envpool import _CustEnvSpec, _CustEnvPool

(
    CustEnvSpec,
    CustDMEnvPool,
    CustGymEnvPool,
    CustGymnasiumEnvPool,
) = py_env(_CustEnvSpec, _CustEnvPool)

if __name__ == "__main__":
    print(" -- test envpool env -- ")
    elapsed = 0
    env_config = dict(zip(CustEnvSpec._config_keys,
                    CustEnvSpec._default_config_values))
    env_config["num_envs"] = 1
    vec_env = CustGymEnvPool(CustEnvSpec(tuple(env_config.values())))

    obs, _ = vec_env.reset()
    print(f"-*- Elapsed {elapsed} -*-")
    print(obs)
    elapsed += 1
    # env 0 reset.
    # -*- Elapsed 0 -*-
    # [0.]

    action = np.array([0])
    obs, rew, terminate, truncate, info = vec_env.step(action)
    print(f"-*- Elapsed {elapsed}, Action {action} -*-")
    print(obs, terminate)
    elapsed += 1
    # -*- Elapsed 1, Action [0] -*-
    # [1.] [False]

    action = np.array([1])
    obs, rew, terminate, truncate, info = vec_env.step(action)
    print(f"-*- Elapsed {elapsed}, Action {action} -*-")
    print(obs, terminate)
    elapsed += 1
    # -*- Elapsed 2, Action [1] -*-
    # [2.] [ True]

    action = np.array([0])
    obs, rew, terminate, truncate, info = vec_env.step(action)
    print(f"-*- Elapsed {elapsed}, Action {action} -*-")
    print(obs, terminate)
    elapsed += 1
    # env 0 reset.
    # -*- Elapsed 3, Action [0] -*-
    # [0.] [False]

    action = np.array([0])
    obs, rew, terminate, truncate, info = vec_env.step(action)
    print(f"-*- Elapsed {elapsed}, Action {action} -*-")
    print(obs, terminate)
    elapsed += 1
    # -*- Elapsed 4, Action [0] -*-
    # [1.] [False]

Gym

import gym
import numpy as np

from gym import spaces
from gym.vector import SyncVectorEnv

class CustGymEnv(gym.Env):

    def __init__(self, env_id) -> None:
        self._env_id = env_id
        self._done = False
        self._obs = np.array([0.0], dtype=np.float32)
        self.action_space = spaces.Discrete(2)
        self.observation_space = spaces.Box(
            np.array([-1.0], dtype=np.float32), 
            np.array([1.0], dtype=np.float32), 
            dtype=np.float32)

    def step(self, action):
        self._obs += 1.0
        if action == 1:
            self._done = True

        return self._obs, 1.0, self._done, False, {}

    def reset(self):
        self._done = False
        self._obs = np.array([0.0], dtype=np.float32)
        print(f"env {self._env_id} reset")

        return self._obs, {}

    def render(self):
        return None

    def close(self):
        return None

if __name__ == "__main__":
    print(" -- test gym env -- ")
    elapsed = 0
    vec_env = SyncVectorEnv([lambda: CustGymEnv(0)])

    obs, _ = vec_env.reset()
    print(f"-*- Elapsed {elapsed} -*-")
    print(obs)
    elapsed += 1
    # env 0 reset
    # -*- Elapsed 0 -*-
    # [[0.]]

    action = np.array([0])
    obs, rew, terminate, truncate, info = vec_env.step(action)
    print(f"-*- Elapsed {elapsed}, Action {action} -*-")
    print(obs, terminate)
    elapsed += 1
    # -*- Elapsed 1, Action [0] -*-
    # [[1.]] [False]

    action = np.array([1])
    obs, rew, terminate, truncate, info = vec_env.step(action)
    print(f"-*- Elapsed {elapsed}, Action {action} -*-")
    print(obs, terminate)
    elapsed += 1
    # env 0 reset
    # -*- Elapsed 2, Action [1] -*-
    # [[0.]] [ True]

    action = np.array([0])
    obs, rew, terminate, truncate, info = vec_env.step(action)
    print(f"-*- Elapsed {elapsed}, Action {action} -*-")
    print(obs, terminate)
    elapsed += 1
    # -*- Elapsed 3, Action [0] -*-
    # [[1.]] [False]

    action = np.array([0])
    obs, rew, terminate, truncate, info = vec_env.step(action)
    print(f"-*- Elapsed {elapsed}, Action {action} -*-")
    print(obs, terminate)
    elapsed += 1
    # -*- Elapsed 4, Action [0] -*-
    # [[2.]] [False]
walkacross commented 1 year ago

In short, auto-reset in gym:

reset_obs = env.reset()
action_result_in_done = poilicy(reset_obs)
a_reset_obs, reward, done, info = env.step(action_result_in_done)

in envpool

reset_obs = envpool_env.reset()
action_result_in_done = policy(reset_obs)
not_a_reset_obs, reward1, done, info = envpool_env.step(action_result_in_done)

action = policy(not_a_reset_obs)
a_reset_obs, reward2, done, info = envpool_env.step(action)

the auto-reset solution in envpool will result in two controlversial transitions when someone use env to collect transitions

reset_obs, action_result_in_done, reward1, not_a_reset_obs

not_a_reset_obs, action, reward2, a_reset_obs

these two transitions, specially for the second, are slightly abnormal and not in logical(specially when model the state transition in model-based case) and should not be involed in ReplayBuffer.

any suggestions for address these issues when some Collector follows the gym.vec.env auto-reset protocol?

Trinkle23897 commented 1 year ago

The motivation is https://github.com/sail-sg/envpool/issues/194#issuecomment-1293493553 for performance. Can we move this discussion to #194?

xiezhipeng-git commented 9 months ago

@Trinkle23897 I would like to learn about other issues related to reset. Can the automatic reset function be manually turned off? Because the environment that you do not want to end during evaluation is automatically reset

Trinkle23897 commented 9 months ago

You can control env_id in env.step to decide which env needs to be stepped or not

xiezhipeng-git commented 9 months ago

@Trinkle23897 Can you help me write the key code that can specify the control environment? Because I couldn't find the code to control the specified environment. Do you use env. send (action, env_id) one by one for processing?

xiezhipeng-git commented 9 months ago

@Trinkle23897

def step( -- 222 | self, 223 | action: Union[Dict[str, Any], np.ndarray], 224 | env_id: Optional[np.ndarray] = None,

Directly pass in the ID to be updated? env_id array?

Trinkle23897 commented 9 months ago

https://github.com/sail-sg/envpool/blob/main/envpool/atari/atari_pretrain_test.py this is the evaluation script

xiezhipeng-git commented 9 months ago

@Trinkle23897 Thanks

xiezhipeng-git commented 9 months ago

@Trinkle23897 This method can work normally. But this has great limitations. For example, if generating actions and environment execution cannot occur simultaneously. There will be situations where the execution of actions is inconsistent with the environment's ids. It makes it difficult to abstract the code every time you have to retrieve the latest ids from the info. Then, if you add functions such as skipping frames. It is possible to automatically reset during operation. Unable to accelerate while obtaining evaluation results and scores effectively. If there are any plans. Suggest a new version. Add the function to disable automatic environment reset

Trinkle23897 commented 9 months ago

Sorry, I think it's hard by design

xiezhipeng-git commented 9 months ago

Okay, I have already planned to use the following logic. All environments are still running. When evaluating

Set a first_ Done The state of not being done. Reward for accumulation. Once the firstdone score is reached, it also accumulates. But the scores did not increase afterwards. Because of parallel execution. It is possible that some environments may not have done it yet, while others may have done it twice or more. So we can only accumulate the scores before firstdone. Until all first-done states have completed the jump out loop

I still hope there will be a feature to disable autoreset in the future